• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2016 Google Inc.
3  *
4  * Use of this source code is governed by a BSD-style license that can be
5  * found in the LICENSE file.
6  */
7 
8 #include "src/sksl/codegen/SkSLSPIRVCodeGenerator.h"
9 
10 #include "include/core/SkSpan.h"
11 #include "include/core/SkTypes.h"
12 #include "include/private/base/SkTArray.h"
13 #include "include/private/base/SkTo.h"
14 #include "src/base/SkEnumBitMask.h"
15 #include "src/core/SkChecksum.h"
16 #include "src/core/SkTHash.h"
17 #include "src/core/SkTraceEvent.h"
18 #include "src/sksl/GLSL.std.450.h"
19 #include "src/sksl/SkSLAnalysis.h"
20 #include "src/sksl/SkSLBuiltinTypes.h"
21 #include "src/sksl/SkSLCompiler.h"
22 #include "src/sksl/SkSLConstantFolder.h"
23 #include "src/sksl/SkSLContext.h"
24 #include "src/sksl/SkSLDefines.h"
25 #include "src/sksl/SkSLErrorReporter.h"
26 #include "src/sksl/SkSLIntrinsicList.h"
27 #include "src/sksl/SkSLMemoryLayout.h"
28 #include "src/sksl/SkSLOperator.h"
29 #include "src/sksl/SkSLOutputStream.h"
30 #include "src/sksl/SkSLPool.h"
31 #include "src/sksl/SkSLPosition.h"
32 #include "src/sksl/SkSLProgramSettings.h"
33 #include "src/sksl/SkSLStringStream.h"
34 #include "src/sksl/SkSLUtil.h"
35 #include "src/sksl/analysis/SkSLSpecialization.h"
36 #include "src/sksl/codegen/SkSLCodeGenerator.h"
37 #include "src/sksl/ir/SkSLBinaryExpression.h"
38 #include "src/sksl/ir/SkSLBlock.h"
39 #include "src/sksl/ir/SkSLConstructor.h"
40 #include "src/sksl/ir/SkSLConstructorArrayCast.h"
41 #include "src/sksl/ir/SkSLConstructorCompound.h"
42 #include "src/sksl/ir/SkSLConstructorCompoundCast.h"
43 #include "src/sksl/ir/SkSLConstructorDiagonalMatrix.h"
44 #include "src/sksl/ir/SkSLConstructorMatrixResize.h"
45 #include "src/sksl/ir/SkSLConstructorScalarCast.h"
46 #include "src/sksl/ir/SkSLConstructorSplat.h"
47 #include "src/sksl/ir/SkSLDoStatement.h"
48 #include "src/sksl/ir/SkSLExpression.h"
49 #include "src/sksl/ir/SkSLExpressionStatement.h"
50 #include "src/sksl/ir/SkSLExtension.h"
51 #include "src/sksl/ir/SkSLFieldAccess.h"
52 #include "src/sksl/ir/SkSLFieldSymbol.h"
53 #include "src/sksl/ir/SkSLForStatement.h"
54 #include "src/sksl/ir/SkSLFunctionCall.h"
55 #include "src/sksl/ir/SkSLFunctionDeclaration.h"
56 #include "src/sksl/ir/SkSLFunctionDefinition.h"
57 #include "src/sksl/ir/SkSLIRNode.h"
58 #include "src/sksl/ir/SkSLIfStatement.h"
59 #include "src/sksl/ir/SkSLIndexExpression.h"
60 #include "src/sksl/ir/SkSLInterfaceBlock.h"
61 #include "src/sksl/ir/SkSLLayout.h"
62 #include "src/sksl/ir/SkSLLiteral.h"
63 #include "src/sksl/ir/SkSLModifierFlags.h"
64 #include "src/sksl/ir/SkSLModifiersDeclaration.h"
65 #include "src/sksl/ir/SkSLPoison.h"
66 #include "src/sksl/ir/SkSLPostfixExpression.h"
67 #include "src/sksl/ir/SkSLPrefixExpression.h"
68 #include "src/sksl/ir/SkSLProgram.h"
69 #include "src/sksl/ir/SkSLProgramElement.h"
70 #include "src/sksl/ir/SkSLReturnStatement.h"
71 #include "src/sksl/ir/SkSLSetting.h"
72 #include "src/sksl/ir/SkSLStatement.h"
73 #include "src/sksl/ir/SkSLSwitchCase.h"
74 #include "src/sksl/ir/SkSLSwitchStatement.h"
75 #include "src/sksl/ir/SkSLSwizzle.h"
76 #include "src/sksl/ir/SkSLSymbol.h"
77 #include "src/sksl/ir/SkSLSymbolTable.h"
78 #include "src/sksl/ir/SkSLTernaryExpression.h"
79 #include "src/sksl/ir/SkSLType.h"
80 #include "src/sksl/ir/SkSLVarDeclarations.h"
81 #include "src/sksl/ir/SkSLVariable.h"
82 #include "src/sksl/ir/SkSLVariableReference.h"
83 #include "src/sksl/spirv.h"
84 #include "src/sksl/transform/SkSLTransform.h"
85 #include "src/utils/SkBitSet.h"
86 
87 #include <algorithm>
88 #include <cstdint>
89 #include <cstring>
90 #include <ctype.h>
91 #include <functional>
92 #include <memory>
93 #include <set>
94 #include <string>
95 #include <string_view>
96 #include <tuple>
97 #include <utility>
98 #include <vector>
99 #include <unordered_map>
100 #include <unordered_set>
101 
102 using namespace skia_private;
103 
104 #define kLast_Capability SpvCapabilityMultiViewport
105 
106 constexpr int DEVICE_FRAGCOORDS_BUILTIN = -1000;
107 constexpr int DEVICE_CLOCKWISE_BUILTIN  = -1001;
108 static constexpr SkSL::Layout kDefaultTypeLayout;
109 
110 namespace SkSL {
111 
112 enum class ProgramKind : int8_t;
113 
114 enum class StorageClass {
115     kUniformConstant,
116     kInput,
117     kUniform,
118     kStorageBuffer,
119     kOutput,
120     kWorkgroup,
121     kCrossWorkgroup,
122     kPrivate,
123     kFunction,
124     kGeneric,
125     kPushConstant,
126     kAtomicCounter,
127     kImage,
128 };
129 
get_storage_class_spv_id(StorageClass storageClass)130 static SpvStorageClass get_storage_class_spv_id(StorageClass storageClass) {
131     switch (storageClass) {
132         case StorageClass::kUniformConstant: return SpvStorageClassUniformConstant;
133         case StorageClass::kInput: return SpvStorageClassInput;
134         case StorageClass::kUniform: return SpvStorageClassUniform;
135         // Note: In SPIR-V 1.3, a storage buffer can be declared with the "StorageBuffer"
136         // storage class and the "Block" decoration and the <1.3 approach we use here ("Uniform"
137         // storage class and the "BufferBlock" decoration) is deprecated. Since we target SPIR-V
138         // 1.0, we have to use the deprecated approach which is well supported in Vulkan and
139         // addresses SkSL use cases (notably SkSL currently doesn't support pointer features that
140         // would benefit from SPV_KHR_variable_pointers capabilities).
141 #ifdef SKSL_EXT
142         case StorageClass::kStorageBuffer: return SpvStorageClassStorageBuffer;
143 #else
144         case StorageClass::kStorageBuffer: return SpvStorageClassUniform;
145 #endif
146         case StorageClass::kOutput: return SpvStorageClassOutput;
147         case StorageClass::kWorkgroup: return SpvStorageClassWorkgroup;
148         case StorageClass::kCrossWorkgroup: return SpvStorageClassCrossWorkgroup;
149         case StorageClass::kPrivate: return SpvStorageClassPrivate;
150         case StorageClass::kFunction: return SpvStorageClassFunction;
151         case StorageClass::kGeneric: return SpvStorageClassGeneric;
152         case StorageClass::kPushConstant: return SpvStorageClassPushConstant;
153         case StorageClass::kAtomicCounter: return SpvStorageClassAtomicCounter;
154         case StorageClass::kImage: return SpvStorageClassImage;
155     }
156 
157     SkUNREACHABLE;
158 }
159 
160 class SPIRVCodeGenerator : public CodeGenerator {
161 public:
162     // We reserve an impossible SpvId as a sentinel. (NA meaning none, n/a, etc.)
163     static constexpr SpvId NA = (SpvId)-1;
164 
165     class LValue {
166     public:
~LValue()167         virtual ~LValue() {}
168 
169         // returns a pointer to the lvalue, if possible. If the lvalue cannot be directly referenced
170         // by a pointer (e.g. vector swizzles), returns NA.
getPointer()171         virtual SpvId getPointer() { return NA; }
172 
173         // Returns true if a valid pointer returned by getPointer represents a memory object
174         // (see https://github.com/KhronosGroup/SPIRV-Tools/issues/2892). Has no meaning if
175         // getPointer() returns NA.
isMemoryObjectPointer() const176         virtual bool isMemoryObjectPointer() const { return true; }
177 
178         // Applies a swizzle to the components of the LValue, if possible. This is used to create
179         // LValues that are swizzes-of-swizzles. Non-swizzle LValues can just return false.
applySwizzle(const ComponentArray & components,const Type & newType)180         virtual bool applySwizzle(const ComponentArray& components, const Type& newType) {
181             return false;
182         }
183 
184         // Returns the storage class of the lvalue.
185         virtual StorageClass storageClass() const = 0;
186 
187         virtual SpvId load(OutputStream& out) = 0;
188 
189         virtual void store(SpvId value, OutputStream& out) = 0;
190     };
191 
SPIRVCodeGenerator(const Context * context,const ShaderCaps * caps,const Program * program,OutputStream * out)192     SPIRVCodeGenerator(const Context* context,
193                        const ShaderCaps* caps,
194                        const Program* program,
195                        OutputStream* out)
196             : CodeGenerator(context, caps, program, out) {}
197 
198     bool generateCode() override;
199 
200 private:
201     enum IntrinsicOpcodeKind {
202         kGLSL_STD_450_IntrinsicOpcodeKind,
203         kSPIRV_IntrinsicOpcodeKind,
204         kSpecial_IntrinsicOpcodeKind,
205         kInvalid_IntrinsicOpcodeKind,
206     };
207 
208     enum SpecialIntrinsic {
209         kAtan_SpecialIntrinsic,
210         kClamp_SpecialIntrinsic,
211         kMatrixCompMult_SpecialIntrinsic,
212         kMax_SpecialIntrinsic,
213         kMin_SpecialIntrinsic,
214         kMix_SpecialIntrinsic,
215         kMod_SpecialIntrinsic,
216         kDFdy_SpecialIntrinsic,
217         kSaturate_SpecialIntrinsic,
218         kSampledImage_SpecialIntrinsic,
219         kSmoothStep_SpecialIntrinsic,
220         kStep_SpecialIntrinsic,
221         kSubpassLoad_SpecialIntrinsic,
222         kTexture_SpecialIntrinsic,
223         kTextureGrad_SpecialIntrinsic,
224         kTextureLod_SpecialIntrinsic,
225         kTextureRead_SpecialIntrinsic,
226         kTextureWrite_SpecialIntrinsic,
227         kTextureWidth_SpecialIntrinsic,
228         kTextureHeight_SpecialIntrinsic,
229         kAtomicAdd_SpecialIntrinsic,
230         kAtomicLoad_SpecialIntrinsic,
231         kAtomicStore_SpecialIntrinsic,
232         kStorageBarrier_SpecialIntrinsic,
233         kWorkgroupBarrier_SpecialIntrinsic,
234 #ifdef SKSL_EXT
235         kTextureSize_SpecialIntrinsic,
236         kSampleGather_SpecialIntrinsic,
237         kNonuniformEXT_SpecialIntrinsic,
238 #endif
239     };
240 
241     enum class Precision {
242         kDefault,
243         kRelaxed,
244     };
245 
246     struct TempVar {
247         SpvId spvId;
248         const Type* type;
249         std::unique_ptr<SPIRVCodeGenerator::LValue> lvalue;
250     };
251 
252     /**
253      * Pass in the type to automatically add a RelaxedPrecision decoration for the id when
254      * appropriate, or null to never add one.
255      */
256     SpvId nextId(const Type* type);
257 
258     SpvId nextId(Precision precision);
259 
260     SpvId getType(const Type& type);
261 
262     SpvId getType(const Type& type, const Layout& typeLayout, const MemoryLayout& memoryLayout);
263 
264     SpvId getFunctionType(const FunctionDeclaration& function);
265 
266     SpvId getFunctionParameterType(const Type& parameterType, const Layout& parameterLayout);
267 
268     SpvId getPointerType(const Type& type, StorageClass storageClass);
269 
270     SpvId getPointerType(const Type& type,
271                          const Layout& typeLayout,
272                          const MemoryLayout& memoryLayout,
273                          StorageClass storageClass);
274 
275     StorageClass getStorageClass(const Expression& expr);
276 
277     TArray<SpvId> getAccessChain(const Expression& expr, OutputStream& out);
278 
279     void writeLayout(const Layout& layout, SpvId target, Position pos);
280 
281     void writeFieldLayout(const Layout& layout, SpvId target, int member);
282 
283     SpvId writeStruct(const Type& type, const MemoryLayout& memoryLayout);
284 
285     void writeProgramElement(const ProgramElement& pe, OutputStream& out);
286 
287     SpvId writeInterfaceBlock(const InterfaceBlock& intf, bool appendRTFlip = true);
288 
289     void writeFunctionStart(const FunctionDeclaration& f, OutputStream& out);
290 
291     SpvId writeFunctionDeclaration(const FunctionDeclaration& f, OutputStream& out);
292 
293     void writeFunction(const FunctionDefinition& f, OutputStream& out);
294 
295     // Writes the function with the defined specializationIndex, if the index is -1, then it is
296     // assumed that the function has no specializations.
297     void writeFunctionInstantiation(const FunctionDefinition& f,
298                                     Analysis::SpecializationIndex specializationIndex,
299                                     const Analysis::SpecializedParameters* specializedParams,
300                                     OutputStream& out);
301 
302     bool writeGlobalVarDeclaration(ProgramKind kind, const VarDeclaration& v);
303 
304     SpvId writeGlobalVar(ProgramKind kind, StorageClass, const Variable& v);
305 
306     void writeVarDeclaration(const VarDeclaration& var, OutputStream& out);
307 
308     SpvId writeVariableReference(const VariableReference& ref, OutputStream& out);
309 
310     int findUniformFieldIndex(const Variable& var) const;
311 
312     std::unique_ptr<LValue> getLValue(const Expression& value, OutputStream& out);
313 
314     SpvId writeExpression(const Expression& expr, OutputStream& out);
315 
316     SpvId writeIntrinsicCall(const FunctionCall& c, OutputStream& out);
317 
318     void writeFunctionCallArgument(TArray<SpvId>& argumentList,
319                                    const FunctionCall& call,
320                                    int argIndex,
321                                    std::vector<TempVar>* tempVars,
322                                    const SkBitSet* specializedParams,
323                                    OutputStream& out);
324 
325     void copyBackTempVars(const std::vector<TempVar>& tempVars, OutputStream& out);
326 
327     SpvId writeFunctionCall(const FunctionCall& c, OutputStream& out);
328 
329 
330     void writeGLSLExtendedInstruction(const Type& type, SpvId id, SpvId floatInst,
331                                       SpvId signedInst, SpvId unsignedInst,
332                                       const TArray<SpvId>& args, OutputStream& out);
333 
334     /**
335      * Promotes an expression to a vector. If the expression is already a vector with vectorSize
336      * columns, returns it unmodified. If the expression is a scalar, either promotes it to a
337      * vector (if vectorSize > 1) or returns it unmodified (if vectorSize == 1). Asserts if the
338      * expression is already a vector and it does not have vectorSize columns.
339      */
340     SpvId vectorize(const Expression& expr, int vectorSize, OutputStream& out);
341 
342     /**
343      * Given a list of potentially mixed scalars and vectors, promotes the scalars to match the
344      * size of the vectors and returns the ids of the written expressions. e.g. given (float, vec2),
345      * returns (vec2(float), vec2). It is an error to use mismatched vector sizes, e.g. (float,
346      * vec2, vec3).
347      */
348     TArray<SpvId> vectorize(const ExpressionArray& args, OutputStream& out);
349 
350     /**
351      * Given a SpvId of a scalar, splats it across the passed-in type (scalar, vector or matrix) and
352      * returns the SpvId of the new value.
353      */
354     SpvId splat(const Type& type, SpvId id, OutputStream& out);
355 
356     SpvId writeSpecialIntrinsic(const FunctionCall& c, SpecialIntrinsic kind, OutputStream& out);
357     SpvId writeAtomicIntrinsic(const FunctionCall& c,
358                                SpecialIntrinsic kind,
359                                SpvId resultId,
360                                OutputStream& out);
361 
362     SpvId castScalarToFloat(SpvId inputId, const Type& inputType, const Type& outputType,
363                             OutputStream& out);
364 
365     SpvId castScalarToSignedInt(SpvId inputId, const Type& inputType, const Type& outputType,
366                                 OutputStream& out);
367 
368     SpvId castScalarToUnsignedInt(SpvId inputId, const Type& inputType, const Type& outputType,
369                                   OutputStream& out);
370 
371     SpvId castScalarToBoolean(SpvId inputId, const Type& inputType, const Type& outputType,
372                               OutputStream& out);
373 
374     SpvId castScalarToType(SpvId inputExprId, const Type& inputType, const Type& outputType,
375                            OutputStream& out);
376 
377     /**
378      * Writes a potentially-different-sized copy of a matrix. Entries which do not exist in the
379      * source matrix are filled with zero; entries which do not exist in the destination matrix are
380      * ignored.
381      */
382     SpvId writeMatrixCopy(SpvId src, const Type& srcType, const Type& dstType, OutputStream& out);
383 
384     void addColumnEntry(const Type& columnType,
385                         TArray<SpvId>* currentColumn,
386                         TArray<SpvId>* columnIds,
387                         int rows,
388                         SpvId entry,
389                         OutputStream& out);
390 
391     SpvId writeConstructorCompound(const ConstructorCompound& c, OutputStream& out);
392 
393     SpvId writeMatrixConstructor(const ConstructorCompound& c, OutputStream& out);
394 
395     SpvId writeVectorConstructor(const ConstructorCompound& c, OutputStream& out);
396 
397     SpvId writeCompositeConstructor(const AnyConstructor& c, OutputStream& out);
398 
399     SpvId writeConstructorDiagonalMatrix(const ConstructorDiagonalMatrix& c, OutputStream& out);
400 
401     SpvId writeConstructorMatrixResize(const ConstructorMatrixResize& c, OutputStream& out);
402 
403     SpvId writeConstructorScalarCast(const ConstructorScalarCast& c, OutputStream& out);
404 
405     SpvId writeConstructorSplat(const ConstructorSplat& c, OutputStream& out);
406 
407     SpvId writeConstructorCompoundCast(const ConstructorCompoundCast& c, OutputStream& out);
408 
409     SpvId writeFieldAccess(const FieldAccess& f, OutputStream& out);
410 
411     SpvId writeSwizzle(const Expression& baseExpr,
412                        const ComponentArray& components,
413                        OutputStream& out);
414 
415     SpvId writeSwizzle(const Swizzle& swizzle, OutputStream& out);
416 
417     /**
418      * Folds the potentially-vector result of a logical operation down to a single bool. If
419      * operandType is a vector type, assumes that the intermediate result in id is a bvec of the
420      * same dimensions, and applys all() to it to fold it down to a single bool value. Otherwise,
421      * returns the original id value.
422      */
423     SpvId foldToBool(SpvId id, const Type& operandType, SpvOp op, OutputStream& out);
424 
425     SpvId writeMatrixComparison(const Type& operandType, SpvId lhs, SpvId rhs, SpvOp_ floatOperator,
426                                 SpvOp_ intOperator, SpvOp_ vectorMergeOperator,
427                                 SpvOp_ mergeOperator, OutputStream& out);
428 
429     SpvId writeStructComparison(const Type& structType, SpvId lhs, Operator op, SpvId rhs,
430                                 OutputStream& out);
431 
432     SpvId writeArrayComparison(const Type& structType, SpvId lhs, Operator op, SpvId rhs,
433                                OutputStream& out);
434 
435     // Used by writeStructComparison and writeArrayComparison to logically combine field-by-field
436     // comparisons into an overall comparison result.
437     // - `a.x == b.x` merged with `a.y == b.y` generates `(a.x == b.x) && (a.y == b.y)`
438     // - `a.x != b.x` merged with `a.y != b.y` generates `(a.x != b.x) || (a.y != b.y)`
439     SpvId mergeComparisons(SpvId comparison, SpvId allComparisons, Operator op, OutputStream& out);
440 
441     // When the RewriteMatrixVectorMultiply caps bit is set, we manually decompose the M*V
442     // multiplication into a sum of vector-scalar products.
443     SpvId writeDecomposedMatrixVectorMultiply(const Type& leftType,
444                                               SpvId lhs,
445                                               const Type& rightType,
446                                               SpvId rhs,
447                                               const Type& resultType,
448                                               OutputStream& out);
449 
450     SpvId writeComponentwiseMatrixUnary(const Type& operandType,
451                                         SpvId operand,
452                                         SpvOp_ op,
453                                         OutputStream& out);
454 
455     SpvId writeComponentwiseMatrixBinary(const Type& operandType, SpvId lhs, SpvId rhs,
456                                          SpvOp_ op, OutputStream& out);
457 
458     SpvId writeBinaryOperationComponentwiseIfMatrix(const Type& resultType, const Type& operandType,
459                                                     SpvId lhs, SpvId rhs,
460                                                     SpvOp_ ifFloat, SpvOp_ ifInt,
461                                                     SpvOp_ ifUInt, SpvOp_ ifBool,
462                                                     OutputStream& out);
463 
464     SpvId writeBinaryOperation(const Type& resultType, const Type& operandType, SpvId lhs,
465                                SpvId rhs, SpvOp_ ifFloat, SpvOp_ ifInt, SpvOp_ ifUInt,
466                                SpvOp_ ifBool, OutputStream& out);
467 
468     SpvId writeBinaryOperation(const Type& resultType, const Type& operandType, SpvId lhs,
469                                SpvId rhs, bool writeComponentwiseIfMatrix, SpvOp_ ifFloat,
470                                SpvOp_ ifInt, SpvOp_ ifUInt, SpvOp_ ifBool, OutputStream& out);
471 
472     SpvId writeReciprocal(const Type& type, SpvId value, OutputStream& out);
473 
474     SpvId writeBinaryExpression(const Type& leftType, SpvId lhs, Operator op,
475                                 const Type& rightType, SpvId rhs, const Type& resultType,
476                                 OutputStream& out);
477 
478     SpvId writeBinaryExpression(const BinaryExpression& b, OutputStream& out);
479 
480     SpvId writeTernaryExpression(const TernaryExpression& t, OutputStream& out);
481 
482     SpvId writeIndexExpression(const IndexExpression& expr, OutputStream& out);
483 
484     SpvId writeLogicalAnd(const Expression& left, const Expression& right, OutputStream& out);
485 
486     SpvId writeLogicalOr(const Expression& left, const Expression& right, OutputStream& out);
487 
488     SpvId writePrefixExpression(const PrefixExpression& p, OutputStream& out);
489 
490     SpvId writePostfixExpression(const PostfixExpression& p, OutputStream& out);
491 
492     SpvId writeLiteral(const Literal& f);
493 
494     SpvId writeLiteral(double value, const Type& type);
495 
496     void writeStatement(const Statement& s, OutputStream& out);
497 
498     void writeBlock(const Block& b, OutputStream& out);
499 
500     void writeIfStatement(const IfStatement& stmt, OutputStream& out);
501 
502     void writeForStatement(const ForStatement& f, OutputStream& out);
503 
504     void writeDoStatement(const DoStatement& d, OutputStream& out);
505 
506     void writeSwitchStatement(const SwitchStatement& s, OutputStream& out);
507 
508     void writeReturnStatement(const ReturnStatement& r, OutputStream& out);
509 
510     void writeCapabilities(OutputStream& out);
511 
512     void writeInstructions(const Program& program, OutputStream& out);
513 
514     void writeOpCode(SpvOp_ opCode, int length, OutputStream& out);
515 
516     void writeWord(int32_t word, OutputStream& out);
517 
518     void writeString(std::string_view s, OutputStream& out);
519 
520     void writeInstruction(SpvOp_ opCode, OutputStream& out);
521 
522     void writeInstruction(SpvOp_ opCode, std::string_view string, OutputStream& out);
523 
524     void writeInstruction(SpvOp_ opCode, int32_t word1, OutputStream& out);
525 
526     void writeInstruction(SpvOp_ opCode, int32_t word1, std::string_view string,
527                           OutputStream& out);
528 
529     void writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, std::string_view string,
530                           OutputStream& out);
531 
532     void writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, OutputStream& out);
533 
534     void writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, int32_t word3,
535                           OutputStream& out);
536 
537     void writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, int32_t word3, int32_t word4,
538                           OutputStream& out);
539 
540     void writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, int32_t word3, int32_t word4,
541                           int32_t word5, OutputStream& out);
542 
543     void writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, int32_t word3, int32_t word4,
544                           int32_t word5, int32_t word6, OutputStream& out);
545 
546     void writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, int32_t word3, int32_t word4,
547                           int32_t word5, int32_t word6, int32_t word7, OutputStream& out);
548 
549     void writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, int32_t word3, int32_t word4,
550                           int32_t word5, int32_t word6, int32_t word7, int32_t word8,
551                           OutputStream& out);
552 
553     // This form of writeInstruction can deduplicate redundant ops.
554     struct Word;
555     // 8 Words is enough for nearly all instructions (except variable-length instructions like
556     // OpAccessChain or OpConstantComposite).
557     using Words = STArray<8, Word, true>;
558     SpvId writeInstruction(SpvOp_ opCode, const TArray<Word, true>& words, OutputStream& out);
559 
560     struct Instruction {
561         SpvId fOp;
562         int32_t fResultKind;
563         STArray<8, int32_t>  fWords;
564 
565         bool operator==(const Instruction& that) const;
566         struct Hash;
567     };
568 
569     static Instruction BuildInstructionKey(SpvOp_ opCode, const TArray<Word, true>& words);
570 
571     // The writeOpXxxxx calls will simplify and deduplicate ops where possible.
572     SpvId writeOpConstantTrue(const Type& type);
573     SpvId writeOpConstantFalse(const Type& type);
574     SpvId writeOpConstant(const Type& type, int32_t valueBits);
575     SpvId writeOpConstantComposite(const Type& type, const TArray<SpvId>& values);
576     SpvId writeOpCompositeConstruct(const Type& type, const TArray<SpvId>&, OutputStream& out);
577     SpvId writeOpCompositeExtract(const Type& type, SpvId base, int component, OutputStream& out);
578     SpvId writeOpCompositeExtract(const Type& type, SpvId base, int componentA, int componentB,
579                                   OutputStream& out);
580     SpvId writeOpLoad(SpvId type, Precision precision, SpvId pointer, OutputStream& out);
581     void writeOpStore(StorageClass storageClass, SpvId pointer, SpvId value, OutputStream& out);
582 
583     // Converts the provided SpvId(s) into an array of scalar OpConstants, if it can be done.
584     bool toConstants(SpvId value, TArray<SpvId>* constants);
585     bool toConstants(SkSpan<const SpvId> values, TArray<SpvId>* constants);
586 
587     // Extracts the requested component SpvId from a composite instruction, if it can be done.
588     Instruction* resultTypeForInstruction(const Instruction& instr);
589     int numComponentsForVecInstruction(const Instruction& instr);
590     SpvId toComponent(SpvId id, int component);
591 
592     struct ConditionalOpCounts {
593         int numReachableOps;
594         int numStoreOps;
595     };
596     ConditionalOpCounts getConditionalOpCounts();
597     void pruneConditionalOps(ConditionalOpCounts ops);
598 
599     enum StraightLineLabelType {
600         // Use "BranchlessBlock" for blocks which are never explicitly branched-to at all. This
601         // happens at the start of a function, or when we find unreachable code.
602         kBranchlessBlock,
603 
604         // Use "BranchIsOnPreviousLine" when writing a label that comes immediately after its
605         // associated branch. Example usage:
606         // - SPIR-V does not implicitly fall through from one block to the next, so you may need to
607         //   use an OpBranch to explicitly jump to the next block, even when they are adjacent in
608         //   the code.
609         // - The block immediately following an OpBranchConditional or OpSwitch.
610         kBranchIsOnPreviousLine,
611     };
612 
613     enum BranchingLabelType {
614         // Use "BranchIsAbove" for labels which are referenced by OpBranch or OpBranchConditional
615         // ops that are above the label in the code--i.e., the branch skips forward in the code.
616         kBranchIsAbove,
617 
618         // Use "BranchIsBelow" for labels which are referenced by OpBranch or OpBranchConditional
619         // ops below the label in the code--i.e., the branch jumps backward in the code.
620         kBranchIsBelow,
621 
622         // Use "BranchesOnBothSides" for labels which have branches coming from both directions.
623         kBranchesOnBothSides,
624     };
625     void writeLabel(SpvId label, StraightLineLabelType type, OutputStream& out);
626     void writeLabel(SpvId label, BranchingLabelType type, ConditionalOpCounts ops,
627                     OutputStream& out);
628 
629     MemoryLayout memoryLayoutForStorageClass(StorageClass storageClass);
630     MemoryLayout memoryLayoutForVariable(const Variable&) const;
631 
632     struct EntrypointAdapter {
633         std::unique_ptr<FunctionDefinition> entrypointDef;
634         std::unique_ptr<FunctionDeclaration> entrypointDecl;
635     };
636 
637     EntrypointAdapter writeEntrypointAdapter(const FunctionDeclaration& main);
638 
639     struct UniformBuffer {
640         std::unique_ptr<InterfaceBlock> fInterfaceBlock;
641         std::unique_ptr<Variable> fInnerVariable;
642         std::unique_ptr<Type> fStruct;
643     };
644 
645     void writeUniformBuffer(SymbolTable* topLevelSymbolTable);
646 
647     void addRTFlipUniform(Position pos);
648 
649 #ifdef SKSL_EXT
650     SpvId writeSpecConstBinaryExpression(const BinaryExpression& b, const Operator& op,
651                                          SpvId lhs, SpvId rhs);
652     void writeExtensions(OutputStream& out);
653 
654     std::unordered_set<std::string> fExtensions;
655     std::unordered_set<uint32_t> fCapabilitiesExt;
656     std::unordered_set<SpvId> fNonUniformSpvId;
657     std::unordered_map<const Variable*, SpvId> fGlobalConstVariableValueMap;
658     bool fEmittingGlobalConstConstructor = false;
659 #endif
660 
661     std::unique_ptr<Expression> identifier(std::string_view name);
662 
663     std::tuple<const Variable*, const Variable*> synthesizeTextureAndSampler(
664             const Variable& combinedSampler);
665 
666 #ifdef SKSL_EXT
667     const MemoryLayout fDefaultMemoryLayout{MemoryLayout::Standard::k430};
668 #else
669     const MemoryLayout fDefaultMemoryLayout{MemoryLayout::Standard::k140};
670 #endif
671     uint64_t fCapabilities = 0;
672     SpvId fIdCount = 1;
673     SpvId fGLSLExtendedInstructions;
674     struct Intrinsic {
675         IntrinsicOpcodeKind opKind;
676         int32_t floatOp;
677         int32_t signedOp;
678         int32_t unsignedOp;
679         int32_t boolOp;
680     };
681     Intrinsic getIntrinsic(IntrinsicKind) const;
682 
683     THashMap<Analysis::SpecializedFunctionKey, SpvId, Analysis::SpecializedFunctionKey::Hash>
684             fFunctionMap;
685 
686     Analysis::SpecializationInfo fSpecializationInfo;
687     Analysis::SpecializationIndex fActiveSpecializationIndex = Analysis::kUnspecialized;
688     const Analysis::SpecializedParameters* fActiveSpecialization = nullptr;
689 
690     THashMap<const Variable*, SpvId> fVariableMap;
691     THashMap<const Type*, SpvId> fStructMap;
692     StringStream fGlobalInitializersBuffer;
693     StringStream fConstantBuffer;
694     StringStream fVariableBuffer;
695     StringStream fNameBuffer;
696     StringStream fDecorationBuffer;
697 
698     // Mapping from combined sampler declarations to synthesized texture/sampler variables.
699     // This is used when the sampler is declared as `layout(webgpu)` or `layout(direct3d)`.
700     bool fUseTextureSamplerPairs = false;
701     struct SynthesizedTextureSamplerPair {
702         // The names of the synthesized variables. The Variable objects themselves store string
703         // views referencing these strings. It is important for the std::string instances to have a
704         // fixed memory location after the string views get created, which is why
705         // `fSynthesizedSamplerMap` stores unique_ptr instead of values.
706         std::string fTextureName;
707         std::string fSamplerName;
708         std::unique_ptr<Variable> fTexture;
709         std::unique_ptr<Variable> fSampler;
710     };
711     THashMap<const Variable*, std::unique_ptr<SynthesizedTextureSamplerPair>>
712             fSynthesizedSamplerMap;
713 
714     // These caches map SpvIds to Instructions, and vice-versa. This enables us to deduplicate code
715     // (by detecting an Instruction we've already issued and reusing the SpvId), and to introspect
716     // and simplify code we've already emitted  (by taking a SpvId from an Instruction and following
717     // it back to its source).
718 
719     // A map of instruction -> SpvId:
720     THashMap<Instruction, SpvId, Instruction::Hash> fOpCache;
721     // A map of SpvId -> instruction:
722     THashMap<SpvId, Instruction> fSpvIdCache;
723     // A map of SpvId -> value SpvId:
724     THashMap<SpvId, SpvId> fStoreCache;
725 
726     // "Reachable" ops are instructions which can safely be accessed from the current block.
727     // For instance, if our SPIR-V contains `%3 = OpFAdd %1 %2`, we would be able to access and
728     // reuse that computation on following lines. However, if that Add operation occurred inside an
729     // `if` block, then its SpvId becomes inaccessible once we complete the if statement (since
730     // depending on the if condition, we may or may not have actually done that computation). The
731     // same logic applies to other control-flow blocks as well. Once an instruction becomes
732     // unreachable, we remove it from both op-caches.
733     TArray<SpvId> fReachableOps;
734 
735     // The "store-ops" list contains a running list of all the pointers in the store cache. If a
736     // store occurs inside of a conditional block, once that block exits, we no longer know what is
737     // stored in that particular SpvId. At that point, we must remove any associated entry from the
738     // store cache.
739     TArray<SpvId> fStoreOps;
740 
741     // label of the current block, or 0 if we are not in a block
742     SpvId fCurrentBlock = 0;
743     TArray<SpvId> fBreakTarget;
744     TArray<SpvId> fContinueTarget;
745     bool fWroteRTFlip = false;
746     // holds variables synthesized during output, for lifetime purposes
747     SymbolTable fSynthetics{/*builtin=*/true};
748     // Holds a list of uniforms that were declared as globals at the top-level instead of in an
749     // interface block.
750     UniformBuffer fUniformBuffer;
751     std::vector<const VarDeclaration*> fTopLevelUniforms;
752     THashMap<const Variable*, int> fTopLevelUniformMap;  // <var, UniformBuffer field index>
753     SpvId fUniformBufferId = NA;
754 
755     friend class PointerLValue;
756     friend class SwizzleLValue;
757 };
758 
759 // Equality and hash operators for Instructions.
operator ==(const SPIRVCodeGenerator::Instruction & that) const760 bool SPIRVCodeGenerator::Instruction::operator==(const SPIRVCodeGenerator::Instruction& that) const {
761     return fOp         == that.fOp &&
762            fResultKind == that.fResultKind &&
763            fWords      == that.fWords;
764 }
765 
766 struct SPIRVCodeGenerator::Instruction::Hash {
operator ()SkSL::SPIRVCodeGenerator::Instruction::Hash767     uint32_t operator()(const SPIRVCodeGenerator::Instruction& key) const {
768         uint32_t hash = key.fResultKind;
769         hash = SkChecksum::Hash32(&key.fOp, sizeof(key.fOp), hash);
770         hash = SkChecksum::Hash32(key.fWords.data(), key.fWords.size() * sizeof(int32_t), hash);
771         return hash;
772     }
773 };
774 
775 // This class is used to pass values and result placeholder slots to writeInstruction.
776 struct SPIRVCodeGenerator::Word {
777     enum Kind {
778         kNone,  // intended for use as a sentinel, not part of any Instruction
779         kSpvId,
780         kNumber,
781         kDefaultPrecisionResult,
782         kRelaxedPrecisionResult,
783         kUniqueResult,
784         kKeyedResult,
785     };
786 
WordSkSL::SPIRVCodeGenerator::Word787     Word(SpvId id) : fValue(id), fKind(Kind::kSpvId) {}
WordSkSL::SPIRVCodeGenerator::Word788     Word(int32_t val, Kind kind) : fValue(val), fKind(kind) {}
789 
NumberSkSL::SPIRVCodeGenerator::Word790     static Word Number(int32_t val) {
791         return Word{val, Kind::kNumber};
792     }
793 
ResultSkSL::SPIRVCodeGenerator::Word794     static Word Result(const Type& type) {
795         return (type.hasPrecision() && !type.highPrecision()) ? RelaxedResult() : Result();
796     }
797 
RelaxedResultSkSL::SPIRVCodeGenerator::Word798     static Word RelaxedResult() {
799         return Word{(int32_t)NA, kRelaxedPrecisionResult};
800     }
801 
UniqueResultSkSL::SPIRVCodeGenerator::Word802     static Word UniqueResult() {
803         return Word{(int32_t)NA, kUniqueResult};
804     }
805 
ResultSkSL::SPIRVCodeGenerator::Word806     static Word Result() {
807         return Word{(int32_t)NA, kDefaultPrecisionResult};
808     }
809 
810     // Unlike a Result (where the result ID is always deduplicated to its first instruction) or a
811     // UniqueResult (which always produces a new instruction), a KeyedResult allows an instruction
812     // to be deduplicated among those that share the same `key`.
KeyedResultSkSL::SPIRVCodeGenerator::Word813     static Word KeyedResult(int32_t key) { return Word{key, Kind::kKeyedResult}; }
814 
isResultSkSL::SPIRVCodeGenerator::Word815     bool isResult() const { return fKind >= Kind::kDefaultPrecisionResult; }
816 
817     int32_t fValue;
818     Kind fKind;
819 };
820 
821 // Skia's magic number is 31 and goes in the top 16 bits. We can use the lower bits to version the
822 // sksl generator if we want.
823 // https://github.com/KhronosGroup/SPIRV-Headers/blob/master/include/spirv/spir-v.xml#L84
824 static const int32_t SKSL_MAGIC  = 0x001F0000;
825 
getIntrinsic(IntrinsicKind ik) const826 SPIRVCodeGenerator::Intrinsic SPIRVCodeGenerator::getIntrinsic(IntrinsicKind ik) const {
827 
828 #define ALL_GLSL(x) Intrinsic{kGLSL_STD_450_IntrinsicOpcodeKind, GLSLstd450 ## x, \
829                               GLSLstd450 ## x, GLSLstd450 ## x, GLSLstd450 ## x}
830 #define BY_TYPE_GLSL(ifFloat, ifInt, ifUInt) Intrinsic{kGLSL_STD_450_IntrinsicOpcodeKind, \
831                                                        GLSLstd450 ## ifFloat,             \
832                                                        GLSLstd450 ## ifInt,               \
833                                                        GLSLstd450 ## ifUInt,              \
834                                                        SpvOpUndef}
835 #define ALL_SPIRV(x) Intrinsic{kSPIRV_IntrinsicOpcodeKind, \
836                                SpvOp ## x, SpvOp ## x, SpvOp ## x, SpvOp ## x}
837 #define BOOL_SPIRV(x) Intrinsic{kSPIRV_IntrinsicOpcodeKind, \
838                                 SpvOpUndef, SpvOpUndef, SpvOpUndef, SpvOp ## x}
839 #define FLOAT_SPIRV(x) Intrinsic{kSPIRV_IntrinsicOpcodeKind, \
840                                  SpvOp ## x, SpvOpUndef, SpvOpUndef, SpvOpUndef}
841 #define SPECIAL(x) Intrinsic{kSpecial_IntrinsicOpcodeKind, k ## x ## _SpecialIntrinsic, \
842                              k ## x ## _SpecialIntrinsic, k ## x ## _SpecialIntrinsic,  \
843                              k ## x ## _SpecialIntrinsic}
844 
845     switch (ik) {
846         case k_round_IntrinsicKind:          return ALL_GLSL(Round);
847         case k_roundEven_IntrinsicKind:      return ALL_GLSL(RoundEven);
848         case k_trunc_IntrinsicKind:          return ALL_GLSL(Trunc);
849         case k_abs_IntrinsicKind:            return BY_TYPE_GLSL(FAbs, SAbs, SAbs);
850         case k_sign_IntrinsicKind:           return BY_TYPE_GLSL(FSign, SSign, SSign);
851         case k_floor_IntrinsicKind:          return ALL_GLSL(Floor);
852         case k_ceil_IntrinsicKind:           return ALL_GLSL(Ceil);
853         case k_fract_IntrinsicKind:          return ALL_GLSL(Fract);
854         case k_radians_IntrinsicKind:        return ALL_GLSL(Radians);
855         case k_degrees_IntrinsicKind:        return ALL_GLSL(Degrees);
856         case k_sin_IntrinsicKind:            return ALL_GLSL(Sin);
857         case k_cos_IntrinsicKind:            return ALL_GLSL(Cos);
858         case k_tan_IntrinsicKind:            return ALL_GLSL(Tan);
859         case k_asin_IntrinsicKind:           return ALL_GLSL(Asin);
860         case k_acos_IntrinsicKind:           return ALL_GLSL(Acos);
861         case k_atan_IntrinsicKind:           return SPECIAL(Atan);
862         case k_sinh_IntrinsicKind:           return ALL_GLSL(Sinh);
863         case k_cosh_IntrinsicKind:           return ALL_GLSL(Cosh);
864         case k_tanh_IntrinsicKind:           return ALL_GLSL(Tanh);
865         case k_asinh_IntrinsicKind:          return ALL_GLSL(Asinh);
866         case k_acosh_IntrinsicKind:          return ALL_GLSL(Acosh);
867         case k_atanh_IntrinsicKind:          return ALL_GLSL(Atanh);
868         case k_pow_IntrinsicKind:            return ALL_GLSL(Pow);
869         case k_exp_IntrinsicKind:            return ALL_GLSL(Exp);
870         case k_log_IntrinsicKind:            return ALL_GLSL(Log);
871         case k_exp2_IntrinsicKind:           return ALL_GLSL(Exp2);
872         case k_log2_IntrinsicKind:           return ALL_GLSL(Log2);
873         case k_sqrt_IntrinsicKind:           return ALL_GLSL(Sqrt);
874         case k_inverse_IntrinsicKind:        return ALL_GLSL(MatrixInverse);
875         case k_outerProduct_IntrinsicKind:   return ALL_SPIRV(OuterProduct);
876         case k_transpose_IntrinsicKind:      return ALL_SPIRV(Transpose);
877         case k_isinf_IntrinsicKind:          return ALL_SPIRV(IsInf);
878         case k_isnan_IntrinsicKind:          return ALL_SPIRV(IsNan);
879         case k_inversesqrt_IntrinsicKind:    return ALL_GLSL(InverseSqrt);
880         case k_determinant_IntrinsicKind:    return ALL_GLSL(Determinant);
881         case k_matrixCompMult_IntrinsicKind: return SPECIAL(MatrixCompMult);
882         case k_matrixInverse_IntrinsicKind:  return ALL_GLSL(MatrixInverse);
883         case k_mod_IntrinsicKind:            return SPECIAL(Mod);
884         case k_modf_IntrinsicKind:           return ALL_GLSL(Modf);
885         case k_min_IntrinsicKind:            return SPECIAL(Min);
886         case k_max_IntrinsicKind:            return SPECIAL(Max);
887         case k_clamp_IntrinsicKind:          return SPECIAL(Clamp);
888         case k_saturate_IntrinsicKind:       return SPECIAL(Saturate);
889         case k_dot_IntrinsicKind:            return FLOAT_SPIRV(Dot);
890         case k_mix_IntrinsicKind:            return SPECIAL(Mix);
891         case k_step_IntrinsicKind:           return SPECIAL(Step);
892         case k_smoothstep_IntrinsicKind:     return SPECIAL(SmoothStep);
893         case k_fma_IntrinsicKind:            return ALL_GLSL(Fma);
894         case k_frexp_IntrinsicKind:          return ALL_GLSL(Frexp);
895         case k_ldexp_IntrinsicKind:          return ALL_GLSL(Ldexp);
896 
897 #define PACK(type) case k_pack##type##_IntrinsicKind:   return ALL_GLSL(Pack##type); \
898                    case k_unpack##type##_IntrinsicKind: return ALL_GLSL(Unpack##type)
899         PACK(Snorm4x8);
900         PACK(Unorm4x8);
901         PACK(Snorm2x16);
902         PACK(Unorm2x16);
903         PACK(Half2x16);
904 #undef PACK
905 
906         case k_length_IntrinsicKind:        return ALL_GLSL(Length);
907         case k_distance_IntrinsicKind:      return ALL_GLSL(Distance);
908         case k_cross_IntrinsicKind:         return ALL_GLSL(Cross);
909         case k_normalize_IntrinsicKind:     return ALL_GLSL(Normalize);
910         case k_faceforward_IntrinsicKind:   return ALL_GLSL(FaceForward);
911         case k_reflect_IntrinsicKind:       return ALL_GLSL(Reflect);
912         case k_refract_IntrinsicKind:       return ALL_GLSL(Refract);
913         case k_bitCount_IntrinsicKind:      return ALL_SPIRV(BitCount);
914         case k_findLSB_IntrinsicKind:       return ALL_GLSL(FindILsb);
915         case k_findMSB_IntrinsicKind:       return BY_TYPE_GLSL(FindSMsb, FindSMsb, FindUMsb);
916         case k_dFdx_IntrinsicKind:          return FLOAT_SPIRV(DPdx);
917         case k_dFdy_IntrinsicKind:          return SPECIAL(DFdy);
918         case k_fwidth_IntrinsicKind:        return FLOAT_SPIRV(Fwidth);
919 
920         case k_sample_IntrinsicKind:      return SPECIAL(Texture);
921         case k_sampleGrad_IntrinsicKind:  return SPECIAL(TextureGrad);
922         case k_sampleLod_IntrinsicKind:   return SPECIAL(TextureLod);
923         case k_subpassLoad_IntrinsicKind: return SPECIAL(SubpassLoad);
924 
925         case k_textureRead_IntrinsicKind:  return SPECIAL(TextureRead);
926         case k_textureWrite_IntrinsicKind:  return SPECIAL(TextureWrite);
927         case k_textureWidth_IntrinsicKind:  return SPECIAL(TextureWidth);
928         case k_textureHeight_IntrinsicKind:  return SPECIAL(TextureHeight);
929 
930         case k_floatBitsToInt_IntrinsicKind:  return ALL_SPIRV(Bitcast);
931         case k_floatBitsToUint_IntrinsicKind: return ALL_SPIRV(Bitcast);
932         case k_intBitsToFloat_IntrinsicKind:  return ALL_SPIRV(Bitcast);
933         case k_uintBitsToFloat_IntrinsicKind: return ALL_SPIRV(Bitcast);
934 
935         case k_any_IntrinsicKind:   return BOOL_SPIRV(Any);
936         case k_all_IntrinsicKind:   return BOOL_SPIRV(All);
937         case k_not_IntrinsicKind:   return BOOL_SPIRV(LogicalNot);
938 
939         case k_equal_IntrinsicKind:
940             return Intrinsic{kSPIRV_IntrinsicOpcodeKind,
941                              SpvOpFOrdEqual,
942                              SpvOpIEqual,
943                              SpvOpIEqual,
944                              SpvOpLogicalEqual};
945         case k_notEqual_IntrinsicKind:
946             return Intrinsic{kSPIRV_IntrinsicOpcodeKind,
947                              SpvOpFUnordNotEqual,
948                              SpvOpINotEqual,
949                              SpvOpINotEqual,
950                              SpvOpLogicalNotEqual};
951         case k_lessThan_IntrinsicKind:
952             return Intrinsic{kSPIRV_IntrinsicOpcodeKind,
953                              SpvOpFOrdLessThan,
954                              SpvOpSLessThan,
955                              SpvOpULessThan,
956                              SpvOpUndef};
957         case k_lessThanEqual_IntrinsicKind:
958             return Intrinsic{kSPIRV_IntrinsicOpcodeKind,
959                              SpvOpFOrdLessThanEqual,
960                              SpvOpSLessThanEqual,
961                              SpvOpULessThanEqual,
962                              SpvOpUndef};
963         case k_greaterThan_IntrinsicKind:
964             return Intrinsic{kSPIRV_IntrinsicOpcodeKind,
965                              SpvOpFOrdGreaterThan,
966                              SpvOpSGreaterThan,
967                              SpvOpUGreaterThan,
968                              SpvOpUndef};
969         case k_greaterThanEqual_IntrinsicKind:
970             return Intrinsic{kSPIRV_IntrinsicOpcodeKind,
971                              SpvOpFOrdGreaterThanEqual,
972                              SpvOpSGreaterThanEqual,
973                              SpvOpUGreaterThanEqual,
974                              SpvOpUndef};
975 
976         case k_atomicAdd_IntrinsicKind:   return SPECIAL(AtomicAdd);
977         case k_atomicLoad_IntrinsicKind:  return SPECIAL(AtomicLoad);
978         case k_atomicStore_IntrinsicKind: return SPECIAL(AtomicStore);
979 
980         case k_storageBarrier_IntrinsicKind:   return SPECIAL(StorageBarrier);
981         case k_workgroupBarrier_IntrinsicKind: return SPECIAL(WorkgroupBarrier);
982 #ifdef SKSL_EXT
983         case k_textureSize_IntrinsicKind:     return SPECIAL(TextureSize);
984         case k_sampleGather_IntrinsicKind:    return SPECIAL(SampleGather);
985         case k_nonuniformEXT_IntrinsicKind:   return SPECIAL(NonuniformEXT);
986 #endif
987         default:
988             return Intrinsic{kInvalid_IntrinsicOpcodeKind, 0, 0, 0, 0};
989     }
990 }
991 
writeWord(int32_t word,OutputStream & out)992 void SPIRVCodeGenerator::writeWord(int32_t word, OutputStream& out) {
993     out.write((const char*) &word, sizeof(word));
994 }
995 
is_float(const Type & type)996 static bool is_float(const Type& type) {
997     return (type.isScalar() || type.isVector() || type.isMatrix()) &&
998            type.componentType().isFloat();
999 }
1000 
is_signed(const Type & type)1001 static bool is_signed(const Type& type) {
1002     return (type.isScalar() || type.isVector()) && type.componentType().isSigned();
1003 }
1004 
is_unsigned(const Type & type)1005 static bool is_unsigned(const Type& type) {
1006     return (type.isScalar() || type.isVector()) && type.componentType().isUnsigned();
1007 }
1008 
is_bool(const Type & type)1009 static bool is_bool(const Type& type) {
1010     return (type.isScalar() || type.isVector()) && type.componentType().isBoolean();
1011 }
1012 
1013 template <typename T>
pick_by_type(const Type & type,T ifFloat,T ifInt,T ifUInt,T ifBool)1014 static T pick_by_type(const Type& type, T ifFloat, T ifInt, T ifUInt, T ifBool) {
1015     if (is_float(type)) {
1016         return ifFloat;
1017     }
1018     if (is_signed(type)) {
1019         return ifInt;
1020     }
1021     if (is_unsigned(type)) {
1022         return ifUInt;
1023     }
1024     if (is_bool(type)) {
1025         return ifBool;
1026     }
1027     SkDEBUGFAIL("unrecognized type");
1028     return ifFloat;
1029 }
1030 
is_out(ModifierFlags f)1031 static bool is_out(ModifierFlags f) {
1032     return SkToBool(f & ModifierFlag::kOut);
1033 }
1034 
is_in(ModifierFlags f)1035 static bool is_in(ModifierFlags f) {
1036     if (f & ModifierFlag::kIn) {
1037         return true;  // `in` and `inout` both count
1038     }
1039     // If neither in/out flag is set, the type is implicitly `in`.
1040     return !SkToBool(f & ModifierFlag::kOut);
1041 }
1042 
is_control_flow_op(SpvOp_ op)1043 static bool is_control_flow_op(SpvOp_ op) {
1044     switch (op) {
1045         case SpvOpReturn:
1046         case SpvOpReturnValue:
1047         case SpvOpKill:
1048         case SpvOpSwitch:
1049         case SpvOpBranch:
1050         case SpvOpBranchConditional:
1051             return true;
1052         default:
1053             return false;
1054     }
1055 }
1056 
is_globally_reachable_op(SpvOp_ op)1057 static bool is_globally_reachable_op(SpvOp_ op) {
1058     switch (op) {
1059         case SpvOpConstant:
1060         case SpvOpConstantTrue:
1061         case SpvOpConstantFalse:
1062         case SpvOpConstantComposite:
1063         case SpvOpTypeVoid:
1064         case SpvOpTypeInt:
1065         case SpvOpTypeFloat:
1066         case SpvOpTypeBool:
1067         case SpvOpTypeVector:
1068         case SpvOpTypeMatrix:
1069         case SpvOpTypeArray:
1070         case SpvOpTypePointer:
1071         case SpvOpTypeFunction:
1072         case SpvOpTypeRuntimeArray:
1073         case SpvOpTypeStruct:
1074         case SpvOpTypeImage:
1075         case SpvOpTypeSampledImage:
1076         case SpvOpTypeSampler:
1077         case SpvOpVariable:
1078         case SpvOpFunction:
1079         case SpvOpFunctionParameter:
1080         case SpvOpFunctionEnd:
1081         case SpvOpExecutionMode:
1082         case SpvOpMemoryModel:
1083         case SpvOpCapability:
1084         case SpvOpExtInstImport:
1085         case SpvOpEntryPoint:
1086         case SpvOpSource:
1087         case SpvOpSourceExtension:
1088         case SpvOpName:
1089         case SpvOpMemberName:
1090         case SpvOpDecorate:
1091         case SpvOpMemberDecorate:
1092 #ifdef SKSL_EXT
1093         case SpvOpExtension:
1094         case SpvOpSpecConstant:
1095         case SpvOpSpecConstantOp:
1096 #endif
1097             return true;
1098         default:
1099             return false;
1100     }
1101 }
1102 
writeOpCode(SpvOp_ opCode,int length,OutputStream & out)1103 void SPIRVCodeGenerator::writeOpCode(SpvOp_ opCode, int length, OutputStream& out) {
1104     SkASSERT(opCode != SpvOpLoad || &out != &fConstantBuffer);
1105     SkASSERT(opCode != SpvOpUndef);
1106     bool foundDeadCode = false;
1107     if (is_control_flow_op(opCode)) {
1108         // This instruction causes us to leave the current block.
1109         foundDeadCode = (fCurrentBlock == 0);
1110         fCurrentBlock = 0;
1111     } else if (!is_globally_reachable_op(opCode)) {
1112         foundDeadCode = (fCurrentBlock == 0);
1113     }
1114 
1115     if (foundDeadCode) {
1116         // We just encountered dead code--an instruction that don't have an associated block.
1117         // Synthesize a label if this happens; this is necessary to satisfy the validator.
1118         this->writeLabel(this->nextId(nullptr), kBranchlessBlock, out);
1119     }
1120 
1121     this->writeWord((length << 16) | opCode, out);
1122 }
1123 
writeLabel(SpvId label,StraightLineLabelType,OutputStream & out)1124 void SPIRVCodeGenerator::writeLabel(SpvId label, StraightLineLabelType, OutputStream& out) {
1125     // The straight-line label type is not important; in any case, no caches are invalidated.
1126     SkASSERT(!fCurrentBlock);
1127     fCurrentBlock = label;
1128     this->writeInstruction(SpvOpLabel, label, out);
1129 }
1130 
writeLabel(SpvId label,BranchingLabelType type,ConditionalOpCounts ops,OutputStream & out)1131 void SPIRVCodeGenerator::writeLabel(SpvId label, BranchingLabelType type,
1132                                     ConditionalOpCounts ops, OutputStream& out) {
1133     switch (type) {
1134         case kBranchIsBelow:
1135         case kBranchesOnBothSides:
1136             // With a backward or bidirectional branch, we haven't seen the code between the label
1137             // and the branch yet, so any stored value is potentially suspect. Without scanning
1138             // ahead to check, the only safe option is to ditch the store cache entirely.
1139             fStoreCache.reset();
1140             [[fallthrough]];
1141 
1142         case kBranchIsAbove:
1143             // With a forward branch, we can rely on stores that we had cached at the start of the
1144             // statement/expression, if they haven't been touched yet. Anything newer than that is
1145             // pruned.
1146             this->pruneConditionalOps(ops);
1147             break;
1148     }
1149 
1150     // Emit the label.
1151     this->writeLabel(label, kBranchlessBlock, out);
1152 }
1153 
writeInstruction(SpvOp_ opCode,OutputStream & out)1154 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, OutputStream& out) {
1155     this->writeOpCode(opCode, 1, out);
1156 }
1157 
writeInstruction(SpvOp_ opCode,int32_t word1,OutputStream & out)1158 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, OutputStream& out) {
1159     this->writeOpCode(opCode, 2, out);
1160     this->writeWord(word1, out);
1161 }
1162 
writeString(std::string_view s,OutputStream & out)1163 void SPIRVCodeGenerator::writeString(std::string_view s, OutputStream& out) {
1164     out.write(s.data(), s.length());
1165     switch (s.length() % 4) {
1166         case 1:
1167             out.write8(0);
1168             [[fallthrough]];
1169         case 2:
1170             out.write8(0);
1171             [[fallthrough]];
1172         case 3:
1173             out.write8(0);
1174             break;
1175         default:
1176             this->writeWord(0, out);
1177             break;
1178     }
1179 }
1180 
writeInstruction(SpvOp_ opCode,std::string_view string,OutputStream & out)1181 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, std::string_view string,
1182                                           OutputStream& out) {
1183     this->writeOpCode(opCode, 1 + (string.length() + 4) / 4, out);
1184     this->writeString(string, out);
1185 }
1186 
writeInstruction(SpvOp_ opCode,int32_t word1,std::string_view string,OutputStream & out)1187 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, std::string_view string,
1188                                           OutputStream& out) {
1189     this->writeOpCode(opCode, 2 + (string.length() + 4) / 4, out);
1190     this->writeWord(word1, out);
1191     this->writeString(string, out);
1192 }
1193 
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,std::string_view string,OutputStream & out)1194 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
1195                                           std::string_view string, OutputStream& out) {
1196     this->writeOpCode(opCode, 3 + (string.length() + 4) / 4, out);
1197     this->writeWord(word1, out);
1198     this->writeWord(word2, out);
1199     this->writeString(string, out);
1200 }
1201 
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,OutputStream & out)1202 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
1203                                           OutputStream& out) {
1204     this->writeOpCode(opCode, 3, out);
1205     this->writeWord(word1, out);
1206     this->writeWord(word2, out);
1207 }
1208 
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,int32_t word3,OutputStream & out)1209 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
1210                                           int32_t word3, OutputStream& out) {
1211     this->writeOpCode(opCode, 4, out);
1212     this->writeWord(word1, out);
1213     this->writeWord(word2, out);
1214     this->writeWord(word3, out);
1215 }
1216 
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,int32_t word3,int32_t word4,OutputStream & out)1217 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
1218                                           int32_t word3, int32_t word4, OutputStream& out) {
1219     this->writeOpCode(opCode, 5, out);
1220     this->writeWord(word1, out);
1221     this->writeWord(word2, out);
1222     this->writeWord(word3, out);
1223     this->writeWord(word4, out);
1224 }
1225 
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,int32_t word3,int32_t word4,int32_t word5,OutputStream & out)1226 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
1227                                           int32_t word3, int32_t word4, int32_t word5,
1228                                           OutputStream& out) {
1229     this->writeOpCode(opCode, 6, out);
1230     this->writeWord(word1, out);
1231     this->writeWord(word2, out);
1232     this->writeWord(word3, out);
1233     this->writeWord(word4, out);
1234     this->writeWord(word5, out);
1235 }
1236 
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,int32_t word3,int32_t word4,int32_t word5,int32_t word6,OutputStream & out)1237 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
1238                                           int32_t word3, int32_t word4, int32_t word5,
1239                                           int32_t word6, OutputStream& out) {
1240     this->writeOpCode(opCode, 7, out);
1241     this->writeWord(word1, out);
1242     this->writeWord(word2, out);
1243     this->writeWord(word3, out);
1244     this->writeWord(word4, out);
1245     this->writeWord(word5, out);
1246     this->writeWord(word6, out);
1247 }
1248 
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,int32_t word3,int32_t word4,int32_t word5,int32_t word6,int32_t word7,OutputStream & out)1249 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
1250                                           int32_t word3, int32_t word4, int32_t word5,
1251                                           int32_t word6, int32_t word7, OutputStream& out) {
1252     this->writeOpCode(opCode, 8, out);
1253     this->writeWord(word1, out);
1254     this->writeWord(word2, out);
1255     this->writeWord(word3, out);
1256     this->writeWord(word4, out);
1257     this->writeWord(word5, out);
1258     this->writeWord(word6, out);
1259     this->writeWord(word7, out);
1260 }
1261 
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,int32_t word3,int32_t word4,int32_t word5,int32_t word6,int32_t word7,int32_t word8,OutputStream & out)1262 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
1263                                           int32_t word3, int32_t word4, int32_t word5,
1264                                           int32_t word6, int32_t word7, int32_t word8,
1265                                           OutputStream& out) {
1266     this->writeOpCode(opCode, 9, out);
1267     this->writeWord(word1, out);
1268     this->writeWord(word2, out);
1269     this->writeWord(word3, out);
1270     this->writeWord(word4, out);
1271     this->writeWord(word5, out);
1272     this->writeWord(word6, out);
1273     this->writeWord(word7, out);
1274     this->writeWord(word8, out);
1275 }
1276 
BuildInstructionKey(SpvOp_ opCode,const TArray<Word> & words)1277 SPIRVCodeGenerator::Instruction SPIRVCodeGenerator::BuildInstructionKey(SpvOp_ opCode,
1278                                                                         const TArray<Word>& words) {
1279     // Assemble a cache key for this instruction.
1280     Instruction key;
1281     key.fOp = opCode;
1282     key.fWords.resize(words.size());
1283     key.fResultKind = Word::Kind::kNone;
1284 
1285     for (int index = 0; index < words.size(); ++index) {
1286         const Word& word = words[index];
1287         key.fWords[index] = word.fValue;
1288         if (word.isResult()) {
1289             SkASSERT(key.fResultKind == Word::Kind::kNone);
1290             key.fResultKind = word.fKind;
1291         }
1292     }
1293 
1294     return key;
1295 }
1296 
writeInstruction(SpvOp_ opCode,const TArray<Word> & words,OutputStream & out)1297 SpvId SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode,
1298                                            const TArray<Word>& words,
1299                                            OutputStream& out) {
1300     // writeOpLoad and writeOpStore have dedicated code.
1301     SkASSERT(opCode != SpvOpLoad);
1302     SkASSERT(opCode != SpvOpStore);
1303 
1304     // If this instruction exists in our op cache, return the cached SpvId.
1305     Instruction key = BuildInstructionKey(opCode, words);
1306     if (SpvId* cachedOp = fOpCache.find(key)) {
1307         return *cachedOp;
1308     }
1309 
1310     SpvId result = NA;
1311     Precision precision = Precision::kDefault;
1312 
1313     switch (key.fResultKind) {
1314         case Word::Kind::kUniqueResult:
1315             // The instruction returns a SpvId, but we do not want deduplication.
1316             result = this->nextId(Precision::kDefault);
1317             fSpvIdCache.set(result, key);
1318             break;
1319 
1320         case Word::Kind::kNone:
1321             // The instruction doesn't return a SpvId, but we can still cache and deduplicate it.
1322             fOpCache.set(key, result);
1323             break;
1324 
1325         case Word::Kind::kRelaxedPrecisionResult:
1326             precision = Precision::kRelaxed;
1327             [[fallthrough]];
1328 
1329         case Word::Kind::kKeyedResult:
1330             [[fallthrough]];
1331 
1332         case Word::Kind::kDefaultPrecisionResult:
1333             // Consume a new SpvId.
1334             result = this->nextId(precision);
1335             fOpCache.set(key, result);
1336             fSpvIdCache.set(result, key);
1337 
1338             // Globally-reachable ops are not subject to the whims of flow control.
1339             if (!is_globally_reachable_op(opCode)) {
1340                 fReachableOps.push_back(result);
1341             }
1342             break;
1343 
1344         default:
1345             SkDEBUGFAIL("unexpected result kind");
1346             break;
1347     }
1348 
1349     // Write the requested instruction.
1350     this->writeOpCode(opCode, words.size() + 1, out);
1351     for (const Word& word : words) {
1352         if (word.isResult()) {
1353             SkASSERT(result != NA);
1354             this->writeWord(result, out);
1355         } else {
1356             this->writeWord(word.fValue, out);
1357         }
1358     }
1359 
1360     // Return the result.
1361     return result;
1362 }
1363 
writeOpLoad(SpvId type,Precision precision,SpvId pointer,OutputStream & out)1364 SpvId SPIRVCodeGenerator::writeOpLoad(SpvId type,
1365                                       Precision precision,
1366                                       SpvId pointer,
1367                                       OutputStream& out) {
1368     // Look for this pointer in our load-cache.
1369     if (SpvId* cachedOp = fStoreCache.find(pointer)) {
1370         return *cachedOp;
1371     }
1372 
1373     // Write the requested OpLoad instruction.
1374 #ifdef SKSL_EXT
1375     SpvId result = -1;
1376     if (fNonUniformSpvId.find(pointer) != fNonUniformSpvId.end()) {
1377         result = this->nextId(nullptr);
1378         this->writeInstruction(SpvOpDecorate, result, SpvDecorationNonUniform, fDecorationBuffer);
1379     } else {
1380         result = this->nextId(precision);
1381     }
1382 #else
1383     SpvId result = this->nextId(precision);
1384 #endif
1385     this->writeInstruction(SpvOpLoad, type, result, pointer, out);
1386     return result;
1387 }
1388 
1389 #ifdef SKSL_EXT
writeExtensions(OutputStream & out)1390 void SPIRVCodeGenerator::writeExtensions(OutputStream& out) {
1391     for (const auto& ext : fExtensions) {
1392         this->writeInstruction(SpvOpExtension, ext, out);
1393     }
1394 }
1395 #endif
1396 
writeOpStore(StorageClass storageClass,SpvId pointer,SpvId value,OutputStream & out)1397 void SPIRVCodeGenerator::writeOpStore(StorageClass storageClass,
1398                                       SpvId pointer,
1399                                       SpvId value,
1400                                       OutputStream& out) {
1401     // Write the uncached SpvOpStore directly.
1402     this->writeInstruction(SpvOpStore, pointer, value, out);
1403 
1404     if (storageClass == StorageClass::kFunction) {
1405         // Insert a pointer-to-SpvId mapping into the load cache. A writeOpLoad to this pointer will
1406         // return the cached value as-is.
1407         fStoreCache.set(pointer, value);
1408         fStoreOps.push_back(pointer);
1409     }
1410 }
1411 
writeOpConstantTrue(const Type & type)1412 SpvId SPIRVCodeGenerator::writeOpConstantTrue(const Type& type) {
1413     return this->writeInstruction(SpvOpConstantTrue,
1414                                   Words{this->getType(type), Word::Result()},
1415                                   fConstantBuffer);
1416 }
1417 
writeOpConstantFalse(const Type & type)1418 SpvId SPIRVCodeGenerator::writeOpConstantFalse(const Type& type) {
1419     return this->writeInstruction(SpvOpConstantFalse,
1420                                   Words{this->getType(type), Word::Result()},
1421                                   fConstantBuffer);
1422 }
1423 
writeOpConstant(const Type & type,int32_t valueBits)1424 SpvId SPIRVCodeGenerator::writeOpConstant(const Type& type, int32_t valueBits) {
1425     return this->writeInstruction(
1426             SpvOpConstant,
1427             Words{this->getType(type), Word::Result(), Word::Number(valueBits)},
1428             fConstantBuffer);
1429 }
1430 
writeOpConstantComposite(const Type & type,const TArray<SpvId> & values)1431 SpvId SPIRVCodeGenerator::writeOpConstantComposite(const Type& type,
1432                                                    const TArray<SpvId>& values) {
1433     SkASSERT(values.size() == (type.isStruct() ? SkToInt(type.fields().size()) : type.columns()));
1434 
1435     Words words;
1436     words.push_back(this->getType(type));
1437     words.push_back(Word::Result());
1438     for (SpvId value : values) {
1439         words.push_back(value);
1440     }
1441     return this->writeInstruction(SpvOpConstantComposite, words, fConstantBuffer);
1442 }
1443 
toConstants(SpvId value,TArray<SpvId> * constants)1444 bool SPIRVCodeGenerator::toConstants(SpvId value, TArray<SpvId>* constants) {
1445     Instruction* instr = fSpvIdCache.find(value);
1446     if (!instr) {
1447         return false;
1448     }
1449     switch (instr->fOp) {
1450         case SpvOpConstant:
1451         case SpvOpConstantTrue:
1452         case SpvOpConstantFalse:
1453             constants->push_back(value);
1454             return true;
1455 
1456         case SpvOpConstantComposite: // OpConstantComposite ResultType ResultID Constituents...
1457             // Start at word 2 to skip past ResultType and ResultID.
1458             for (int i = 2; i < instr->fWords.size(); ++i) {
1459                 if (!this->toConstants(instr->fWords[i], constants)) {
1460                     return false;
1461                 }
1462             }
1463             return true;
1464 
1465         default:
1466             return false;
1467     }
1468 }
1469 
toConstants(SkSpan<const SpvId> values,TArray<SpvId> * constants)1470 bool SPIRVCodeGenerator::toConstants(SkSpan<const SpvId> values, TArray<SpvId>* constants) {
1471     for (SpvId value : values) {
1472         if (!this->toConstants(value, constants)) {
1473             return false;
1474         }
1475     }
1476     return true;
1477 }
1478 
writeOpCompositeConstruct(const Type & type,const TArray<SpvId> & values,OutputStream & out)1479 SpvId SPIRVCodeGenerator::writeOpCompositeConstruct(const Type& type,
1480                                                     const TArray<SpvId>& values,
1481                                                     OutputStream& out) {
1482     // If this is a vector composed entirely of literals, write a constant-composite instead.
1483     if (type.isVector()) {
1484         STArray<4, SpvId> constants;
1485         if (this->toConstants(SkSpan(values), &constants)) {
1486             // Create a vector from literals.
1487             return this->writeOpConstantComposite(type, constants);
1488         }
1489     }
1490 
1491     // If this is a matrix composed entirely of literals, constant-composite them instead.
1492     if (type.isMatrix()) {
1493         STArray<16, SpvId> constants;
1494         if (this->toConstants(SkSpan(values), &constants)) {
1495             // Create each matrix column.
1496             SkASSERT(type.isMatrix());
1497             const Type& vecType = type.columnType(fContext);
1498             STArray<4, SpvId> columnIDs;
1499             for (int index=0; index < type.columns(); ++index) {
1500                 STArray<4, SpvId> columnConstants(&constants[index * type.rows()],
1501                                                     type.rows());
1502                 columnIDs.push_back(this->writeOpConstantComposite(vecType, columnConstants));
1503             }
1504             // Compose the matrix from its columns.
1505             return this->writeOpConstantComposite(type, columnIDs);
1506         }
1507     }
1508 
1509     Words words;
1510     words.push_back(this->getType(type));
1511     words.push_back(Word::Result(type));
1512     for (SpvId value : values) {
1513         words.push_back(value);
1514     }
1515 #ifdef SKSL_EXT
1516     return this->writeInstruction(fEmittingGlobalConstConstructor ?
1517         SpvOpConstantComposite : SpvOpCompositeConstruct, words, out);
1518 #else
1519     return this->writeInstruction(SpvOpCompositeConstruct, words, out);
1520 #endif
1521 }
1522 
resultTypeForInstruction(const Instruction & instr)1523 SPIRVCodeGenerator::Instruction* SPIRVCodeGenerator::resultTypeForInstruction(
1524         const Instruction& instr) {
1525     // This list should contain every op that we cache that has a result and result-type.
1526     // (If one is missing, we will not find some optimization opportunities.)
1527     // Generally, the result type of an op is in the 0th word, but I'm not sure if this is
1528     // universally true, so it's configurable on a per-op basis.
1529     int resultTypeWord;
1530     switch (instr.fOp) {
1531         case SpvOpConstant:
1532         case SpvOpConstantTrue:
1533         case SpvOpConstantFalse:
1534         case SpvOpConstantComposite:
1535         case SpvOpCompositeConstruct:
1536         case SpvOpCompositeExtract:
1537         case SpvOpLoad:
1538             resultTypeWord = 0;
1539             break;
1540 
1541         default:
1542             return nullptr;
1543     }
1544 
1545     Instruction* typeInstr = fSpvIdCache.find(instr.fWords[resultTypeWord]);
1546     SkASSERT(typeInstr);
1547     return typeInstr;
1548 }
1549 
numComponentsForVecInstruction(const Instruction & instr)1550 int SPIRVCodeGenerator::numComponentsForVecInstruction(const Instruction& instr) {
1551     // If an instruction is in the op cache, its type should be as well.
1552     Instruction* typeInstr = this->resultTypeForInstruction(instr);
1553     SkASSERT(typeInstr);
1554     SkASSERT(typeInstr->fOp == SpvOpTypeVector || typeInstr->fOp == SpvOpTypeFloat ||
1555              typeInstr->fOp == SpvOpTypeInt || typeInstr->fOp == SpvOpTypeBool);
1556 
1557     // For vectors, extract their column count. Scalars have one component by definition.
1558     //   SpvOpTypeVector ResultID ComponentType NumComponents
1559     return (typeInstr->fOp == SpvOpTypeVector) ? typeInstr->fWords[2]
1560                                                : 1;
1561 }
1562 
toComponent(SpvId id,int component)1563 SpvId SPIRVCodeGenerator::toComponent(SpvId id, int component) {
1564     Instruction* instr = fSpvIdCache.find(id);
1565     if (!instr) {
1566         return NA;
1567     }
1568     if (instr->fOp == SpvOpConstantComposite) {
1569         // SpvOpConstantComposite ResultType ResultID [components...]
1570         // Add 2 to the component index to skip past ResultType and ResultID.
1571         return instr->fWords[2 + component];
1572     }
1573     if (instr->fOp == SpvOpCompositeConstruct) {
1574         // SpvOpCompositeConstruct ResultType ResultID [components...]
1575         // Vectors have special rules; check to see if we are composing a vector.
1576         Instruction* composedType = fSpvIdCache.find(instr->fWords[0]);
1577         SkASSERT(composedType);
1578 
1579         // When composing a non-vector, each instruction word maps 1:1 to the component index.
1580         // We can just extract out the associated component directly.
1581         if (composedType->fOp != SpvOpTypeVector) {
1582             return instr->fWords[2 + component];
1583         }
1584 
1585         // When composing a vector, components can be either scalars or vectors.
1586         // This means we need to check the op type on each component. (+2 to skip ResultType/Result)
1587         for (int index = 2; index < instr->fWords.size(); ++index) {
1588             int32_t currentWord = instr->fWords[index];
1589 
1590             // Retrieve the sub-instruction pointed to by OpCompositeConstruct.
1591             Instruction* subinstr = fSpvIdCache.find(currentWord);
1592             if (!subinstr) {
1593                 return NA;
1594             }
1595             // If this subinstruction contains the component we're looking for...
1596             int numComponents = this->numComponentsForVecInstruction(*subinstr);
1597             if (component < numComponents) {
1598                 if (numComponents == 1) {
1599                     // ... it's a scalar. Return it.
1600                     SkASSERT(component == 0);
1601                     return currentWord;
1602                 } else {
1603                     // ... it's a vector. Recurse into it.
1604                     return this->toComponent(currentWord, component);
1605                 }
1606             }
1607             // This sub-instruction doesn't contain our component. Keep walking forward.
1608             component -= numComponents;
1609         }
1610         SkDEBUGFAIL("component index goes past the end of this composite value");
1611         return NA;
1612     }
1613     return NA;
1614 }
1615 
writeOpCompositeExtract(const Type & type,SpvId base,int component,OutputStream & out)1616 SpvId SPIRVCodeGenerator::writeOpCompositeExtract(const Type& type,
1617                                                   SpvId base,
1618                                                   int component,
1619                                                   OutputStream& out) {
1620     // If the base op is a composite, we can extract from it directly.
1621     SpvId result = this->toComponent(base, component);
1622     if (result != NA) {
1623         return result;
1624     }
1625     return this->writeInstruction(
1626             SpvOpCompositeExtract,
1627             {this->getType(type), Word::Result(type), base, Word::Number(component)},
1628             out);
1629 }
1630 
writeOpCompositeExtract(const Type & type,SpvId base,int componentA,int componentB,OutputStream & out)1631 SpvId SPIRVCodeGenerator::writeOpCompositeExtract(const Type& type,
1632                                                   SpvId base,
1633                                                   int componentA,
1634                                                   int componentB,
1635                                                   OutputStream& out) {
1636     // If the base op is a composite, we can extract from it directly.
1637     SpvId result = this->toComponent(base, componentA);
1638     if (result != NA) {
1639         return this->writeOpCompositeExtract(type, result, componentB, out);
1640     }
1641     return this->writeInstruction(SpvOpCompositeExtract,
1642                                   {this->getType(type),
1643                                    Word::Result(type),
1644                                    base,
1645                                    Word::Number(componentA),
1646                                    Word::Number(componentB)},
1647                                   out);
1648 }
1649 
writeCapabilities(OutputStream & out)1650 void SPIRVCodeGenerator::writeCapabilities(OutputStream& out) {
1651     for (uint64_t i = 0, bit = 1; i <= kLast_Capability; i++, bit <<= 1) {
1652         if (fCapabilities & bit) {
1653             this->writeInstruction(SpvOpCapability, (SpvId) i, out);
1654         }
1655     }
1656     this->writeInstruction(SpvOpCapability, SpvCapabilityShader, out);
1657 #ifdef SKSL_EXT
1658     for (auto i : fCapabilitiesExt) {
1659         this->writeInstruction(SpvOpCapability, (SpvId) i, out);
1660     }
1661 #endif
1662 }
1663 
nextId(const Type * type)1664 SpvId SPIRVCodeGenerator::nextId(const Type* type) {
1665     return this->nextId(type && type->hasPrecision() && !type->highPrecision()
1666                 ? Precision::kRelaxed
1667                 : Precision::kDefault);
1668 }
1669 
nextId(Precision precision)1670 SpvId SPIRVCodeGenerator::nextId(Precision precision) {
1671     if (precision == Precision::kRelaxed && !fProgram.fConfig->fSettings.fForceHighPrecision) {
1672         this->writeInstruction(SpvOpDecorate, fIdCount, SpvDecorationRelaxedPrecision,
1673                                fDecorationBuffer);
1674     }
1675     return fIdCount++;
1676 }
1677 
writeStruct(const Type & type,const MemoryLayout & memoryLayout)1678 SpvId SPIRVCodeGenerator::writeStruct(const Type& type, const MemoryLayout& memoryLayout) {
1679     // If we've already written out this struct, return its existing SpvId.
1680     if (SpvId* cachedStructId = fStructMap.find(&type)) {
1681         return *cachedStructId;
1682     }
1683 
1684     // Write all of the field types first, so we don't inadvertently write them while we're in the
1685     // middle of writing the struct instruction.
1686     Words words;
1687     words.push_back(Word::UniqueResult());
1688     for (const auto& f : type.fields()) {
1689         words.push_back(this->getType(*f.fType, f.fLayout, memoryLayout));
1690     }
1691     SpvId resultId = this->writeInstruction(SpvOpTypeStruct, words, fConstantBuffer);
1692     this->writeInstruction(SpvOpName, resultId, type.name(), fNameBuffer);
1693     fStructMap.set(&type, resultId);
1694 
1695     size_t offset = 0;
1696     for (int32_t i = 0; i < (int32_t) type.fields().size(); i++) {
1697         const Field& field = type.fields()[i];
1698         if (!memoryLayout.isSupported(*field.fType)) {
1699             fContext.fErrors->error(type.fPosition, "type '" + field.fType->displayName() +
1700                                                     "' is not permitted here");
1701             return resultId;
1702         }
1703         size_t size = memoryLayout.size(*field.fType);
1704         size_t alignment = memoryLayout.alignment(*field.fType);
1705         const Layout& fieldLayout = field.fLayout;
1706         if (fieldLayout.fOffset >= 0) {
1707             if (fieldLayout.fOffset < (int) offset) {
1708                 fContext.fErrors->error(field.fPosition, "offset of field '" +
1709                         std::string(field.fName) + "' must be at least " + std::to_string(offset));
1710             }
1711             if (fieldLayout.fOffset % alignment) {
1712                 fContext.fErrors->error(field.fPosition,
1713                                         "offset of field '" + std::string(field.fName) +
1714                                         "' must be a multiple of " + std::to_string(alignment));
1715             }
1716             offset = fieldLayout.fOffset;
1717         } else {
1718             size_t mod = offset % alignment;
1719             if (mod) {
1720                 offset += alignment - mod;
1721             }
1722         }
1723         this->writeInstruction(SpvOpMemberName, resultId, i, field.fName, fNameBuffer);
1724         this->writeFieldLayout(fieldLayout, resultId, i);
1725         if (field.fLayout.fBuiltin < 0) {
1726             this->writeInstruction(SpvOpMemberDecorate, resultId, (SpvId) i, SpvDecorationOffset,
1727                                    (SpvId) offset, fDecorationBuffer);
1728         }
1729         if (field.fType->isMatrix()) {
1730             this->writeInstruction(SpvOpMemberDecorate, resultId, i, SpvDecorationColMajor,
1731                                    fDecorationBuffer);
1732             this->writeInstruction(SpvOpMemberDecorate, resultId, i, SpvDecorationMatrixStride,
1733                                    (SpvId) memoryLayout.stride(*field.fType),
1734                                    fDecorationBuffer);
1735         }
1736 #ifdef SKSL_EXT
1737         if (field.fType->isArray()) {
1738             if (field.fType->componentType().isMatrix()) {
1739                 this->writeInstruction(SpvOpMemberDecorate, resultId, i, SpvDecorationColMajor,
1740                                        fDecorationBuffer);
1741                 this->writeInstruction(SpvOpMemberDecorate, resultId, i, SpvDecorationMatrixStride,
1742                                        (SpvId) memoryLayout.stride(field.fType->componentType()),
1743                                        fDecorationBuffer);
1744             }
1745         }
1746 #endif
1747         if (!field.fType->highPrecision()) {
1748             this->writeInstruction(SpvOpMemberDecorate, resultId, (SpvId) i,
1749                                    SpvDecorationRelaxedPrecision, fDecorationBuffer);
1750         }
1751         offset += size;
1752         if ((field.fType->isArray() || field.fType->isStruct()) && offset % alignment != 0) {
1753             offset += alignment - offset % alignment;
1754         }
1755     }
1756 
1757     return resultId;
1758 }
1759 
getType(const Type & type)1760 SpvId SPIRVCodeGenerator::getType(const Type& type) {
1761     return this->getType(type, kDefaultTypeLayout, fDefaultMemoryLayout);
1762 }
1763 
layout_flags_to_image_format(LayoutFlags flags)1764 static SpvImageFormat layout_flags_to_image_format(LayoutFlags flags) {
1765     flags &= LayoutFlag::kAllPixelFormats;
1766     switch (flags.value()) {
1767         case (int)LayoutFlag::kRGBA8:
1768             return SpvImageFormatRgba8;
1769 
1770         case (int)LayoutFlag::kRGBA32F:
1771             return SpvImageFormatRgba32f;
1772 
1773         case (int)LayoutFlag::kR32F:
1774             return SpvImageFormatR32f;
1775 
1776         default:
1777             return SpvImageFormatUnknown;
1778     }
1779 
1780     SkUNREACHABLE;
1781 }
1782 
getType(const Type & rawType,const Layout & typeLayout,const MemoryLayout & memoryLayout)1783 SpvId SPIRVCodeGenerator::getType(const Type& rawType,
1784                                   const Layout& typeLayout,
1785                                   const MemoryLayout& memoryLayout) {
1786     const Type* type = &rawType;
1787 
1788     switch (type->typeKind()) {
1789         case Type::TypeKind::kVoid: {
1790             return this->writeInstruction(SpvOpTypeVoid, Words{Word::Result()}, fConstantBuffer);
1791         }
1792         case Type::TypeKind::kScalar:
1793         case Type::TypeKind::kLiteral: {
1794             if (type->isBoolean()) {
1795                 return this->writeInstruction(SpvOpTypeBool, {Word::Result()}, fConstantBuffer);
1796             }
1797             if (type->isSigned()) {
1798                 return this->writeInstruction(
1799                         SpvOpTypeInt,
1800                         Words{Word::Result(), Word::Number(32), Word::Number(1)},
1801                         fConstantBuffer);
1802             }
1803             if (type->isUnsigned()) {
1804                 return this->writeInstruction(
1805                         SpvOpTypeInt,
1806                         Words{Word::Result(), Word::Number(32), Word::Number(0)},
1807                         fConstantBuffer);
1808             }
1809             if (type->isFloat()) {
1810                 return this->writeInstruction(
1811                         SpvOpTypeFloat,
1812                         Words{Word::Result(), Word::Number(32)},
1813                         fConstantBuffer);
1814             }
1815             SkDEBUGFAILF("unrecognized scalar type '%s'", type->description().c_str());
1816             return NA;
1817         }
1818         case Type::TypeKind::kVector: {
1819             SpvId scalarTypeId = this->getType(type->componentType(), typeLayout, memoryLayout);
1820             return this->writeInstruction(
1821                     SpvOpTypeVector,
1822                     Words{Word::Result(), scalarTypeId, Word::Number(type->columns())},
1823                     fConstantBuffer);
1824         }
1825         case Type::TypeKind::kMatrix: {
1826             SpvId vectorTypeId = this->getType(IndexExpression::IndexType(fContext, *type),
1827                                                typeLayout,
1828                                                memoryLayout);
1829             return this->writeInstruction(
1830                     SpvOpTypeMatrix,
1831                     Words{Word::Result(), vectorTypeId, Word::Number(type->columns())},
1832                     fConstantBuffer);
1833         }
1834         case Type::TypeKind::kArray: {
1835             const MemoryLayout arrayMemoryLayout =
1836                                     fCaps.fForceStd430ArrayLayout
1837                                         ? MemoryLayout(MemoryLayout::Standard::k430)
1838                                         : memoryLayout;
1839 
1840             if (!arrayMemoryLayout.isSupported(*type)) {
1841                 fContext.fErrors->error(type->fPosition, "type '" + type->displayName() +
1842                                                          "' is not permitted here");
1843                 return NA;
1844             }
1845             size_t stride = arrayMemoryLayout.stride(*type);
1846             SpvId typeId = this->getType(type->componentType(), typeLayout, arrayMemoryLayout);
1847             SpvId result = NA;
1848             if (type->isUnsizedArray()) {
1849 #ifdef SKSL_EXT
1850             if (type->componentType().isSampler()) {
1851                 fCapabilitiesExt.insert(SpvCapabilityRuntimeDescriptorArray);
1852                 fExtensions.insert("SPV_EXT_descriptor_indexing");
1853             }
1854 #endif
1855                 result = this->writeInstruction(SpvOpTypeRuntimeArray,
1856                                                 Words{Word::KeyedResult(stride), typeId},
1857                                                 fConstantBuffer);
1858             } else {
1859                 SpvId countId = this->writeLiteral(type->columns(), *fContext.fTypes.fInt);
1860                 result = this->writeInstruction(SpvOpTypeArray,
1861                                                 Words{Word::KeyedResult(stride), typeId, countId},
1862                                                 fConstantBuffer);
1863             }
1864             this->writeInstruction(SpvOpDecorate,
1865                                    {result, SpvDecorationArrayStride, Word::Number(stride)},
1866                                    fDecorationBuffer);
1867             return result;
1868         }
1869         case Type::TypeKind::kStruct: {
1870             return this->writeStruct(*type, memoryLayout);
1871         }
1872         case Type::TypeKind::kSeparateSampler: {
1873             return this->writeInstruction(SpvOpTypeSampler, Words{Word::Result()}, fConstantBuffer);
1874         }
1875         case Type::TypeKind::kSampler: {
1876             if (SpvDimBuffer == type->dimensions()) {
1877                 fCapabilities |= 1ULL << SpvCapabilitySampledBuffer;
1878             }
1879             SpvId imageTypeId = this->getType(type->textureType(), typeLayout, memoryLayout);
1880             return this->writeInstruction(SpvOpTypeSampledImage,
1881                                           Words{Word::Result(), imageTypeId},
1882                                           fConstantBuffer);
1883         }
1884         case Type::TypeKind::kTexture: {
1885             SpvId floatTypeId = this->getType(*fContext.fTypes.fFloat,
1886                                               kDefaultTypeLayout,
1887                                               memoryLayout);
1888 
1889             bool sampled = (type->textureAccess() == Type::TextureAccess::kSample);
1890             SpvImageFormat format = (!sampled && type->dimensions() != SpvDimSubpassData)
1891                                             ? layout_flags_to_image_format(typeLayout.fFlags)
1892                                             : SpvImageFormatUnknown;
1893 
1894             return this->writeInstruction(SpvOpTypeImage,
1895                                           Words{Word::Result(),
1896                                                 floatTypeId,
1897                                                 Word::Number(type->dimensions()),
1898                                                 Word::Number(type->isDepth()),
1899                                                 Word::Number(type->isArrayedTexture()),
1900                                                 Word::Number(type->isMultisampled()),
1901                                                 Word::Number(sampled ? 1 : 2),
1902                                                 format},
1903                                           fConstantBuffer);
1904         }
1905         case Type::TypeKind::kAtomic: {
1906             // SkSL currently only supports the atomicUint type.
1907             SkASSERT(type->matches(*fContext.fTypes.fAtomicUInt));
1908             // SPIR-V doesn't have atomic types. Rather, it allows atomic operations on primitive
1909             // types. The SPIR-V type of an SkSL atomic is simply the underlying type.
1910             return this->writeInstruction(SpvOpTypeInt,
1911                                           Words{Word::Result(), Word::Number(32), Word::Number(0)},
1912                                           fConstantBuffer);
1913         }
1914         default: {
1915             SkDEBUGFAILF("invalid type: %s", type->description().c_str());
1916             return NA;
1917         }
1918     }
1919 }
1920 
getFunctionType(const FunctionDeclaration & function)1921 SpvId SPIRVCodeGenerator::getFunctionType(const FunctionDeclaration& function) {
1922     Words words;
1923     words.push_back(Word::Result());
1924     words.push_back(this->getType(function.returnType()));
1925     for (const Variable* parameter : function.parameters()) {
1926         bool paramIsSpecialized = fActiveSpecialization && fActiveSpecialization->find(parameter);
1927         if (fUseTextureSamplerPairs && parameter->type().isSampler()) {
1928             words.push_back(this->getFunctionParameterType(parameter->type().textureType(),
1929                                                            parameter->layout()));
1930             if (!paramIsSpecialized) {
1931                 words.push_back(this->getFunctionParameterType(*fContext.fTypes.fSampler,
1932                                                                kDefaultTypeLayout));
1933             }
1934         } else if (!paramIsSpecialized) {
1935             words.push_back(this->getFunctionParameterType(parameter->type(), parameter->layout()));
1936         }
1937     }
1938     return this->writeInstruction(SpvOpTypeFunction, words, fConstantBuffer);
1939 }
1940 
getFunctionParameterType(const Type & parameterType,const Layout & parameterLayout)1941 SpvId SPIRVCodeGenerator::getFunctionParameterType(const Type& parameterType,
1942                                                    const Layout& parameterLayout) {
1943     // glslang treats all function arguments as pointers whether they need to be or
1944     // not. I was initially puzzled by this until I ran bizarre failures with certain
1945     // patterns of function calls and control constructs, as exemplified by this minimal
1946     // failure case:
1947     //
1948     // void sphere(float x) {
1949     // }
1950     //
1951     // void map() {
1952     //     sphere(1.0);
1953     // }
1954     //
1955     // void main() {
1956     //     for (int i = 0; i < 1; i++) {
1957     //         map();
1958     //     }
1959     // }
1960     //
1961     // As of this writing, compiling this in the "obvious" way (with sphere taking a float)
1962     // crashes. Making it take a float* and storing the argument in a temporary variable,
1963     // as glslang does, fixes it.
1964     //
1965     // The consensus among shader compiler authors seems to be that GPU driver generally don't
1966     // handle value-based parameters consistently. It is highly likely that they fit their
1967     // implementations to conform to glslang. We take care to do so ourselves.
1968     //
1969     // Our implementation first stores every parameter value into a function storage-class pointer
1970     // before calling a function. The exception is for opaque handle types (samplers and textures)
1971     // which must be stored in a pointer with UniformConstant storage-class. This prevents
1972     // unnecessary temporaries (becuase opaque handles are always rooted in a pointer variable),
1973     // matches glslang's behavior, and translates into WGSL more easily when targeting Dawn.
1974     StorageClass storageClass;
1975     if (parameterType.typeKind() == Type::TypeKind::kSampler ||
1976         parameterType.typeKind() == Type::TypeKind::kSeparateSampler ||
1977         parameterType.typeKind() == Type::TypeKind::kTexture) {
1978         storageClass = StorageClass::kUniformConstant;
1979     } else {
1980         storageClass = StorageClass::kFunction;
1981     }
1982     return this->getPointerType(parameterType,
1983                                 parameterLayout,
1984                                 this->memoryLayoutForStorageClass(storageClass),
1985                                 storageClass);
1986 }
1987 
getPointerType(const Type & type,StorageClass storageClass)1988 SpvId SPIRVCodeGenerator::getPointerType(const Type& type, StorageClass storageClass) {
1989     return this->getPointerType(type,
1990                                 kDefaultTypeLayout,
1991                                 this->memoryLayoutForStorageClass(storageClass),
1992                                 storageClass);
1993 }
1994 
getPointerType(const Type & type,const Layout & typeLayout,const MemoryLayout & memoryLayout,StorageClass storageClass)1995 SpvId SPIRVCodeGenerator::getPointerType(const Type& type,
1996                                          const Layout& typeLayout,
1997                                          const MemoryLayout& memoryLayout,
1998                                          StorageClass storageClass) {
1999     return this->writeInstruction(SpvOpTypePointer,
2000                                   Words{Word::Result(),
2001                                         Word::Number(get_storage_class_spv_id(storageClass)),
2002                                         this->getType(type, typeLayout, memoryLayout)},
2003                                   fConstantBuffer);
2004 }
2005 
writeExpression(const Expression & expr,OutputStream & out)2006 SpvId SPIRVCodeGenerator::writeExpression(const Expression& expr, OutputStream& out) {
2007     switch (expr.kind()) {
2008         case Expression::Kind::kBinary:
2009             return this->writeBinaryExpression(expr.as<BinaryExpression>(), out);
2010         case Expression::Kind::kConstructorArrayCast:
2011             return this->writeExpression(*expr.as<ConstructorArrayCast>().argument(), out);
2012         case Expression::Kind::kConstructorArray:
2013         case Expression::Kind::kConstructorStruct:
2014             return this->writeCompositeConstructor(expr.asAnyConstructor(), out);
2015         case Expression::Kind::kConstructorDiagonalMatrix:
2016             return this->writeConstructorDiagonalMatrix(expr.as<ConstructorDiagonalMatrix>(), out);
2017         case Expression::Kind::kConstructorMatrixResize:
2018             return this->writeConstructorMatrixResize(expr.as<ConstructorMatrixResize>(), out);
2019         case Expression::Kind::kConstructorScalarCast:
2020             return this->writeConstructorScalarCast(expr.as<ConstructorScalarCast>(), out);
2021         case Expression::Kind::kConstructorSplat:
2022             return this->writeConstructorSplat(expr.as<ConstructorSplat>(), out);
2023         case Expression::Kind::kConstructorCompound:
2024             return this->writeConstructorCompound(expr.as<ConstructorCompound>(), out);
2025         case Expression::Kind::kConstructorCompoundCast:
2026             return this->writeConstructorCompoundCast(expr.as<ConstructorCompoundCast>(), out);
2027         case Expression::Kind::kEmpty:
2028             return NA;
2029         case Expression::Kind::kFieldAccess:
2030             return this->writeFieldAccess(expr.as<FieldAccess>(), out);
2031         case Expression::Kind::kFunctionCall:
2032             return this->writeFunctionCall(expr.as<FunctionCall>(), out);
2033         case Expression::Kind::kLiteral:
2034             return this->writeLiteral(expr.as<Literal>());
2035         case Expression::Kind::kPrefix:
2036             return this->writePrefixExpression(expr.as<PrefixExpression>(), out);
2037         case Expression::Kind::kPostfix:
2038             return this->writePostfixExpression(expr.as<PostfixExpression>(), out);
2039         case Expression::Kind::kSwizzle:
2040             return this->writeSwizzle(expr.as<Swizzle>(), out);
2041         case Expression::Kind::kVariableReference:
2042             return this->writeVariableReference(expr.as<VariableReference>(), out);
2043         case Expression::Kind::kTernary:
2044             return this->writeTernaryExpression(expr.as<TernaryExpression>(), out);
2045         case Expression::Kind::kIndex:
2046             return this->writeIndexExpression(expr.as<IndexExpression>(), out);
2047         case Expression::Kind::kSetting:
2048             return this->writeExpression(*expr.as<Setting>().toLiteral(fCaps), out);
2049         default:
2050             SkDEBUGFAILF("unsupported expression: %s", expr.description().c_str());
2051             break;
2052     }
2053     return NA;
2054 }
2055 
writeIntrinsicCall(const FunctionCall & c,OutputStream & out)2056 SpvId SPIRVCodeGenerator::writeIntrinsicCall(const FunctionCall& c, OutputStream& out) {
2057     const FunctionDeclaration& function = c.function();
2058     Intrinsic intrinsic = this->getIntrinsic(function.intrinsicKind());
2059     if (intrinsic.opKind == kInvalid_IntrinsicOpcodeKind) {
2060         fContext.fErrors->error(c.fPosition, "unsupported intrinsic '" + function.description() +
2061                 "'");
2062         return NA;
2063     }
2064     const ExpressionArray& arguments = c.arguments();
2065     int32_t intrinsicId = intrinsic.floatOp;
2066     if (!arguments.empty()) {
2067         const Type& type = arguments[0]->type();
2068         if (intrinsic.opKind == kSpecial_IntrinsicOpcodeKind) {
2069             // Keep the default float op.
2070         } else {
2071             intrinsicId = pick_by_type(type, intrinsic.floatOp, intrinsic.signedOp,
2072                                        intrinsic.unsignedOp, intrinsic.boolOp);
2073         }
2074     }
2075     switch (intrinsic.opKind) {
2076         case kGLSL_STD_450_IntrinsicOpcodeKind: {
2077             SpvId result = this->nextId(&c.type());
2078             TArray<SpvId> argumentIds;
2079             argumentIds.reserve_exact(arguments.size());
2080             std::vector<TempVar> tempVars;
2081             for (int i = 0; i < arguments.size(); i++) {
2082                 this->writeFunctionCallArgument(argumentIds, c, i, &tempVars,
2083                                                 /*specializedParams=*/nullptr, out);
2084             }
2085             this->writeOpCode(SpvOpExtInst, 5 + (int32_t) argumentIds.size(), out);
2086             this->writeWord(this->getType(c.type()), out);
2087             this->writeWord(result, out);
2088             this->writeWord(fGLSLExtendedInstructions, out);
2089             this->writeWord(intrinsicId, out);
2090             for (SpvId id : argumentIds) {
2091                 this->writeWord(id, out);
2092             }
2093             this->copyBackTempVars(tempVars, out);
2094             return result;
2095         }
2096         case kSPIRV_IntrinsicOpcodeKind: {
2097             // GLSL supports dot(float, float), but SPIR-V does not. Convert it to FMul
2098             if (intrinsicId == SpvOpDot && arguments[0]->type().isScalar()) {
2099                 intrinsicId = SpvOpFMul;
2100             }
2101             SpvId result = this->nextId(&c.type());
2102             TArray<SpvId> argumentIds;
2103             argumentIds.reserve_exact(arguments.size());
2104             std::vector<TempVar> tempVars;
2105             for (int i = 0; i < arguments.size(); i++) {
2106                 this->writeFunctionCallArgument(argumentIds, c, i, &tempVars,
2107                                                 /*specializedParams=*/nullptr, out);
2108             }
2109             if (!c.type().isVoid()) {
2110                 this->writeOpCode((SpvOp_) intrinsicId, 3 + (int32_t) arguments.size(), out);
2111                 this->writeWord(this->getType(c.type()), out);
2112                 this->writeWord(result, out);
2113             } else {
2114                 this->writeOpCode((SpvOp_) intrinsicId, 1 + (int32_t) arguments.size(), out);
2115             }
2116             for (SpvId id : argumentIds) {
2117                 this->writeWord(id, out);
2118             }
2119             this->copyBackTempVars(tempVars, out);
2120             return result;
2121         }
2122         case kSpecial_IntrinsicOpcodeKind:
2123             return this->writeSpecialIntrinsic(c, (SpecialIntrinsic) intrinsicId, out);
2124         default:
2125             fContext.fErrors->error(c.fPosition, "unsupported intrinsic '" +
2126                     function.description() + "'");
2127             return NA;
2128     }
2129 }
2130 
vectorize(const Expression & arg,int vectorSize,OutputStream & out)2131 SpvId SPIRVCodeGenerator::vectorize(const Expression& arg, int vectorSize, OutputStream& out) {
2132     SkASSERT(vectorSize >= 1 && vectorSize <= 4);
2133     const Type& argType = arg.type();
2134     if (argType.isScalar() && vectorSize > 1) {
2135         SpvId argID = this->writeExpression(arg, out);
2136         return this->splat(argType.toCompound(fContext, vectorSize, /*rows=*/1), argID, out);
2137     }
2138 
2139     SkASSERT(vectorSize == argType.columns());
2140     return this->writeExpression(arg, out);
2141 }
2142 
vectorize(const ExpressionArray & args,OutputStream & out)2143 TArray<SpvId> SPIRVCodeGenerator::vectorize(const ExpressionArray& args, OutputStream& out) {
2144     int vectorSize = 1;
2145     for (const auto& a : args) {
2146         if (a->type().isVector()) {
2147             if (vectorSize > 1) {
2148                 SkASSERT(a->type().columns() == vectorSize);
2149             } else {
2150                 vectorSize = a->type().columns();
2151             }
2152         }
2153     }
2154     TArray<SpvId> result;
2155     result.reserve_exact(args.size());
2156     for (const auto& arg : args) {
2157         result.push_back(this->vectorize(*arg, vectorSize, out));
2158     }
2159     return result;
2160 }
2161 
writeGLSLExtendedInstruction(const Type & type,SpvId id,SpvId floatInst,SpvId signedInst,SpvId unsignedInst,const TArray<SpvId> & args,OutputStream & out)2162 void SPIRVCodeGenerator::writeGLSLExtendedInstruction(const Type& type, SpvId id, SpvId floatInst,
2163                                                       SpvId signedInst, SpvId unsignedInst,
2164                                                       const TArray<SpvId>& args,
2165                                                       OutputStream& out) {
2166     this->writeOpCode(SpvOpExtInst, 5 + args.size(), out);
2167     this->writeWord(this->getType(type), out);
2168     this->writeWord(id, out);
2169     this->writeWord(fGLSLExtendedInstructions, out);
2170     this->writeWord(pick_by_type(type, floatInst, signedInst, unsignedInst, NA), out);
2171     for (SpvId a : args) {
2172         this->writeWord(a, out);
2173     }
2174 }
2175 
writeSpecialIntrinsic(const FunctionCall & c,SpecialIntrinsic kind,OutputStream & out)2176 SpvId SPIRVCodeGenerator::writeSpecialIntrinsic(const FunctionCall& c, SpecialIntrinsic kind,
2177                                                 OutputStream& out) {
2178     const ExpressionArray& arguments = c.arguments();
2179     const Type& callType = c.type();
2180 #ifdef SKSL_EXT
2181     SpvId result = this->nextId(kind == kMix_SpecialIntrinsic ? nullptr : &callType);
2182 #else
2183     SpvId result = this->nextId(nullptr);
2184 #endif
2185     switch (kind) {
2186         case kAtan_SpecialIntrinsic: {
2187             STArray<2, SpvId> argumentIds;
2188             for (const std::unique_ptr<Expression>& arg : arguments) {
2189                 argumentIds.push_back(this->writeExpression(*arg, out));
2190             }
2191             this->writeOpCode(SpvOpExtInst, 5 + (int32_t) argumentIds.size(), out);
2192             this->writeWord(this->getType(callType), out);
2193             this->writeWord(result, out);
2194             this->writeWord(fGLSLExtendedInstructions, out);
2195             this->writeWord(argumentIds.size() == 2 ? GLSLstd450Atan2 : GLSLstd450Atan, out);
2196             for (SpvId id : argumentIds) {
2197                 this->writeWord(id, out);
2198             }
2199             break;
2200         }
2201         case kSampledImage_SpecialIntrinsic: {
2202             SkASSERT(arguments.size() == 2);
2203             SpvId img = this->writeExpression(*arguments[0], out);
2204             SpvId sampler = this->writeExpression(*arguments[1], out);
2205             this->writeInstruction(SpvOpSampledImage,
2206                                    this->getType(callType),
2207                                    result,
2208                                    img,
2209                                    sampler,
2210                                    out);
2211             break;
2212         }
2213         case kSubpassLoad_SpecialIntrinsic: {
2214             SpvId img = this->writeExpression(*arguments[0], out);
2215             ExpressionArray args;
2216             args.reserve_exact(2);
2217             args.push_back(Literal::MakeInt(fContext, Position(), /*value=*/0));
2218             args.push_back(Literal::MakeInt(fContext, Position(), /*value=*/0));
2219             ConstructorCompound ctor(Position(), *fContext.fTypes.fInt2, std::move(args));
2220             SpvId coords = this->writeExpression(ctor, out);
2221             if (arguments.size() == 1) {
2222                 this->writeInstruction(SpvOpImageRead,
2223                                        this->getType(callType),
2224                                        result,
2225                                        img,
2226                                        coords,
2227                                        out);
2228             } else {
2229                 SkASSERT(arguments.size() == 2);
2230                 SpvId sample = this->writeExpression(*arguments[1], out);
2231                 this->writeInstruction(SpvOpImageRead,
2232                                        this->getType(callType),
2233                                        result,
2234                                        img,
2235                                        coords,
2236                                        SpvImageOperandsSampleMask,
2237                                        sample,
2238                                        out);
2239             }
2240             break;
2241         }
2242         case kTexture_SpecialIntrinsic: {
2243             SpvOp_ op = SpvOpImageSampleImplicitLod;
2244             const Type& arg1Type = arguments[1]->type();
2245             switch (arguments[0]->type().dimensions()) {
2246                 case SpvDim1D:
2247                     if (arg1Type.matches(*fContext.fTypes.fFloat2)) {
2248                         op = SpvOpImageSampleProjImplicitLod;
2249                     } else {
2250                         SkASSERT(arg1Type.matches(*fContext.fTypes.fFloat));
2251                     }
2252                     break;
2253                 case SpvDim2D:
2254                     if (arg1Type.matches(*fContext.fTypes.fFloat3)) {
2255                         op = SpvOpImageSampleProjImplicitLod;
2256                     } else {
2257                         SkASSERT(arg1Type.matches(*fContext.fTypes.fFloat2));
2258                     }
2259                     break;
2260                 case SpvDim3D:
2261                     if (arg1Type.matches(*fContext.fTypes.fFloat4)) {
2262                         op = SpvOpImageSampleProjImplicitLod;
2263                     } else {
2264                         SkASSERT(arg1Type.matches(*fContext.fTypes.fFloat3));
2265                     }
2266                     break;
2267                 case SpvDimCube:   // fall through
2268                 case SpvDimRect:   // fall through
2269                 case SpvDimBuffer: // fall through
2270                 case SpvDimSubpassData:
2271                     break;
2272             }
2273             SpvId type = this->getType(callType);
2274             SpvId sampler = this->writeExpression(*arguments[0], out);
2275             SpvId uv = this->writeExpression(*arguments[1], out);
2276             if (arguments.size() == 3) {
2277                 this->writeInstruction(op, type, result, sampler, uv,
2278                                        SpvImageOperandsBiasMask,
2279                                        this->writeExpression(*arguments[2], out),
2280                                        out);
2281             } else {
2282                 SkASSERT(arguments.size() == 2);
2283                 if (fProgram.fConfig->fSettings.fSharpenTextures) {
2284                     SpvId lodBias = this->writeLiteral(kSharpenTexturesBias,
2285                                                        *fContext.fTypes.fFloat);
2286                     this->writeInstruction(op, type, result, sampler, uv,
2287                                            SpvImageOperandsBiasMask, lodBias, out);
2288                 } else {
2289                     this->writeInstruction(op, type, result, sampler, uv,
2290                                            out);
2291                 }
2292             }
2293             break;
2294         }
2295         case kTextureGrad_SpecialIntrinsic: {
2296             SpvOp_ op = SpvOpImageSampleExplicitLod;
2297             SkASSERT(arguments.size() == 4);
2298             SkASSERT(arguments[0]->type().dimensions() == SpvDim2D);
2299             SkASSERT(arguments[1]->type().matches(*fContext.fTypes.fFloat2));
2300             SkASSERT(arguments[2]->type().matches(*fContext.fTypes.fFloat2));
2301             SkASSERT(arguments[3]->type().matches(*fContext.fTypes.fFloat2));
2302             SpvId type = this->getType(callType);
2303             SpvId sampler = this->writeExpression(*arguments[0], out);
2304             SpvId uv = this->writeExpression(*arguments[1], out);
2305             SpvId dPdx = this->writeExpression(*arguments[2], out);
2306             SpvId dPdy = this->writeExpression(*arguments[3], out);
2307             this->writeInstruction(op, type, result, sampler, uv, SpvImageOperandsGradMask,
2308                                    dPdx, dPdy, out);
2309             break;
2310         }
2311         case kTextureLod_SpecialIntrinsic: {
2312             SpvOp_ op = SpvOpImageSampleExplicitLod;
2313             SkASSERT(arguments.size() == 3);
2314             SkASSERT(arguments[0]->type().dimensions() == SpvDim2D);
2315             SkASSERT(arguments[2]->type().matches(*fContext.fTypes.fFloat));
2316             const Type& arg1Type = arguments[1]->type();
2317             if (arg1Type.matches(*fContext.fTypes.fFloat3)) {
2318                 op = SpvOpImageSampleProjExplicitLod;
2319             } else {
2320                 SkASSERT(arg1Type.matches(*fContext.fTypes.fFloat2));
2321             }
2322             SpvId type = this->getType(callType);
2323             SpvId sampler = this->writeExpression(*arguments[0], out);
2324             SpvId uv = this->writeExpression(*arguments[1], out);
2325             this->writeInstruction(op, type, result, sampler, uv,
2326                                    SpvImageOperandsLodMask,
2327                                    this->writeExpression(*arguments[2], out),
2328                                    out);
2329             break;
2330         }
2331         case kTextureRead_SpecialIntrinsic: {
2332             SkASSERT(arguments[0]->type().dimensions() == SpvDim2D);
2333             SkASSERT(arguments[1]->type().matches(*fContext.fTypes.fUInt2));
2334 
2335             SpvId type = this->getType(callType);
2336             SpvId image = this->writeExpression(*arguments[0], out);
2337             SpvId coord = this->writeExpression(*arguments[1], out);
2338 
2339             const Type& arg0Type = arguments[0]->type();
2340             SkASSERT(arg0Type.typeKind() == Type::TypeKind::kTexture);
2341 
2342             switch (arg0Type.textureAccess()) {
2343                 case Type::TextureAccess::kSample:
2344                     this->writeInstruction(SpvOpImageFetch, type, result, image, coord,
2345                                            SpvImageOperandsLodMask,
2346                                            this->writeOpConstant(*fContext.fTypes.fInt, 0),
2347                                            out);
2348                     break;
2349                 case Type::TextureAccess::kRead:
2350                 case Type::TextureAccess::kReadWrite:
2351                     this->writeInstruction(SpvOpImageRead, type, result, image, coord, out);
2352                     break;
2353                 case Type::TextureAccess::kWrite:
2354                 default:
2355                     SkDEBUGFAIL("'textureRead' called on writeonly texture type");
2356                     break;
2357             }
2358 
2359             break;
2360         }
2361         case kTextureWrite_SpecialIntrinsic: {
2362             SkASSERT(arguments[0]->type().dimensions() == SpvDim2D);
2363             SkASSERT(arguments[1]->type().matches(*fContext.fTypes.fUInt2));
2364             SkASSERT(arguments[2]->type().matches(*fContext.fTypes.fHalf4));
2365 
2366             SpvId image = this->writeExpression(*arguments[0], out);
2367             SpvId coord = this->writeExpression(*arguments[1], out);
2368             SpvId texel = this->writeExpression(*arguments[2], out);
2369 
2370             this->writeInstruction(SpvOpImageWrite, image, coord, texel, out);
2371             break;
2372         }
2373         case kTextureWidth_SpecialIntrinsic:
2374         case kTextureHeight_SpecialIntrinsic: {
2375             SkASSERT(arguments[0]->type().dimensions() == SpvDim2D);
2376             fCapabilities |= 1ULL << SpvCapabilityImageQuery;
2377 
2378             SpvId dimsType = this->getType(*fContext.fTypes.fUInt2);
2379             SpvId dims = this->nextId(nullptr);
2380             SpvId image = this->writeExpression(*arguments[0], out);
2381             this->writeInstruction(SpvOpImageQuerySize, dimsType, dims, image, out);
2382 
2383             SpvId type = this->getType(callType);
2384             int32_t index = (kind == kTextureWidth_SpecialIntrinsic) ? 0 : 1;
2385             this->writeInstruction(SpvOpCompositeExtract, type, result, dims, index, out);
2386             break;
2387         }
2388         case kMod_SpecialIntrinsic: {
2389             TArray<SpvId> args = this->vectorize(arguments, out);
2390             SkASSERT(args.size() == 2);
2391             const Type& operandType = arguments[0]->type();
2392             SpvOp_ op = pick_by_type(operandType, SpvOpFMod, SpvOpSMod, SpvOpUMod, SpvOpUndef);
2393             SkASSERT(op != SpvOpUndef);
2394             this->writeOpCode(op, 5, out);
2395             this->writeWord(this->getType(operandType), out);
2396             this->writeWord(result, out);
2397             this->writeWord(args[0], out);
2398             this->writeWord(args[1], out);
2399             break;
2400         }
2401         case kDFdy_SpecialIntrinsic: {
2402             SpvId fn = this->writeExpression(*arguments[0], out);
2403             this->writeOpCode(SpvOpDPdy, 4, out);
2404             this->writeWord(this->getType(callType), out);
2405             this->writeWord(result, out);
2406             this->writeWord(fn, out);
2407             if (!fProgram.fConfig->fSettings.fForceNoRTFlip) {
2408                 this->addRTFlipUniform(c.fPosition);
2409                 ComponentArray componentArray;
2410                 for (int index = 0; index < callType.columns(); ++index) {
2411                     componentArray.push_back(SwizzleComponent::Y);
2412                 }
2413                 SpvId rtFlipY = this->writeSwizzle(*this->identifier(SKSL_RTFLIP_NAME),
2414                                                    componentArray, out);
2415                 SpvId flipped = this->nextId(&callType);
2416                 this->writeInstruction(SpvOpFMul, this->getType(callType), flipped, result,
2417                                        rtFlipY, out);
2418                 result = flipped;
2419             }
2420             break;
2421         }
2422         case kClamp_SpecialIntrinsic: {
2423             TArray<SpvId> args = this->vectorize(arguments, out);
2424             SkASSERT(args.size() == 3);
2425             this->writeGLSLExtendedInstruction(callType, result, GLSLstd450FClamp, GLSLstd450SClamp,
2426                                                GLSLstd450UClamp, args, out);
2427             break;
2428         }
2429         case kMax_SpecialIntrinsic: {
2430             TArray<SpvId> args = this->vectorize(arguments, out);
2431             SkASSERT(args.size() == 2);
2432             this->writeGLSLExtendedInstruction(callType, result, GLSLstd450FMax, GLSLstd450SMax,
2433                                                GLSLstd450UMax, args, out);
2434             break;
2435         }
2436         case kMin_SpecialIntrinsic: {
2437             TArray<SpvId> args = this->vectorize(arguments, out);
2438             SkASSERT(args.size() == 2);
2439             this->writeGLSLExtendedInstruction(callType, result, GLSLstd450FMin, GLSLstd450SMin,
2440                                                GLSLstd450UMin, args, out);
2441             break;
2442         }
2443         case kMix_SpecialIntrinsic: {
2444             TArray<SpvId> args = this->vectorize(arguments, out);
2445             SkASSERT(args.size() == 3);
2446             if (arguments[2]->type().componentType().isBoolean()) {
2447                 // Use OpSelect to implement Boolean mix().
2448                 SpvId falseId     = this->writeExpression(*arguments[0], out);
2449                 SpvId trueId      = this->writeExpression(*arguments[1], out);
2450                 SpvId conditionId = this->writeExpression(*arguments[2], out);
2451                 this->writeInstruction(SpvOpSelect, this->getType(arguments[0]->type()), result,
2452                                        conditionId, trueId, falseId, out);
2453             } else {
2454                 this->writeGLSLExtendedInstruction(callType, result, GLSLstd450FMix, SpvOpUndef,
2455                                                    SpvOpUndef, args, out);
2456             }
2457             break;
2458         }
2459         case kSaturate_SpecialIntrinsic: {
2460             SkASSERT(arguments.size() == 1);
2461             int width = arguments[0]->type().columns();
2462             STArray<3, SpvId> spvArgs{
2463                 this->vectorize(*arguments[0], width, out),
2464                 this->vectorize(*Literal::MakeFloat(fContext, Position(), /*value=*/0), width, out),
2465                 this->vectorize(*Literal::MakeFloat(fContext, Position(), /*value=*/1), width, out),
2466             };
2467             this->writeGLSLExtendedInstruction(callType, result, GLSLstd450FClamp, GLSLstd450SClamp,
2468                                                GLSLstd450UClamp, spvArgs, out);
2469             break;
2470         }
2471         case kSmoothStep_SpecialIntrinsic: {
2472             TArray<SpvId> args = this->vectorize(arguments, out);
2473             SkASSERT(args.size() == 3);
2474             this->writeGLSLExtendedInstruction(callType, result, GLSLstd450SmoothStep, SpvOpUndef,
2475                                                SpvOpUndef, args, out);
2476             break;
2477         }
2478         case kStep_SpecialIntrinsic: {
2479             TArray<SpvId> args = this->vectorize(arguments, out);
2480             SkASSERT(args.size() == 2);
2481             this->writeGLSLExtendedInstruction(callType, result, GLSLstd450Step, SpvOpUndef,
2482                                                SpvOpUndef, args, out);
2483             break;
2484         }
2485         case kMatrixCompMult_SpecialIntrinsic: {
2486             SkASSERT(arguments.size() == 2);
2487             SpvId lhs = this->writeExpression(*arguments[0], out);
2488             SpvId rhs = this->writeExpression(*arguments[1], out);
2489             result = this->writeComponentwiseMatrixBinary(callType, lhs, rhs, SpvOpFMul, out);
2490             break;
2491         }
2492 #ifdef SKSL_EXT
2493         case kTextureSize_SpecialIntrinsic: {
2494             SkASSERT(arguments[0]->type().dimensions() == SpvDim2D);
2495 
2496             fCapabilities |= 1ULL << SpvCapabilityImageQuery;
2497 
2498             SpvId dimsType = this->getType(*fContext.fTypes.fInt2);
2499             SpvId sampledImage = this->writeExpression(*arguments[0], out);
2500             SpvId image = this->nextId(nullptr);
2501             SpvId imageType = this->getType(arguments[0]->type().textureType());
2502             SpvId lod = this->writeExpression(*arguments[1], out);
2503             this->writeInstruction(SpvOpImage, imageType, image, sampledImage, out);
2504             this->writeInstruction(SpvOpImageQuerySizeLod, dimsType, result, image, lod, out);
2505             break;
2506         }
2507         case kSampleGather_SpecialIntrinsic: {
2508             SpvOp_ op = SpvOpImageGather;
2509             SpvId type = this->getType(callType);
2510             SpvId sampler = this->writeExpression(*arguments[0], out);
2511             SpvId uv = this->writeExpression(*arguments[1], out);
2512             SpvId comp;
2513             // 3 is the number of input parameters for this function.
2514             if (arguments.size() == 3) {
2515                 // 2 is the position of the comp parameter, representing the sequence number of the color channel.
2516                 comp = this->writeExpression(*arguments[2], out);
2517             } else {
2518                 // 2 is the number of input parameter for function.
2519                 SkASSERT(arguments.size() == 2);
2520                 comp = this->writeLiteral(0, *fContext.fTypes.fInt);
2521             }
2522             this->writeInstruction(op, type, result, sampler, uv, comp, out);
2523             break;
2524         }
2525         case kNonuniformEXT_SpecialIntrinsic: {
2526             fCapabilitiesExt.insert(SpvCapabilityShaderNonUniform);
2527             SpvId dimsType = this->getType(*fContext.fTypes.fUInt);
2528             this->writeInstruction(SpvOpDecorate, result, SpvDecorationNonUniform, fDecorationBuffer);
2529             SpvId lod = this->writeExpression(*arguments[0], out);
2530             fNonUniformSpvId.insert(result);
2531             this->writeInstruction(SpvOpCopyObject, dimsType, result, lod, out);
2532             break;
2533         }
2534 #endif
2535         case kAtomicAdd_SpecialIntrinsic:
2536         case kAtomicLoad_SpecialIntrinsic:
2537         case kAtomicStore_SpecialIntrinsic:
2538             result = this->writeAtomicIntrinsic(c, kind, result, out);
2539             break;
2540         case kStorageBarrier_SpecialIntrinsic:
2541         case kWorkgroupBarrier_SpecialIntrinsic: {
2542             // Both barrier types operate in the workgroup execution and memory scope and differ
2543             // only in memory semantics. storageBarrier() is not a device-scope barrier.
2544             SpvId scopeId =
2545                     this->writeOpConstant(*fContext.fTypes.fUInt, (int32_t)SpvScopeWorkgroup);
2546             int32_t memSemMask = (kind == kStorageBarrier_SpecialIntrinsic)
2547                                          ? SpvMemorySemanticsAcquireReleaseMask |
2548                                                    SpvMemorySemanticsUniformMemoryMask
2549                                          : SpvMemorySemanticsAcquireReleaseMask |
2550                                                    SpvMemorySemanticsWorkgroupMemoryMask;
2551             SpvId memorySemanticsId = this->writeOpConstant(*fContext.fTypes.fUInt, memSemMask);
2552             this->writeInstruction(SpvOpControlBarrier,
2553                                    scopeId,  // execution scope
2554                                    scopeId,  // memory scope
2555                                    memorySemanticsId,
2556                                    out);
2557             break;
2558         }
2559     }
2560     return result;
2561 }
2562 
writeAtomicIntrinsic(const FunctionCall & c,SpecialIntrinsic kind,SpvId resultId,OutputStream & out)2563 SpvId SPIRVCodeGenerator::writeAtomicIntrinsic(const FunctionCall& c,
2564                                                SpecialIntrinsic kind,
2565                                                SpvId resultId,
2566                                                OutputStream& out) {
2567     const ExpressionArray& arguments = c.arguments();
2568     SkASSERT(!arguments.empty());
2569 
2570     std::unique_ptr<LValue> atomicPtr = this->getLValue(*arguments[0], out);
2571     SpvId atomicPtrId = atomicPtr->getPointer();
2572     if (atomicPtrId == NA) {
2573         SkDEBUGFAILF("atomic intrinsic expected a pointer argument: %s",
2574                      arguments[0]->description().c_str());
2575         return NA;
2576     }
2577 
2578     SpvId memoryScopeId = NA;
2579     {
2580         // In SkSL, the atomicUint type can only be declared as a workgroup variable or SSBO block
2581         // member. The two memory scopes that these map to are "workgroup" and "device",
2582         // respectively.
2583         SpvScope memoryScope;
2584         switch (atomicPtr->storageClass()) {
2585             case StorageClass::kUniform:
2586             case StorageClass::kStorageBuffer:
2587                 // We encode storage buffers in the uniform address space (with the BufferBlock
2588                 // decorator).
2589                 memoryScope = SpvScopeDevice;
2590                 break;
2591             case StorageClass::kWorkgroup:
2592                 memoryScope = SpvScopeWorkgroup;
2593                 break;
2594             default:
2595                 SkDEBUGFAILF("atomic argument has invalid storage class: %d",
2596                              get_storage_class_spv_id(atomicPtr->storageClass()));
2597                 return NA;
2598         }
2599         memoryScopeId = this->writeOpConstant(*fContext.fTypes.fUInt, (int32_t)memoryScope);
2600     }
2601 
2602     SpvId relaxedMemoryOrderId =
2603             this->writeOpConstant(*fContext.fTypes.fUInt, (int32_t)SpvMemorySemanticsMaskNone);
2604 
2605     switch (kind) {
2606         case kAtomicAdd_SpecialIntrinsic:
2607             SkASSERT(arguments.size() == 2);
2608             this->writeInstruction(SpvOpAtomicIAdd,
2609                                    this->getType(c.type()),
2610                                    resultId,
2611                                    atomicPtrId,
2612                                    memoryScopeId,
2613                                    relaxedMemoryOrderId,
2614                                    this->writeExpression(*arguments[1], out),
2615                                    out);
2616             break;
2617         case kAtomicLoad_SpecialIntrinsic:
2618             SkASSERT(arguments.size() == 1);
2619             this->writeInstruction(SpvOpAtomicLoad,
2620                                    this->getType(c.type()),
2621                                    resultId,
2622                                    atomicPtrId,
2623                                    memoryScopeId,
2624                                    relaxedMemoryOrderId,
2625                                    out);
2626             break;
2627         case kAtomicStore_SpecialIntrinsic:
2628             SkASSERT(arguments.size() == 2);
2629             this->writeInstruction(SpvOpAtomicStore,
2630                                    atomicPtrId,
2631                                    memoryScopeId,
2632                                    relaxedMemoryOrderId,
2633                                    this->writeExpression(*arguments[1], out),
2634                                    out);
2635             break;
2636         default:
2637             SkUNREACHABLE;
2638     }
2639 
2640     return resultId;
2641 }
2642 
writeFunctionCallArgument(TArray<SpvId> & argumentList,const FunctionCall & call,int argIndex,std::vector<TempVar> * tempVars,const SkBitSet * specializedParams,OutputStream & out)2643 void SPIRVCodeGenerator::writeFunctionCallArgument(TArray<SpvId>& argumentList,
2644                                                    const FunctionCall& call,
2645                                                    int argIndex,
2646                                                    std::vector<TempVar>* tempVars,
2647                                                    const SkBitSet* specializedParams,
2648                                                    OutputStream& out) {
2649     const FunctionDeclaration& funcDecl = call.function();
2650     const Expression& arg = *call.arguments()[argIndex];
2651     const Variable* param = funcDecl.parameters()[argIndex];
2652     bool paramIsSpecialized = specializedParams && specializedParams->test(argIndex);
2653     ModifierFlags paramFlags = param->modifierFlags();
2654 
2655     // Ignore the argument since it is specialized, if fUseTextureSamplerPairs is true and this
2656     // argument is a sampler, handle ignoring the sampler below when generating the texture and
2657     // sampler pair arguments.
2658     if (paramIsSpecialized && !(param->type().isSampler() && fUseTextureSamplerPairs)) {
2659         return;
2660     }
2661 
2662     if (arg.is<VariableReference>() && (arg.type().typeKind() == Type::TypeKind::kSampler ||
2663                                         arg.type().typeKind() == Type::TypeKind::kSeparateSampler ||
2664                                         arg.type().typeKind() == Type::TypeKind::kTexture)) {
2665         // Opaque handle (sampler/texture) arguments are always declared as pointers but never
2666         // stored in intermediates when calling user-defined functions.
2667         //
2668         // The case for intrinsics (which take opaque arguments by value) is handled above just like
2669         // regular pointers.
2670         //
2671         // See getFunctionParameterType for further explanation.
2672         const Variable* var = arg.as<VariableReference>().variable();
2673 
2674         // In Dawn-mode the texture and sampler arguments are forwarded to the helper function.
2675         if (fUseTextureSamplerPairs && var->type().isSampler()) {
2676             if (const auto* p = fSynthesizedSamplerMap.find(var)) {
2677                 SpvId* img = fVariableMap.find((*p)->fTexture.get());
2678                 SkASSERT(img);
2679 
2680                 argumentList.push_back(*img);
2681 
2682                 if (!paramIsSpecialized) {
2683                     SpvId* sampler = fVariableMap.find((*p)->fSampler.get());
2684                     SkASSERT(sampler);
2685                     argumentList.push_back(*sampler);
2686                 }
2687                 return;
2688             }
2689             SkDEBUGFAIL("sampler missing from fSynthesizedSamplerMap");
2690         }
2691 
2692         SpvId* entry = fVariableMap.find(var);
2693         SkASSERTF(entry, "%s", arg.description().c_str());
2694         argumentList.push_back(*entry);
2695         return;
2696     }
2697     SkASSERT(!paramIsSpecialized);
2698 
2699     // ID of temporary variable that we will use to hold this argument, or 0 if it is being
2700     // passed directly
2701     SpvId tmpVar = NA;
2702     // if we need a temporary var to store this argument, this is the value to store in the var
2703     SpvId tmpValueId = NA;
2704 
2705     if (is_out(paramFlags)) {
2706         std::unique_ptr<LValue> lv = this->getLValue(arg, out);
2707         // We handle out params with a temp var that we copy back to the original variable at the
2708         // end of the call. GLSL guarantees that the original variable will be unchanged until the
2709         // end of the call, and also that out params are written back to their original variables in
2710         // a specific order (left-to-right), so it's unsafe to pass a pointer to the original value.
2711         if (is_in(paramFlags)) {
2712             tmpValueId = lv->load(out);
2713         }
2714         tmpVar = this->nextId(&arg.type());
2715         tempVars->push_back(TempVar{tmpVar, &arg.type(), std::move(lv)});
2716     } else if (funcDecl.isIntrinsic()) {
2717         // Unlike user function calls, non-out intrinsic arguments don't need pointer parameters.
2718         argumentList.push_back(this->writeExpression(arg, out));
2719         return;
2720     } else {
2721         // We always use pointer parameters when calling user functions.
2722         // See getFunctionParameterType for further explanation.
2723         tmpValueId = this->writeExpression(arg, out);
2724         tmpVar = this->nextId(nullptr);
2725     }
2726     this->writeInstruction(SpvOpVariable,
2727                            this->getPointerType(arg.type(), StorageClass::kFunction),
2728                            tmpVar,
2729                            SpvStorageClassFunction,
2730                            fVariableBuffer);
2731     if (tmpValueId != NA) {
2732         this->writeOpStore(StorageClass::kFunction, tmpVar, tmpValueId, out);
2733     }
2734     argumentList.push_back(tmpVar);
2735 }
2736 
copyBackTempVars(const std::vector<TempVar> & tempVars,OutputStream & out)2737 void SPIRVCodeGenerator::copyBackTempVars(const std::vector<TempVar>& tempVars, OutputStream& out) {
2738     for (const TempVar& tempVar : tempVars) {
2739         SpvId load = this->nextId(tempVar.type);
2740         this->writeInstruction(SpvOpLoad, this->getType(*tempVar.type), load, tempVar.spvId, out);
2741         tempVar.lvalue->store(load, out);
2742     }
2743 }
2744 
writeFunctionCall(const FunctionCall & c,OutputStream & out)2745 SpvId SPIRVCodeGenerator::writeFunctionCall(const FunctionCall& c, OutputStream& out) {
2746     // Handle intrinsics.
2747     const FunctionDeclaration& function = c.function();
2748     if (function.isIntrinsic() && !function.definition()) {
2749         return this->writeIntrinsicCall(c, out);
2750     }
2751 
2752     // Look up this function (or its specialization, if any) in our map of function SpvIds.
2753     Analysis::SpecializationIndex specializationIndex = Analysis::FindSpecializationIndexForCall(
2754             c, fSpecializationInfo, fActiveSpecializationIndex);
2755     SpvId* entry = fFunctionMap.find({&function, specializationIndex});
2756     if (!entry) {
2757         fContext.fErrors->error(c.fPosition, "function '" + function.description() +
2758                                              "' is not defined");
2759         return NA;
2760     }
2761 
2762     // If we are calling a specialized function, we need to gather the specialized parameters
2763     // so we can remove them from the argument list.
2764     SkBitSet specializedParams =
2765             Analysis::FindSpecializedParametersForFunction(c.function(), fSpecializationInfo);
2766 
2767     // Temp variables are used to write back out-parameters after the function call is complete.
2768     const ExpressionArray& arguments = c.arguments();
2769     std::vector<TempVar> tempVars;
2770     TArray<SpvId> argumentIds;
2771     argumentIds.reserve_exact(arguments.size());
2772     for (int i = 0; i < arguments.size(); i++) {
2773         this->writeFunctionCallArgument(argumentIds, c, i, &tempVars, &specializedParams, out);
2774     }
2775     SpvId result = this->nextId(nullptr);
2776     this->writeOpCode(SpvOpFunctionCall, 4 + (int32_t)argumentIds.size(), out);
2777     this->writeWord(this->getType(c.type()), out);
2778     this->writeWord(result, out);
2779     this->writeWord(*entry, out);
2780     for (SpvId id : argumentIds) {
2781         this->writeWord(id, out);
2782     }
2783     // Now that the call is complete, we copy temp out-variables back to their real lvalues.
2784     this->copyBackTempVars(tempVars, out);
2785     return result;
2786 }
2787 
castScalarToType(SpvId inputExprId,const Type & inputType,const Type & outputType,OutputStream & out)2788 SpvId SPIRVCodeGenerator::castScalarToType(SpvId inputExprId,
2789                                            const Type& inputType,
2790                                            const Type& outputType,
2791                                            OutputStream& out) {
2792     if (outputType.isFloat()) {
2793         return this->castScalarToFloat(inputExprId, inputType, outputType, out);
2794     }
2795     if (outputType.isSigned()) {
2796         return this->castScalarToSignedInt(inputExprId, inputType, outputType, out);
2797     }
2798     if (outputType.isUnsigned()) {
2799         return this->castScalarToUnsignedInt(inputExprId, inputType, outputType, out);
2800     }
2801     if (outputType.isBoolean()) {
2802         return this->castScalarToBoolean(inputExprId, inputType, outputType, out);
2803     }
2804 
2805     fContext.fErrors->error(Position(), "unsupported cast: " + inputType.description() + " to " +
2806             outputType.description());
2807     return inputExprId;
2808 }
2809 
castScalarToFloat(SpvId inputId,const Type & inputType,const Type & outputType,OutputStream & out)2810 SpvId SPIRVCodeGenerator::castScalarToFloat(SpvId inputId, const Type& inputType,
2811                                             const Type& outputType, OutputStream& out) {
2812     // Casting a float to float is a no-op.
2813     if (inputType.isFloat()) {
2814         return inputId;
2815     }
2816 
2817     // Given the input type, generate the appropriate instruction to cast to float.
2818     SpvId result = this->nextId(&outputType);
2819     if (inputType.isBoolean()) {
2820         // Use OpSelect to convert the boolean argument to a literal 1.0 or 0.0.
2821         const SpvId oneID = this->writeLiteral(1.0, *fContext.fTypes.fFloat);
2822         const SpvId zeroID = this->writeLiteral(0.0, *fContext.fTypes.fFloat);
2823         this->writeInstruction(SpvOpSelect, this->getType(outputType), result,
2824                                inputId, oneID, zeroID, out);
2825     } else if (inputType.isSigned()) {
2826         this->writeInstruction(SpvOpConvertSToF, this->getType(outputType), result, inputId, out);
2827     } else if (inputType.isUnsigned()) {
2828         this->writeInstruction(SpvOpConvertUToF, this->getType(outputType), result, inputId, out);
2829     } else {
2830         SkDEBUGFAILF("unsupported type for float typecast: %s", inputType.description().c_str());
2831         return NA;
2832     }
2833     return result;
2834 }
2835 
castScalarToSignedInt(SpvId inputId,const Type & inputType,const Type & outputType,OutputStream & out)2836 SpvId SPIRVCodeGenerator::castScalarToSignedInt(SpvId inputId, const Type& inputType,
2837                                                 const Type& outputType, OutputStream& out) {
2838     // Casting a signed int to signed int is a no-op.
2839     if (inputType.isSigned()) {
2840         return inputId;
2841     }
2842 
2843     // Given the input type, generate the appropriate instruction to cast to signed int.
2844     SpvId result = this->nextId(&outputType);
2845     if (inputType.isBoolean()) {
2846         // Use OpSelect to convert the boolean argument to a literal 1 or 0.
2847         const SpvId oneID = this->writeLiteral(1.0, *fContext.fTypes.fInt);
2848         const SpvId zeroID = this->writeLiteral(0.0, *fContext.fTypes.fInt);
2849         this->writeInstruction(SpvOpSelect, this->getType(outputType), result,
2850                                inputId, oneID, zeroID, out);
2851     } else if (inputType.isFloat()) {
2852         this->writeInstruction(SpvOpConvertFToS, this->getType(outputType), result, inputId, out);
2853     } else if (inputType.isUnsigned()) {
2854         this->writeInstruction(SpvOpBitcast, this->getType(outputType), result, inputId, out);
2855     } else {
2856         SkDEBUGFAILF("unsupported type for signed int typecast: %s",
2857                      inputType.description().c_str());
2858         return NA;
2859     }
2860 #ifdef SKSL_EXT
2861     if (fNonUniformSpvId.find(inputId) != fNonUniformSpvId.end()) {
2862         fNonUniformSpvId.insert(result);
2863         this->writeInstruction(SpvOpDecorate, result, SpvDecorationNonUniform, fDecorationBuffer);
2864     }
2865 #endif
2866     return result;
2867 }
2868 
castScalarToUnsignedInt(SpvId inputId,const Type & inputType,const Type & outputType,OutputStream & out)2869 SpvId SPIRVCodeGenerator::castScalarToUnsignedInt(SpvId inputId, const Type& inputType,
2870                                                   const Type& outputType, OutputStream& out) {
2871     // Casting an unsigned int to unsigned int is a no-op.
2872     if (inputType.isUnsigned()) {
2873         return inputId;
2874     }
2875 
2876     // Given the input type, generate the appropriate instruction to cast to unsigned int.
2877     SpvId result = this->nextId(&outputType);
2878     if (inputType.isBoolean()) {
2879         // Use OpSelect to convert the boolean argument to a literal 1u or 0u.
2880         const SpvId oneID = this->writeLiteral(1.0, *fContext.fTypes.fUInt);
2881         const SpvId zeroID = this->writeLiteral(0.0, *fContext.fTypes.fUInt);
2882         this->writeInstruction(SpvOpSelect, this->getType(outputType), result,
2883                                inputId, oneID, zeroID, out);
2884     } else if (inputType.isFloat()) {
2885         this->writeInstruction(SpvOpConvertFToU, this->getType(outputType), result, inputId, out);
2886     } else if (inputType.isSigned()) {
2887         this->writeInstruction(SpvOpBitcast, this->getType(outputType), result, inputId, out);
2888     } else {
2889         SkDEBUGFAILF("unsupported type for unsigned int typecast: %s",
2890                      inputType.description().c_str());
2891         return NA;
2892     }
2893 #ifdef SKSL_EXT
2894     if (fNonUniformSpvId.find(inputId) != fNonUniformSpvId.end()) {
2895         fNonUniformSpvId.insert(result);
2896         this->writeInstruction(SpvOpDecorate, result, SpvDecorationNonUniform, fDecorationBuffer);
2897     }
2898 #endif
2899     return result;
2900 }
2901 
castScalarToBoolean(SpvId inputId,const Type & inputType,const Type & outputType,OutputStream & out)2902 SpvId SPIRVCodeGenerator::castScalarToBoolean(SpvId inputId, const Type& inputType,
2903                                               const Type& outputType, OutputStream& out) {
2904     // Casting a bool to bool is a no-op.
2905     if (inputType.isBoolean()) {
2906         return inputId;
2907     }
2908 
2909     // Given the input type, generate the appropriate instruction to cast to bool.
2910     SpvId result = this->nextId(nullptr);
2911     if (inputType.isSigned()) {
2912         // Synthesize a boolean result by comparing the input against a signed zero literal.
2913         const SpvId zeroID = this->writeLiteral(0.0, *fContext.fTypes.fInt);
2914         this->writeInstruction(SpvOpINotEqual, this->getType(outputType), result,
2915                                inputId, zeroID, out);
2916     } else if (inputType.isUnsigned()) {
2917         // Synthesize a boolean result by comparing the input against an unsigned zero literal.
2918         const SpvId zeroID = this->writeLiteral(0.0, *fContext.fTypes.fUInt);
2919         this->writeInstruction(SpvOpINotEqual, this->getType(outputType), result,
2920                                inputId, zeroID, out);
2921     } else if (inputType.isFloat()) {
2922         // Synthesize a boolean result by comparing the input against a floating-point zero literal.
2923         const SpvId zeroID = this->writeLiteral(0.0, *fContext.fTypes.fFloat);
2924         this->writeInstruction(SpvOpFUnordNotEqual, this->getType(outputType), result,
2925                                inputId, zeroID, out);
2926     } else {
2927         SkDEBUGFAILF("unsupported type for boolean typecast: %s", inputType.description().c_str());
2928         return NA;
2929     }
2930     return result;
2931 }
2932 
writeMatrixCopy(SpvId src,const Type & srcType,const Type & dstType,OutputStream & out)2933 SpvId SPIRVCodeGenerator::writeMatrixCopy(SpvId src, const Type& srcType, const Type& dstType,
2934                                           OutputStream& out) {
2935     SkASSERT(srcType.isMatrix());
2936     SkASSERT(dstType.isMatrix());
2937     SkASSERT(srcType.componentType().matches(dstType.componentType()));
2938     const Type& srcColumnType = srcType.componentType().toCompound(fContext, srcType.rows(), 1);
2939     const Type& dstColumnType = dstType.componentType().toCompound(fContext, dstType.rows(), 1);
2940     SkASSERT(dstType.componentType().isFloat());
2941     SpvId dstColumnTypeId = this->getType(dstColumnType);
2942     const SpvId zeroId = this->writeLiteral(0.0, dstType.componentType());
2943     const SpvId oneId = this->writeLiteral(1.0, dstType.componentType());
2944 
2945     STArray<4, SpvId> columns;
2946     for (int i = 0; i < dstType.columns(); i++) {
2947         if (i < srcType.columns()) {
2948             // we're still inside the src matrix, copy the column
2949             SpvId srcColumn = this->writeOpCompositeExtract(srcColumnType, src, i, out);
2950             SpvId dstColumn;
2951             if (srcType.rows() == dstType.rows()) {
2952                 // columns are equal size, don't need to do anything
2953                 dstColumn = srcColumn;
2954             }
2955             else if (dstType.rows() > srcType.rows()) {
2956                 // dst column is bigger, need to zero-pad it
2957                 STArray<4, SpvId> values;
2958                 values.push_back(srcColumn);
2959                 for (int j = srcType.rows(); j < dstType.rows(); ++j) {
2960                     values.push_back((i == j) ? oneId : zeroId);
2961                 }
2962                 dstColumn = this->writeOpCompositeConstruct(dstColumnType, values, out);
2963             }
2964             else {
2965                 // dst column is smaller, need to swizzle the src column
2966                 dstColumn = this->nextId(&dstType);
2967                 this->writeOpCode(SpvOpVectorShuffle, 5 + dstType.rows(), out);
2968                 this->writeWord(dstColumnTypeId, out);
2969                 this->writeWord(dstColumn, out);
2970                 this->writeWord(srcColumn, out);
2971                 this->writeWord(srcColumn, out);
2972                 for (int j = 0; j < dstType.rows(); j++) {
2973                     this->writeWord(j, out);
2974                 }
2975             }
2976             columns.push_back(dstColumn);
2977         } else {
2978             // we're past the end of the src matrix, need to synthesize an identity-matrix column
2979             STArray<4, SpvId> values;
2980             for (int j = 0; j < dstType.rows(); ++j) {
2981                 values.push_back((i == j) ? oneId : zeroId);
2982             }
2983             columns.push_back(this->writeOpCompositeConstruct(dstColumnType, values, out));
2984         }
2985     }
2986 
2987     return this->writeOpCompositeConstruct(dstType, columns, out);
2988 }
2989 
addColumnEntry(const Type & columnType,TArray<SpvId> * currentColumn,TArray<SpvId> * columnIds,int rows,SpvId entry,OutputStream & out)2990 void SPIRVCodeGenerator::addColumnEntry(const Type& columnType,
2991                                         TArray<SpvId>* currentColumn,
2992                                         TArray<SpvId>* columnIds,
2993                                         int rows,
2994                                         SpvId entry,
2995                                         OutputStream& out) {
2996     SkASSERT(currentColumn->size() < rows);
2997     currentColumn->push_back(entry);
2998     if (currentColumn->size() == rows) {
2999         // Synthesize this column into a vector.
3000         SpvId columnId = this->writeOpCompositeConstruct(columnType, *currentColumn, out);
3001         columnIds->push_back(columnId);
3002         currentColumn->clear();
3003     }
3004 }
3005 
writeMatrixConstructor(const ConstructorCompound & c,OutputStream & out)3006 SpvId SPIRVCodeGenerator::writeMatrixConstructor(const ConstructorCompound& c, OutputStream& out) {
3007     const Type& type = c.type();
3008     SkASSERT(type.isMatrix());
3009     SkASSERT(!c.arguments().empty());
3010     const Type& arg0Type = c.arguments()[0]->type();
3011     // go ahead and write the arguments so we don't try to write new instructions in the middle of
3012     // an instruction
3013     STArray<16, SpvId> arguments;
3014     for (const std::unique_ptr<Expression>& arg : c.arguments()) {
3015         arguments.push_back(this->writeExpression(*arg, out));
3016     }
3017 
3018     if (arguments.size() == 1 && arg0Type.isVector()) {
3019         // Special-case handling of float4 -> mat2x2.
3020         SkASSERT(type.rows() == 2 && type.columns() == 2);
3021         SkASSERT(arg0Type.columns() == 4);
3022         SpvId v[4];
3023         for (int i = 0; i < 4; ++i) {
3024             v[i] = this->writeOpCompositeExtract(type.componentType(), arguments[0], i, out);
3025         }
3026         const Type& vecType = type.columnType(fContext);
3027         SpvId v0v1 = this->writeOpCompositeConstruct(vecType, {v[0], v[1]}, out);
3028         SpvId v2v3 = this->writeOpCompositeConstruct(vecType, {v[2], v[3]}, out);
3029         return this->writeOpCompositeConstruct(type, {v0v1, v2v3}, out);
3030     }
3031 
3032     int rows = type.rows();
3033     const Type& columnType = type.columnType(fContext);
3034     // SpvIds of completed columns of the matrix.
3035     STArray<4, SpvId> columnIds;
3036     // SpvIds of scalars we have written to the current column so far.
3037     STArray<4, SpvId> currentColumn;
3038     for (int i = 0; i < arguments.size(); i++) {
3039         const Type& argType = c.arguments()[i]->type();
3040         if (currentColumn.empty() && argType.isVector() && argType.columns() == rows) {
3041             // This vector is a complete matrix column by itself and can be used as-is.
3042             columnIds.push_back(arguments[i]);
3043         } else if (argType.columns() == 1) {
3044             // This argument is a lone scalar and can be added to the current column as-is.
3045             this->addColumnEntry(columnType, &currentColumn, &columnIds, rows, arguments[i], out);
3046         } else {
3047             // This argument needs to be decomposed into its constituent scalars.
3048             for (int j = 0; j < argType.columns(); ++j) {
3049                 SpvId swizzle = this->writeOpCompositeExtract(argType.componentType(),
3050                                                               arguments[i], j, out);
3051                 this->addColumnEntry(columnType, &currentColumn, &columnIds, rows, swizzle, out);
3052             }
3053         }
3054     }
3055     SkASSERT(columnIds.size() == type.columns());
3056     return this->writeOpCompositeConstruct(type, columnIds, out);
3057 }
3058 
writeConstructorCompound(const ConstructorCompound & c,OutputStream & out)3059 SpvId SPIRVCodeGenerator::writeConstructorCompound(const ConstructorCompound& c,
3060                                                    OutputStream& out) {
3061     return c.type().isMatrix() ? this->writeMatrixConstructor(c, out)
3062                                : this->writeVectorConstructor(c, out);
3063 }
3064 
writeVectorConstructor(const ConstructorCompound & c,OutputStream & out)3065 SpvId SPIRVCodeGenerator::writeVectorConstructor(const ConstructorCompound& c, OutputStream& out) {
3066     const Type& type = c.type();
3067     const Type& componentType = type.componentType();
3068     SkASSERT(type.isVector());
3069 
3070     STArray<4, SpvId> arguments;
3071     for (int i = 0; i < c.arguments().size(); i++) {
3072         const Type& argType = c.arguments()[i]->type();
3073         SkASSERT(componentType.numberKind() == argType.componentType().numberKind());
3074 
3075         SpvId arg = this->writeExpression(*c.arguments()[i], out);
3076         if (argType.isMatrix()) {
3077             // CompositeConstruct cannot take a 2x2 matrix as an input, so we need to extract out
3078             // each scalar separately.
3079             SkASSERT(argType.rows() == 2);
3080             SkASSERT(argType.columns() == 2);
3081             for (int j = 0; j < 4; ++j) {
3082                 arguments.push_back(this->writeOpCompositeExtract(componentType, arg,
3083                                                                   j / 2, j % 2, out));
3084             }
3085         } else if (argType.isVector()) {
3086             // There's a bug in the Intel Vulkan driver where OpCompositeConstruct doesn't handle
3087             // vector arguments at all, so we always extract each vector component and pass them
3088             // into OpCompositeConstruct individually.
3089             for (int j = 0; j < argType.columns(); j++) {
3090                 arguments.push_back(this->writeOpCompositeExtract(componentType, arg, j, out));
3091             }
3092         } else {
3093             arguments.push_back(arg);
3094         }
3095     }
3096 
3097     return this->writeOpCompositeConstruct(type, arguments, out);
3098 }
3099 
writeConstructorSplat(const ConstructorSplat & c,OutputStream & out)3100 SpvId SPIRVCodeGenerator::writeConstructorSplat(const ConstructorSplat& c, OutputStream& out) {
3101     // Write the splat argument as a scalar, then splat it.
3102     SpvId argument = this->writeExpression(*c.argument(), out);
3103     return this->splat(c.type(), argument, out);
3104 }
3105 
writeCompositeConstructor(const AnyConstructor & c,OutputStream & out)3106 SpvId SPIRVCodeGenerator::writeCompositeConstructor(const AnyConstructor& c, OutputStream& out) {
3107     SkASSERT(c.type().isArray() || c.type().isStruct());
3108     auto ctorArgs = c.argumentSpan();
3109 
3110     STArray<4, SpvId> arguments;
3111     for (const std::unique_ptr<Expression>& arg : ctorArgs) {
3112         arguments.push_back(this->writeExpression(*arg, out));
3113     }
3114 
3115     return this->writeOpCompositeConstruct(c.type(), arguments, out);
3116 }
3117 
writeConstructorScalarCast(const ConstructorScalarCast & c,OutputStream & out)3118 SpvId SPIRVCodeGenerator::writeConstructorScalarCast(const ConstructorScalarCast& c,
3119                                                      OutputStream& out) {
3120     const Type& type = c.type();
3121     if (type.componentType().numberKind() == c.argument()->type().componentType().numberKind()) {
3122         return this->writeExpression(*c.argument(), out);
3123     }
3124 
3125     const Expression& ctorExpr = *c.argument();
3126     SpvId expressionId = this->writeExpression(ctorExpr, out);
3127     return this->castScalarToType(expressionId, ctorExpr.type(), type, out);
3128 }
3129 
writeConstructorCompoundCast(const ConstructorCompoundCast & c,OutputStream & out)3130 SpvId SPIRVCodeGenerator::writeConstructorCompoundCast(const ConstructorCompoundCast& c,
3131                                                        OutputStream& out) {
3132     const Type& ctorType = c.type();
3133     const Type& argType = c.argument()->type();
3134     SkASSERT(ctorType.isVector() || ctorType.isMatrix());
3135 
3136     // Write the composite that we are casting. If the actual type matches, we are done.
3137     SpvId compositeId = this->writeExpression(*c.argument(), out);
3138     if (ctorType.componentType().numberKind() == argType.componentType().numberKind()) {
3139         return compositeId;
3140     }
3141 
3142     // writeMatrixCopy can cast matrices to a different type.
3143     if (ctorType.isMatrix()) {
3144         return this->writeMatrixCopy(compositeId, argType, ctorType, out);
3145     }
3146 
3147     // SPIR-V doesn't support vector(vector-of-different-type) directly, so we need to extract the
3148     // components and convert each one manually.
3149     const Type& srcType = argType.componentType();
3150     const Type& dstType = ctorType.componentType();
3151 
3152     STArray<4, SpvId> arguments;
3153     for (int index = 0; index < argType.columns(); ++index) {
3154         SpvId componentId = this->writeOpCompositeExtract(srcType, compositeId, index, out);
3155         arguments.push_back(this->castScalarToType(componentId, srcType, dstType, out));
3156     }
3157 
3158     return this->writeOpCompositeConstruct(ctorType, arguments, out);
3159 }
3160 
writeConstructorDiagonalMatrix(const ConstructorDiagonalMatrix & c,OutputStream & out)3161 SpvId SPIRVCodeGenerator::writeConstructorDiagonalMatrix(const ConstructorDiagonalMatrix& c,
3162                                                          OutputStream& out) {
3163     const Type& type = c.type();
3164     SkASSERT(type.isMatrix());
3165     SkASSERT(c.argument()->type().isScalar());
3166 
3167     // Write out the scalar argument.
3168     SpvId diagonal = this->writeExpression(*c.argument(), out);
3169 
3170     // Build the diagonal matrix.
3171     SpvId zeroId = this->writeLiteral(0.0, *fContext.fTypes.fFloat);
3172 
3173     const Type& vecType = type.columnType(fContext);
3174     STArray<4, SpvId> columnIds;
3175     STArray<4, SpvId> arguments;
3176     arguments.resize(type.rows());
3177     for (int column = 0; column < type.columns(); column++) {
3178         for (int row = 0; row < type.rows(); row++) {
3179             arguments[row] = (row == column) ? diagonal : zeroId;
3180         }
3181         columnIds.push_back(this->writeOpCompositeConstruct(vecType, arguments, out));
3182     }
3183     return this->writeOpCompositeConstruct(type, columnIds, out);
3184 }
3185 
writeConstructorMatrixResize(const ConstructorMatrixResize & c,OutputStream & out)3186 SpvId SPIRVCodeGenerator::writeConstructorMatrixResize(const ConstructorMatrixResize& c,
3187                                                        OutputStream& out) {
3188     // Write the input matrix.
3189     SpvId argument = this->writeExpression(*c.argument(), out);
3190 
3191     // Use matrix-copy to resize the input matrix to its new size.
3192     return this->writeMatrixCopy(argument, c.argument()->type(), c.type(), out);
3193 }
3194 
get_storage_class_for_global_variable(const Variable & var,StorageClass fallbackStorageClass)3195 static StorageClass get_storage_class_for_global_variable(
3196         const Variable& var, StorageClass fallbackStorageClass) {
3197     SkASSERT(var.storage() == Variable::Storage::kGlobal);
3198 
3199     if (var.type().typeKind() == Type::TypeKind::kSampler ||
3200 #ifdef SKSL_EXT
3201         (var.type().typeKind() == Type::TypeKind::kArray &&
3202          var.type().componentType().typeKind() == Type::TypeKind::kSampler) ||
3203 #endif
3204         var.type().typeKind() == Type::TypeKind::kSeparateSampler ||
3205         var.type().typeKind() == Type::TypeKind::kTexture) {
3206         return StorageClass::kUniformConstant;
3207     }
3208 
3209     const Layout& layout = var.layout();
3210     ModifierFlags flags = var.modifierFlags();
3211 #ifdef SKSL_EXT
3212     if (!(flags & ModifierFlag::kUniform) &&
3213         var.storage() == Variable::Storage::kGlobal &&
3214         (flags & ModifierFlag::kConst)) {
3215         return StorageClass::kFunction;
3216     }
3217 #endif
3218     if (flags & ModifierFlag::kIn) {
3219         SkASSERT(!(layout.fFlags & LayoutFlag::kPushConstant));
3220         return StorageClass::kInput;
3221     }
3222     if (flags & ModifierFlag::kOut) {
3223         SkASSERT(!(layout.fFlags & LayoutFlag::kPushConstant));
3224         return StorageClass::kOutput;
3225     }
3226     if (flags.isUniform()) {
3227         if (layout.fFlags & LayoutFlag::kPushConstant) {
3228             return StorageClass::kPushConstant;
3229         }
3230         return StorageClass::kUniform;
3231     }
3232     if (flags.isBuffer()) {
3233         return StorageClass::kStorageBuffer;
3234     }
3235     if (flags.isWorkgroup()) {
3236         return StorageClass::kWorkgroup;
3237     }
3238     return fallbackStorageClass;
3239 }
3240 
getStorageClass(const Expression & expr)3241 StorageClass SPIRVCodeGenerator::getStorageClass(const Expression& expr) {
3242     switch (expr.kind()) {
3243         case Expression::Kind::kVariableReference: {
3244             const Variable& var = *expr.as<VariableReference>().variable();
3245             if (fActiveSpecialization) {
3246                 const Expression** specializedExpr = fActiveSpecialization->find(&var);
3247                 if (specializedExpr && (*specializedExpr)->is<FieldAccess>()) {
3248                     return this->getStorageClass(**specializedExpr);
3249                 }
3250             }
3251             if (var.storage() != Variable::Storage::kGlobal) {
3252                 return StorageClass::kFunction;
3253             }
3254             return get_storage_class_for_global_variable(var, StorageClass::kPrivate);
3255         }
3256         case Expression::Kind::kFieldAccess:
3257             return this->getStorageClass(*expr.as<FieldAccess>().base());
3258         case Expression::Kind::kIndex:
3259             return this->getStorageClass(*expr.as<IndexExpression>().base());
3260         default:
3261             return StorageClass::kFunction;
3262     }
3263 }
3264 
getAccessChain(const Expression & expr,OutputStream & out)3265 TArray<SpvId> SPIRVCodeGenerator::getAccessChain(const Expression& expr, OutputStream& out) {
3266     switch (expr.kind()) {
3267         case Expression::Kind::kIndex: {
3268             const IndexExpression& indexExpr = expr.as<IndexExpression>();
3269             if (indexExpr.base()->is<Swizzle>()) {
3270                 // Access chains don't directly support dynamically indexing into a swizzle, but we
3271                 // can rewrite them into a supported form.
3272                 return this->getAccessChain(*Transform::RewriteIndexedSwizzle(fContext, indexExpr),
3273                                             out);
3274             }
3275             // All other index-expressions can be represented as typical access chains.
3276             TArray<SpvId> chain = this->getAccessChain(*indexExpr.base(), out);
3277             chain.push_back(this->writeExpression(*indexExpr.index(), out));
3278             return chain;
3279         }
3280         case Expression::Kind::kFieldAccess: {
3281             const FieldAccess& fieldExpr = expr.as<FieldAccess>();
3282             TArray<SpvId> chain = this->getAccessChain(*fieldExpr.base(), out);
3283             chain.push_back(this->writeLiteral(fieldExpr.fieldIndex(), *fContext.fTypes.fInt));
3284             return chain;
3285         }
3286         case Expression::Kind::kVariableReference: {
3287             if (fActiveSpecialization) {
3288                 const Expression** specializedFieldIndex =
3289                         fActiveSpecialization->find(expr.as<VariableReference>().variable());
3290                 if (specializedFieldIndex && (*specializedFieldIndex)->is<FieldAccess>()) {
3291                     return this->getAccessChain(**specializedFieldIndex, out);
3292                 }
3293             }
3294             [[fallthrough]];
3295         }
3296         default: {
3297             SpvId id = this->getLValue(expr, out)->getPointer();
3298             SkASSERT(id != NA);
3299             return TArray<SpvId>{id};
3300         }
3301     }
3302     SkUNREACHABLE;
3303 }
3304 
3305 class PointerLValue : public SPIRVCodeGenerator::LValue {
3306 public:
PointerLValue(SPIRVCodeGenerator & gen,SpvId pointer,bool isMemoryObject,SpvId type,SPIRVCodeGenerator::Precision precision,StorageClass storageClass)3307     PointerLValue(SPIRVCodeGenerator& gen, SpvId pointer, bool isMemoryObject, SpvId type,
3308                   SPIRVCodeGenerator::Precision precision, StorageClass storageClass)
3309     : fGen(gen)
3310     , fPointer(pointer)
3311     , fIsMemoryObject(isMemoryObject)
3312     , fType(type)
3313     , fPrecision(precision)
3314     , fStorageClass(storageClass) {}
3315 
getPointer()3316     SpvId getPointer() override {
3317         return fPointer;
3318     }
3319 
isMemoryObjectPointer() const3320     bool isMemoryObjectPointer() const override {
3321         return fIsMemoryObject;
3322     }
3323 
storageClass() const3324     StorageClass storageClass() const override {
3325         return fStorageClass;
3326     }
3327 
load(OutputStream & out)3328     SpvId load(OutputStream& out) override {
3329         return fGen.writeOpLoad(fType, fPrecision, fPointer, out);
3330     }
3331 
store(SpvId value,OutputStream & out)3332     void store(SpvId value, OutputStream& out) override {
3333         if (!fIsMemoryObject) {
3334             // We are going to write into an access chain; this could represent one component of a
3335             // vector, or one element of an array. This has the potential to invalidate other,
3336             // *unknown* elements of our store cache. (e.g. if the store cache holds `%50 = myVec4`,
3337             // and we store `%60 = myVec4.z`, this invalidates the cached value for %50.) To avoid
3338             // relying on stale data, reset the store cache entirely when this happens.
3339             fGen.fStoreCache.reset();
3340         }
3341 
3342         fGen.writeOpStore(fStorageClass, fPointer, value, out);
3343     }
3344 
3345 private:
3346     SPIRVCodeGenerator& fGen;
3347     const SpvId fPointer;
3348     const bool fIsMemoryObject;
3349     const SpvId fType;
3350     const SPIRVCodeGenerator::Precision fPrecision;
3351     const StorageClass fStorageClass;
3352 };
3353 
3354 class SwizzleLValue : public SPIRVCodeGenerator::LValue {
3355 public:
SwizzleLValue(SPIRVCodeGenerator & gen,SpvId vecPointer,const ComponentArray & components,const Type & baseType,const Type & swizzleType,StorageClass storageClass)3356     SwizzleLValue(SPIRVCodeGenerator& gen, SpvId vecPointer, const ComponentArray& components,
3357                   const Type& baseType, const Type& swizzleType, StorageClass storageClass)
3358     : fGen(gen)
3359     , fVecPointer(vecPointer)
3360     , fComponents(components)
3361     , fBaseType(&baseType)
3362     , fSwizzleType(&swizzleType)
3363     , fStorageClass(storageClass) {}
3364 
applySwizzle(const ComponentArray & components,const Type & newType)3365     bool applySwizzle(const ComponentArray& components, const Type& newType) override {
3366         ComponentArray updatedSwizzle;
3367         for (int8_t component : components) {
3368             if (component < 0 || component >= fComponents.size()) {
3369                 SkDEBUGFAILF("swizzle accessed nonexistent component %d", (int)component);
3370                 return false;
3371             }
3372             updatedSwizzle.push_back(fComponents[component]);
3373         }
3374         fComponents = updatedSwizzle;
3375         fSwizzleType = &newType;
3376         return true;
3377     }
3378 
storageClass() const3379     StorageClass storageClass() const override {
3380         return fStorageClass;
3381     }
3382 
load(OutputStream & out)3383     SpvId load(OutputStream& out) override {
3384         SpvId base = fGen.nextId(fBaseType);
3385         fGen.writeInstruction(SpvOpLoad, fGen.getType(*fBaseType), base, fVecPointer, out);
3386         SpvId result = fGen.nextId(fBaseType);
3387         fGen.writeOpCode(SpvOpVectorShuffle, 5 + (int32_t) fComponents.size(), out);
3388         fGen.writeWord(fGen.getType(*fSwizzleType), out);
3389         fGen.writeWord(result, out);
3390         fGen.writeWord(base, out);
3391         fGen.writeWord(base, out);
3392         for (int component : fComponents) {
3393             fGen.writeWord(component, out);
3394         }
3395         return result;
3396     }
3397 
store(SpvId value,OutputStream & out)3398     void store(SpvId value, OutputStream& out) override {
3399         // use OpVectorShuffle to mix and match the vector components. We effectively create
3400         // a virtual vector out of the concatenation of the left and right vectors, and then
3401         // select components from this virtual vector to make the result vector. For
3402         // instance, given:
3403         // float3L = ...;
3404         // float3R = ...;
3405         // L.xz = R.xy;
3406         // we end up with the virtual vector (L.x, L.y, L.z, R.x, R.y, R.z). Then we want
3407         // our result vector to look like (R.x, L.y, R.y), so we need to select indices
3408         // (3, 1, 4).
3409         SpvId base = fGen.nextId(fBaseType);
3410         fGen.writeInstruction(SpvOpLoad, fGen.getType(*fBaseType), base, fVecPointer, out);
3411         SpvId shuffle = fGen.nextId(fBaseType);
3412         fGen.writeOpCode(SpvOpVectorShuffle, 5 + fBaseType->columns(), out);
3413         fGen.writeWord(fGen.getType(*fBaseType), out);
3414         fGen.writeWord(shuffle, out);
3415         fGen.writeWord(base, out);
3416         fGen.writeWord(value, out);
3417         for (int i = 0; i < fBaseType->columns(); i++) {
3418             // current offset into the virtual vector, defaults to pulling the unmodified
3419             // value from the left side
3420             int offset = i;
3421             // check to see if we are writing this component
3422             for (int j = 0; j < fComponents.size(); j++) {
3423                 if (fComponents[j] == i) {
3424                     // we're writing to this component, so adjust the offset to pull from
3425                     // the correct component of the right side instead of preserving the
3426                     // value from the left
3427                     offset = (int) (j + fBaseType->columns());
3428                     break;
3429                 }
3430             }
3431             fGen.writeWord(offset, out);
3432         }
3433         fGen.writeOpStore(fStorageClass, fVecPointer, shuffle, out);
3434     }
3435 
3436 private:
3437     SPIRVCodeGenerator& fGen;
3438     const SpvId fVecPointer;
3439     ComponentArray fComponents;
3440     const Type* fBaseType;
3441     const Type* fSwizzleType;
3442     const StorageClass fStorageClass;
3443 };
3444 
findUniformFieldIndex(const Variable & var) const3445 int SPIRVCodeGenerator::findUniformFieldIndex(const Variable& var) const {
3446     int* fieldIndex = fTopLevelUniformMap.find(&var);
3447     return fieldIndex ? *fieldIndex : -1;
3448 }
3449 
getLValue(const Expression & expr,OutputStream & out)3450 std::unique_ptr<SPIRVCodeGenerator::LValue> SPIRVCodeGenerator::getLValue(const Expression& expr,
3451                                                                           OutputStream& out) {
3452     const Type& type = expr.type();
3453     Precision precision = type.highPrecision() ? Precision::kDefault : Precision::kRelaxed;
3454     switch (expr.kind()) {
3455         case Expression::Kind::kVariableReference: {
3456             const Variable& var = *expr.as<VariableReference>().variable();
3457             int uniformIdx = this->findUniformFieldIndex(var);
3458             if (uniformIdx >= 0) {
3459                 // Access uniforms via an AccessChain into the uniform-buffer struct.
3460                 SpvId memberId = this->nextId(nullptr);
3461                 SpvId typeId = this->getPointerType(type, StorageClass::kUniform);
3462                 SpvId uniformIdxId = this->writeLiteral((double)uniformIdx, *fContext.fTypes.fInt);
3463                 this->writeInstruction(SpvOpAccessChain, typeId, memberId, fUniformBufferId,
3464                                        uniformIdxId, out);
3465                 return std::make_unique<PointerLValue>(
3466                         *this,
3467                         memberId,
3468                         /*isMemoryObjectPointer=*/true,
3469                         this->getType(type, kDefaultTypeLayout, this->memoryLayoutForVariable(var)),
3470                         precision,
3471                         StorageClass::kUniform);
3472             }
3473 #ifdef SKSL_EXT
3474             if (fGlobalConstVariableValueMap.find(&var) != fGlobalConstVariableValueMap.end()) {
3475                 SpvId id = this->nextId(&type);
3476                 fVariableMap[&var] = id;
3477                 SpvId typeId = this->getPointerType(type, StorageClass::kFunction);
3478                 this->writeInstruction(SpvOpVariable, typeId, id, SpvStorageClassFunction, fVariableBuffer);
3479                 this->writeInstruction(SpvOpName, id, var.name(), fNameBuffer);
3480                 this->writeInstruction(SpvOpStore, id, fGlobalConstVariableValueMap[&var], out);
3481             }
3482 #endif
3483             SpvId* entry = fVariableMap.find(&var);
3484             SkASSERTF(entry, "%s", expr.description().c_str());
3485 
3486             if (var.layout().fBuiltin == SK_SAMPLEMASKIN_BUILTIN ||
3487                 var.layout().fBuiltin == SK_SAMPLEMASK_BUILTIN) {
3488                 // Access sk_SampleMask and sk_SampleMaskIn via an array access, since Vulkan
3489                 // represents sample masks as an array of uints.
3490                 StorageClass storageClass =
3491                         get_storage_class_for_global_variable(var, StorageClass::kPrivate);
3492                 SkASSERT(storageClass != StorageClass::kPrivate);
3493                 SkASSERT(type.matches(*fContext.fTypes.fUInt));
3494 
3495                 SpvId accessId = this->nextId(nullptr);
3496                 SpvId typeId = this->getPointerType(type, storageClass);
3497                 SpvId indexId = this->writeLiteral(0.0, *fContext.fTypes.fInt);
3498                 this->writeInstruction(SpvOpAccessChain, typeId, accessId, *entry, indexId, out);
3499                 return std::make_unique<PointerLValue>(*this,
3500                                                        accessId,
3501                                                        /*isMemoryObjectPointer=*/true,
3502                                                        this->getType(type),
3503                                                        precision,
3504                                                        storageClass);
3505             }
3506             SpvId typeId = this->getType(type, var.layout(), this->memoryLayoutForVariable(var));
3507             return std::make_unique<PointerLValue>(*this, *entry,
3508                                                    /*isMemoryObjectPointer=*/true,
3509                                                    typeId, precision, this->getStorageClass(expr));
3510         }
3511         case Expression::Kind::kIndex: // fall through
3512         case Expression::Kind::kFieldAccess: {
3513             TArray<SpvId> chain = this->getAccessChain(expr, out);
3514             SpvId member = this->nextId(nullptr);
3515             StorageClass storageClass = this->getStorageClass(expr);
3516             this->writeOpCode(SpvOpAccessChain, (SpvId) (3 + chain.size()), out);
3517             this->writeWord(this->getPointerType(type, storageClass), out);
3518             this->writeWord(member, out);
3519 #ifdef SKSL_EXT
3520             bool needDecorate = false;
3521             for (SpvId idx : chain) {
3522                 this->writeWord(idx, out);
3523                 needDecorate |= fNonUniformSpvId.find(idx) != fNonUniformSpvId.end();
3524             }
3525             if (needDecorate) {
3526                 fNonUniformSpvId.insert(member);
3527                 this->writeInstruction(SpvOpDecorate, member, SpvDecorationNonUniform, fDecorationBuffer);
3528             }
3529 #else
3530             for (SpvId idx : chain) {
3531                 this->writeWord(idx, out);
3532             }
3533 #endif
3534             return std::make_unique<PointerLValue>(
3535                     *this,
3536                     member,
3537                     /*isMemoryObjectPointer=*/false,
3538                     this->getType(type,
3539                                   kDefaultTypeLayout,
3540                                   this->memoryLayoutForStorageClass(storageClass)),
3541                     precision,
3542                     storageClass);
3543         }
3544         case Expression::Kind::kSwizzle: {
3545             const Swizzle& swizzle = expr.as<Swizzle>();
3546             std::unique_ptr<LValue> lvalue = this->getLValue(*swizzle.base(), out);
3547             if (lvalue->applySwizzle(swizzle.components(), type)) {
3548                 return lvalue;
3549             }
3550             SpvId base = lvalue->getPointer();
3551             if (base == NA) {
3552                 fContext.fErrors->error(swizzle.fPosition,
3553                         "unable to retrieve lvalue from swizzle");
3554             }
3555             StorageClass storageClass = this->getStorageClass(*swizzle.base());
3556             if (swizzle.components().size() == 1) {
3557                 SpvId member = this->nextId(nullptr);
3558                 SpvId typeId = this->getPointerType(type, storageClass);
3559                 SpvId indexId = this->writeLiteral(swizzle.components()[0], *fContext.fTypes.fInt);
3560                 this->writeInstruction(SpvOpAccessChain, typeId, member, base, indexId, out);
3561                 return std::make_unique<PointerLValue>(*this, member,
3562                                                        /*isMemoryObjectPointer=*/false,
3563                                                        this->getType(type),
3564                                                        precision, storageClass);
3565             } else {
3566                 return std::make_unique<SwizzleLValue>(*this, base, swizzle.components(),
3567                                                        swizzle.base()->type(), type, storageClass);
3568             }
3569         }
3570         default: {
3571             // expr isn't actually an lvalue, create a placeholder variable for it. This case
3572             // happens due to the need to store values in temporary variables during function
3573             // calls (see comments in getFunctionParameterType); erroneous uses of rvalues as
3574             // lvalues should have been caught before code generation.
3575             //
3576             // This is with the exception of opaque handle types (textures/samplers) which are
3577             // always defined as UniformConstant pointers and don't need to be explicitly stored
3578             // into a temporary (which is handled explicitly in writeFunctionCallArgument).
3579             SpvId result = this->nextId(nullptr);
3580             SpvId pointerType = this->getPointerType(type, StorageClass::kFunction);
3581             this->writeInstruction(SpvOpVariable, pointerType, result, SpvStorageClassFunction,
3582                                    fVariableBuffer);
3583             this->writeOpStore(StorageClass::kFunction, result, this->writeExpression(expr, out),
3584                                out);
3585             return std::make_unique<PointerLValue>(*this, result, /*isMemoryObjectPointer=*/true,
3586                                                    this->getType(type), precision,
3587                                                    StorageClass::kFunction);
3588         }
3589     }
3590 }
3591 
identifier(std::string_view name)3592 std::unique_ptr<Expression> SPIRVCodeGenerator::identifier(std::string_view name) {
3593     std::unique_ptr<Expression> expr =
3594             fProgram.fSymbols->instantiateSymbolRef(fContext, name, Position());
3595     return expr ? std::move(expr)
3596                 : Poison::Make(Position(), fContext);
3597 }
3598 
writeVariableReference(const VariableReference & ref,OutputStream & out)3599 SpvId SPIRVCodeGenerator::writeVariableReference(const VariableReference& ref, OutputStream& out) {
3600     const Variable* variable = ref.variable();
3601     switch (variable->layout().fBuiltin) {
3602         case DEVICE_FRAGCOORDS_BUILTIN: {
3603             // Down below, we rewrite raw references to sk_FragCoord with expressions that reference
3604             // DEVICE_FRAGCOORDS_BUILTIN. This is a fake variable that means we need to directly
3605             // access the fragcoord; do so now.
3606             return this->getLValue(*this->identifier("sk_FragCoord"), out)->load(out);
3607         }
3608         case DEVICE_CLOCKWISE_BUILTIN: {
3609             // Down below, we rewrite raw references to sk_Clockwise with expressions that reference
3610             // DEVICE_CLOCKWISE_BUILTIN. This is a fake variable that means we need to directly
3611             // access front facing; do so now.
3612             return this->getLValue(*this->identifier("sk_Clockwise"), out)->load(out);
3613         }
3614         case SK_SECONDARYFRAGCOLOR_BUILTIN: {
3615             if (fCaps.fDualSourceBlendingSupport) {
3616                 return this->getLValue(*this->identifier("sk_SecondaryFragColor"), out)->load(out);
3617             } else {
3618                 fContext.fErrors->error(ref.position(), "'sk_SecondaryFragColor' not supported");
3619                 return NA;
3620             }
3621         }
3622         case SK_FRAGCOORD_BUILTIN: {
3623             if (fProgram.fConfig->fSettings.fForceNoRTFlip) {
3624                 return this->getLValue(*this->identifier("sk_FragCoord"), out)->load(out);
3625             }
3626 
3627             // Handle inserting use of uniform to flip y when referencing sk_FragCoord.
3628             this->addRTFlipUniform(ref.fPosition);
3629             // Use sk_RTAdjust to compute the flipped coordinate
3630             // Use a uniform to flip the Y coordinate. The new expression will be written in
3631             // terms of $device_FragCoords, which is a fake variable that means "access the
3632             // underlying fragcoords directly without flipping it".
3633             static constexpr char DEVICE_COORDS_NAME[] = "$device_FragCoords";
3634             if (!fProgram.fSymbols->find(DEVICE_COORDS_NAME)) {
3635                 AutoAttachPoolToThread attach(fProgram.fPool.get());
3636                 Layout layout;
3637                 layout.fBuiltin = DEVICE_FRAGCOORDS_BUILTIN;
3638                 auto coordsVar = Variable::Make(/*pos=*/Position(),
3639                                                 /*modifiersPosition=*/Position(),
3640                                                 layout,
3641                                                 ModifierFlag::kNone,
3642                                                 fContext.fTypes.fFloat4.get(),
3643                                                 DEVICE_COORDS_NAME,
3644                                                 /*mangledName=*/"",
3645                                                 /*builtin=*/true,
3646                                                 Variable::Storage::kGlobal);
3647                 fProgram.fSymbols->add(fContext, std::move(coordsVar));
3648             }
3649             std::unique_ptr<Expression> deviceCoord = this->identifier(DEVICE_COORDS_NAME);
3650             std::unique_ptr<Expression> rtFlip = this->identifier(SKSL_RTFLIP_NAME);
3651             SpvId rtFlipX = this->writeSwizzle(*rtFlip, {SwizzleComponent::X}, out);
3652             SpvId rtFlipY = this->writeSwizzle(*rtFlip, {SwizzleComponent::Y}, out);
3653             SpvId deviceCoordX  = this->writeSwizzle(*deviceCoord, {SwizzleComponent::X}, out);
3654             SpvId deviceCoordY  = this->writeSwizzle(*deviceCoord, {SwizzleComponent::Y}, out);
3655             SpvId deviceCoordZW = this->writeSwizzle(*deviceCoord, {SwizzleComponent::Z,
3656                                                                     SwizzleComponent::W}, out);
3657             // Compute `flippedY = u_RTFlip.y * $device_FragCoords.y`.
3658             SpvId flippedY = this->writeBinaryExpression(
3659                                      *fContext.fTypes.fFloat, rtFlipY, OperatorKind::STAR,
3660                                      *fContext.fTypes.fFloat, deviceCoordY,
3661                                      *fContext.fTypes.fFloat, out);
3662 
3663             // Compute `flippedY = u_RTFlip.x + flippedY`.
3664             flippedY = this->writeBinaryExpression(
3665                                *fContext.fTypes.fFloat, rtFlipX, OperatorKind::PLUS,
3666                                *fContext.fTypes.fFloat, flippedY,
3667                                *fContext.fTypes.fFloat, out);
3668 
3669             // Return `float4(deviceCoord.x, flippedY, deviceCoord.zw)`.
3670             return this->writeOpCompositeConstruct(*fContext.fTypes.fFloat4,
3671                                                    {deviceCoordX, flippedY, deviceCoordZW},
3672                                                    out);
3673         }
3674         case SK_CLOCKWISE_BUILTIN: {
3675             if (fProgram.fConfig->fSettings.fForceNoRTFlip) {
3676                 return this->getLValue(*this->identifier("sk_Clockwise"), out)->load(out);
3677             }
3678 
3679             // Apply RTFlip to sk_Clockwise.
3680             this->addRTFlipUniform(ref.fPosition);
3681             // Use a uniform to flip the Y coordinate. The new expression will be written in
3682             // terms of $device_Clockwise, which is a fake variable that means "access the
3683             // underlying FrontFacing directly".
3684             static constexpr char DEVICE_CLOCKWISE_NAME[] = "$device_Clockwise";
3685             if (!fProgram.fSymbols->find(DEVICE_CLOCKWISE_NAME)) {
3686                 AutoAttachPoolToThread attach(fProgram.fPool.get());
3687                 Layout layout;
3688                 layout.fBuiltin = DEVICE_CLOCKWISE_BUILTIN;
3689                 auto clockwiseVar = Variable::Make(/*pos=*/Position(),
3690                                                    /*modifiersPosition=*/Position(),
3691                                                    layout,
3692                                                    ModifierFlag::kNone,
3693                                                    fContext.fTypes.fBool.get(),
3694                                                    DEVICE_CLOCKWISE_NAME,
3695                                                    /*mangledName=*/"",
3696                                                    /*builtin=*/true,
3697                                                    Variable::Storage::kGlobal);
3698                 fProgram.fSymbols->add(fContext, std::move(clockwiseVar));
3699             }
3700             // FrontFacing in Vulkan is defined in terms of a top-down render target. In Skia,
3701             // we use the default convention of "counter-clockwise face is front".
3702 
3703             // Compute `positiveRTFlip = (rtFlip.y > 0)`.
3704             std::unique_ptr<Expression> rtFlip = this->identifier(SKSL_RTFLIP_NAME);
3705             SpvId rtFlipY = this->writeSwizzle(*rtFlip, {SwizzleComponent::Y}, out);
3706             SpvId zero = this->writeLiteral(0.0, *fContext.fTypes.fFloat);
3707             SpvId positiveRTFlip = this->writeBinaryExpression(
3708                                            *fContext.fTypes.fFloat, rtFlipY, OperatorKind::GT,
3709                                            *fContext.fTypes.fFloat, zero,
3710                                            *fContext.fTypes.fBool, out);
3711 
3712             // Compute `positiveRTFlip ^^ $device_Clockwise`.
3713             std::unique_ptr<Expression> deviceClockwise = this->identifier(DEVICE_CLOCKWISE_NAME);
3714             SpvId deviceClockwiseID = this->writeExpression(*deviceClockwise, out);
3715             return this->writeBinaryExpression(
3716                            *fContext.fTypes.fBool, positiveRTFlip, OperatorKind::LOGICALXOR,
3717                            *fContext.fTypes.fBool, deviceClockwiseID,
3718                            *fContext.fTypes.fBool, out);
3719         }
3720         default: {
3721             // Constant-propagate variables that have a known compile-time value.
3722             if (const Expression* expr = ConstantFolder::GetConstantValueOrNull(ref)) {
3723                 return this->writeExpression(*expr, out);
3724             }
3725 
3726             // A reference to a sampler variable at global scope with synthesized texture/sampler
3727             // backing should construct a function-scope combined image-sampler from the synthesized
3728             // constituents. This is the case in which a sample intrinsic was invoked.
3729             //
3730             // Variable references to opaque handles (texture/sampler) that appear as the argument
3731             // of a user-defined function call are explicitly handled in writeFunctionCallArgument.
3732             if (fUseTextureSamplerPairs && variable->type().isSampler()) {
3733                 if (const auto* p = fSynthesizedSamplerMap.find(variable)) {
3734                     SpvId* imgPtr = fVariableMap.find((*p)->fTexture.get());
3735                     SpvId* samplerPtr = fVariableMap.find((*p)->fSampler.get());
3736                     SkASSERT(imgPtr);
3737                     SkASSERT(samplerPtr);
3738 
3739                     SpvId img = this->writeOpLoad(this->getType((*p)->fTexture->type()),
3740                                                   Precision::kDefault, *imgPtr, out);
3741                     SpvId sampler = this->writeOpLoad(this->getType((*p)->fSampler->type()),
3742                                                       Precision::kDefault,
3743                                                       *samplerPtr,
3744                                                       out);
3745                     SpvId result = this->nextId(nullptr);
3746                     this->writeInstruction(SpvOpSampledImage,
3747                                            this->getType(variable->type()),
3748                                            result,
3749                                            img,
3750                                            sampler,
3751                                            out);
3752                     return result;
3753                 }
3754                 SkDEBUGFAIL("sampler missing from fSynthesizedSamplerMap");
3755             }
3756 #ifdef SKSL_EXT
3757             const Variable* var = ref.as<VariableReference>().variable();
3758             if (var && (var->layout().fFlags & SkSL::LayoutFlag::kConstantId)) {
3759                 return fVariableMap[var];
3760             }
3761 #endif
3762             return this->getLValue(ref, out)->load(out);
3763         }
3764     }
3765 }
3766 
writeIndexExpression(const IndexExpression & expr,OutputStream & out)3767 SpvId SPIRVCodeGenerator::writeIndexExpression(const IndexExpression& expr, OutputStream& out) {
3768     if (expr.base()->type().isVector()) {
3769         SpvId base = this->writeExpression(*expr.base(), out);
3770         SpvId index = this->writeExpression(*expr.index(), out);
3771         SpvId result = this->nextId(nullptr);
3772         this->writeInstruction(SpvOpVectorExtractDynamic, this->getType(expr.type()), result, base,
3773                                index, out);
3774         return result;
3775     }
3776     return getLValue(expr, out)->load(out);
3777 }
3778 
writeFieldAccess(const FieldAccess & f,OutputStream & out)3779 SpvId SPIRVCodeGenerator::writeFieldAccess(const FieldAccess& f, OutputStream& out) {
3780     return getLValue(f, out)->load(out);
3781 }
3782 
writeSwizzle(const Expression & baseExpr,const ComponentArray & components,OutputStream & out)3783 SpvId SPIRVCodeGenerator::writeSwizzle(const Expression& baseExpr,
3784                                        const ComponentArray& components,
3785                                        OutputStream& out) {
3786     size_t count = components.size();
3787     const Type& type = baseExpr.type().componentType().toCompound(fContext, count, /*rows=*/1);
3788     SpvId base = this->writeExpression(baseExpr, out);
3789     if (count == 1) {
3790         return this->writeOpCompositeExtract(type, base, components[0], out);
3791     }
3792 
3793     SpvId result = this->nextId(&type);
3794     this->writeOpCode(SpvOpVectorShuffle, 5 + (int32_t) count, out);
3795     this->writeWord(this->getType(type), out);
3796     this->writeWord(result, out);
3797     this->writeWord(base, out);
3798     this->writeWord(base, out);
3799     for (int component : components) {
3800         this->writeWord(component, out);
3801     }
3802     return result;
3803 }
3804 
writeSwizzle(const Swizzle & swizzle,OutputStream & out)3805 SpvId SPIRVCodeGenerator::writeSwizzle(const Swizzle& swizzle, OutputStream& out) {
3806     return this->writeSwizzle(*swizzle.base(), swizzle.components(), out);
3807 }
3808 
writeBinaryOperation(const Type & resultType,const Type & operandType,SpvId lhs,SpvId rhs,bool writeComponentwiseIfMatrix,SpvOp_ ifFloat,SpvOp_ ifInt,SpvOp_ ifUInt,SpvOp_ ifBool,OutputStream & out)3809 SpvId SPIRVCodeGenerator::writeBinaryOperation(const Type& resultType, const Type& operandType,
3810                                                SpvId lhs, SpvId rhs,
3811                                                bool writeComponentwiseIfMatrix,
3812                                                SpvOp_ ifFloat, SpvOp_ ifInt, SpvOp_ ifUInt,
3813                                                SpvOp_ ifBool, OutputStream& out) {
3814     SpvOp_ op = pick_by_type(operandType, ifFloat, ifInt, ifUInt, ifBool);
3815     if (op == SpvOpUndef) {
3816         fContext.fErrors->error(operandType.fPosition,
3817                 "unsupported operand for binary expression: " + operandType.description());
3818         return NA;
3819     }
3820     if (writeComponentwiseIfMatrix && operandType.isMatrix()) {
3821         return this->writeComponentwiseMatrixBinary(resultType, lhs, rhs, op, out);
3822     }
3823     SpvId result = this->nextId(&resultType);
3824     this->writeInstruction(op, this->getType(resultType), result, lhs, rhs, out);
3825     return result;
3826 }
3827 
writeBinaryOperationComponentwiseIfMatrix(const Type & resultType,const Type & operandType,SpvId lhs,SpvId rhs,SpvOp_ ifFloat,SpvOp_ ifInt,SpvOp_ ifUInt,SpvOp_ ifBool,OutputStream & out)3828 SpvId SPIRVCodeGenerator::writeBinaryOperationComponentwiseIfMatrix(const Type& resultType,
3829                                                                     const Type& operandType,
3830                                                                     SpvId lhs, SpvId rhs,
3831                                                                     SpvOp_ ifFloat, SpvOp_ ifInt,
3832                                                                     SpvOp_ ifUInt, SpvOp_ ifBool,
3833                                                                     OutputStream& out) {
3834     return this->writeBinaryOperation(resultType, operandType, lhs, rhs,
3835                                       /*writeComponentwiseIfMatrix=*/true,
3836                                       ifFloat, ifInt, ifUInt, ifBool, out);
3837 }
3838 
writeBinaryOperation(const Type & resultType,const Type & operandType,SpvId lhs,SpvId rhs,SpvOp_ ifFloat,SpvOp_ ifInt,SpvOp_ ifUInt,SpvOp_ ifBool,OutputStream & out)3839 SpvId SPIRVCodeGenerator::writeBinaryOperation(const Type& resultType, const Type& operandType,
3840                                                SpvId lhs, SpvId rhs, SpvOp_ ifFloat, SpvOp_ ifInt,
3841                                                SpvOp_ ifUInt, SpvOp_ ifBool, OutputStream& out) {
3842     return this->writeBinaryOperation(resultType, operandType, lhs, rhs,
3843                                       /*writeComponentwiseIfMatrix=*/false,
3844                                       ifFloat, ifInt, ifUInt, ifBool, out);
3845 }
3846 
foldToBool(SpvId id,const Type & operandType,SpvOp op,OutputStream & out)3847 SpvId SPIRVCodeGenerator::foldToBool(SpvId id, const Type& operandType, SpvOp op,
3848                                      OutputStream& out) {
3849     if (operandType.isVector()) {
3850         SpvId result = this->nextId(nullptr);
3851         this->writeInstruction(op, this->getType(*fContext.fTypes.fBool), result, id, out);
3852         return result;
3853     }
3854     return id;
3855 }
3856 
writeMatrixComparison(const Type & operandType,SpvId lhs,SpvId rhs,SpvOp_ floatOperator,SpvOp_ intOperator,SpvOp_ vectorMergeOperator,SpvOp_ mergeOperator,OutputStream & out)3857 SpvId SPIRVCodeGenerator::writeMatrixComparison(const Type& operandType, SpvId lhs, SpvId rhs,
3858                                                 SpvOp_ floatOperator, SpvOp_ intOperator,
3859                                                 SpvOp_ vectorMergeOperator, SpvOp_ mergeOperator,
3860                                                 OutputStream& out) {
3861     SpvOp_ compareOp = is_float(operandType) ? floatOperator : intOperator;
3862     SkASSERT(operandType.isMatrix());
3863     const Type& columnType = operandType.componentType().toCompound(fContext,
3864                                                                     operandType.rows(),
3865                                                                     1);
3866     SpvId bvecType = this->getType(fContext.fTypes.fBool->toCompound(fContext,
3867                                                                      operandType.rows(),
3868                                                                      1));
3869     SpvId boolType = this->getType(*fContext.fTypes.fBool);
3870     SpvId result = 0;
3871     for (int i = 0; i < operandType.columns(); i++) {
3872         SpvId columnL = this->writeOpCompositeExtract(columnType, lhs, i, out);
3873         SpvId columnR = this->writeOpCompositeExtract(columnType, rhs, i, out);
3874         SpvId compare = this->nextId(&operandType);
3875         this->writeInstruction(compareOp, bvecType, compare, columnL, columnR, out);
3876         SpvId merge = this->nextId(nullptr);
3877         this->writeInstruction(vectorMergeOperator, boolType, merge, compare, out);
3878         if (result != 0) {
3879             SpvId next = this->nextId(nullptr);
3880             this->writeInstruction(mergeOperator, boolType, next, result, merge, out);
3881             result = next;
3882         } else {
3883             result = merge;
3884         }
3885     }
3886     return result;
3887 }
3888 
writeComponentwiseMatrixUnary(const Type & operandType,SpvId operand,SpvOp_ op,OutputStream & out)3889 SpvId SPIRVCodeGenerator::writeComponentwiseMatrixUnary(const Type& operandType,
3890                                                         SpvId operand,
3891                                                         SpvOp_ op,
3892                                                         OutputStream& out) {
3893     SkASSERT(operandType.isMatrix());
3894     const Type& columnType = operandType.columnType(fContext);
3895     SpvId columnTypeId = this->getType(columnType);
3896 
3897     STArray<4, SpvId> columns;
3898     for (int i = 0; i < operandType.columns(); i++) {
3899         SpvId srcColumn = this->writeOpCompositeExtract(columnType, operand, i, out);
3900         SpvId dstColumn = this->nextId(&operandType);
3901         this->writeInstruction(op, columnTypeId, dstColumn, srcColumn, out);
3902         columns.push_back(dstColumn);
3903     }
3904 
3905     return this->writeOpCompositeConstruct(operandType, columns, out);
3906 }
3907 
writeComponentwiseMatrixBinary(const Type & operandType,SpvId lhs,SpvId rhs,SpvOp_ op,OutputStream & out)3908 SpvId SPIRVCodeGenerator::writeComponentwiseMatrixBinary(const Type& operandType, SpvId lhs,
3909                                                          SpvId rhs, SpvOp_ op, OutputStream& out) {
3910     SkASSERT(operandType.isMatrix());
3911     const Type& columnType = operandType.columnType(fContext);
3912     SpvId columnTypeId = this->getType(columnType);
3913 
3914     STArray<4, SpvId> columns;
3915     for (int i = 0; i < operandType.columns(); i++) {
3916         SpvId columnL = this->writeOpCompositeExtract(columnType, lhs, i, out);
3917         SpvId columnR = this->writeOpCompositeExtract(columnType, rhs, i, out);
3918         columns.push_back(this->nextId(&operandType));
3919         this->writeInstruction(op, columnTypeId, columns[i], columnL, columnR, out);
3920     }
3921     return this->writeOpCompositeConstruct(operandType, columns, out);
3922 }
3923 
writeReciprocal(const Type & type,SpvId value,OutputStream & out)3924 SpvId SPIRVCodeGenerator::writeReciprocal(const Type& type, SpvId value, OutputStream& out) {
3925     SkASSERT(type.isFloat());
3926     SpvId one = this->writeLiteral(1.0, type);
3927     SpvId reciprocal = this->nextId(&type);
3928     this->writeInstruction(SpvOpFDiv, this->getType(type), reciprocal, one, value, out);
3929     return reciprocal;
3930 }
3931 
splat(const Type & type,SpvId id,OutputStream & out)3932 SpvId SPIRVCodeGenerator::splat(const Type& type, SpvId id, OutputStream& out) {
3933     if (type.isScalar()) {
3934         // Scalars require no additional work; we can return the passed-in ID as is.
3935     } else {
3936         SkASSERT(type.isVector() || type.isMatrix());
3937         bool isMatrix = type.isMatrix();
3938 
3939         // Splat the input scalar across a vector.
3940         int vectorSize = (isMatrix ? type.rows() : type.columns());
3941         const Type& vectorType = type.componentType().toCompound(fContext, vectorSize, /*rows=*/1);
3942 
3943         STArray<4, SpvId> values;
3944         values.push_back_n(/*n=*/vectorSize, /*t=*/id);
3945         id = this->writeOpCompositeConstruct(vectorType, values, out);
3946 
3947         if (isMatrix) {
3948             // Splat the newly-synthesized vector into a matrix.
3949             STArray<4, SpvId> matArguments;
3950             matArguments.push_back_n(/*n=*/type.columns(), /*t=*/id);
3951             id = this->writeOpCompositeConstruct(type, matArguments, out);
3952         }
3953     }
3954 
3955     return id;
3956 }
3957 
types_match(const Type & a,const Type & b)3958 static bool types_match(const Type& a, const Type& b) {
3959     if (a.matches(b)) {
3960         return true;
3961     }
3962     return (a.typeKind() == b.typeKind()) &&
3963            (a.isScalar() || a.isVector() || a.isMatrix()) &&
3964            (a.columns() == b.columns() && a.rows() == b.rows()) &&
3965            a.componentType().numberKind() == b.componentType().numberKind();
3966 }
3967 
writeDecomposedMatrixVectorMultiply(const Type & leftType,SpvId lhs,const Type & rightType,SpvId rhs,const Type & resultType,OutputStream & out)3968 SpvId SPIRVCodeGenerator::writeDecomposedMatrixVectorMultiply(const Type& leftType,
3969                                                               SpvId lhs,
3970                                                               const Type& rightType,
3971                                                               SpvId rhs,
3972                                                               const Type& resultType,
3973                                                               OutputStream& out) {
3974     SpvId sum = NA;
3975     const Type& columnType = leftType.columnType(fContext);
3976     const Type& scalarType = rightType.componentType();
3977 
3978     for (int n = 0; n < leftType.rows(); ++n) {
3979         // Extract mat[N] from the matrix.
3980         SpvId matN = this->writeOpCompositeExtract(columnType, lhs, n, out);
3981 
3982         // Extract vec[N] from the vector.
3983         SpvId vecN = this->writeOpCompositeExtract(scalarType, rhs, n, out);
3984 
3985         // Multiply them together.
3986         SpvId product = this->writeBinaryExpression(columnType, matN, OperatorKind::STAR,
3987                                                     scalarType, vecN,
3988                                                     columnType, out);
3989 
3990         // Sum all the components together.
3991         if (sum == NA) {
3992             sum = product;
3993         } else {
3994             sum = this->writeBinaryExpression(columnType, sum, OperatorKind::PLUS,
3995                                               columnType, product,
3996                                               columnType, out);
3997         }
3998     }
3999 
4000     return sum;
4001 }
4002 
writeBinaryExpression(const Type & leftType,SpvId lhs,Operator op,const Type & rightType,SpvId rhs,const Type & resultType,OutputStream & out)4003 SpvId SPIRVCodeGenerator::writeBinaryExpression(const Type& leftType, SpvId lhs, Operator op,
4004                                                 const Type& rightType, SpvId rhs,
4005                                                 const Type& resultType, OutputStream& out) {
4006     // The comma operator ignores the type of the left-hand side entirely.
4007     if (op.kind() == Operator::Kind::COMMA) {
4008         return rhs;
4009     }
4010     // overall type we are operating on: float2, int, uint4...
4011     const Type* operandType;
4012     if (types_match(leftType, rightType)) {
4013         operandType = &leftType;
4014     } else {
4015         // IR allows mismatched types in expressions (e.g. float2 * float), but they need special
4016         // handling in SPIR-V
4017         if (leftType.isVector() && rightType.isNumber()) {
4018             if (resultType.componentType().isFloat()) {
4019                 switch (op.kind()) {
4020                     case Operator::Kind::SLASH: {
4021                         rhs = this->writeReciprocal(rightType, rhs, out);
4022                         [[fallthrough]];
4023                     }
4024                     case Operator::Kind::STAR: {
4025                         SpvId result = this->nextId(&resultType);
4026                         this->writeInstruction(SpvOpVectorTimesScalar, this->getType(resultType),
4027                                                result, lhs, rhs, out);
4028                         return result;
4029                     }
4030                     default:
4031                         break;
4032                 }
4033             }
4034             // Vectorize the right-hand side.
4035             STArray<4, SpvId> arguments;
4036             arguments.push_back_n(/*n=*/leftType.columns(), /*t=*/rhs);
4037             rhs = this->writeOpCompositeConstruct(leftType, arguments, out);
4038             operandType = &leftType;
4039         } else if (rightType.isVector() && leftType.isNumber()) {
4040             if (resultType.componentType().isFloat()) {
4041                 if (op.kind() == Operator::Kind::STAR) {
4042                     SpvId result = this->nextId(&resultType);
4043                     this->writeInstruction(SpvOpVectorTimesScalar, this->getType(resultType),
4044                                            result, rhs, lhs, out);
4045                     return result;
4046                 }
4047             }
4048             // Vectorize the left-hand side.
4049             STArray<4, SpvId> arguments;
4050             arguments.push_back_n(/*n=*/rightType.columns(), /*t=*/lhs);
4051             lhs = this->writeOpCompositeConstruct(rightType, arguments, out);
4052             operandType = &rightType;
4053         } else if (leftType.isMatrix()) {
4054             if (op.kind() == Operator::Kind::STAR) {
4055                 // When the rewriteMatrixVectorMultiply bit is set, we rewrite medium-precision
4056                 // matrix * vector multiplication as (mat[0]*vec[0] + ... + mat[N]*vec[N]).
4057                 if (fCaps.fRewriteMatrixVectorMultiply &&
4058                     rightType.isVector() &&
4059                     !resultType.highPrecision()) {
4060                     return this->writeDecomposedMatrixVectorMultiply(leftType, lhs, rightType, rhs,
4061                                                                      resultType, out);
4062                 }
4063 
4064                 // Matrix-times-vector and matrix-times-scalar have dedicated ops in SPIR-V.
4065                 SpvOp_ spvop;
4066                 if (rightType.isMatrix()) {
4067                     spvop = SpvOpMatrixTimesMatrix;
4068                 } else if (rightType.isVector()) {
4069                     spvop = SpvOpMatrixTimesVector;
4070                 } else {
4071                     SkASSERT(rightType.isScalar());
4072                     spvop = SpvOpMatrixTimesScalar;
4073                 }
4074                 SpvId result = this->nextId(&resultType);
4075                 this->writeInstruction(spvop, this->getType(resultType), result, lhs, rhs, out);
4076                 return result;
4077             } else {
4078                 // Matrix-op-vector is not supported in GLSL/SkSL for non-multiplication ops; we
4079                 // expect to have a scalar here.
4080                 SkASSERT(rightType.isScalar());
4081 
4082                 // Splat rhs across an entire matrix so we can reuse the matrix-op-matrix path.
4083                 SpvId rhsMatrix = this->splat(leftType, rhs, out);
4084 
4085                 // Perform this operation as matrix-op-matrix.
4086                 return this->writeBinaryExpression(leftType, lhs, op, leftType, rhsMatrix,
4087                                                    resultType, out);
4088             }
4089         } else if (rightType.isMatrix()) {
4090             if (op.kind() == Operator::Kind::STAR) {
4091                 // Matrix-times-vector and matrix-times-scalar have dedicated ops in SPIR-V.
4092                 SpvId result = this->nextId(&resultType);
4093                 if (leftType.isVector()) {
4094                     this->writeInstruction(SpvOpVectorTimesMatrix, this->getType(resultType),
4095                                            result, lhs, rhs, out);
4096                 } else {
4097                     SkASSERT(leftType.isScalar());
4098                     this->writeInstruction(SpvOpMatrixTimesScalar, this->getType(resultType),
4099                                            result, rhs, lhs, out);
4100                 }
4101                 return result;
4102             } else {
4103                 // Vector-op-matrix is not supported in GLSL/SkSL for non-multiplication ops; we
4104                 // expect to have a scalar here.
4105                 SkASSERT(leftType.isScalar());
4106 
4107                 // Splat lhs across an entire matrix so we can reuse the matrix-op-matrix path.
4108                 SpvId lhsMatrix = this->splat(rightType, lhs, out);
4109 
4110                 // Perform this operation as matrix-op-matrix.
4111                 return this->writeBinaryExpression(rightType, lhsMatrix, op, rightType, rhs,
4112                                                    resultType, out);
4113             }
4114         } else {
4115             fContext.fErrors->error(leftType.fPosition, "unsupported mixed-type expression");
4116             return NA;
4117         }
4118     }
4119 
4120     switch (op.kind()) {
4121         case Operator::Kind::EQEQ: {
4122             if (operandType->isMatrix()) {
4123                 return this->writeMatrixComparison(*operandType, lhs, rhs, SpvOpFOrdEqual,
4124                                                    SpvOpIEqual, SpvOpAll, SpvOpLogicalAnd, out);
4125             }
4126             if (operandType->isStruct()) {
4127                 return this->writeStructComparison(*operandType, lhs, op, rhs, out);
4128             }
4129             if (operandType->isArray()) {
4130                 return this->writeArrayComparison(*operandType, lhs, op, rhs, out);
4131             }
4132             SkASSERT(resultType.isBoolean());
4133             const Type* tmpType;
4134             if (operandType->isVector()) {
4135                 tmpType = &fContext.fTypes.fBool->toCompound(fContext,
4136                                                              operandType->columns(),
4137                                                              operandType->rows());
4138             } else {
4139                 tmpType = &resultType;
4140             }
4141             if (lhs == rhs) {
4142                 // This ignores the effects of NaN.
4143                 return this->writeOpConstantTrue(*fContext.fTypes.fBool);
4144             }
4145             return this->foldToBool(this->writeBinaryOperation(*tmpType, *operandType, lhs, rhs,
4146                                                                SpvOpFOrdEqual, SpvOpIEqual,
4147                                                                SpvOpIEqual, SpvOpLogicalEqual, out),
4148                                     *operandType, SpvOpAll, out);
4149         }
4150         case Operator::Kind::NEQ:
4151             if (operandType->isMatrix()) {
4152                 return this->writeMatrixComparison(*operandType, lhs, rhs, SpvOpFUnordNotEqual,
4153                                                    SpvOpINotEqual, SpvOpAny, SpvOpLogicalOr, out);
4154             }
4155             if (operandType->isStruct()) {
4156                 return this->writeStructComparison(*operandType, lhs, op, rhs, out);
4157             }
4158             if (operandType->isArray()) {
4159                 return this->writeArrayComparison(*operandType, lhs, op, rhs, out);
4160             }
4161             [[fallthrough]];
4162         case Operator::Kind::LOGICALXOR:
4163             SkASSERT(resultType.isBoolean());
4164             const Type* tmpType;
4165             if (operandType->isVector()) {
4166                 tmpType = &fContext.fTypes.fBool->toCompound(fContext,
4167                                                              operandType->columns(),
4168                                                              operandType->rows());
4169             } else {
4170                 tmpType = &resultType;
4171             }
4172             if (lhs == rhs) {
4173                 // This ignores the effects of NaN.
4174                 return this->writeOpConstantFalse(*fContext.fTypes.fBool);
4175             }
4176             return this->foldToBool(this->writeBinaryOperation(*tmpType, *operandType, lhs, rhs,
4177                                                                SpvOpFUnordNotEqual, SpvOpINotEqual,
4178                                                                SpvOpINotEqual, SpvOpLogicalNotEqual,
4179                                                                out),
4180                                     *operandType, SpvOpAny, out);
4181         case Operator::Kind::GT:
4182             SkASSERT(resultType.isBoolean());
4183             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
4184                                               SpvOpFOrdGreaterThan, SpvOpSGreaterThan,
4185                                               SpvOpUGreaterThan, SpvOpUndef, out);
4186         case Operator::Kind::LT:
4187             SkASSERT(resultType.isBoolean());
4188             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFOrdLessThan,
4189                                               SpvOpSLessThan, SpvOpULessThan, SpvOpUndef, out);
4190         case Operator::Kind::GTEQ:
4191             SkASSERT(resultType.isBoolean());
4192             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
4193                                               SpvOpFOrdGreaterThanEqual, SpvOpSGreaterThanEqual,
4194                                               SpvOpUGreaterThanEqual, SpvOpUndef, out);
4195         case Operator::Kind::LTEQ:
4196             SkASSERT(resultType.isBoolean());
4197             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
4198                                               SpvOpFOrdLessThanEqual, SpvOpSLessThanEqual,
4199                                               SpvOpULessThanEqual, SpvOpUndef, out);
4200         case Operator::Kind::PLUS:
4201             return this->writeBinaryOperationComponentwiseIfMatrix(resultType, *operandType,
4202                                                                    lhs, rhs, SpvOpFAdd, SpvOpIAdd,
4203                                                                    SpvOpIAdd, SpvOpUndef, out);
4204         case Operator::Kind::MINUS:
4205             return this->writeBinaryOperationComponentwiseIfMatrix(resultType, *operandType,
4206                                                                    lhs, rhs, SpvOpFSub, SpvOpISub,
4207                                                                    SpvOpISub, SpvOpUndef, out);
4208         case Operator::Kind::STAR:
4209             if (leftType.isMatrix() && rightType.isMatrix()) {
4210                 // matrix multiply
4211                 SpvId result = this->nextId(&resultType);
4212                 this->writeInstruction(SpvOpMatrixTimesMatrix, this->getType(resultType), result,
4213                                        lhs, rhs, out);
4214                 return result;
4215             }
4216             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFMul,
4217                                               SpvOpIMul, SpvOpIMul, SpvOpUndef, out);
4218         case Operator::Kind::SLASH:
4219             return this->writeBinaryOperationComponentwiseIfMatrix(resultType, *operandType,
4220                                                                    lhs, rhs, SpvOpFDiv, SpvOpSDiv,
4221                                                                    SpvOpUDiv, SpvOpUndef, out);
4222         case Operator::Kind::PERCENT:
4223             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFMod,
4224                                               SpvOpSMod, SpvOpUMod, SpvOpUndef, out);
4225         case Operator::Kind::SHL:
4226             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef,
4227                                               SpvOpShiftLeftLogical, SpvOpShiftLeftLogical,
4228                                               SpvOpUndef, out);
4229         case Operator::Kind::SHR:
4230             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef,
4231                                               SpvOpShiftRightArithmetic, SpvOpShiftRightLogical,
4232                                               SpvOpUndef, out);
4233         case Operator::Kind::BITWISEAND:
4234             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef,
4235                                               SpvOpBitwiseAnd, SpvOpBitwiseAnd, SpvOpUndef, out);
4236         case Operator::Kind::BITWISEOR:
4237             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef,
4238                                               SpvOpBitwiseOr, SpvOpBitwiseOr, SpvOpUndef, out);
4239         case Operator::Kind::BITWISEXOR:
4240             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef,
4241                                               SpvOpBitwiseXor, SpvOpBitwiseXor, SpvOpUndef, out);
4242         default:
4243             fContext.fErrors->error(Position(), "unsupported token");
4244             return NA;
4245     }
4246 }
4247 
writeArrayComparison(const Type & arrayType,SpvId lhs,Operator op,SpvId rhs,OutputStream & out)4248 SpvId SPIRVCodeGenerator::writeArrayComparison(const Type& arrayType, SpvId lhs, Operator op,
4249                                                SpvId rhs, OutputStream& out) {
4250     // The inputs must be arrays, and the op must be == or !=.
4251     SkASSERT(op.kind() == Operator::Kind::EQEQ || op.kind() == Operator::Kind::NEQ);
4252     SkASSERT(arrayType.isArray());
4253     const Type& componentType = arrayType.componentType();
4254     const int arraySize = arrayType.columns();
4255     SkASSERT(arraySize > 0);
4256 
4257     // Synthesize equality checks for each item in the array.
4258     const Type& boolType = *fContext.fTypes.fBool;
4259     SpvId allComparisons = NA;
4260     for (int index = 0; index < arraySize; ++index) {
4261         // Get the left and right item in the array.
4262         SpvId itemL = this->writeOpCompositeExtract(componentType, lhs, index, out);
4263         SpvId itemR = this->writeOpCompositeExtract(componentType, rhs, index, out);
4264         // Use `writeBinaryExpression` with the requested == or != operator on these items.
4265         SpvId comparison = this->writeBinaryExpression(componentType, itemL, op,
4266                                                        componentType, itemR, boolType, out);
4267         // Merge this comparison result with all the other comparisons we've done.
4268         allComparisons = this->mergeComparisons(comparison, allComparisons, op, out);
4269     }
4270     return allComparisons;
4271 }
4272 
writeStructComparison(const Type & structType,SpvId lhs,Operator op,SpvId rhs,OutputStream & out)4273 SpvId SPIRVCodeGenerator::writeStructComparison(const Type& structType, SpvId lhs, Operator op,
4274                                                 SpvId rhs, OutputStream& out) {
4275     // The inputs must be structs containing fields, and the op must be == or !=.
4276     SkASSERT(op.kind() == Operator::Kind::EQEQ || op.kind() == Operator::Kind::NEQ);
4277     SkASSERT(structType.isStruct());
4278     SkSpan<const Field> fields = structType.fields();
4279     SkASSERT(!fields.empty());
4280 
4281     // Synthesize equality checks for each field in the struct.
4282     const Type& boolType = *fContext.fTypes.fBool;
4283     SpvId allComparisons = NA;
4284     for (int index = 0; index < (int)fields.size(); ++index) {
4285         // Get the left and right versions of this field.
4286         const Type& fieldType = *fields[index].fType;
4287 
4288         SpvId fieldL = this->writeOpCompositeExtract(fieldType, lhs, index, out);
4289         SpvId fieldR = this->writeOpCompositeExtract(fieldType, rhs, index, out);
4290         // Use `writeBinaryExpression` with the requested == or != operator on these fields.
4291         SpvId comparison = this->writeBinaryExpression(fieldType, fieldL, op, fieldType, fieldR,
4292                                                        boolType, out);
4293         // Merge this comparison result with all the other comparisons we've done.
4294         allComparisons = this->mergeComparisons(comparison, allComparisons, op, out);
4295     }
4296     return allComparisons;
4297 }
4298 
mergeComparisons(SpvId comparison,SpvId allComparisons,Operator op,OutputStream & out)4299 SpvId SPIRVCodeGenerator::mergeComparisons(SpvId comparison, SpvId allComparisons, Operator op,
4300                                            OutputStream& out) {
4301     // If this is the first entry, we don't need to merge comparison results with anything.
4302     if (allComparisons == NA) {
4303         return comparison;
4304     }
4305     // Use LogicalAnd or LogicalOr to combine the comparison with all the other comparisons.
4306     const Type& boolType = *fContext.fTypes.fBool;
4307     SpvId boolTypeId = this->getType(boolType);
4308     SpvId logicalOp = this->nextId(&boolType);
4309     switch (op.kind()) {
4310         case Operator::Kind::EQEQ:
4311             this->writeInstruction(SpvOpLogicalAnd, boolTypeId, logicalOp,
4312                                    comparison, allComparisons, out);
4313             break;
4314         case Operator::Kind::NEQ:
4315             this->writeInstruction(SpvOpLogicalOr, boolTypeId, logicalOp,
4316                                    comparison, allComparisons, out);
4317             break;
4318         default:
4319             SkDEBUGFAILF("mergeComparisons only supports == and !=, not %s", op.operatorName());
4320             return NA;
4321     }
4322     return logicalOp;
4323 }
4324 
4325 #ifdef SKSL_EXT
writeSpecConstBinaryExpression(const BinaryExpression & b,const Operator & op,SpvId lhs,SpvId rhs)4326 SpvId SPIRVCodeGenerator::writeSpecConstBinaryExpression(const BinaryExpression& b, const Operator& op,
4327                                                          SpvId lhs, SpvId rhs) {
4328     SpvId result = this->nextId(&(b.type()));
4329     switch (op.removeAssignment().kind()) {
4330         case Operator::Kind::EQEQ:
4331             this->writeInstruction(SpvOpSpecConstantOp, this->getType(b.type()), result,
4332                                    SpvOpIEqual, lhs, rhs, fConstantBuffer);
4333             break;
4334         case Operator::Kind::NEQ:
4335             this->writeInstruction(SpvOpSpecConstantOp, this->getType(b.type()), result,
4336                                    SpvOpINotEqual, lhs, rhs, fConstantBuffer);
4337             break;
4338         case Operator::Kind::LT:
4339             this->writeInstruction(SpvOpSpecConstantOp, this->getType(b.type()), result,
4340                                    SpvOpULessThan, lhs, rhs, fConstantBuffer);
4341             break;
4342         case Operator::Kind::LTEQ:
4343             this->writeInstruction(SpvOpSpecConstantOp, this->getType(b.type()), result,
4344                                    SpvOpULessThanEqual, lhs, rhs, fConstantBuffer);
4345             break;
4346         case Operator::Kind::GT:
4347             this->writeInstruction(SpvOpSpecConstantOp, this->getType(b.type()), result,
4348                                    SpvOpUGreaterThan, lhs, rhs, fConstantBuffer);
4349             break;
4350         case Operator::Kind::GTEQ:
4351             this->writeInstruction(SpvOpSpecConstantOp, this->getType(b.type()), result,
4352                                    SpvOpUGreaterThanEqual, lhs, rhs, fConstantBuffer);
4353             break;
4354         default:
4355             fContext.fErrors->error(b.fPosition, "spec constant does not support operator: " +
4356                                     std::string(op.operatorName()));
4357             return -1;
4358     }
4359     return result;
4360 }
4361 #endif
4362 
writeBinaryExpression(const BinaryExpression & b,OutputStream & out)4363 SpvId SPIRVCodeGenerator::writeBinaryExpression(const BinaryExpression& b, OutputStream& out) {
4364     const Expression* left = b.left().get();
4365     const Expression* right = b.right().get();
4366     Operator op = b.getOperator();
4367 
4368     switch (op.kind()) {
4369         case Operator::Kind::EQ: {
4370             // Handles assignment.
4371             SpvId rhs = this->writeExpression(*right, out);
4372             this->getLValue(*left, out)->store(rhs, out);
4373             return rhs;
4374         }
4375         case Operator::Kind::LOGICALAND:
4376             // Handles short-circuiting; we don't necessarily evaluate both LHS and RHS.
4377             return this->writeLogicalAnd(*b.left(), *b.right(), out);
4378 
4379         case Operator::Kind::LOGICALOR:
4380             // Handles short-circuiting; we don't necessarily evaluate both LHS and RHS.
4381             return this->writeLogicalOr(*b.left(), *b.right(), out);
4382 
4383         default:
4384             break;
4385     }
4386 
4387     std::unique_ptr<LValue> lvalue;
4388     SpvId lhs;
4389     if (op.isAssignment()) {
4390         lvalue = this->getLValue(*left, out);
4391         lhs = lvalue->load(out);
4392     } else {
4393         lvalue = nullptr;
4394         lhs = this->writeExpression(*left, out);
4395     }
4396 
4397     SpvId rhs = this->writeExpression(*right, out);
4398 #ifdef SKSL_EXT
4399     if (left->kind() == Expression::Kind::kVariableReference) {
4400         VariableReference* rightRef = (VariableReference*) right;
4401         const Expression* expr = ConstantFolder::GetConstantValueForVariable(*rightRef);
4402         if (expr != rightRef) {
4403             VariableReference* ref = (VariableReference*) left;
4404             const Variable* var = ref->variable();
4405             if (var && (var->layout().fFlags & SkSL::LayoutFlag::kConstantId)) {
4406                 return writeSpecConstBinaryExpression(b, op, lhs, rhs);
4407             }
4408         }
4409     }
4410     if (right->kind() == Expression::Kind::kVariableReference) {
4411         VariableReference* leftRef = (VariableReference*) left;
4412         const Expression* expr = ConstantFolder::GetConstantValueForVariable(*leftRef);
4413         if (expr != leftRef) {
4414             VariableReference* ref = (VariableReference*) right;
4415             const Variable* var = ref->variable();
4416             if (var && (var->layout().fFlags & SkSL::LayoutFlag::kConstantId)) {
4417                 return writeSpecConstBinaryExpression(b, op, lhs, rhs);
4418             }
4419         }
4420     }
4421 #endif
4422     SpvId result = this->writeBinaryExpression(left->type(), lhs, op.removeAssignment(),
4423                                                right->type(), rhs, b.type(), out);
4424     if (lvalue) {
4425         lvalue->store(result, out);
4426     }
4427     return result;
4428 }
4429 
writeLogicalAnd(const Expression & left,const Expression & right,OutputStream & out)4430 SpvId SPIRVCodeGenerator::writeLogicalAnd(const Expression& left, const Expression& right,
4431                                           OutputStream& out) {
4432     SpvId falseConstant = this->writeLiteral(0.0, *fContext.fTypes.fBool);
4433     SpvId lhs = this->writeExpression(left, out);
4434 
4435     ConditionalOpCounts conditionalOps = this->getConditionalOpCounts();
4436 
4437     SpvId rhsLabel = this->nextId(nullptr);
4438     SpvId end = this->nextId(nullptr);
4439     SpvId lhsBlock = fCurrentBlock;
4440     this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
4441     this->writeInstruction(SpvOpBranchConditional, lhs, rhsLabel, end, out);
4442     this->writeLabel(rhsLabel, kBranchIsOnPreviousLine, out);
4443     SpvId rhs = this->writeExpression(right, out);
4444     SpvId rhsBlock = fCurrentBlock;
4445     this->writeInstruction(SpvOpBranch, end, out);
4446     this->writeLabel(end, kBranchIsAbove, conditionalOps, out);
4447     SpvId result = this->nextId(nullptr);
4448     this->writeInstruction(SpvOpPhi, this->getType(*fContext.fTypes.fBool), result, falseConstant,
4449                            lhsBlock, rhs, rhsBlock, out);
4450 
4451     return result;
4452 }
4453 
writeLogicalOr(const Expression & left,const Expression & right,OutputStream & out)4454 SpvId SPIRVCodeGenerator::writeLogicalOr(const Expression& left, const Expression& right,
4455                                          OutputStream& out) {
4456     SpvId trueConstant = this->writeLiteral(1.0, *fContext.fTypes.fBool);
4457     SpvId lhs = this->writeExpression(left, out);
4458 
4459     ConditionalOpCounts conditionalOps = this->getConditionalOpCounts();
4460 
4461     SpvId rhsLabel = this->nextId(nullptr);
4462     SpvId end = this->nextId(nullptr);
4463     SpvId lhsBlock = fCurrentBlock;
4464     this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
4465     this->writeInstruction(SpvOpBranchConditional, lhs, end, rhsLabel, out);
4466     this->writeLabel(rhsLabel, kBranchIsOnPreviousLine, out);
4467     SpvId rhs = this->writeExpression(right, out);
4468     SpvId rhsBlock = fCurrentBlock;
4469     this->writeInstruction(SpvOpBranch, end, out);
4470     this->writeLabel(end, kBranchIsAbove, conditionalOps, out);
4471     SpvId result = this->nextId(nullptr);
4472     this->writeInstruction(SpvOpPhi, this->getType(*fContext.fTypes.fBool), result, trueConstant,
4473                            lhsBlock, rhs, rhsBlock, out);
4474 
4475     return result;
4476 }
4477 
writeTernaryExpression(const TernaryExpression & t,OutputStream & out)4478 SpvId SPIRVCodeGenerator::writeTernaryExpression(const TernaryExpression& t, OutputStream& out) {
4479     const Type& type = t.type();
4480     SpvId test = this->writeExpression(*t.test(), out);
4481     if (t.ifTrue()->type().columns() == 1 &&
4482         Analysis::IsCompileTimeConstant(*t.ifTrue()) &&
4483         Analysis::IsCompileTimeConstant(*t.ifFalse())) {
4484         // both true and false are constants, can just use OpSelect
4485         SpvId result = this->nextId(nullptr);
4486         SpvId trueId = this->writeExpression(*t.ifTrue(), out);
4487         SpvId falseId = this->writeExpression(*t.ifFalse(), out);
4488         this->writeInstruction(SpvOpSelect, this->getType(type), result, test, trueId, falseId,
4489                                out);
4490         return result;
4491     }
4492 
4493     ConditionalOpCounts conditionalOps = this->getConditionalOpCounts();
4494 
4495     // was originally using OpPhi to choose the result, but for some reason that is crashing on
4496     // Adreno. Switched to storing the result in a temp variable as glslang does.
4497     SpvId var = this->nextId(nullptr);
4498     this->writeInstruction(SpvOpVariable, this->getPointerType(type, StorageClass::kFunction),
4499                            var, SpvStorageClassFunction, fVariableBuffer);
4500     SpvId trueLabel = this->nextId(nullptr);
4501     SpvId falseLabel = this->nextId(nullptr);
4502     SpvId end = this->nextId(nullptr);
4503     this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
4504     this->writeInstruction(SpvOpBranchConditional, test, trueLabel, falseLabel, out);
4505     this->writeLabel(trueLabel, kBranchIsOnPreviousLine, out);
4506     this->writeOpStore(StorageClass::kFunction, var, this->writeExpression(*t.ifTrue(), out), out);
4507     this->writeInstruction(SpvOpBranch, end, out);
4508     this->writeLabel(falseLabel, kBranchIsAbove, conditionalOps, out);
4509     this->writeOpStore(StorageClass::kFunction, var, this->writeExpression(*t.ifFalse(), out), out);
4510     this->writeInstruction(SpvOpBranch, end, out);
4511     this->writeLabel(end, kBranchIsAbove, conditionalOps, out);
4512     SpvId result = this->nextId(&type);
4513     this->writeInstruction(SpvOpLoad, this->getType(type), result, var, out);
4514 
4515     return result;
4516 }
4517 
writePrefixExpression(const PrefixExpression & p,OutputStream & out)4518 SpvId SPIRVCodeGenerator::writePrefixExpression(const PrefixExpression& p, OutputStream& out) {
4519     const Type& type = p.type();
4520     if (p.getOperator().kind() == Operator::Kind::MINUS) {
4521         SpvOp_ negateOp = pick_by_type(type, SpvOpFNegate, SpvOpSNegate, SpvOpSNegate, SpvOpUndef);
4522         SkASSERT(negateOp != SpvOpUndef);
4523         SpvId expr = this->writeExpression(*p.operand(), out);
4524         if (type.isMatrix()) {
4525             return this->writeComponentwiseMatrixUnary(type, expr, negateOp, out);
4526         }
4527         SpvId result = this->nextId(&type);
4528         SpvId typeId = this->getType(type);
4529         this->writeInstruction(negateOp, typeId, result, expr, out);
4530         return result;
4531     }
4532     switch (p.getOperator().kind()) {
4533         case Operator::Kind::PLUS:
4534             return this->writeExpression(*p.operand(), out);
4535 
4536         case Operator::Kind::PLUSPLUS: {
4537             std::unique_ptr<LValue> lv = this->getLValue(*p.operand(), out);
4538             SpvId one = this->writeLiteral(1.0, type.componentType());
4539             one = this->splat(type, one, out);
4540             SpvId result = this->writeBinaryOperationComponentwiseIfMatrix(type, type,
4541                                                                            lv->load(out), one,
4542                                                                            SpvOpFAdd, SpvOpIAdd,
4543                                                                            SpvOpIAdd, SpvOpUndef,
4544                                                                            out);
4545             lv->store(result, out);
4546             return result;
4547         }
4548         case Operator::Kind::MINUSMINUS: {
4549             std::unique_ptr<LValue> lv = this->getLValue(*p.operand(), out);
4550             SpvId one = this->writeLiteral(1.0, type.componentType());
4551             one = this->splat(type, one, out);
4552             SpvId result = this->writeBinaryOperationComponentwiseIfMatrix(type, type,
4553                                                                            lv->load(out), one,
4554                                                                            SpvOpFSub, SpvOpISub,
4555                                                                            SpvOpISub, SpvOpUndef,
4556                                                                            out);
4557             lv->store(result, out);
4558             return result;
4559         }
4560         case Operator::Kind::LOGICALNOT: {
4561             SkASSERT(p.operand()->type().isBoolean());
4562             SpvId result = this->nextId(nullptr);
4563             this->writeInstruction(SpvOpLogicalNot, this->getType(type), result,
4564                                    this->writeExpression(*p.operand(), out), out);
4565             return result;
4566         }
4567         case Operator::Kind::BITWISENOT: {
4568             SpvId result = this->nextId(nullptr);
4569             this->writeInstruction(SpvOpNot, this->getType(type), result,
4570                                    this->writeExpression(*p.operand(), out), out);
4571             return result;
4572         }
4573         default:
4574             SkDEBUGFAILF("unsupported prefix expression: %s",
4575                          p.description(OperatorPrecedence::kExpression).c_str());
4576             return NA;
4577     }
4578 }
4579 
writePostfixExpression(const PostfixExpression & p,OutputStream & out)4580 SpvId SPIRVCodeGenerator::writePostfixExpression(const PostfixExpression& p, OutputStream& out) {
4581     const Type& type = p.type();
4582     std::unique_ptr<LValue> lv = this->getLValue(*p.operand(), out);
4583     SpvId result = lv->load(out);
4584     SpvId one = this->writeLiteral(1.0, type.componentType());
4585     one = this->splat(type, one, out);
4586     switch (p.getOperator().kind()) {
4587         case Operator::Kind::PLUSPLUS: {
4588             SpvId temp = this->writeBinaryOperationComponentwiseIfMatrix(type, type, result, one,
4589                                                                          SpvOpFAdd, SpvOpIAdd,
4590                                                                          SpvOpIAdd, SpvOpUndef,
4591                                                                          out);
4592             lv->store(temp, out);
4593             return result;
4594         }
4595         case Operator::Kind::MINUSMINUS: {
4596             SpvId temp = this->writeBinaryOperationComponentwiseIfMatrix(type, type, result, one,
4597                                                                          SpvOpFSub, SpvOpISub,
4598                                                                          SpvOpISub, SpvOpUndef,
4599                                                                          out);
4600             lv->store(temp, out);
4601             return result;
4602         }
4603         default:
4604             SkDEBUGFAILF("unsupported postfix expression %s",
4605                          p.description(OperatorPrecedence::kExpression).c_str());
4606             return NA;
4607     }
4608 }
4609 
writeLiteral(const Literal & l)4610 SpvId SPIRVCodeGenerator::writeLiteral(const Literal& l) {
4611     return this->writeLiteral(l.value(), l.type());
4612 }
4613 
writeLiteral(double value,const Type & type)4614 SpvId SPIRVCodeGenerator::writeLiteral(double value, const Type& type) {
4615     switch (type.numberKind()) {
4616         case Type::NumberKind::kFloat: {
4617             float floatVal = value;
4618             int32_t valueBits;
4619             memcpy(&valueBits, &floatVal, sizeof(valueBits));
4620             return this->writeOpConstant(type, valueBits);
4621         }
4622         case Type::NumberKind::kBoolean: {
4623             return value ? this->writeOpConstantTrue(type)
4624                          : this->writeOpConstantFalse(type);
4625         }
4626         default: {
4627             return this->writeOpConstant(type, (SKSL_INT)value);
4628         }
4629     }
4630 }
4631 
writeFunctionStart(const FunctionDeclaration & f,OutputStream & out)4632 void SPIRVCodeGenerator::writeFunctionStart(const FunctionDeclaration& f, OutputStream& out) {
4633     SpvId result = fFunctionMap[{&f, fActiveSpecializationIndex}];
4634     SpvId returnTypeId = this->getType(f.returnType());
4635     SpvId functionTypeId = this->getFunctionType(f);
4636     this->writeInstruction(SpvOpFunction, returnTypeId, result,
4637                            SpvFunctionControlMaskNone, functionTypeId, out);
4638     std::string mangledName = f.mangledName();
4639 
4640     // For specialized functions, tack on `_param1_param2` to the function name.
4641     Analysis::GetParameterMappingsForFunction(
4642             f, fSpecializationInfo, fActiveSpecializationIndex,
4643             [&](int, const Variable*, const Expression* expr) {
4644                 std::string name = expr->description();
4645                 std::replace_if(name.begin(), name.end(), [](char c) { return !isalnum(c); }, '_');
4646 
4647                 mangledName += "_" + name;
4648             });
4649 
4650     this->writeInstruction(SpvOpName,
4651                            result,
4652                            std::string_view(mangledName.c_str(), mangledName.size()),
4653                            fNameBuffer);
4654     for (const Variable* parameter : f.parameters()) {
4655         const Variable* specializedVar = nullptr;
4656         if (fActiveSpecialization) {
4657             if (const Expression** specializedExpr = fActiveSpecialization->find(parameter)) {
4658                 if ((*specializedExpr)->is<FieldAccess>()) {
4659                     continue;
4660                 }
4661                 SkASSERT((*specializedExpr)->is<VariableReference>());
4662                 specializedVar = (*specializedExpr)->as<VariableReference>().variable();
4663             }
4664         }
4665 
4666         if (fUseTextureSamplerPairs && parameter->type().isSampler()) {
4667             auto [texture, sampler] = this->synthesizeTextureAndSampler(*parameter);
4668 
4669             SpvId textureId = this->nextId(nullptr);
4670             fVariableMap.set(texture, textureId);
4671 
4672             SpvId textureType = this->getFunctionParameterType(texture->type(), texture->layout());
4673             this->writeInstruction(SpvOpFunctionParameter, textureType, textureId, out);
4674 
4675             if (specializedVar) {
4676                 const auto* p = fSynthesizedSamplerMap.find(specializedVar);
4677                 SkASSERT(p);
4678                 const SpvId* uniformId = fVariableMap.find((*p)->fSampler.get());
4679                 SkASSERT(uniformId);
4680                 fVariableMap.set(sampler, *uniformId);
4681             } else {
4682                 SpvId samplerId = this->nextId(nullptr);
4683                 fVariableMap.set(sampler, samplerId);
4684 
4685                 SpvId samplerType =
4686                         this->getFunctionParameterType(sampler->type(), kDefaultTypeLayout);
4687                 this->writeInstruction(SpvOpFunctionParameter, samplerType, samplerId, out);
4688             }
4689         } else {
4690             if (specializedVar) {
4691                 const SpvId* uniformId = fVariableMap.find(specializedVar);
4692                 SkASSERT(uniformId);
4693                 fVariableMap.set(parameter, *uniformId);
4694             } else {
4695                 SpvId id = this->nextId(nullptr);
4696                 fVariableMap.set(parameter, id);
4697 
4698                 SpvId type = this->getFunctionParameterType(parameter->type(), parameter->layout());
4699                 this->writeInstruction(SpvOpFunctionParameter, type, id, out);
4700             }
4701         }
4702     }
4703 }
4704 
writeFunction(const FunctionDefinition & f,OutputStream & out)4705 void SPIRVCodeGenerator::writeFunction(const FunctionDefinition& f, OutputStream& out) {
4706     if (const Analysis::Specializations* specializations =
4707                 fSpecializationInfo.fSpecializationMap.find(&f.declaration())) {
4708         for (int i = 0; i < specializations->size(); i++) {
4709             this->writeFunctionInstantiation(f, i, &specializations->at(i), out);
4710         }
4711     } else {
4712         this->writeFunctionInstantiation(f,
4713                                          Analysis::kUnspecialized,
4714                                          /*specializedParams=*/nullptr,
4715                                          out);
4716     }
4717 }
4718 
writeFunctionInstantiation(const FunctionDefinition & f,Analysis::SpecializationIndex specializationIndex,const Analysis::SpecializedParameters * specializedParams,OutputStream & out)4719 void SPIRVCodeGenerator::writeFunctionInstantiation(
4720         const FunctionDefinition& f,
4721         Analysis::SpecializationIndex specializationIndex,
4722         const Analysis::SpecializedParameters* specializedParams,
4723         OutputStream& out) {
4724     ConditionalOpCounts conditionalOps = this->getConditionalOpCounts();
4725 
4726     fVariableBuffer.reset();
4727     fActiveSpecialization = specializedParams;
4728     fActiveSpecializationIndex = specializationIndex;
4729     this->writeFunctionStart(f.declaration(), out);
4730     fCurrentBlock = 0;
4731     this->writeLabel(this->nextId(nullptr), kBranchlessBlock, out);
4732     StringStream bodyBuffer;
4733     this->writeBlock(f.body()->as<Block>(), bodyBuffer);
4734     fActiveSpecialization = nullptr;
4735     fActiveSpecializationIndex = Analysis::kUnspecialized;
4736     write_stringstream(fVariableBuffer, out);
4737     if (f.declaration().isMain()) {
4738         write_stringstream(fGlobalInitializersBuffer, out);
4739     }
4740     write_stringstream(bodyBuffer, out);
4741     if (fCurrentBlock) {
4742         if (f.declaration().returnType().isVoid()) {
4743             this->writeInstruction(SpvOpReturn, out);
4744         } else {
4745             this->writeInstruction(SpvOpUnreachable, out);
4746         }
4747     }
4748     this->writeInstruction(SpvOpFunctionEnd, out);
4749     this->pruneConditionalOps(conditionalOps);
4750 }
4751 
writeLayout(const Layout & layout,SpvId target,Position pos)4752 void SPIRVCodeGenerator::writeLayout(const Layout& layout, SpvId target, Position pos) {
4753     bool isPushConstant = SkToBool(layout.fFlags & LayoutFlag::kPushConstant);
4754     if (layout.fLocation >= 0) {
4755         this->writeInstruction(SpvOpDecorate, target, SpvDecorationLocation, layout.fLocation,
4756                                fDecorationBuffer);
4757     }
4758     if (layout.fBinding >= 0) {
4759         if (isPushConstant) {
4760             fContext.fErrors->error(pos, "Can't apply 'binding' to push constants");
4761         } else {
4762             this->writeInstruction(SpvOpDecorate, target, SpvDecorationBinding, layout.fBinding,
4763                                    fDecorationBuffer);
4764         }
4765     }
4766     if (layout.fIndex >= 0) {
4767         this->writeInstruction(SpvOpDecorate, target, SpvDecorationIndex, layout.fIndex,
4768                                fDecorationBuffer);
4769     }
4770     if (layout.fSet >= 0) {
4771         if (isPushConstant) {
4772             fContext.fErrors->error(pos, "Can't apply 'set' to push constants");
4773         } else {
4774             this->writeInstruction(SpvOpDecorate, target, SpvDecorationDescriptorSet, layout.fSet,
4775                                    fDecorationBuffer);
4776         }
4777     }
4778     if (layout.fInputAttachmentIndex >= 0) {
4779         this->writeInstruction(SpvOpDecorate, target, SpvDecorationInputAttachmentIndex,
4780                                layout.fInputAttachmentIndex, fDecorationBuffer);
4781         fCapabilities |= (((uint64_t) 1) << SpvCapabilityInputAttachment);
4782     }
4783     if (layout.fBuiltin >= 0 && (layout.fBuiltin != SK_FRAGCOLOR_BUILTIN &&
4784                                  layout.fBuiltin != SK_SECONDARYFRAGCOLOR_BUILTIN)) {
4785             this->writeInstruction(SpvOpDecorate, target, SpvDecorationBuiltIn, layout.fBuiltin,
4786                                    fDecorationBuffer);
4787     }
4788 }
4789 
writeFieldLayout(const Layout & layout,SpvId target,int member)4790 void SPIRVCodeGenerator::writeFieldLayout(const Layout& layout, SpvId target, int member) {
4791     // 'binding' and 'set' can not be applied to struct members
4792     SkASSERT(layout.fBinding == -1);
4793     SkASSERT(layout.fSet == -1);
4794     if (layout.fLocation >= 0) {
4795         this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationLocation,
4796                                layout.fLocation, fDecorationBuffer);
4797     }
4798     if (layout.fIndex >= 0) {
4799         this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationIndex,
4800                                layout.fIndex, fDecorationBuffer);
4801     }
4802     if (layout.fInputAttachmentIndex >= 0) {
4803         this->writeInstruction(SpvOpDecorate, target, member, SpvDecorationInputAttachmentIndex,
4804                                layout.fInputAttachmentIndex, fDecorationBuffer);
4805     }
4806     if (layout.fBuiltin >= 0) {
4807         this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationBuiltIn,
4808                                layout.fBuiltin, fDecorationBuffer);
4809     }
4810 }
4811 
memoryLayoutForStorageClass(StorageClass storageClass)4812 MemoryLayout SPIRVCodeGenerator::memoryLayoutForStorageClass(StorageClass storageClass) {
4813     return storageClass == StorageClass::kPushConstant ||
4814            storageClass == StorageClass::kStorageBuffer
4815                                 ? MemoryLayout(MemoryLayout::Standard::k430)
4816                                 : fDefaultMemoryLayout;
4817 }
4818 
memoryLayoutForVariable(const Variable & v) const4819 MemoryLayout SPIRVCodeGenerator::memoryLayoutForVariable(const Variable& v) const {
4820     bool pushConstant = SkToBool(v.layout().fFlags & LayoutFlag::kPushConstant);
4821     bool buffer = v.modifierFlags().isBuffer();
4822     return pushConstant || buffer ? MemoryLayout(MemoryLayout::Standard::k430)
4823                                   : fDefaultMemoryLayout;
4824 }
4825 
writeInterfaceBlock(const InterfaceBlock & intf,bool appendRTFlip)4826 SpvId SPIRVCodeGenerator::writeInterfaceBlock(const InterfaceBlock& intf, bool appendRTFlip) {
4827     MemoryLayout memoryLayout = this->memoryLayoutForVariable(*intf.var());
4828     SpvId result = this->nextId(nullptr);
4829     const Variable& intfVar = *intf.var();
4830     const Type& type = intfVar.type();
4831     if (!memoryLayout.isSupported(type)) {
4832         fContext.fErrors->error(type.fPosition, "type '" + type.displayName() +
4833                                                 "' is not permitted here");
4834         return this->nextId(nullptr);
4835     }
4836     StorageClass storageClass =
4837             get_storage_class_for_global_variable(intfVar, StorageClass::kFunction);
4838     if (fProgram.fInterface.fRTFlipUniform != Program::Interface::kRTFlip_None && appendRTFlip &&
4839         !fWroteRTFlip && type.isStruct()) {
4840         // We can only have one interface block (because we use push_constant and that is limited
4841         // to one per program), so we need to append rtflip to this one rather than synthesize an
4842         // entirely new block when the variable is referenced. And we can't modify the existing
4843         // block, so we instead create a modified copy of it and write that.
4844         SkSpan<const Field> fieldSpan = type.fields();
4845         TArray<Field> fields(fieldSpan.data(), fieldSpan.size());
4846         fields.emplace_back(Position(),
4847                             Layout(LayoutFlag::kNone,
4848                                    /*location=*/-1,
4849                                    fProgram.fConfig->fSettings.fRTFlipOffset,
4850                                    /*binding=*/-1,
4851                                    /*index=*/-1,
4852                                    /*set=*/-1,
4853                                    /*builtin=*/-1,
4854                                    /*inputAttachmentIndex=*/-1),
4855                             ModifierFlag::kNone,
4856                             SKSL_RTFLIP_NAME,
4857                             fContext.fTypes.fFloat2.get());
4858         {
4859             AutoAttachPoolToThread attach(fProgram.fPool.get());
4860             const Type* rtFlipStructType = fProgram.fSymbols->takeOwnershipOfSymbol(
4861                     Type::MakeStructType(fContext,
4862                                          type.fPosition,
4863                                          type.name(),
4864                                          std::move(fields),
4865                                          /*interfaceBlock=*/true));
4866             Variable* modifiedVar = fProgram.fSymbols->takeOwnershipOfSymbol(
4867                     Variable::Make(intfVar.fPosition,
4868                                    intfVar.modifiersPosition(),
4869                                    intfVar.layout(),
4870                                    intfVar.modifierFlags(),
4871                                    rtFlipStructType,
4872                                    intfVar.name(),
4873                                    /*mangledName=*/"",
4874                                    intfVar.isBuiltin(),
4875                                    intfVar.storage()));
4876             InterfaceBlock modifiedCopy(intf.fPosition, modifiedVar);
4877             result = this->writeInterfaceBlock(modifiedCopy, /*appendRTFlip=*/false);
4878             fProgram.fSymbols->add(fContext, std::make_unique<FieldSymbol>(
4879                     Position(), modifiedVar, rtFlipStructType->fields().size() - 1));
4880         }
4881         fVariableMap.set(&intfVar, result);
4882         fWroteRTFlip = true;
4883         return result;
4884     }
4885     SpvId typeId = this->getType(type, kDefaultTypeLayout, memoryLayout);
4886     if (intfVar.layout().fBuiltin == -1) {
4887         // Note: In SPIR-V 1.3, a storage buffer can be declared with the "StorageBuffer"
4888         // storage class and the "Block" decoration and the <1.3 approach we use here ("Uniform"
4889         // storage class and the "BufferBlock" decoration) is deprecated. Since we target SPIR-V
4890         // 1.0, we have to use the deprecated approach which is well supported in Vulkan and
4891         // addresses SkSL use cases (notably SkSL currently doesn't support pointer features that
4892         // would benefit from SPV_KHR_variable_pointers capabilities).
4893 #ifdef SKSL_EXT
4894         this->writeInstruction(SpvOpDecorate, typeId, SpvDecorationBlock, fDecorationBuffer);
4895 #else
4896         bool isStorageBuffer = intfVar.modifierFlags().isBuffer();
4897         this->writeInstruction(SpvOpDecorate,
4898                                typeId,
4899                                isStorageBuffer ? SpvDecorationBufferBlock : SpvDecorationBlock,
4900                                fDecorationBuffer);
4901 #endif
4902     }
4903     SpvId ptrType = this->nextId(nullptr);
4904     this->writeInstruction(SpvOpTypePointer, ptrType,
4905                            get_storage_class_spv_id(storageClass), typeId, fConstantBuffer);
4906     this->writeInstruction(SpvOpVariable, ptrType, result,
4907                            get_storage_class_spv_id(storageClass), fConstantBuffer);
4908     Layout layout = intfVar.layout();
4909     if ((storageClass == StorageClass::kUniform ||
4910                 storageClass == StorageClass::kStorageBuffer) && layout.fSet < 0) {
4911         layout.fSet = fProgram.fConfig->fSettings.fDefaultUniformSet;
4912     }
4913     this->writeLayout(layout, result, intfVar.fPosition);
4914     fVariableMap.set(&intfVar, result);
4915     return result;
4916 }
4917 
4918 // This function determines whether to skip an OpVariable (of pointer type) declaration for
4919 // compile-time constant scalars and vectors which we turn into OpConstant/OpConstantComposite and
4920 // always reference by value.
4921 //
4922 // Accessing a matrix or array member with a dynamic index requires the use of OpAccessChain which
4923 // requires a base operand of pointer type. However, a vector can always be accessed by value using
4924 // OpVectorExtractDynamic (see writeIndexExpression).
4925 //
4926 // This is why we always emit an OpVariable for all non-scalar and non-vector types in case they get
4927 // accessed via a dynamic index.
is_vardecl_compile_time_constant(const VarDeclaration & varDecl)4928 static bool is_vardecl_compile_time_constant(const VarDeclaration& varDecl) {
4929     return varDecl.var()->modifierFlags().isConst() &&
4930 #ifdef SKSL_EXT
4931            (varDecl.var()->layout().fFlags & LayoutFlag::kConstantId) == LayoutFlag::kNone &&
4932 #endif
4933            (varDecl.var()->type().isScalar() || varDecl.var()->type().isVector()) &&
4934            (ConstantFolder::GetConstantValueOrNull(*varDecl.value()) ||
4935             Analysis::IsCompileTimeConstant(*varDecl.value()));
4936 }
4937 
writeGlobalVarDeclaration(ProgramKind kind,const VarDeclaration & varDecl)4938 bool SPIRVCodeGenerator::writeGlobalVarDeclaration(ProgramKind kind,
4939                                                    const VarDeclaration& varDecl) {
4940     const Variable* var = varDecl.var();
4941     const LayoutFlags backendFlags = var->layout().fFlags & LayoutFlag::kAllBackends;
4942     const LayoutFlags kPermittedBackendFlags =
4943             LayoutFlag::kVulkan | LayoutFlag::kWebGPU | LayoutFlag::kDirect3D;
4944     if (backendFlags & ~kPermittedBackendFlags) {
4945         fContext.fErrors->error(var->fPosition, "incompatible backend flag in SPIR-V codegen");
4946         return false;
4947     }
4948 
4949     // If this global variable is a compile-time constant then we'll emit OpConstant or
4950     // OpConstantComposite later when the variable is referenced. Avoid declaring an OpVariable now.
4951     if (is_vardecl_compile_time_constant(varDecl)) {
4952         return true;
4953     }
4954 
4955     StorageClass storageClass =
4956             get_storage_class_for_global_variable(*var, StorageClass::kPrivate);
4957     if (storageClass == StorageClass::kUniform || storageClass == StorageClass::kStorageBuffer) {
4958         // Top-level uniforms are emitted in writeUniformBuffer.
4959         fTopLevelUniforms.push_back(&varDecl);
4960         return true;
4961     }
4962 #ifdef SKSL_EXT
4963     if (var->layout().fFlags & LayoutFlag::kConstantId) {
4964         Layout layout = var->layout();
4965         const Type& type = var->type();
4966         SpvId id = this->nextId(&type);
4967         fVariableMap[var] = id;
4968         SpvId typeId = this->getType(type);
4969         if (type.isInteger() && varDecl.value()) {
4970             int tmp = (*varDecl.value()).as<Literal>().intValue();
4971             this->writeInstruction(SpvOpSpecConstant, typeId, id, tmp, fConstantBuffer);
4972         } else {
4973             fContext.fErrors->error(var->fPosition,
4974                 "spec const '" + std::string(var->name()) + "' must be an integer literal");
4975             return false;
4976         }
4977         this->writeInstruction(SpvOpName, id, var->name(), fNameBuffer);
4978         this->writeInstruction(SpvOpDecorate, id, SpvDecorationSpecId, layout.fConstantId,
4979                                fDecorationBuffer);
4980         return true;
4981     }
4982 #endif
4983     if (fUseTextureSamplerPairs && var->type().isSampler()) {
4984         if (var->layout().fTexture == -1 || var->layout().fSampler == -1) {
4985             fContext.fErrors->error(var->fPosition, "selected backend requires separate texture "
4986                                                     "and sampler indices");
4987             return false;
4988         }
4989         SkASSERT(storageClass == StorageClass::kUniformConstant);
4990 
4991         auto [texture, sampler] = this->synthesizeTextureAndSampler(*var);
4992         this->writeGlobalVar(kind, storageClass, *texture);
4993         this->writeGlobalVar(kind, storageClass, *sampler);
4994 
4995         return true;
4996     }
4997 #ifdef SKSL_EXT
4998     if (storageClass == StorageClass::kFunction) {
4999         SkASSERT(varDecl.value());
5000         this->getPointerType(var->type(), storageClass);
5001         fEmittingGlobalConstConstructor = true;
5002         SpvId value = this->writeExpression(*varDecl.value(), fConstantBuffer);
5003         fEmittingGlobalConstConstructor = false;
5004         fGlobalConstVariableValueMap[var] = value;
5005     } else {
5006 #endif
5007         SpvId id = this->writeGlobalVar(kind, storageClass, *var);
5008         if (id != NA && varDecl.value()) {
5009             SkASSERT(!fCurrentBlock);
5010             fCurrentBlock = NA;
5011             SpvId value = this->writeExpression(*varDecl.value(), fGlobalInitializersBuffer);
5012             this->writeOpStore(storageClass, id, value, fGlobalInitializersBuffer);
5013             fCurrentBlock = 0;
5014         }
5015 #ifdef SKSL_EXT
5016     }
5017 #endif
5018     return true;
5019 }
5020 
writeGlobalVar(ProgramKind kind,StorageClass storageClass,const Variable & var)5021 SpvId SPIRVCodeGenerator::writeGlobalVar(ProgramKind kind,
5022                                          StorageClass storageClass,
5023                                          const Variable& var) {
5024     Layout layout = var.layout();
5025     const ModifierFlags flags = var.modifierFlags();
5026     const Type* type = &var.type();
5027     switch (layout.fBuiltin) {
5028         case SK_FRAGCOLOR_BUILTIN:
5029         case SK_SECONDARYFRAGCOLOR_BUILTIN:
5030             if (!ProgramConfig::IsFragment(kind)) {
5031                 SkASSERT(!fProgram.fConfig->fSettings.fFragColorIsInOut);
5032                 return NA;
5033             }
5034             break;
5035 
5036         case SK_SAMPLEMASKIN_BUILTIN:
5037         case SK_SAMPLEMASK_BUILTIN:
5038             // SkSL exposes this as a `uint` but SPIR-V, like GLSL, uses an array of signed `uint`
5039             // decorated with SpvBuiltinSampleMask.
5040             type = fSynthetics.addArrayDimension(fContext, type, /*arraySize=*/1);
5041             layout.fBuiltin = SpvBuiltInSampleMask;
5042             break;
5043     }
5044 
5045     // Add this global to the variable map.
5046     SpvId id = this->nextId(type);
5047     fVariableMap.set(&var, id);
5048 
5049     if (layout.fSet < 0 && storageClass == StorageClass::kUniformConstant) {
5050         layout.fSet = fProgram.fConfig->fSettings.fDefaultUniformSet;
5051     }
5052 
5053     SpvId typeId = this->getPointerType(*type,
5054                                         layout,
5055                                         this->memoryLayoutForStorageClass(storageClass),
5056                                         storageClass);
5057     this->writeInstruction(SpvOpVariable, typeId, id,
5058                            get_storage_class_spv_id(storageClass), fConstantBuffer);
5059     this->writeInstruction(SpvOpName, id, var.name(), fNameBuffer);
5060     this->writeLayout(layout, id, var.fPosition);
5061     if (flags & ModifierFlag::kFlat) {
5062         this->writeInstruction(SpvOpDecorate, id, SpvDecorationFlat, fDecorationBuffer);
5063     }
5064     if (flags & ModifierFlag::kNoPerspective) {
5065         this->writeInstruction(SpvOpDecorate, id, SpvDecorationNoPerspective,
5066                                fDecorationBuffer);
5067     }
5068     if (flags.isWriteOnly()) {
5069         this->writeInstruction(SpvOpDecorate, id, SpvDecorationNonReadable, fDecorationBuffer);
5070     } else if (flags.isReadOnly()) {
5071         this->writeInstruction(SpvOpDecorate, id, SpvDecorationNonWritable, fDecorationBuffer);
5072     }
5073 
5074     return id;
5075 }
5076 
writeVarDeclaration(const VarDeclaration & varDecl,OutputStream & out)5077 void SPIRVCodeGenerator::writeVarDeclaration(const VarDeclaration& varDecl, OutputStream& out) {
5078     // If this variable is a compile-time constant then we'll emit OpConstant or
5079     // OpConstantComposite later when the variable is referenced. Avoid declaring an OpVariable now.
5080     if (is_vardecl_compile_time_constant(varDecl)) {
5081         return;
5082     }
5083 
5084     const Variable* var = varDecl.var();
5085     SpvId id = this->nextId(&var->type());
5086     fVariableMap.set(var, id);
5087     SpvId type = this->getPointerType(var->type(), StorageClass::kFunction);
5088     this->writeInstruction(SpvOpVariable, type, id, SpvStorageClassFunction, fVariableBuffer);
5089     this->writeInstruction(SpvOpName, id, var->name(), fNameBuffer);
5090     if (varDecl.value()) {
5091         SpvId value = this->writeExpression(*varDecl.value(), out);
5092         this->writeOpStore(StorageClass::kFunction, id, value, out);
5093     }
5094 }
5095 
writeStatement(const Statement & s,OutputStream & out)5096 void SPIRVCodeGenerator::writeStatement(const Statement& s, OutputStream& out) {
5097     switch (s.kind()) {
5098         case Statement::Kind::kNop:
5099             break;
5100         case Statement::Kind::kBlock:
5101             this->writeBlock(s.as<Block>(), out);
5102             break;
5103         case Statement::Kind::kExpression:
5104             this->writeExpression(*s.as<ExpressionStatement>().expression(), out);
5105             break;
5106         case Statement::Kind::kReturn:
5107             this->writeReturnStatement(s.as<ReturnStatement>(), out);
5108             break;
5109         case Statement::Kind::kVarDeclaration:
5110             this->writeVarDeclaration(s.as<VarDeclaration>(), out);
5111             break;
5112         case Statement::Kind::kIf:
5113             this->writeIfStatement(s.as<IfStatement>(), out);
5114             break;
5115         case Statement::Kind::kFor:
5116             this->writeForStatement(s.as<ForStatement>(), out);
5117             break;
5118         case Statement::Kind::kDo:
5119             this->writeDoStatement(s.as<DoStatement>(), out);
5120             break;
5121         case Statement::Kind::kSwitch:
5122             this->writeSwitchStatement(s.as<SwitchStatement>(), out);
5123             break;
5124         case Statement::Kind::kBreak:
5125             this->writeInstruction(SpvOpBranch, fBreakTarget.back(), out);
5126             break;
5127         case Statement::Kind::kContinue:
5128             this->writeInstruction(SpvOpBranch, fContinueTarget.back(), out);
5129             break;
5130         case Statement::Kind::kDiscard:
5131             this->writeInstruction(SpvOpKill, out);
5132             break;
5133         default:
5134             SkDEBUGFAILF("unsupported statement: %s", s.description().c_str());
5135             break;
5136     }
5137 }
5138 
writeBlock(const Block & b,OutputStream & out)5139 void SPIRVCodeGenerator::writeBlock(const Block& b, OutputStream& out) {
5140     for (const std::unique_ptr<Statement>& stmt : b.children()) {
5141         this->writeStatement(*stmt, out);
5142     }
5143 }
5144 
getConditionalOpCounts()5145 SPIRVCodeGenerator::ConditionalOpCounts SPIRVCodeGenerator::getConditionalOpCounts() {
5146     return {fReachableOps.size(), fStoreOps.size()};
5147 }
5148 
pruneConditionalOps(ConditionalOpCounts ops)5149 void SPIRVCodeGenerator::pruneConditionalOps(ConditionalOpCounts ops) {
5150     // Remove ops which are no longer reachable.
5151     while (fReachableOps.size() > ops.numReachableOps) {
5152         SpvId prunableSpvId = fReachableOps.back();
5153         const Instruction* prunableOp = fSpvIdCache.find(prunableSpvId);
5154 
5155         if (prunableOp) {
5156             fOpCache.remove(*prunableOp);
5157             fSpvIdCache.remove(prunableSpvId);
5158         } else {
5159             SkDEBUGFAIL("reachable-op list contains unrecognized SpvId");
5160         }
5161 
5162         fReachableOps.pop_back();
5163     }
5164 
5165     // Remove any cached stores that occurred during the conditional block.
5166     while (fStoreOps.size() > ops.numStoreOps) {
5167         if (fStoreCache.find(fStoreOps.back())) {
5168             fStoreCache.remove(fStoreOps.back());
5169         }
5170         fStoreOps.pop_back();
5171     }
5172 }
5173 
writeIfStatement(const IfStatement & stmt,OutputStream & out)5174 void SPIRVCodeGenerator::writeIfStatement(const IfStatement& stmt, OutputStream& out) {
5175     SpvId test = this->writeExpression(*stmt.test(), out);
5176     SpvId ifTrue = this->nextId(nullptr);
5177     SpvId ifFalse = this->nextId(nullptr);
5178 
5179     ConditionalOpCounts conditionalOps = this->getConditionalOpCounts();
5180 
5181     if (stmt.ifFalse()) {
5182         SpvId end = this->nextId(nullptr);
5183         this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
5184         this->writeInstruction(SpvOpBranchConditional, test, ifTrue, ifFalse, out);
5185         this->writeLabel(ifTrue, kBranchIsOnPreviousLine, out);
5186         this->writeStatement(*stmt.ifTrue(), out);
5187         if (fCurrentBlock) {
5188             this->writeInstruction(SpvOpBranch, end, out);
5189         }
5190         this->writeLabel(ifFalse, kBranchIsAbove, conditionalOps, out);
5191         this->writeStatement(*stmt.ifFalse(), out);
5192         if (fCurrentBlock) {
5193             this->writeInstruction(SpvOpBranch, end, out);
5194         }
5195         this->writeLabel(end, kBranchIsAbove, conditionalOps, out);
5196     } else {
5197         this->writeInstruction(SpvOpSelectionMerge, ifFalse, SpvSelectionControlMaskNone, out);
5198         this->writeInstruction(SpvOpBranchConditional, test, ifTrue, ifFalse, out);
5199         this->writeLabel(ifTrue, kBranchIsOnPreviousLine, out);
5200         this->writeStatement(*stmt.ifTrue(), out);
5201         if (fCurrentBlock) {
5202             this->writeInstruction(SpvOpBranch, ifFalse, out);
5203         }
5204         this->writeLabel(ifFalse, kBranchIsAbove, conditionalOps, out);
5205     }
5206 }
5207 
writeForStatement(const ForStatement & f,OutputStream & out)5208 void SPIRVCodeGenerator::writeForStatement(const ForStatement& f, OutputStream& out) {
5209     if (f.initializer()) {
5210         this->writeStatement(*f.initializer(), out);
5211     }
5212 
5213     ConditionalOpCounts conditionalOps = this->getConditionalOpCounts();
5214 
5215     // The store cache isn't trustworthy in the presence of branches; store caching only makes sense
5216     // in the context of linear straight-line execution. If we wanted to be more clever, we could
5217     // only invalidate store cache entries for variables affected by the loop body, but for now we
5218     // simply clear the entire cache whenever branching occurs.
5219     SpvId header = this->nextId(nullptr);
5220     SpvId start = this->nextId(nullptr);
5221     SpvId body = this->nextId(nullptr);
5222     SpvId next = this->nextId(nullptr);
5223     fContinueTarget.push_back(next);
5224     SpvId end = this->nextId(nullptr);
5225     fBreakTarget.push_back(end);
5226     this->writeInstruction(SpvOpBranch, header, out);
5227     this->writeLabel(header, kBranchIsBelow, conditionalOps, out);
5228     this->writeInstruction(SpvOpLoopMerge, end, next, SpvLoopControlMaskNone, out);
5229     this->writeInstruction(SpvOpBranch, start, out);
5230     this->writeLabel(start, kBranchIsOnPreviousLine, out);
5231     if (f.test()) {
5232         SpvId test = this->writeExpression(*f.test(), out);
5233         this->writeInstruction(SpvOpBranchConditional, test, body, end, out);
5234     } else {
5235         this->writeInstruction(SpvOpBranch, body, out);
5236     }
5237     this->writeLabel(body, kBranchIsOnPreviousLine, out);
5238     this->writeStatement(*f.statement(), out);
5239     if (fCurrentBlock) {
5240         this->writeInstruction(SpvOpBranch, next, out);
5241     }
5242     this->writeLabel(next, kBranchIsAbove, conditionalOps, out);
5243     if (f.next()) {
5244         this->writeExpression(*f.next(), out);
5245     }
5246     this->writeInstruction(SpvOpBranch, header, out);
5247     this->writeLabel(end, kBranchIsAbove, conditionalOps, out);
5248     fBreakTarget.pop_back();
5249     fContinueTarget.pop_back();
5250 }
5251 
writeDoStatement(const DoStatement & d,OutputStream & out)5252 void SPIRVCodeGenerator::writeDoStatement(const DoStatement& d, OutputStream& out) {
5253     ConditionalOpCounts conditionalOps = this->getConditionalOpCounts();
5254 
5255     // The store cache isn't trustworthy in the presence of branches; store caching only makes sense
5256     // in the context of linear straight-line execution. If we wanted to be more clever, we could
5257     // only invalidate store cache entries for variables affected by the loop body, but for now we
5258     // simply clear the entire cache whenever branching occurs.
5259     SpvId header = this->nextId(nullptr);
5260     SpvId start = this->nextId(nullptr);
5261     SpvId next = this->nextId(nullptr);
5262     SpvId continueTarget = this->nextId(nullptr);
5263     fContinueTarget.push_back(continueTarget);
5264     SpvId end = this->nextId(nullptr);
5265     fBreakTarget.push_back(end);
5266     this->writeInstruction(SpvOpBranch, header, out);
5267     this->writeLabel(header, kBranchIsBelow, conditionalOps, out);
5268     this->writeInstruction(SpvOpLoopMerge, end, continueTarget, SpvLoopControlMaskNone, out);
5269     this->writeInstruction(SpvOpBranch, start, out);
5270     this->writeLabel(start, kBranchIsOnPreviousLine, out);
5271     this->writeStatement(*d.statement(), out);
5272     if (fCurrentBlock) {
5273         this->writeInstruction(SpvOpBranch, next, out);
5274         this->writeLabel(next, kBranchIsOnPreviousLine, out);
5275         this->writeInstruction(SpvOpBranch, continueTarget, out);
5276     }
5277     this->writeLabel(continueTarget, kBranchIsAbove, conditionalOps, out);
5278     SpvId test = this->writeExpression(*d.test(), out);
5279     this->writeInstruction(SpvOpBranchConditional, test, header, end, out);
5280     this->writeLabel(end, kBranchIsAbove, conditionalOps, out);
5281     fBreakTarget.pop_back();
5282     fContinueTarget.pop_back();
5283 }
5284 
writeSwitchStatement(const SwitchStatement & s,OutputStream & out)5285 void SPIRVCodeGenerator::writeSwitchStatement(const SwitchStatement& s, OutputStream& out) {
5286     SpvId value = this->writeExpression(*s.value(), out);
5287 
5288     ConditionalOpCounts conditionalOps = this->getConditionalOpCounts();
5289 
5290     // The store cache isn't trustworthy in the presence of branches; store caching only makes sense
5291     // in the context of linear straight-line execution. If we wanted to be more clever, we could
5292     // only invalidate store cache entries for variables affected by the switch body, but for now we
5293     // simply clear the entire cache whenever branching occurs.
5294     TArray<SpvId> labels;
5295     SpvId end = this->nextId(nullptr);
5296     SpvId defaultLabel = end;
5297     fBreakTarget.push_back(end);
5298     int size = 3;
5299     const StatementArray& cases = s.cases();
5300     for (const std::unique_ptr<Statement>& stmt : cases) {
5301         const SwitchCase& c = stmt->as<SwitchCase>();
5302         SpvId label = this->nextId(nullptr);
5303         labels.push_back(label);
5304         if (!c.isDefault()) {
5305             size += 2;
5306         } else {
5307             defaultLabel = label;
5308         }
5309     }
5310 
5311     // We should have exactly one label for each case.
5312     SkASSERT(labels.size() == cases.size());
5313 
5314     // Collapse adjacent switch-cases into one; that is, reduce `case 1: case 2: case 3:` into a
5315     // single OpLabel. The Tint SPIR-V reader does not support switch-case fallthrough, but it
5316     // does support multiple switch-cases branching to the same label.
5317     SkBitSet caseIsCollapsed(cases.size());
5318     for (int index = cases.size() - 2; index >= 0; index--) {
5319         if (cases[index]->as<SwitchCase>().statement()->isEmpty()) {
5320             caseIsCollapsed.set(index);
5321             labels[index] = labels[index + 1];
5322         }
5323     }
5324 
5325     labels.push_back(end);
5326 
5327     this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
5328     this->writeOpCode(SpvOpSwitch, size, out);
5329     this->writeWord(value, out);
5330     this->writeWord(defaultLabel, out);
5331     for (int i = 0; i < cases.size(); ++i) {
5332         const SwitchCase& c = cases[i]->as<SwitchCase>();
5333         if (c.isDefault()) {
5334             continue;
5335         }
5336         this->writeWord(c.value(), out);
5337         this->writeWord(labels[i], out);
5338     }
5339     for (int i = 0; i < cases.size(); ++i) {
5340         if (caseIsCollapsed.test(i)) {
5341             continue;
5342         }
5343         const SwitchCase& c = cases[i]->as<SwitchCase>();
5344         if (i == 0) {
5345             this->writeLabel(labels[i], kBranchIsOnPreviousLine, out);
5346         } else {
5347             this->writeLabel(labels[i], kBranchIsAbove, conditionalOps, out);
5348         }
5349         this->writeStatement(*c.statement(), out);
5350         if (fCurrentBlock) {
5351             this->writeInstruction(SpvOpBranch, labels[i + 1], out);
5352         }
5353     }
5354     this->writeLabel(end, kBranchIsAbove, conditionalOps, out);
5355     fBreakTarget.pop_back();
5356 }
5357 
writeReturnStatement(const ReturnStatement & r,OutputStream & out)5358 void SPIRVCodeGenerator::writeReturnStatement(const ReturnStatement& r, OutputStream& out) {
5359     if (r.expression()) {
5360         this->writeInstruction(SpvOpReturnValue, this->writeExpression(*r.expression(), out),
5361                                out);
5362     } else {
5363         this->writeInstruction(SpvOpReturn, out);
5364     }
5365 }
5366 
5367 // Given any function, returns the top-level symbol table (OUTSIDE of the function's scope).
get_top_level_symbol_table(const FunctionDeclaration & anyFunc)5368 static SymbolTable* get_top_level_symbol_table(const FunctionDeclaration& anyFunc) {
5369     return anyFunc.definition()->body()->as<Block>().symbolTable()->fParent;
5370 }
5371 
writeEntrypointAdapter(const FunctionDeclaration & main)5372 SPIRVCodeGenerator::EntrypointAdapter SPIRVCodeGenerator::writeEntrypointAdapter(
5373         const FunctionDeclaration& main) {
5374     // Our goal is to synthesize a tiny helper function which looks like this:
5375     //     void _entrypoint() { sk_FragColor = main(); }
5376 
5377     // Fish a symbol table out of main().
5378     SymbolTable* symbolTable = get_top_level_symbol_table(main);
5379 
5380     // Get `sk_FragColor` as a writable reference.
5381     const Symbol* skFragColorSymbol = symbolTable->find("sk_FragColor");
5382     SkASSERT(skFragColorSymbol);
5383     const Variable& skFragColorVar = skFragColorSymbol->as<Variable>();
5384     auto skFragColorRef = std::make_unique<VariableReference>(Position(), &skFragColorVar,
5385                                                               VariableReference::RefKind::kWrite);
5386 
5387     // TODO get secondary frag color as one as well?
5388 
5389     // Synthesize a call to the `main()` function.
5390     if (!main.returnType().matches(skFragColorRef->type())) {
5391         fContext.fErrors->error(main.fPosition, "SPIR-V does not support returning '" +
5392                 main.returnType().description() + "' from main()");
5393         return {};
5394     }
5395     ExpressionArray args;
5396     if (main.parameters().size() == 1) {
5397         if (!main.parameters()[0]->type().matches(*fContext.fTypes.fFloat2)) {
5398             fContext.fErrors->error(main.fPosition,
5399                     "SPIR-V does not support parameter of type '" +
5400                     main.parameters()[0]->type().description() + "' to main()");
5401             return {};
5402         }
5403         double kZero[2] = {0.0, 0.0};
5404         args.push_back(ConstructorCompound::MakeFromConstants(fContext, Position{},
5405                                                               *fContext.fTypes.fFloat2, kZero));
5406     }
5407     auto callMainFn = FunctionCall::Make(fContext, Position(), &main.returnType(),
5408                                          main, std::move(args));
5409 
5410     // Synthesize `skFragColor = main()` as a BinaryExpression.
5411     auto assignmentStmt = std::make_unique<ExpressionStatement>(std::make_unique<BinaryExpression>(
5412             Position(),
5413             std::move(skFragColorRef),
5414             Operator::Kind::EQ,
5415             std::move(callMainFn),
5416             &main.returnType()));
5417 
5418     // Function bodies are always wrapped in a Block.
5419     StatementArray entrypointStmts;
5420     entrypointStmts.push_back(std::move(assignmentStmt));
5421     auto entrypointBlock = Block::Make(Position(), std::move(entrypointStmts),
5422                                        Block::Kind::kBracedScope, /*symbols=*/nullptr);
5423     // Declare an entrypoint function.
5424     EntrypointAdapter adapter;
5425     adapter.entrypointDecl =
5426             std::make_unique<FunctionDeclaration>(fContext,
5427                                                   Position(),
5428                                                   ModifierFlag::kNone,
5429                                                   "_entrypoint",
5430                                                   /*parameters=*/TArray<Variable*>{},
5431                                                   /*returnType=*/fContext.fTypes.fVoid.get(),
5432                                                   kNotIntrinsic);
5433     // Define it.
5434     adapter.entrypointDef = FunctionDefinition::Convert(fContext,
5435                                                         Position(),
5436                                                         *adapter.entrypointDecl,
5437                                                         std::move(entrypointBlock));
5438 
5439     adapter.entrypointDecl->setDefinition(adapter.entrypointDef.get());
5440     return adapter;
5441 }
5442 
writeUniformBuffer(SymbolTable * topLevelSymbolTable)5443 void SPIRVCodeGenerator::writeUniformBuffer(SymbolTable* topLevelSymbolTable) {
5444     SkASSERT(!fTopLevelUniforms.empty());
5445     static constexpr char kUniformBufferName[] = "_UniformBuffer";
5446 
5447     // Convert the list of top-level uniforms into a matching struct named _UniformBuffer, and build
5448     // a lookup table of variables to UniformBuffer field indices.
5449     TArray<Field> fields;
5450     fields.reserve_exact(fTopLevelUniforms.size());
5451     for (const VarDeclaration* topLevelUniform : fTopLevelUniforms) {
5452         const Variable* var = topLevelUniform->var();
5453         fTopLevelUniformMap.set(var, (int)fields.size());
5454         ModifierFlags flags = var->modifierFlags() & ~ModifierFlag::kUniform;
5455         fields.emplace_back(var->fPosition, var->layout(), flags, var->name(), &var->type());
5456     }
5457     fUniformBuffer.fStruct = Type::MakeStructType(fContext,
5458                                                   Position(),
5459                                                   kUniformBufferName,
5460                                                   std::move(fields),
5461                                                   /*interfaceBlock=*/true);
5462 
5463     // Create a global variable to contain this struct.
5464     Layout layout;
5465     layout.fBinding = fProgram.fConfig->fSettings.fDefaultUniformBinding;
5466     layout.fSet     = fProgram.fConfig->fSettings.fDefaultUniformSet;
5467 
5468     fUniformBuffer.fInnerVariable = Variable::Make(/*pos=*/Position(),
5469                                                    /*modifiersPosition=*/Position(),
5470                                                    layout,
5471                                                    ModifierFlag::kUniform,
5472                                                    fUniformBuffer.fStruct.get(),
5473                                                    kUniformBufferName,
5474                                                    /*mangledName=*/"",
5475                                                    /*builtin=*/false,
5476                                                    Variable::Storage::kGlobal);
5477 
5478     // Create an interface block object for this global variable.
5479     fUniformBuffer.fInterfaceBlock =
5480             std::make_unique<InterfaceBlock>(Position(), fUniformBuffer.fInnerVariable.get());
5481 
5482     // Generate an interface block and hold onto its ID.
5483     fUniformBufferId = this->writeInterfaceBlock(*fUniformBuffer.fInterfaceBlock);
5484 }
5485 
addRTFlipUniform(Position pos)5486 void SPIRVCodeGenerator::addRTFlipUniform(Position pos) {
5487     SkASSERT(!fProgram.fConfig->fSettings.fForceNoRTFlip);
5488 
5489     if (fWroteRTFlip) {
5490         return;
5491     }
5492     // Flip variable hasn't been written yet. This means we don't have an existing
5493     // interface block, so we're free to just synthesize one.
5494     fWroteRTFlip = true;
5495     TArray<Field> fields;
5496     if (fProgram.fConfig->fSettings.fRTFlipOffset < 0) {
5497         fContext.fErrors->error(pos, "RTFlipOffset is negative");
5498     }
5499     fields.emplace_back(pos,
5500                         Layout(LayoutFlag::kNone,
5501                                /*location=*/-1,
5502                                fProgram.fConfig->fSettings.fRTFlipOffset,
5503                                /*binding=*/-1,
5504                                /*index=*/-1,
5505                                /*set=*/-1,
5506                                /*builtin=*/-1,
5507                                /*inputAttachmentIndex=*/-1),
5508                         ModifierFlag::kNone,
5509                         SKSL_RTFLIP_NAME,
5510                         fContext.fTypes.fFloat2.get());
5511     std::string_view name = "sksl_synthetic_uniforms";
5512     const Type* intfStruct = fSynthetics.takeOwnershipOfSymbol(Type::MakeStructType(
5513             fContext, Position(), name, std::move(fields), /*interfaceBlock=*/true));
5514     bool usePushConstants = fProgram.fConfig->fSettings.fUseVulkanPushConstantsForGaneshRTAdjust;
5515     int binding = -1, set = -1;
5516     if (!usePushConstants) {
5517         binding = fProgram.fConfig->fSettings.fRTFlipBinding;
5518         if (binding == -1) {
5519             fContext.fErrors->error(pos, "layout(binding=...) is required in SPIR-V");
5520         }
5521         set = fProgram.fConfig->fSettings.fRTFlipSet;
5522         if (set == -1) {
5523             fContext.fErrors->error(pos, "layout(set=...) is required in SPIR-V");
5524         }
5525     }
5526     Layout layout(/*flags=*/usePushConstants ? LayoutFlag::kPushConstant : LayoutFlag::kNone,
5527                   /*location=*/-1,
5528                   /*offset=*/-1,
5529                   binding,
5530                   /*index=*/-1,
5531                   set,
5532                   /*builtin=*/-1,
5533                   /*inputAttachmentIndex=*/-1);
5534     Variable* intfVar =
5535             fSynthetics.takeOwnershipOfSymbol(Variable::Make(/*pos=*/Position(),
5536                                                              /*modifiersPosition=*/Position(),
5537                                                              layout,
5538                                                              ModifierFlag::kUniform,
5539                                                              intfStruct,
5540                                                              name,
5541                                                              /*mangledName=*/"",
5542                                                              /*builtin=*/false,
5543                                                              Variable::Storage::kGlobal));
5544     {
5545         AutoAttachPoolToThread attach(fProgram.fPool.get());
5546         fProgram.fSymbols->add(fContext,
5547                                std::make_unique<FieldSymbol>(Position(), intfVar, /*field=*/0));
5548     }
5549     InterfaceBlock intf(Position(), intfVar);
5550     this->writeInterfaceBlock(intf, false);
5551 }
5552 
synthesizeTextureAndSampler(const Variable & combinedSampler)5553 std::tuple<const Variable*, const Variable*> SPIRVCodeGenerator::synthesizeTextureAndSampler(
5554         const Variable& combinedSampler) {
5555     SkASSERT(fUseTextureSamplerPairs);
5556     SkASSERT(combinedSampler.type().typeKind() == Type::TypeKind::kSampler);
5557 
5558     if (std::unique_ptr<SynthesizedTextureSamplerPair>* existingData =
5559             fSynthesizedSamplerMap.find(&combinedSampler)) {
5560         return {(*existingData)->fTexture.get(), (*existingData)->fSampler.get()};
5561     }
5562 
5563     auto data = std::make_unique<SynthesizedTextureSamplerPair>();
5564 
5565     Layout texLayout = combinedSampler.layout();
5566     texLayout.fBinding = texLayout.fTexture;
5567     data->fTextureName = std::string(combinedSampler.name()) + "_texture";
5568 
5569     auto texture = Variable::Make(/*pos=*/Position(),
5570                                   /*modifiersPosition=*/Position(),
5571                                   texLayout,
5572                                   combinedSampler.modifierFlags(),
5573                                   &combinedSampler.type().textureType(),
5574                                   data->fTextureName,
5575                                   /*mangledName=*/"",
5576                                   /*builtin=*/false,
5577                                   Variable::Storage::kGlobal);
5578 
5579     Layout samplerLayout = combinedSampler.layout();
5580     samplerLayout.fBinding = samplerLayout.fSampler;
5581     samplerLayout.fFlags &= ~LayoutFlag::kAllPixelFormats;
5582     data->fSamplerName = std::string(combinedSampler.name()) + "_sampler";
5583 
5584     auto sampler = Variable::Make(/*pos=*/Position(),
5585                                   /*modifiersPosition=*/Position(),
5586                                   samplerLayout,
5587                                   combinedSampler.modifierFlags(),
5588                                   fContext.fTypes.fSampler.get(),
5589                                   data->fSamplerName,
5590                                   /*mangledName=*/"",
5591                                   /*builtin=*/false,
5592                                   Variable::Storage::kGlobal);
5593 
5594     const Variable* t = texture.get();
5595     const Variable* s = sampler.get();
5596     data->fTexture = std::move(texture);
5597     data->fSampler = std::move(sampler);
5598     fSynthesizedSamplerMap.set(&combinedSampler, std::move(data));
5599 
5600     return {t, s};
5601 }
5602 
writeInstructions(const Program & program,OutputStream & out)5603 void SPIRVCodeGenerator::writeInstructions(const Program& program, OutputStream& out) {
5604     Analysis::FindFunctionsToSpecialize(program, &fSpecializationInfo, [](const Variable& param) {
5605         return param.type().isSampler() || param.type().isUnsizedArray();
5606     });
5607 
5608     fGLSLExtendedInstructions = this->nextId(nullptr);
5609     StringStream body;
5610 
5611     // Do an initial pass over the program elements to establish some baseline info.
5612     const FunctionDeclaration* main = nullptr;
5613     int localSizeX = 1, localSizeY = 1, localSizeZ = 1;
5614     Position combinedSamplerPos;
5615     Position separateSamplerPos;
5616     for (const ProgramElement* e : program.elements()) {
5617         switch (e->kind()) {
5618             case ProgramElement::Kind::kFunction: {
5619                 // Assign SpvIds to functions.
5620                 const FunctionDefinition& funcDef = e->as<FunctionDefinition>();
5621                 const FunctionDeclaration& funcDecl = funcDef.declaration();
5622                 if (const Analysis::Specializations* specializations =
5623                             fSpecializationInfo.fSpecializationMap.find(&funcDecl)) {
5624                     for (int i = 0; i < specializations->size(); i++) {
5625                         fFunctionMap.set({&funcDecl, i}, this->nextId(nullptr));
5626                     }
5627                 } else {
5628                     fFunctionMap.set({&funcDecl, Analysis::kUnspecialized}, this->nextId(nullptr));
5629                 }
5630                 if (funcDecl.isMain()) {
5631                     main = &funcDecl;
5632                 }
5633                 break;
5634             }
5635             case ProgramElement::Kind::kGlobalVar: {
5636                 // Look for sampler variables and determine whether or not this program uses
5637                 // combined samplers or separate samplers. The layout backend will be marked as
5638                 // WebGPU for separate samplers, or Vulkan for combined samplers.
5639                 const GlobalVarDeclaration& decl = e->as<GlobalVarDeclaration>();
5640                 const Variable& var = *decl.varDeclaration().var();
5641                 if (var.type().isSampler()) {
5642                     if (var.layout().fFlags & LayoutFlag::kVulkan) {
5643                         combinedSamplerPos = decl.position();
5644                     }
5645                     if (var.layout().fFlags & (LayoutFlag::kWebGPU | LayoutFlag::kDirect3D)) {
5646                         separateSamplerPos = decl.position();
5647                     }
5648                 }
5649                 break;
5650             }
5651             case ProgramElement::Kind::kModifiers: {
5652                 // If this is a compute program, collect the local-size values. Dimensions that are
5653                 // not present will be assigned a value of 1.
5654                 if (ProgramConfig::IsCompute(program.fConfig->fKind)) {
5655                     const ModifiersDeclaration& modifiers = e->as<ModifiersDeclaration>();
5656                     if (modifiers.layout().fLocalSizeX >= 0) {
5657                         localSizeX = modifiers.layout().fLocalSizeX;
5658                     }
5659                     if (modifiers.layout().fLocalSizeY >= 0) {
5660                         localSizeY = modifiers.layout().fLocalSizeY;
5661                     }
5662                     if (modifiers.layout().fLocalSizeZ >= 0) {
5663                         localSizeZ = modifiers.layout().fLocalSizeZ;
5664                     }
5665                 }
5666                 break;
5667             }
5668             default:
5669                 break;
5670         }
5671     }
5672 
5673     // Make sure we have a main() function.
5674     if (!main) {
5675         fContext.fErrors->error(Position(), "program does not contain a main() function");
5676         return;
5677     }
5678     // Make sure our program's sampler usage is consistent.
5679     if (combinedSamplerPos.valid() && separateSamplerPos.valid()) {
5680         fContext.fErrors->error(Position(), "programs cannot contain a mixture of sampler types");
5681         fContext.fErrors->error(combinedSamplerPos, "combined sampler found here:");
5682         fContext.fErrors->error(separateSamplerPos, "separate sampler found here:");
5683         return;
5684     }
5685     fUseTextureSamplerPairs = separateSamplerPos.valid();
5686 
5687     // Emit interface blocks.
5688     std::set<SpvId> interfaceVars;
5689     for (const ProgramElement* e : program.elements()) {
5690         if (e->is<InterfaceBlock>()) {
5691             const InterfaceBlock& intf = e->as<InterfaceBlock>();
5692             SpvId id = this->writeInterfaceBlock(intf);
5693 
5694             if ((intf.var()->modifierFlags() & (ModifierFlag::kIn | ModifierFlag::kOut)) &&
5695                 intf.var()->layout().fBuiltin == -1) {
5696                 interfaceVars.insert(id);
5697             }
5698         }
5699     }
5700     // If MustDeclareFragmentFrontFacing is set, the front-facing flag (sk_Clockwise) needs to be
5701     // explicitly declared in the output, whether or not the program explicitly references it.
5702     // However, if the program naturally declares it, we don't want to include it a second time;
5703     // we keep track of the real global variable declarations to see if sk_Clockwise is emitted.
5704     const VarDeclaration* missingClockwiseDecl = nullptr;
5705     if (fCaps.fMustDeclareFragmentFrontFacing) {
5706         if (const Symbol* clockwise = program.fSymbols->findBuiltinSymbol("sk_Clockwise")) {
5707             missingClockwiseDecl = clockwise->as<Variable>().varDeclaration();
5708         }
5709     }
5710     // Emit global variable declarations.
5711     for (const ProgramElement* e : program.elements()) {
5712         if (e->is<GlobalVarDeclaration>()) {
5713             const VarDeclaration& decl = e->as<GlobalVarDeclaration>().varDeclaration();
5714             if (!this->writeGlobalVarDeclaration(program.fConfig->fKind, decl)) {
5715                 return;
5716             }
5717             if (missingClockwiseDecl == &decl) {
5718                 // We emitted an sk_Clockwise declaration naturally, so we don't need a workaround.
5719                 missingClockwiseDecl = nullptr;
5720             }
5721         }
5722     }
5723     // All the global variables have been declared. If sk_Clockwise was not naturally included in
5724     // the output, but MustDeclareFragmentFrontFacing was set, we need to bodge it in ourselves.
5725     if (missingClockwiseDecl) {
5726         if (!this->writeGlobalVarDeclaration(program.fConfig->fKind, *missingClockwiseDecl)) {
5727             return;
5728         }
5729         missingClockwiseDecl = nullptr;
5730     }
5731     // Emit top-level uniforms into a dedicated uniform buffer.
5732     if (!fTopLevelUniforms.empty()) {
5733         this->writeUniformBuffer(get_top_level_symbol_table(*main));
5734     }
5735     // If main() returns a half4, synthesize a tiny entrypoint function which invokes the real
5736     // main() and stores the result into sk_FragColor.
5737     EntrypointAdapter adapter;
5738     if (main->returnType().matches(*fContext.fTypes.fHalf4)) {
5739         adapter = this->writeEntrypointAdapter(*main);
5740         if (adapter.entrypointDecl) {
5741             fFunctionMap.set({adapter.entrypointDecl.get(), Analysis::kUnspecialized},
5742                              this->nextId(nullptr));
5743             this->writeFunction(*adapter.entrypointDef, body);
5744             main = adapter.entrypointDecl.get();
5745         }
5746     }
5747     // Emit all the functions.
5748     for (const ProgramElement* e : program.elements()) {
5749         if (e->is<FunctionDefinition>()) {
5750             this->writeFunction(e->as<FunctionDefinition>(), body);
5751         }
5752     }
5753     // Add global in/out variables to the list of interface variables.
5754     for (const auto& [var, spvId] : fVariableMap) {
5755         if (var->storage() == Variable::Storage::kGlobal &&
5756 #ifdef SKSL_EXT
5757             !(var->layout().fFlags & SkSL::LayoutFlag::kConstantId) &&
5758             ((var->modifierFlags() == ModifierFlag::kNone) ||
5759                 (var->modifierFlags() & (
5760                     ModifierFlag::kIn |
5761                     ModifierFlag::kOut |
5762                     ModifierFlag::kUniform |
5763                     ModifierFlag::kBuffer)))) {
5764 #else
5765             (var->modifierFlags() & (ModifierFlag::kIn | ModifierFlag::kOut))) {
5766 #endif
5767             interfaceVars.insert(spvId);
5768         }
5769     }
5770     this->writeCapabilities(out);
5771 #ifdef SKSL_EXT
5772     this->writeExtensions(out);
5773 #endif
5774     this->writeInstruction(SpvOpExtInstImport, fGLSLExtendedInstructions, "GLSL.std.450", out);
5775     this->writeInstruction(SpvOpMemoryModel, SpvAddressingModelLogical, SpvMemoryModelGLSL450, out);
5776     this->writeOpCode(SpvOpEntryPoint,
5777                       (SpvId)(3 + (main->name().length() + 4) / 4) + (int32_t)interfaceVars.size(),
5778                       out);
5779     if (ProgramConfig::IsVertex(program.fConfig->fKind)) {
5780         this->writeWord(SpvExecutionModelVertex, out);
5781     } else if (ProgramConfig::IsFragment(program.fConfig->fKind)) {
5782         this->writeWord(SpvExecutionModelFragment, out);
5783     } else if (ProgramConfig::IsCompute(program.fConfig->fKind)) {
5784         this->writeWord(SpvExecutionModelGLCompute, out);
5785     } else {
5786         SK_ABORT("cannot write this kind of program to SPIR-V\n");
5787     }
5788     const Analysis::SpecializedFunctionKey mainKey{main, Analysis::kUnspecialized};
5789     SpvId entryPoint = fFunctionMap[mainKey];
5790     this->writeWord(entryPoint, out);
5791     this->writeString(main->name(), out);
5792     for (int var : interfaceVars) {
5793         this->writeWord(var, out);
5794     }
5795     if (ProgramConfig::IsFragment(program.fConfig->fKind)) {
5796         this->writeInstruction(SpvOpExecutionMode,
5797                                fFunctionMap[mainKey],
5798                                SpvExecutionModeOriginUpperLeft,
5799                                out);
5800     } else if (ProgramConfig::IsCompute(program.fConfig->fKind)) {
5801         this->writeInstruction(SpvOpExecutionMode,
5802                                fFunctionMap[mainKey],
5803                                SpvExecutionModeLocalSize,
5804                                localSizeX, localSizeY, localSizeZ,
5805                                out);
5806     }
5807     for (const ProgramElement* e : program.elements()) {
5808         if (e->is<Extension>()) {
5809             this->writeInstruction(SpvOpSourceExtension, e->as<Extension>().name(), out);
5810         }
5811     }
5812 
5813     write_stringstream(fNameBuffer, out);
5814     write_stringstream(fDecorationBuffer, out);
5815     write_stringstream(fConstantBuffer, out);
5816     write_stringstream(body, out);
5817 }
5818 
5819 bool SPIRVCodeGenerator::generateCode() {
5820     SkASSERT(!fContext.fErrors->errorCount());
5821     this->writeWord(SpvMagicNumber, *fOut);
5822     this->writeWord(SpvVersion, *fOut);
5823     this->writeWord(SKSL_MAGIC, *fOut);
5824     StringStream buffer;
5825     this->writeInstructions(fProgram, buffer);
5826     this->writeWord(fIdCount, *fOut);
5827     this->writeWord(0, *fOut); // reserved, always zero
5828     write_stringstream(buffer, *fOut);
5829     return fContext.fErrors->errorCount() == 0;
5830 }
5831 
5832 bool ToSPIRV(Program& program,
5833              const ShaderCaps* caps,
5834              OutputStream& out,
5835              ValidateSPIRVProc validateSPIRV) {
5836     TRACE_EVENT0("skia.shaders", "SkSL::ToSPIRV");
5837     SkASSERT(caps != nullptr);
5838 
5839     program.fContext->fErrors->setSource(*program.fSource);
5840     bool result;
5841     if (validateSPIRV) {
5842         StringStream buffer;
5843         SPIRVCodeGenerator cg(program.fContext.get(), caps, &program, &buffer);
5844         result = cg.generateCode();
5845 
5846         if (result && program.fConfig->fSettings.fValidateSPIRV) {
5847             std::string_view binary = buffer.str();
5848             result = validateSPIRV(*program.fContext->fErrors, binary);
5849             out.write(binary.data(), binary.size());
5850         }
5851     } else {
5852         SPIRVCodeGenerator cg(program.fContext.get(), caps, &program, &out);
5853         result = cg.generateCode();
5854     }
5855     program.fContext->fErrors->setSource(std::string_view());
5856 
5857     return result;
5858 }
5859 
5860 bool ToSPIRV(Program& program,
5861              const ShaderCaps* caps,
5862              std::string* out,
5863              ValidateSPIRVProc validateSPIRV) {
5864     StringStream buffer;
5865     if (!ToSPIRV(program, caps, buffer, validateSPIRV)) {
5866         return false;
5867     }
5868     *out = buffer.str();
5869     return true;
5870 }
5871 
5872 }  // namespace SkSL
5873