• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2016 Google Inc.
3  *
4  * Use of this source code is governed by a BSD-style license that can be
5  * found in the LICENSE file.
6  */
7 
8 #include "src/sksl/codegen/SkSLSPIRVCodeGenerator.h"
9 
10 #include "include/core/SkSpan.h"
11 #include "include/core/SkTypes.h"
12 #include "include/private/base/SkTArray.h"
13 #include "include/private/base/SkTo.h"
14 #include "src/base/SkEnumBitMask.h"
15 #include "src/core/SkChecksum.h"
16 #include "src/core/SkTHash.h"
17 #include "src/core/SkTraceEvent.h"
18 #include "src/sksl/GLSL.std.450.h"
19 #include "src/sksl/SkSLAnalysis.h"
20 #include "src/sksl/SkSLBuiltinTypes.h"
21 #include "src/sksl/SkSLCompiler.h"
22 #include "src/sksl/SkSLConstantFolder.h"
23 #include "src/sksl/SkSLContext.h"
24 #include "src/sksl/SkSLDefines.h"
25 #include "src/sksl/SkSLErrorReporter.h"
26 #include "src/sksl/SkSLIntrinsicList.h"
27 #include "src/sksl/SkSLMemoryLayout.h"
28 #include "src/sksl/SkSLOperator.h"
29 #include "src/sksl/SkSLOutputStream.h"
30 #include "src/sksl/SkSLPool.h"
31 #include "src/sksl/SkSLPosition.h"
32 #include "src/sksl/SkSLProgramSettings.h"
33 #include "src/sksl/SkSLStringStream.h"
34 #include "src/sksl/SkSLUtil.h"
35 #include "src/sksl/codegen/SkSLCodeGenerator.h"
36 #include "src/sksl/ir/SkSLBinaryExpression.h"
37 #include "src/sksl/ir/SkSLBlock.h"
38 #include "src/sksl/ir/SkSLConstructor.h"
39 #include "src/sksl/ir/SkSLConstructorArrayCast.h"
40 #include "src/sksl/ir/SkSLConstructorCompound.h"
41 #include "src/sksl/ir/SkSLConstructorCompoundCast.h"
42 #include "src/sksl/ir/SkSLConstructorDiagonalMatrix.h"
43 #include "src/sksl/ir/SkSLConstructorMatrixResize.h"
44 #include "src/sksl/ir/SkSLConstructorScalarCast.h"
45 #include "src/sksl/ir/SkSLConstructorSplat.h"
46 #include "src/sksl/ir/SkSLDoStatement.h"
47 #include "src/sksl/ir/SkSLExpression.h"
48 #include "src/sksl/ir/SkSLExpressionStatement.h"
49 #include "src/sksl/ir/SkSLExtension.h"
50 #include "src/sksl/ir/SkSLFieldAccess.h"
51 #include "src/sksl/ir/SkSLFieldSymbol.h"
52 #include "src/sksl/ir/SkSLForStatement.h"
53 #include "src/sksl/ir/SkSLFunctionCall.h"
54 #include "src/sksl/ir/SkSLFunctionDeclaration.h"
55 #include "src/sksl/ir/SkSLFunctionDefinition.h"
56 #include "src/sksl/ir/SkSLIRNode.h"
57 #include "src/sksl/ir/SkSLIfStatement.h"
58 #include "src/sksl/ir/SkSLIndexExpression.h"
59 #include "src/sksl/ir/SkSLInterfaceBlock.h"
60 #include "src/sksl/ir/SkSLLayout.h"
61 #include "src/sksl/ir/SkSLLiteral.h"
62 #include "src/sksl/ir/SkSLModifierFlags.h"
63 #include "src/sksl/ir/SkSLModifiersDeclaration.h"
64 #include "src/sksl/ir/SkSLPoison.h"
65 #include "src/sksl/ir/SkSLPostfixExpression.h"
66 #include "src/sksl/ir/SkSLPrefixExpression.h"
67 #include "src/sksl/ir/SkSLProgram.h"
68 #include "src/sksl/ir/SkSLProgramElement.h"
69 #include "src/sksl/ir/SkSLReturnStatement.h"
70 #include "src/sksl/ir/SkSLSetting.h"
71 #include "src/sksl/ir/SkSLStatement.h"
72 #include "src/sksl/ir/SkSLSwitchCase.h"
73 #include "src/sksl/ir/SkSLSwitchStatement.h"
74 #include "src/sksl/ir/SkSLSwizzle.h"
75 #include "src/sksl/ir/SkSLSymbol.h"
76 #include "src/sksl/ir/SkSLSymbolTable.h"
77 #include "src/sksl/ir/SkSLTernaryExpression.h"
78 #include "src/sksl/ir/SkSLType.h"
79 #include "src/sksl/ir/SkSLVarDeclarations.h"
80 #include "src/sksl/ir/SkSLVariable.h"
81 #include "src/sksl/ir/SkSLVariableReference.h"
82 #include "src/sksl/spirv.h"
83 #include "src/sksl/transform/SkSLTransform.h"
84 #include "src/utils/SkBitSet.h"
85 
86 #include <cstdint>
87 #include <cstring>
88 #include <memory>
89 #include <set>
90 #include <string>
91 #include <string_view>
92 #include <tuple>
93 #include <utility>
94 #include <vector>
95 
96 #ifdef SK_ENABLE_SPIRV_VALIDATION
97 #include "spirv-tools/libspirv.hpp"
98 #endif
99 
100 using namespace skia_private;
101 
102 #define kLast_Capability SpvCapabilityMultiViewport
103 
104 constexpr int DEVICE_FRAGCOORDS_BUILTIN = -1000;
105 constexpr int DEVICE_CLOCKWISE_BUILTIN  = -1001;
106 static constexpr SkSL::Layout kDefaultTypeLayout;
107 
108 namespace SkSL {
109 
110 enum class ProgramKind : int8_t;
111 
112 class SPIRVCodeGenerator : public CodeGenerator {
113 public:
114     // We reserve an impossible SpvId as a sentinel. (NA meaning none, n/a, etc.)
115     static constexpr SpvId NA = (SpvId)-1;
116 
117     class LValue {
118     public:
~LValue()119         virtual ~LValue() {}
120 
121         // returns a pointer to the lvalue, if possible. If the lvalue cannot be directly referenced
122         // by a pointer (e.g. vector swizzles), returns NA.
getPointer()123         virtual SpvId getPointer() { return NA; }
124 
125         // Returns true if a valid pointer returned by getPointer represents a memory object
126         // (see https://github.com/KhronosGroup/SPIRV-Tools/issues/2892). Has no meaning if
127         // getPointer() returns NA.
isMemoryObjectPointer() const128         virtual bool isMemoryObjectPointer() const { return true; }
129 
130         // Applies a swizzle to the components of the LValue, if possible. This is used to create
131         // LValues that are swizzes-of-swizzles. Non-swizzle LValues can just return false.
applySwizzle(const ComponentArray & components,const Type & newType)132         virtual bool applySwizzle(const ComponentArray& components, const Type& newType) {
133             return false;
134         }
135 
136         // Returns the storage class of the lvalue.
137         virtual SpvStorageClass storageClass() const = 0;
138 
139         virtual SpvId load(OutputStream& out) = 0;
140 
141         virtual void store(SpvId value, OutputStream& out) = 0;
142     };
143 
SPIRVCodeGenerator(const Context * context,const ShaderCaps * caps,const Program * program,OutputStream * out)144     SPIRVCodeGenerator(const Context* context,
145                        const ShaderCaps* caps,
146                        const Program* program,
147                        OutputStream* out)
148             : INHERITED(context, caps, program, out) {}
149 
150     bool generateCode() override;
151 
152 private:
153     enum IntrinsicOpcodeKind {
154         kGLSL_STD_450_IntrinsicOpcodeKind,
155         kSPIRV_IntrinsicOpcodeKind,
156         kSpecial_IntrinsicOpcodeKind,
157         kInvalid_IntrinsicOpcodeKind,
158     };
159 
160     enum SpecialIntrinsic {
161         kAtan_SpecialIntrinsic,
162         kClamp_SpecialIntrinsic,
163         kMatrixCompMult_SpecialIntrinsic,
164         kMax_SpecialIntrinsic,
165         kMin_SpecialIntrinsic,
166         kMix_SpecialIntrinsic,
167         kMod_SpecialIntrinsic,
168         kDFdy_SpecialIntrinsic,
169         kSaturate_SpecialIntrinsic,
170         kSampledImage_SpecialIntrinsic,
171         kSmoothStep_SpecialIntrinsic,
172         kStep_SpecialIntrinsic,
173         kSubpassLoad_SpecialIntrinsic,
174         kTexture_SpecialIntrinsic,
175         kTextureGrad_SpecialIntrinsic,
176         kTextureLod_SpecialIntrinsic,
177         kTextureRead_SpecialIntrinsic,
178         kTextureWrite_SpecialIntrinsic,
179         kTextureWidth_SpecialIntrinsic,
180         kTextureHeight_SpecialIntrinsic,
181         kAtomicAdd_SpecialIntrinsic,
182         kAtomicLoad_SpecialIntrinsic,
183         kAtomicStore_SpecialIntrinsic,
184         kStorageBarrier_SpecialIntrinsic,
185         kWorkgroupBarrier_SpecialIntrinsic,
186     };
187 
188     enum class Precision {
189         kDefault,
190         kRelaxed,
191     };
192 
193     struct TempVar {
194         SpvId spvId;
195         const Type* type;
196         std::unique_ptr<SPIRVCodeGenerator::LValue> lvalue;
197     };
198 
199     /**
200      * Pass in the type to automatically add a RelaxedPrecision decoration for the id when
201      * appropriate, or null to never add one.
202      */
203     SpvId nextId(const Type* type);
204 
205     SpvId nextId(Precision precision);
206 
207     SpvId getType(const Type& type);
208 
209     SpvId getType(const Type& type, const Layout& typeLayout, const MemoryLayout& memoryLayout);
210 
211     SpvId getFunctionType(const FunctionDeclaration& function);
212 
213     SpvId getFunctionParameterType(const Type& parameterType, const Layout& parameterLayout);
214 
215     SpvId getPointerType(const Type& type, SpvStorageClass_ storageClass);
216 
217     SpvId getPointerType(const Type& type,
218                          const Layout& typeLayout,
219                          const MemoryLayout& memoryLayout,
220                          SpvStorageClass_ storageClass);
221 
222     skia_private::TArray<SpvId> getAccessChain(const Expression& expr, OutputStream& out);
223 
224     void writeLayout(const Layout& layout, SpvId target, Position pos);
225 
226     void writeFieldLayout(const Layout& layout, SpvId target, int member);
227 
228     SpvId writeStruct(const Type& type, const MemoryLayout& memoryLayout);
229 
230     void writeProgramElement(const ProgramElement& pe, OutputStream& out);
231 
232     SpvId writeInterfaceBlock(const InterfaceBlock& intf, bool appendRTFlip = true);
233 
234     SpvId writeFunctionStart(const FunctionDeclaration& f, OutputStream& out);
235 
236     SpvId writeFunctionDeclaration(const FunctionDeclaration& f, OutputStream& out);
237 
238     SpvId writeFunction(const FunctionDefinition& f, OutputStream& out);
239 
240     bool writeGlobalVarDeclaration(ProgramKind kind, const VarDeclaration& v);
241 
242     SpvId writeGlobalVar(ProgramKind kind, SpvStorageClass_, const Variable& v);
243 
244     void writeVarDeclaration(const VarDeclaration& var, OutputStream& out);
245 
246     SpvId writeVariableReference(const VariableReference& ref, OutputStream& out);
247 
248     int findUniformFieldIndex(const Variable& var) const;
249 
250     std::unique_ptr<LValue> getLValue(const Expression& value, OutputStream& out);
251 
252     SpvId writeExpression(const Expression& expr, OutputStream& out);
253 
254     SpvId writeIntrinsicCall(const FunctionCall& c, OutputStream& out);
255 
256     SpvId writeFunctionCallArgument(const FunctionCall& call,
257                                     int argIndex,
258                                     std::vector<TempVar>* tempVars,
259                                     OutputStream& out,
260                                     SpvId* outSynthesizedSamplerId = nullptr);
261 
262     void copyBackTempVars(const std::vector<TempVar>& tempVars, OutputStream& out);
263 
264     SpvId writeFunctionCall(const FunctionCall& c, OutputStream& out);
265 
266 
267     void writeGLSLExtendedInstruction(const Type& type, SpvId id, SpvId floatInst,
268                                       SpvId signedInst, SpvId unsignedInst,
269                                       const skia_private::TArray<SpvId>& args, OutputStream& out);
270 
271     /**
272      * Promotes an expression to a vector. If the expression is already a vector with vectorSize
273      * columns, returns it unmodified. If the expression is a scalar, either promotes it to a
274      * vector (if vectorSize > 1) or returns it unmodified (if vectorSize == 1). Asserts if the
275      * expression is already a vector and it does not have vectorSize columns.
276      */
277     SpvId vectorize(const Expression& expr, int vectorSize, OutputStream& out);
278 
279     /**
280      * Given a list of potentially mixed scalars and vectors, promotes the scalars to match the
281      * size of the vectors and returns the ids of the written expressions. e.g. given (float, vec2),
282      * returns (vec2(float), vec2). It is an error to use mismatched vector sizes, e.g. (float,
283      * vec2, vec3).
284      */
285     skia_private::TArray<SpvId> vectorize(const ExpressionArray& args, OutputStream& out);
286 
287     /**
288      * Given a SpvId of a scalar, splats it across the passed-in type (scalar, vector or matrix) and
289      * returns the SpvId of the new value.
290      */
291     SpvId splat(const Type& type, SpvId id, OutputStream& out);
292 
293     SpvId writeSpecialIntrinsic(const FunctionCall& c, SpecialIntrinsic kind, OutputStream& out);
294     SpvId writeAtomicIntrinsic(const FunctionCall& c,
295                                SpecialIntrinsic kind,
296                                SpvId resultId,
297                                OutputStream& out);
298 
299     SpvId castScalarToFloat(SpvId inputId, const Type& inputType, const Type& outputType,
300                             OutputStream& out);
301 
302     SpvId castScalarToSignedInt(SpvId inputId, const Type& inputType, const Type& outputType,
303                                 OutputStream& out);
304 
305     SpvId castScalarToUnsignedInt(SpvId inputId, const Type& inputType, const Type& outputType,
306                                   OutputStream& out);
307 
308     SpvId castScalarToBoolean(SpvId inputId, const Type& inputType, const Type& outputType,
309                               OutputStream& out);
310 
311     SpvId castScalarToType(SpvId inputExprId, const Type& inputType, const Type& outputType,
312                            OutputStream& out);
313 
314     /**
315      * Writes a potentially-different-sized copy of a matrix. Entries which do not exist in the
316      * source matrix are filled with zero; entries which do not exist in the destination matrix are
317      * ignored.
318      */
319     SpvId writeMatrixCopy(SpvId src, const Type& srcType, const Type& dstType, OutputStream& out);
320 
321     void addColumnEntry(const Type& columnType,
322                         skia_private::TArray<SpvId>* currentColumn,
323                         skia_private::TArray<SpvId>* columnIds,
324                         int rows, SpvId entry, OutputStream& out);
325 
326     SpvId writeConstructorCompound(const ConstructorCompound& c, OutputStream& out);
327 
328     SpvId writeMatrixConstructor(const ConstructorCompound& c, OutputStream& out);
329 
330     SpvId writeVectorConstructor(const ConstructorCompound& c, OutputStream& out);
331 
332     SpvId writeCompositeConstructor(const AnyConstructor& c, OutputStream& out);
333 
334     SpvId writeConstructorDiagonalMatrix(const ConstructorDiagonalMatrix& c, OutputStream& out);
335 
336     SpvId writeConstructorMatrixResize(const ConstructorMatrixResize& c, OutputStream& out);
337 
338     SpvId writeConstructorScalarCast(const ConstructorScalarCast& c, OutputStream& out);
339 
340     SpvId writeConstructorSplat(const ConstructorSplat& c, OutputStream& out);
341 
342     SpvId writeConstructorCompoundCast(const ConstructorCompoundCast& c, OutputStream& out);
343 
344     SpvId writeFieldAccess(const FieldAccess& f, OutputStream& out);
345 
346     SpvId writeSwizzle(const Expression& baseExpr,
347                        const ComponentArray& components,
348                        OutputStream& out);
349 
350     SpvId writeSwizzle(const Swizzle& swizzle, OutputStream& out);
351 
352     /**
353      * Folds the potentially-vector result of a logical operation down to a single bool. If
354      * operandType is a vector type, assumes that the intermediate result in id is a bvec of the
355      * same dimensions, and applys all() to it to fold it down to a single bool value. Otherwise,
356      * returns the original id value.
357      */
358     SpvId foldToBool(SpvId id, const Type& operandType, SpvOp op, OutputStream& out);
359 
360     SpvId writeMatrixComparison(const Type& operandType, SpvId lhs, SpvId rhs, SpvOp_ floatOperator,
361                                 SpvOp_ intOperator, SpvOp_ vectorMergeOperator,
362                                 SpvOp_ mergeOperator, OutputStream& out);
363 
364     SpvId writeStructComparison(const Type& structType, SpvId lhs, Operator op, SpvId rhs,
365                                 OutputStream& out);
366 
367     SpvId writeArrayComparison(const Type& structType, SpvId lhs, Operator op, SpvId rhs,
368                                OutputStream& out);
369 
370     // Used by writeStructComparison and writeArrayComparison to logically combine field-by-field
371     // comparisons into an overall comparison result.
372     // - `a.x == b.x` merged with `a.y == b.y` generates `(a.x == b.x) && (a.y == b.y)`
373     // - `a.x != b.x` merged with `a.y != b.y` generates `(a.x != b.x) || (a.y != b.y)`
374     SpvId mergeComparisons(SpvId comparison, SpvId allComparisons, Operator op, OutputStream& out);
375 
376     // When the RewriteMatrixVectorMultiply caps bit is set, we manually decompose the M*V
377     // multiplication into a sum of vector-scalar products.
378     SpvId writeDecomposedMatrixVectorMultiply(const Type& leftType,
379                                               SpvId lhs,
380                                               const Type& rightType,
381                                               SpvId rhs,
382                                               const Type& resultType,
383                                               OutputStream& out);
384 
385     SpvId writeComponentwiseMatrixUnary(const Type& operandType,
386                                         SpvId operand,
387                                         SpvOp_ op,
388                                         OutputStream& out);
389 
390     SpvId writeComponentwiseMatrixBinary(const Type& operandType, SpvId lhs, SpvId rhs,
391                                          SpvOp_ op, OutputStream& out);
392 
393     SpvId writeBinaryOperationComponentwiseIfMatrix(const Type& resultType, const Type& operandType,
394                                                     SpvId lhs, SpvId rhs,
395                                                     SpvOp_ ifFloat, SpvOp_ ifInt,
396                                                     SpvOp_ ifUInt, SpvOp_ ifBool,
397                                                     OutputStream& out);
398 
399     SpvId writeBinaryOperation(const Type& resultType, const Type& operandType, SpvId lhs,
400                                SpvId rhs, SpvOp_ ifFloat, SpvOp_ ifInt, SpvOp_ ifUInt,
401                                SpvOp_ ifBool, OutputStream& out);
402 
403     SpvId writeBinaryOperation(const Type& resultType, const Type& operandType, SpvId lhs,
404                                SpvId rhs, bool writeComponentwiseIfMatrix, SpvOp_ ifFloat,
405                                SpvOp_ ifInt, SpvOp_ ifUInt, SpvOp_ ifBool, OutputStream& out);
406 
407     SpvId writeReciprocal(const Type& type, SpvId value, OutputStream& out);
408 
409     SpvId writeBinaryExpression(const Type& leftType, SpvId lhs, Operator op,
410                                 const Type& rightType, SpvId rhs, const Type& resultType,
411                                 OutputStream& out);
412 
413     SpvId writeBinaryExpression(const BinaryExpression& b, OutputStream& out);
414 
415     SpvId writeTernaryExpression(const TernaryExpression& t, OutputStream& out);
416 
417     SpvId writeIndexExpression(const IndexExpression& expr, OutputStream& out);
418 
419     SpvId writeLogicalAnd(const Expression& left, const Expression& right, OutputStream& out);
420 
421     SpvId writeLogicalOr(const Expression& left, const Expression& right, OutputStream& out);
422 
423     SpvId writePrefixExpression(const PrefixExpression& p, OutputStream& out);
424 
425     SpvId writePostfixExpression(const PostfixExpression& p, OutputStream& out);
426 
427     SpvId writeLiteral(const Literal& f);
428 
429     SpvId writeLiteral(double value, const Type& type);
430 
431     void writeStatement(const Statement& s, OutputStream& out);
432 
433     void writeBlock(const Block& b, OutputStream& out);
434 
435     void writeIfStatement(const IfStatement& stmt, OutputStream& out);
436 
437     void writeForStatement(const ForStatement& f, OutputStream& out);
438 
439     void writeDoStatement(const DoStatement& d, OutputStream& out);
440 
441     void writeSwitchStatement(const SwitchStatement& s, OutputStream& out);
442 
443     void writeReturnStatement(const ReturnStatement& r, OutputStream& out);
444 
445     void writeCapabilities(OutputStream& out);
446 
447     void writeInstructions(const Program& program, OutputStream& out);
448 
449     void writeOpCode(SpvOp_ opCode, int length, OutputStream& out);
450 
451     void writeWord(int32_t word, OutputStream& out);
452 
453     void writeString(std::string_view s, OutputStream& out);
454 
455     void writeInstruction(SpvOp_ opCode, OutputStream& out);
456 
457     void writeInstruction(SpvOp_ opCode, std::string_view string, OutputStream& out);
458 
459     void writeInstruction(SpvOp_ opCode, int32_t word1, OutputStream& out);
460 
461     void writeInstruction(SpvOp_ opCode, int32_t word1, std::string_view string,
462                           OutputStream& out);
463 
464     void writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, std::string_view string,
465                           OutputStream& out);
466 
467     void writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, OutputStream& out);
468 
469     void writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, int32_t word3,
470                           OutputStream& out);
471 
472     void writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, int32_t word3, int32_t word4,
473                           OutputStream& out);
474 
475     void writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, int32_t word3, int32_t word4,
476                           int32_t word5, OutputStream& out);
477 
478     void writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, int32_t word3, int32_t word4,
479                           int32_t word5, int32_t word6, OutputStream& out);
480 
481     void writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, int32_t word3, int32_t word4,
482                           int32_t word5, int32_t word6, int32_t word7, OutputStream& out);
483 
484     void writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, int32_t word3, int32_t word4,
485                           int32_t word5, int32_t word6, int32_t word7, int32_t word8,
486                           OutputStream& out);
487 
488     // This form of writeInstruction can deduplicate redundant ops.
489     struct Word;
490     // 8 Words is enough for nearly all instructions (except variable-length instructions like
491     // OpAccessChain or OpConstantComposite).
492     using Words = skia_private::STArray<8, Word, true>;
493     SpvId writeInstruction(
494             SpvOp_ opCode, const skia_private::TArray<Word, true>& words, OutputStream& out);
495 
496     struct Instruction {
497         SpvId fOp;
498         int32_t fResultKind;
499         skia_private::STArray<8, int32_t>  fWords;
500 
501         bool operator==(const Instruction& that) const;
502         struct Hash;
503     };
504 
505     static Instruction BuildInstructionKey(SpvOp_ opCode,
506                                            const skia_private::TArray<Word, true>& words);
507 
508     // The writeOpXxxxx calls will simplify and deduplicate ops where possible.
509     SpvId writeOpConstantTrue(const Type& type);
510     SpvId writeOpConstantFalse(const Type& type);
511     SpvId writeOpConstant(const Type& type, int32_t valueBits);
512     SpvId writeOpConstantComposite(const Type& type, const skia_private::TArray<SpvId>& values);
513     SpvId writeOpCompositeConstruct(const Type& type, const skia_private::TArray<SpvId>&,
514                                     OutputStream& out);
515     SpvId writeOpCompositeExtract(const Type& type, SpvId base, int component, OutputStream& out);
516     SpvId writeOpCompositeExtract(const Type& type, SpvId base, int componentA, int componentB,
517                                   OutputStream& out);
518     SpvId writeOpLoad(SpvId type, Precision precision, SpvId pointer, OutputStream& out);
519     void writeOpStore(SpvStorageClass_ storageClass, SpvId pointer, SpvId value, OutputStream& out);
520 
521     // Converts the provided SpvId(s) into an array of scalar OpConstants, if it can be done.
522     bool toConstants(SpvId value, skia_private::TArray<SpvId>* constants);
523     bool toConstants(SkSpan<const SpvId> values, skia_private::TArray<SpvId>* constants);
524 
525     // Extracts the requested component SpvId from a composite instruction, if it can be done.
526     Instruction* resultTypeForInstruction(const Instruction& instr);
527     int numComponentsForVecInstruction(const Instruction& instr);
528     SpvId toComponent(SpvId id, int component);
529 
530     struct ConditionalOpCounts {
531         int numReachableOps;
532         int numStoreOps;
533     };
534     ConditionalOpCounts getConditionalOpCounts();
535     void pruneConditionalOps(ConditionalOpCounts ops);
536 
537     enum StraightLineLabelType {
538         // Use "BranchlessBlock" for blocks which are never explicitly branched-to at all. This
539         // happens at the start of a function, or when we find unreachable code.
540         kBranchlessBlock,
541 
542         // Use "BranchIsOnPreviousLine" when writing a label that comes immediately after its
543         // associated branch. Example usage:
544         // - SPIR-V does not implicitly fall through from one block to the next, so you may need to
545         //   use an OpBranch to explicitly jump to the next block, even when they are adjacent in
546         //   the code.
547         // - The block immediately following an OpBranchConditional or OpSwitch.
548         kBranchIsOnPreviousLine,
549     };
550 
551     enum BranchingLabelType {
552         // Use "BranchIsAbove" for labels which are referenced by OpBranch or OpBranchConditional
553         // ops that are above the label in the code--i.e., the branch skips forward in the code.
554         kBranchIsAbove,
555 
556         // Use "BranchIsBelow" for labels which are referenced by OpBranch or OpBranchConditional
557         // ops below the label in the code--i.e., the branch jumps backward in the code.
558         kBranchIsBelow,
559 
560         // Use "BranchesOnBothSides" for labels which have branches coming from both directions.
561         kBranchesOnBothSides,
562     };
563     void writeLabel(SpvId label, StraightLineLabelType type, OutputStream& out);
564     void writeLabel(SpvId label, BranchingLabelType type, ConditionalOpCounts ops,
565                     OutputStream& out);
566 
567     MemoryLayout memoryLayoutForStorageClass(SpvStorageClass_ storageClass);
568     MemoryLayout memoryLayoutForVariable(const Variable&) const;
569 
570     struct EntrypointAdapter {
571         std::unique_ptr<FunctionDefinition> entrypointDef;
572         std::unique_ptr<FunctionDeclaration> entrypointDecl;
573     };
574 
575     EntrypointAdapter writeEntrypointAdapter(const FunctionDeclaration& main);
576 
577     struct UniformBuffer {
578         std::unique_ptr<InterfaceBlock> fInterfaceBlock;
579         std::unique_ptr<Variable> fInnerVariable;
580         std::unique_ptr<Type> fStruct;
581     };
582 
583     void writeUniformBuffer(SymbolTable* topLevelSymbolTable);
584 
585     void addRTFlipUniform(Position pos);
586 
587     std::unique_ptr<Expression> identifier(std::string_view name);
588 
589     std::tuple<const Variable*, const Variable*> synthesizeTextureAndSampler(
590             const Variable& combinedSampler);
591 
592     const MemoryLayout fDefaultMemoryLayout{MemoryLayout::Standard::k140};
593 
594     uint64_t fCapabilities = 0;
595     SpvId fIdCount = 1;
596     SpvId fGLSLExtendedInstructions;
597     struct Intrinsic {
598         IntrinsicOpcodeKind opKind;
599         int32_t floatOp;
600         int32_t signedOp;
601         int32_t unsignedOp;
602         int32_t boolOp;
603     };
604     Intrinsic getIntrinsic(IntrinsicKind) const;
605     skia_private::THashMap<const FunctionDeclaration*, SpvId> fFunctionMap;
606     skia_private::THashMap<const Variable*, SpvId> fVariableMap;
607     skia_private::THashMap<const Type*, SpvId> fStructMap;
608     StringStream fGlobalInitializersBuffer;
609     StringStream fConstantBuffer;
610     StringStream fVariableBuffer;
611     StringStream fNameBuffer;
612     StringStream fDecorationBuffer;
613 
614     // Mapping from combined sampler declarations to synthesized texture/sampler variables.
615     // This is used when the sampler is declared as `layout(webgpu)` or `layout(direct3d)`.
616     bool fUseTextureSamplerPairs = false;
617     struct SynthesizedTextureSamplerPair {
618         // The names of the synthesized variables. The Variable objects themselves store string
619         // views referencing these strings. It is important for the std::string instances to have a
620         // fixed memory location after the string views get created, which is why
621         // `fSynthesizedSamplerMap` stores unique_ptr instead of values.
622         std::string fTextureName;
623         std::string fSamplerName;
624         std::unique_ptr<Variable> fTexture;
625         std::unique_ptr<Variable> fSampler;
626     };
627     skia_private::THashMap<const Variable*, std::unique_ptr<SynthesizedTextureSamplerPair>>
628             fSynthesizedSamplerMap;
629 
630     // These caches map SpvIds to Instructions, and vice-versa. This enables us to deduplicate code
631     // (by detecting an Instruction we've already issued and reusing the SpvId), and to introspect
632     // and simplify code we've already emitted  (by taking a SpvId from an Instruction and following
633     // it back to its source).
634 
635     // A map of instruction -> SpvId:
636     skia_private::THashMap<Instruction, SpvId, Instruction::Hash> fOpCache;
637     // A map of SpvId -> instruction:
638     skia_private::THashMap<SpvId, Instruction> fSpvIdCache;
639     // A map of SpvId -> value SpvId:
640     skia_private::THashMap<SpvId, SpvId> fStoreCache;
641 
642     // "Reachable" ops are instructions which can safely be accessed from the current block.
643     // For instance, if our SPIR-V contains `%3 = OpFAdd %1 %2`, we would be able to access and
644     // reuse that computation on following lines. However, if that Add operation occurred inside an
645     // `if` block, then its SpvId becomes inaccessible once we complete the if statement (since
646     // depending on the if condition, we may or may not have actually done that computation). The
647     // same logic applies to other control-flow blocks as well. Once an instruction becomes
648     // unreachable, we remove it from both op-caches.
649     skia_private::TArray<SpvId> fReachableOps;
650 
651     // The "store-ops" list contains a running list of all the pointers in the store cache. If a
652     // store occurs inside of a conditional block, once that block exits, we no longer know what is
653     // stored in that particular SpvId. At that point, we must remove any associated entry from the
654     // store cache.
655     skia_private::TArray<SpvId> fStoreOps;
656 
657     // label of the current block, or 0 if we are not in a block
658     SpvId fCurrentBlock = 0;
659     skia_private::TArray<SpvId> fBreakTarget;
660     skia_private::TArray<SpvId> fContinueTarget;
661     bool fWroteRTFlip = false;
662     // holds variables synthesized during output, for lifetime purposes
663     SymbolTable fSynthetics{/*builtin=*/true};
664     // Holds a list of uniforms that were declared as globals at the top-level instead of in an
665     // interface block.
666     UniformBuffer fUniformBuffer;
667     std::vector<const VarDeclaration*> fTopLevelUniforms;
668     skia_private::THashMap<const Variable*, int>
669             fTopLevelUniformMap;  // <var, UniformBuffer field index>
670     SpvId fUniformBufferId = NA;
671 
672     friend class PointerLValue;
673     friend class SwizzleLValue;
674 
675     using INHERITED = CodeGenerator;
676 };
677 
678 // Equality and hash operators for Instructions.
operator ==(const SPIRVCodeGenerator::Instruction & that) const679 bool SPIRVCodeGenerator::Instruction::operator==(const SPIRVCodeGenerator::Instruction& that) const {
680     return fOp         == that.fOp &&
681            fResultKind == that.fResultKind &&
682            fWords      == that.fWords;
683 }
684 
685 struct SPIRVCodeGenerator::Instruction::Hash {
operator ()SkSL::SPIRVCodeGenerator::Instruction::Hash686     uint32_t operator()(const SPIRVCodeGenerator::Instruction& key) const {
687         uint32_t hash = key.fResultKind;
688         hash = SkChecksum::Hash32(&key.fOp, sizeof(key.fOp), hash);
689         hash = SkChecksum::Hash32(key.fWords.data(), key.fWords.size() * sizeof(int32_t), hash);
690         return hash;
691     }
692 };
693 
694 // This class is used to pass values and result placeholder slots to writeInstruction.
695 struct SPIRVCodeGenerator::Word {
696     enum Kind {
697         kNone,  // intended for use as a sentinel, not part of any Instruction
698         kSpvId,
699         kNumber,
700         kDefaultPrecisionResult,
701         kRelaxedPrecisionResult,
702         kUniqueResult,
703         kKeyedResult,
704     };
705 
WordSkSL::SPIRVCodeGenerator::Word706     Word(SpvId id) : fValue(id), fKind(Kind::kSpvId) {}
WordSkSL::SPIRVCodeGenerator::Word707     Word(int32_t val, Kind kind) : fValue(val), fKind(kind) {}
708 
NumberSkSL::SPIRVCodeGenerator::Word709     static Word Number(int32_t val) {
710         return Word{val, Kind::kNumber};
711     }
712 
ResultSkSL::SPIRVCodeGenerator::Word713     static Word Result(const Type& type) {
714         return (type.hasPrecision() && !type.highPrecision()) ? RelaxedResult() : Result();
715     }
716 
RelaxedResultSkSL::SPIRVCodeGenerator::Word717     static Word RelaxedResult() {
718         return Word{(int32_t)NA, kRelaxedPrecisionResult};
719     }
720 
UniqueResultSkSL::SPIRVCodeGenerator::Word721     static Word UniqueResult() {
722         return Word{(int32_t)NA, kUniqueResult};
723     }
724 
ResultSkSL::SPIRVCodeGenerator::Word725     static Word Result() {
726         return Word{(int32_t)NA, kDefaultPrecisionResult};
727     }
728 
729     // Unlike a Result (where the result ID is always deduplicated to its first instruction) or a
730     // UniqueResult (which always produces a new instruction), a KeyedResult allows an instruction
731     // to be deduplicated among those that share the same `key`.
KeyedResultSkSL::SPIRVCodeGenerator::Word732     static Word KeyedResult(int32_t key) { return Word{key, Kind::kKeyedResult}; }
733 
isResultSkSL::SPIRVCodeGenerator::Word734     bool isResult() const { return fKind >= Kind::kDefaultPrecisionResult; }
735 
736     int32_t fValue;
737     Kind fKind;
738 };
739 
740 // Skia's magic number is 31 and goes in the top 16 bits. We can use the lower bits to version the
741 // sksl generator if we want.
742 // https://github.com/KhronosGroup/SPIRV-Headers/blob/master/include/spirv/spir-v.xml#L84
743 static const int32_t SKSL_MAGIC  = 0x001F0000;
744 
getIntrinsic(IntrinsicKind ik) const745 SPIRVCodeGenerator::Intrinsic SPIRVCodeGenerator::getIntrinsic(IntrinsicKind ik) const {
746 
747 #define ALL_GLSL(x) Intrinsic{kGLSL_STD_450_IntrinsicOpcodeKind, GLSLstd450 ## x, \
748                               GLSLstd450 ## x, GLSLstd450 ## x, GLSLstd450 ## x}
749 #define BY_TYPE_GLSL(ifFloat, ifInt, ifUInt) Intrinsic{kGLSL_STD_450_IntrinsicOpcodeKind, \
750                                                        GLSLstd450 ## ifFloat,             \
751                                                        GLSLstd450 ## ifInt,               \
752                                                        GLSLstd450 ## ifUInt,              \
753                                                        SpvOpUndef}
754 #define ALL_SPIRV(x) Intrinsic{kSPIRV_IntrinsicOpcodeKind, \
755                                SpvOp ## x, SpvOp ## x, SpvOp ## x, SpvOp ## x}
756 #define BOOL_SPIRV(x) Intrinsic{kSPIRV_IntrinsicOpcodeKind, \
757                                 SpvOpUndef, SpvOpUndef, SpvOpUndef, SpvOp ## x}
758 #define FLOAT_SPIRV(x) Intrinsic{kSPIRV_IntrinsicOpcodeKind, \
759                                  SpvOp ## x, SpvOpUndef, SpvOpUndef, SpvOpUndef}
760 #define SPECIAL(x) Intrinsic{kSpecial_IntrinsicOpcodeKind, k ## x ## _SpecialIntrinsic, \
761                              k ## x ## _SpecialIntrinsic, k ## x ## _SpecialIntrinsic,  \
762                              k ## x ## _SpecialIntrinsic}
763 
764     switch (ik) {
765         case k_round_IntrinsicKind:          return ALL_GLSL(Round);
766         case k_roundEven_IntrinsicKind:      return ALL_GLSL(RoundEven);
767         case k_trunc_IntrinsicKind:          return ALL_GLSL(Trunc);
768         case k_abs_IntrinsicKind:            return BY_TYPE_GLSL(FAbs, SAbs, SAbs);
769         case k_sign_IntrinsicKind:           return BY_TYPE_GLSL(FSign, SSign, SSign);
770         case k_floor_IntrinsicKind:          return ALL_GLSL(Floor);
771         case k_ceil_IntrinsicKind:           return ALL_GLSL(Ceil);
772         case k_fract_IntrinsicKind:          return ALL_GLSL(Fract);
773         case k_radians_IntrinsicKind:        return ALL_GLSL(Radians);
774         case k_degrees_IntrinsicKind:        return ALL_GLSL(Degrees);
775         case k_sin_IntrinsicKind:            return ALL_GLSL(Sin);
776         case k_cos_IntrinsicKind:            return ALL_GLSL(Cos);
777         case k_tan_IntrinsicKind:            return ALL_GLSL(Tan);
778         case k_asin_IntrinsicKind:           return ALL_GLSL(Asin);
779         case k_acos_IntrinsicKind:           return ALL_GLSL(Acos);
780         case k_atan_IntrinsicKind:           return SPECIAL(Atan);
781         case k_sinh_IntrinsicKind:           return ALL_GLSL(Sinh);
782         case k_cosh_IntrinsicKind:           return ALL_GLSL(Cosh);
783         case k_tanh_IntrinsicKind:           return ALL_GLSL(Tanh);
784         case k_asinh_IntrinsicKind:          return ALL_GLSL(Asinh);
785         case k_acosh_IntrinsicKind:          return ALL_GLSL(Acosh);
786         case k_atanh_IntrinsicKind:          return ALL_GLSL(Atanh);
787         case k_pow_IntrinsicKind:            return ALL_GLSL(Pow);
788         case k_exp_IntrinsicKind:            return ALL_GLSL(Exp);
789         case k_log_IntrinsicKind:            return ALL_GLSL(Log);
790         case k_exp2_IntrinsicKind:           return ALL_GLSL(Exp2);
791         case k_log2_IntrinsicKind:           return ALL_GLSL(Log2);
792         case k_sqrt_IntrinsicKind:           return ALL_GLSL(Sqrt);
793         case k_inverse_IntrinsicKind:        return ALL_GLSL(MatrixInverse);
794         case k_outerProduct_IntrinsicKind:   return ALL_SPIRV(OuterProduct);
795         case k_transpose_IntrinsicKind:      return ALL_SPIRV(Transpose);
796         case k_isinf_IntrinsicKind:          return ALL_SPIRV(IsInf);
797         case k_isnan_IntrinsicKind:          return ALL_SPIRV(IsNan);
798         case k_inversesqrt_IntrinsicKind:    return ALL_GLSL(InverseSqrt);
799         case k_determinant_IntrinsicKind:    return ALL_GLSL(Determinant);
800         case k_matrixCompMult_IntrinsicKind: return SPECIAL(MatrixCompMult);
801         case k_matrixInverse_IntrinsicKind:  return ALL_GLSL(MatrixInverse);
802         case k_mod_IntrinsicKind:            return SPECIAL(Mod);
803         case k_modf_IntrinsicKind:           return ALL_GLSL(Modf);
804         case k_min_IntrinsicKind:            return SPECIAL(Min);
805         case k_max_IntrinsicKind:            return SPECIAL(Max);
806         case k_clamp_IntrinsicKind:          return SPECIAL(Clamp);
807         case k_saturate_IntrinsicKind:       return SPECIAL(Saturate);
808         case k_dot_IntrinsicKind:            return FLOAT_SPIRV(Dot);
809         case k_mix_IntrinsicKind:            return SPECIAL(Mix);
810         case k_step_IntrinsicKind:           return SPECIAL(Step);
811         case k_smoothstep_IntrinsicKind:     return SPECIAL(SmoothStep);
812         case k_fma_IntrinsicKind:            return ALL_GLSL(Fma);
813         case k_frexp_IntrinsicKind:          return ALL_GLSL(Frexp);
814         case k_ldexp_IntrinsicKind:          return ALL_GLSL(Ldexp);
815 
816 #define PACK(type) case k_pack##type##_IntrinsicKind:   return ALL_GLSL(Pack##type); \
817                    case k_unpack##type##_IntrinsicKind: return ALL_GLSL(Unpack##type)
818         PACK(Snorm4x8);
819         PACK(Unorm4x8);
820         PACK(Snorm2x16);
821         PACK(Unorm2x16);
822         PACK(Half2x16);
823 #undef PACK
824 
825         case k_length_IntrinsicKind:        return ALL_GLSL(Length);
826         case k_distance_IntrinsicKind:      return ALL_GLSL(Distance);
827         case k_cross_IntrinsicKind:         return ALL_GLSL(Cross);
828         case k_normalize_IntrinsicKind:     return ALL_GLSL(Normalize);
829         case k_faceforward_IntrinsicKind:   return ALL_GLSL(FaceForward);
830         case k_reflect_IntrinsicKind:       return ALL_GLSL(Reflect);
831         case k_refract_IntrinsicKind:       return ALL_GLSL(Refract);
832         case k_bitCount_IntrinsicKind:      return ALL_SPIRV(BitCount);
833         case k_findLSB_IntrinsicKind:       return ALL_GLSL(FindILsb);
834         case k_findMSB_IntrinsicKind:       return BY_TYPE_GLSL(FindSMsb, FindSMsb, FindUMsb);
835         case k_dFdx_IntrinsicKind:          return FLOAT_SPIRV(DPdx);
836         case k_dFdy_IntrinsicKind:          return SPECIAL(DFdy);
837         case k_fwidth_IntrinsicKind:        return FLOAT_SPIRV(Fwidth);
838 
839         case k_sample_IntrinsicKind:      return SPECIAL(Texture);
840         case k_sampleGrad_IntrinsicKind:  return SPECIAL(TextureGrad);
841         case k_sampleLod_IntrinsicKind:   return SPECIAL(TextureLod);
842         case k_subpassLoad_IntrinsicKind: return SPECIAL(SubpassLoad);
843 
844         case k_textureRead_IntrinsicKind:  return SPECIAL(TextureRead);
845         case k_textureWrite_IntrinsicKind:  return SPECIAL(TextureWrite);
846         case k_textureWidth_IntrinsicKind:  return SPECIAL(TextureWidth);
847         case k_textureHeight_IntrinsicKind:  return SPECIAL(TextureHeight);
848 
849         case k_floatBitsToInt_IntrinsicKind:  return ALL_SPIRV(Bitcast);
850         case k_floatBitsToUint_IntrinsicKind: return ALL_SPIRV(Bitcast);
851         case k_intBitsToFloat_IntrinsicKind:  return ALL_SPIRV(Bitcast);
852         case k_uintBitsToFloat_IntrinsicKind: return ALL_SPIRV(Bitcast);
853 
854         case k_any_IntrinsicKind:   return BOOL_SPIRV(Any);
855         case k_all_IntrinsicKind:   return BOOL_SPIRV(All);
856         case k_not_IntrinsicKind:   return BOOL_SPIRV(LogicalNot);
857 
858         case k_equal_IntrinsicKind:
859             return Intrinsic{kSPIRV_IntrinsicOpcodeKind,
860                              SpvOpFOrdEqual,
861                              SpvOpIEqual,
862                              SpvOpIEqual,
863                              SpvOpLogicalEqual};
864         case k_notEqual_IntrinsicKind:
865             return Intrinsic{kSPIRV_IntrinsicOpcodeKind,
866                              SpvOpFUnordNotEqual,
867                              SpvOpINotEqual,
868                              SpvOpINotEqual,
869                              SpvOpLogicalNotEqual};
870         case k_lessThan_IntrinsicKind:
871             return Intrinsic{kSPIRV_IntrinsicOpcodeKind,
872                              SpvOpFOrdLessThan,
873                              SpvOpSLessThan,
874                              SpvOpULessThan,
875                              SpvOpUndef};
876         case k_lessThanEqual_IntrinsicKind:
877             return Intrinsic{kSPIRV_IntrinsicOpcodeKind,
878                              SpvOpFOrdLessThanEqual,
879                              SpvOpSLessThanEqual,
880                              SpvOpULessThanEqual,
881                              SpvOpUndef};
882         case k_greaterThan_IntrinsicKind:
883             return Intrinsic{kSPIRV_IntrinsicOpcodeKind,
884                              SpvOpFOrdGreaterThan,
885                              SpvOpSGreaterThan,
886                              SpvOpUGreaterThan,
887                              SpvOpUndef};
888         case k_greaterThanEqual_IntrinsicKind:
889             return Intrinsic{kSPIRV_IntrinsicOpcodeKind,
890                              SpvOpFOrdGreaterThanEqual,
891                              SpvOpSGreaterThanEqual,
892                              SpvOpUGreaterThanEqual,
893                              SpvOpUndef};
894 
895         case k_atomicAdd_IntrinsicKind:   return SPECIAL(AtomicAdd);
896         case k_atomicLoad_IntrinsicKind:  return SPECIAL(AtomicLoad);
897         case k_atomicStore_IntrinsicKind: return SPECIAL(AtomicStore);
898 
899         case k_storageBarrier_IntrinsicKind:   return SPECIAL(StorageBarrier);
900         case k_workgroupBarrier_IntrinsicKind: return SPECIAL(WorkgroupBarrier);
901         default:
902             return Intrinsic{kInvalid_IntrinsicOpcodeKind, 0, 0, 0, 0};
903     }
904 }
905 
writeWord(int32_t word,OutputStream & out)906 void SPIRVCodeGenerator::writeWord(int32_t word, OutputStream& out) {
907     out.write((const char*) &word, sizeof(word));
908 }
909 
is_float(const Type & type)910 static bool is_float(const Type& type) {
911     return (type.isScalar() || type.isVector() || type.isMatrix()) &&
912            type.componentType().isFloat();
913 }
914 
is_signed(const Type & type)915 static bool is_signed(const Type& type) {
916     return (type.isScalar() || type.isVector()) && type.componentType().isSigned();
917 }
918 
is_unsigned(const Type & type)919 static bool is_unsigned(const Type& type) {
920     return (type.isScalar() || type.isVector()) && type.componentType().isUnsigned();
921 }
922 
is_bool(const Type & type)923 static bool is_bool(const Type& type) {
924     return (type.isScalar() || type.isVector()) && type.componentType().isBoolean();
925 }
926 
927 template <typename T>
pick_by_type(const Type & type,T ifFloat,T ifInt,T ifUInt,T ifBool)928 static T pick_by_type(const Type& type, T ifFloat, T ifInt, T ifUInt, T ifBool) {
929     if (is_float(type)) {
930         return ifFloat;
931     }
932     if (is_signed(type)) {
933         return ifInt;
934     }
935     if (is_unsigned(type)) {
936         return ifUInt;
937     }
938     if (is_bool(type)) {
939         return ifBool;
940     }
941     SkDEBUGFAIL("unrecognized type");
942     return ifFloat;
943 }
944 
is_out(ModifierFlags f)945 static bool is_out(ModifierFlags f) {
946     return SkToBool(f & ModifierFlag::kOut);
947 }
948 
is_in(ModifierFlags f)949 static bool is_in(ModifierFlags f) {
950     if (f & ModifierFlag::kIn) {
951         return true;  // `in` and `inout` both count
952     }
953     // If neither in/out flag is set, the type is implicitly `in`.
954     return !SkToBool(f & ModifierFlag::kOut);
955 }
956 
is_control_flow_op(SpvOp_ op)957 static bool is_control_flow_op(SpvOp_ op) {
958     switch (op) {
959         case SpvOpReturn:
960         case SpvOpReturnValue:
961         case SpvOpKill:
962         case SpvOpSwitch:
963         case SpvOpBranch:
964         case SpvOpBranchConditional:
965             return true;
966         default:
967             return false;
968     }
969 }
970 
is_globally_reachable_op(SpvOp_ op)971 static bool is_globally_reachable_op(SpvOp_ op) {
972     switch (op) {
973         case SpvOpConstant:
974         case SpvOpConstantTrue:
975         case SpvOpConstantFalse:
976         case SpvOpConstantComposite:
977         case SpvOpTypeVoid:
978         case SpvOpTypeInt:
979         case SpvOpTypeFloat:
980         case SpvOpTypeBool:
981         case SpvOpTypeVector:
982         case SpvOpTypeMatrix:
983         case SpvOpTypeArray:
984         case SpvOpTypePointer:
985         case SpvOpTypeFunction:
986         case SpvOpTypeRuntimeArray:
987         case SpvOpTypeStruct:
988         case SpvOpTypeImage:
989         case SpvOpTypeSampledImage:
990         case SpvOpTypeSampler:
991         case SpvOpVariable:
992         case SpvOpFunction:
993         case SpvOpFunctionParameter:
994         case SpvOpFunctionEnd:
995         case SpvOpExecutionMode:
996         case SpvOpMemoryModel:
997         case SpvOpCapability:
998         case SpvOpExtInstImport:
999         case SpvOpEntryPoint:
1000         case SpvOpSource:
1001         case SpvOpSourceExtension:
1002         case SpvOpName:
1003         case SpvOpMemberName:
1004         case SpvOpDecorate:
1005         case SpvOpMemberDecorate:
1006             return true;
1007         default:
1008             return false;
1009     }
1010 }
1011 
writeOpCode(SpvOp_ opCode,int length,OutputStream & out)1012 void SPIRVCodeGenerator::writeOpCode(SpvOp_ opCode, int length, OutputStream& out) {
1013     SkASSERT(opCode != SpvOpLoad || &out != &fConstantBuffer);
1014     SkASSERT(opCode != SpvOpUndef);
1015     bool foundDeadCode = false;
1016     if (is_control_flow_op(opCode)) {
1017         // This instruction causes us to leave the current block.
1018         foundDeadCode = (fCurrentBlock == 0);
1019         fCurrentBlock = 0;
1020     } else if (!is_globally_reachable_op(opCode)) {
1021         foundDeadCode = (fCurrentBlock == 0);
1022     }
1023 
1024     if (foundDeadCode) {
1025         // We just encountered dead code--an instruction that don't have an associated block.
1026         // Synthesize a label if this happens; this is necessary to satisfy the validator.
1027         this->writeLabel(this->nextId(nullptr), kBranchlessBlock, out);
1028     }
1029 
1030     this->writeWord((length << 16) | opCode, out);
1031 }
1032 
writeLabel(SpvId label,StraightLineLabelType,OutputStream & out)1033 void SPIRVCodeGenerator::writeLabel(SpvId label, StraightLineLabelType, OutputStream& out) {
1034     // The straight-line label type is not important; in any case, no caches are invalidated.
1035     SkASSERT(!fCurrentBlock);
1036     fCurrentBlock = label;
1037     this->writeInstruction(SpvOpLabel, label, out);
1038 }
1039 
writeLabel(SpvId label,BranchingLabelType type,ConditionalOpCounts ops,OutputStream & out)1040 void SPIRVCodeGenerator::writeLabel(SpvId label, BranchingLabelType type,
1041                                     ConditionalOpCounts ops, OutputStream& out) {
1042     switch (type) {
1043         case kBranchIsBelow:
1044         case kBranchesOnBothSides:
1045             // With a backward or bidirectional branch, we haven't seen the code between the label
1046             // and the branch yet, so any stored value is potentially suspect. Without scanning
1047             // ahead to check, the only safe option is to ditch the store cache entirely.
1048             fStoreCache.reset();
1049             [[fallthrough]];
1050 
1051         case kBranchIsAbove:
1052             // With a forward branch, we can rely on stores that we had cached at the start of the
1053             // statement/expression, if they haven't been touched yet. Anything newer than that is
1054             // pruned.
1055             this->pruneConditionalOps(ops);
1056             break;
1057     }
1058 
1059     // Emit the label.
1060     this->writeLabel(label, kBranchlessBlock, out);
1061 }
1062 
writeInstruction(SpvOp_ opCode,OutputStream & out)1063 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, OutputStream& out) {
1064     this->writeOpCode(opCode, 1, out);
1065 }
1066 
writeInstruction(SpvOp_ opCode,int32_t word1,OutputStream & out)1067 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, OutputStream& out) {
1068     this->writeOpCode(opCode, 2, out);
1069     this->writeWord(word1, out);
1070 }
1071 
writeString(std::string_view s,OutputStream & out)1072 void SPIRVCodeGenerator::writeString(std::string_view s, OutputStream& out) {
1073     out.write(s.data(), s.length());
1074     switch (s.length() % 4) {
1075         case 1:
1076             out.write8(0);
1077             [[fallthrough]];
1078         case 2:
1079             out.write8(0);
1080             [[fallthrough]];
1081         case 3:
1082             out.write8(0);
1083             break;
1084         default:
1085             this->writeWord(0, out);
1086             break;
1087     }
1088 }
1089 
writeInstruction(SpvOp_ opCode,std::string_view string,OutputStream & out)1090 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, std::string_view string,
1091                                           OutputStream& out) {
1092     this->writeOpCode(opCode, 1 + (string.length() + 4) / 4, out);
1093     this->writeString(string, out);
1094 }
1095 
writeInstruction(SpvOp_ opCode,int32_t word1,std::string_view string,OutputStream & out)1096 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, std::string_view string,
1097                                           OutputStream& out) {
1098     this->writeOpCode(opCode, 2 + (string.length() + 4) / 4, out);
1099     this->writeWord(word1, out);
1100     this->writeString(string, out);
1101 }
1102 
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,std::string_view string,OutputStream & out)1103 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
1104                                           std::string_view string, OutputStream& out) {
1105     this->writeOpCode(opCode, 3 + (string.length() + 4) / 4, out);
1106     this->writeWord(word1, out);
1107     this->writeWord(word2, out);
1108     this->writeString(string, out);
1109 }
1110 
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,OutputStream & out)1111 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
1112                                           OutputStream& out) {
1113     this->writeOpCode(opCode, 3, out);
1114     this->writeWord(word1, out);
1115     this->writeWord(word2, out);
1116 }
1117 
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,int32_t word3,OutputStream & out)1118 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
1119                                           int32_t word3, OutputStream& out) {
1120     this->writeOpCode(opCode, 4, out);
1121     this->writeWord(word1, out);
1122     this->writeWord(word2, out);
1123     this->writeWord(word3, out);
1124 }
1125 
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,int32_t word3,int32_t word4,OutputStream & out)1126 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
1127                                           int32_t word3, int32_t word4, OutputStream& out) {
1128     this->writeOpCode(opCode, 5, out);
1129     this->writeWord(word1, out);
1130     this->writeWord(word2, out);
1131     this->writeWord(word3, out);
1132     this->writeWord(word4, out);
1133 }
1134 
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,int32_t word3,int32_t word4,int32_t word5,OutputStream & out)1135 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
1136                                           int32_t word3, int32_t word4, int32_t word5,
1137                                           OutputStream& out) {
1138     this->writeOpCode(opCode, 6, out);
1139     this->writeWord(word1, out);
1140     this->writeWord(word2, out);
1141     this->writeWord(word3, out);
1142     this->writeWord(word4, out);
1143     this->writeWord(word5, out);
1144 }
1145 
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,int32_t word3,int32_t word4,int32_t word5,int32_t word6,OutputStream & out)1146 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
1147                                           int32_t word3, int32_t word4, int32_t word5,
1148                                           int32_t word6, OutputStream& out) {
1149     this->writeOpCode(opCode, 7, out);
1150     this->writeWord(word1, out);
1151     this->writeWord(word2, out);
1152     this->writeWord(word3, out);
1153     this->writeWord(word4, out);
1154     this->writeWord(word5, out);
1155     this->writeWord(word6, out);
1156 }
1157 
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,int32_t word3,int32_t word4,int32_t word5,int32_t word6,int32_t word7,OutputStream & out)1158 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
1159                                           int32_t word3, int32_t word4, int32_t word5,
1160                                           int32_t word6, int32_t word7, OutputStream& out) {
1161     this->writeOpCode(opCode, 8, out);
1162     this->writeWord(word1, out);
1163     this->writeWord(word2, out);
1164     this->writeWord(word3, out);
1165     this->writeWord(word4, out);
1166     this->writeWord(word5, out);
1167     this->writeWord(word6, out);
1168     this->writeWord(word7, out);
1169 }
1170 
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,int32_t word3,int32_t word4,int32_t word5,int32_t word6,int32_t word7,int32_t word8,OutputStream & out)1171 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
1172                                           int32_t word3, int32_t word4, int32_t word5,
1173                                           int32_t word6, int32_t word7, int32_t word8,
1174                                           OutputStream& out) {
1175     this->writeOpCode(opCode, 9, out);
1176     this->writeWord(word1, out);
1177     this->writeWord(word2, out);
1178     this->writeWord(word3, out);
1179     this->writeWord(word4, out);
1180     this->writeWord(word5, out);
1181     this->writeWord(word6, out);
1182     this->writeWord(word7, out);
1183     this->writeWord(word8, out);
1184 }
1185 
BuildInstructionKey(SpvOp_ opCode,const TArray<Word> & words)1186 SPIRVCodeGenerator::Instruction SPIRVCodeGenerator::BuildInstructionKey(SpvOp_ opCode,
1187                                                                         const TArray<Word>& words) {
1188     // Assemble a cache key for this instruction.
1189     Instruction key;
1190     key.fOp = opCode;
1191     key.fWords.resize(words.size());
1192     key.fResultKind = Word::Kind::kNone;
1193 
1194     for (int index = 0; index < words.size(); ++index) {
1195         const Word& word = words[index];
1196         key.fWords[index] = word.fValue;
1197         if (word.isResult()) {
1198             SkASSERT(key.fResultKind == Word::Kind::kNone);
1199             key.fResultKind = word.fKind;
1200         }
1201     }
1202 
1203     return key;
1204 }
1205 
writeInstruction(SpvOp_ opCode,const TArray<Word> & words,OutputStream & out)1206 SpvId SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode,
1207                                            const TArray<Word>& words,
1208                                            OutputStream& out) {
1209     // writeOpLoad and writeOpStore have dedicated code.
1210     SkASSERT(opCode != SpvOpLoad);
1211     SkASSERT(opCode != SpvOpStore);
1212 
1213     // If this instruction exists in our op cache, return the cached SpvId.
1214     Instruction key = BuildInstructionKey(opCode, words);
1215     if (SpvId* cachedOp = fOpCache.find(key)) {
1216         return *cachedOp;
1217     }
1218 
1219     SpvId result = NA;
1220     Precision precision = Precision::kDefault;
1221 
1222     switch (key.fResultKind) {
1223         case Word::Kind::kUniqueResult:
1224             // The instruction returns a SpvId, but we do not want deduplication.
1225             result = this->nextId(Precision::kDefault);
1226             fSpvIdCache.set(result, key);
1227             break;
1228 
1229         case Word::Kind::kNone:
1230             // The instruction doesn't return a SpvId, but we can still cache and deduplicate it.
1231             fOpCache.set(key, result);
1232             break;
1233 
1234         case Word::Kind::kRelaxedPrecisionResult:
1235             precision = Precision::kRelaxed;
1236             [[fallthrough]];
1237 
1238         case Word::Kind::kKeyedResult:
1239             [[fallthrough]];
1240 
1241         case Word::Kind::kDefaultPrecisionResult:
1242             // Consume a new SpvId.
1243             result = this->nextId(precision);
1244             fOpCache.set(key, result);
1245             fSpvIdCache.set(result, key);
1246 
1247             // Globally-reachable ops are not subject to the whims of flow control.
1248             if (!is_globally_reachable_op(opCode)) {
1249                 fReachableOps.push_back(result);
1250             }
1251             break;
1252 
1253         default:
1254             SkDEBUGFAIL("unexpected result kind");
1255             break;
1256     }
1257 
1258     // Write the requested instruction.
1259     this->writeOpCode(opCode, words.size() + 1, out);
1260     for (const Word& word : words) {
1261         if (word.isResult()) {
1262             SkASSERT(result != NA);
1263             this->writeWord(result, out);
1264         } else {
1265             this->writeWord(word.fValue, out);
1266         }
1267     }
1268 
1269     // Return the result.
1270     return result;
1271 }
1272 
writeOpLoad(SpvId type,Precision precision,SpvId pointer,OutputStream & out)1273 SpvId SPIRVCodeGenerator::writeOpLoad(SpvId type,
1274                                       Precision precision,
1275                                       SpvId pointer,
1276                                       OutputStream& out) {
1277     // Look for this pointer in our load-cache.
1278     if (SpvId* cachedOp = fStoreCache.find(pointer)) {
1279         return *cachedOp;
1280     }
1281 
1282     // Write the requested OpLoad instruction.
1283     SpvId result = this->nextId(precision);
1284     this->writeInstruction(SpvOpLoad, type, result, pointer, out);
1285     return result;
1286 }
1287 
writeOpStore(SpvStorageClass_ storageClass,SpvId pointer,SpvId value,OutputStream & out)1288 void SPIRVCodeGenerator::writeOpStore(SpvStorageClass_ storageClass,
1289                                       SpvId pointer,
1290                                       SpvId value,
1291                                       OutputStream& out) {
1292     // Write the uncached SpvOpStore directly.
1293     this->writeInstruction(SpvOpStore, pointer, value, out);
1294 
1295     if (storageClass == SpvStorageClassFunction) {
1296         // Insert a pointer-to-SpvId mapping into the load cache. A writeOpLoad to this pointer will
1297         // return the cached value as-is.
1298         fStoreCache.set(pointer, value);
1299         fStoreOps.push_back(pointer);
1300     }
1301 }
1302 
writeOpConstantTrue(const Type & type)1303 SpvId SPIRVCodeGenerator::writeOpConstantTrue(const Type& type) {
1304     return this->writeInstruction(SpvOpConstantTrue,
1305                                   Words{this->getType(type), Word::Result()},
1306                                   fConstantBuffer);
1307 }
1308 
writeOpConstantFalse(const Type & type)1309 SpvId SPIRVCodeGenerator::writeOpConstantFalse(const Type& type) {
1310     return this->writeInstruction(SpvOpConstantFalse,
1311                                   Words{this->getType(type), Word::Result()},
1312                                   fConstantBuffer);
1313 }
1314 
writeOpConstant(const Type & type,int32_t valueBits)1315 SpvId SPIRVCodeGenerator::writeOpConstant(const Type& type, int32_t valueBits) {
1316     return this->writeInstruction(
1317             SpvOpConstant,
1318             Words{this->getType(type), Word::Result(), Word::Number(valueBits)},
1319             fConstantBuffer);
1320 }
1321 
writeOpConstantComposite(const Type & type,const TArray<SpvId> & values)1322 SpvId SPIRVCodeGenerator::writeOpConstantComposite(const Type& type,
1323                                                    const TArray<SpvId>& values) {
1324     SkASSERT(values.size() == (type.isStruct() ? SkToInt(type.fields().size()) : type.columns()));
1325 
1326     Words words;
1327     words.push_back(this->getType(type));
1328     words.push_back(Word::Result());
1329     for (SpvId value : values) {
1330         words.push_back(value);
1331     }
1332     return this->writeInstruction(SpvOpConstantComposite, words, fConstantBuffer);
1333 }
1334 
toConstants(SpvId value,TArray<SpvId> * constants)1335 bool SPIRVCodeGenerator::toConstants(SpvId value, TArray<SpvId>* constants) {
1336     Instruction* instr = fSpvIdCache.find(value);
1337     if (!instr) {
1338         return false;
1339     }
1340     switch (instr->fOp) {
1341         case SpvOpConstant:
1342         case SpvOpConstantTrue:
1343         case SpvOpConstantFalse:
1344             constants->push_back(value);
1345             return true;
1346 
1347         case SpvOpConstantComposite: // OpConstantComposite ResultType ResultID Constituents...
1348             // Start at word 2 to skip past ResultType and ResultID.
1349             for (int i = 2; i < instr->fWords.size(); ++i) {
1350                 if (!this->toConstants(instr->fWords[i], constants)) {
1351                     return false;
1352                 }
1353             }
1354             return true;
1355 
1356         default:
1357             return false;
1358     }
1359 }
1360 
toConstants(SkSpan<const SpvId> values,TArray<SpvId> * constants)1361 bool SPIRVCodeGenerator::toConstants(SkSpan<const SpvId> values, TArray<SpvId>* constants) {
1362     for (SpvId value : values) {
1363         if (!this->toConstants(value, constants)) {
1364             return false;
1365         }
1366     }
1367     return true;
1368 }
1369 
writeOpCompositeConstruct(const Type & type,const TArray<SpvId> & values,OutputStream & out)1370 SpvId SPIRVCodeGenerator::writeOpCompositeConstruct(const Type& type,
1371                                                     const TArray<SpvId>& values,
1372                                                     OutputStream& out) {
1373     // If this is a vector composed entirely of literals, write a constant-composite instead.
1374     if (type.isVector()) {
1375         STArray<4, SpvId> constants;
1376         if (this->toConstants(SkSpan(values), &constants)) {
1377             // Create a vector from literals.
1378             return this->writeOpConstantComposite(type, constants);
1379         }
1380     }
1381 
1382     // If this is a matrix composed entirely of literals, constant-composite them instead.
1383     if (type.isMatrix()) {
1384         STArray<16, SpvId> constants;
1385         if (this->toConstants(SkSpan(values), &constants)) {
1386             // Create each matrix column.
1387             SkASSERT(type.isMatrix());
1388             const Type& vecType = type.columnType(fContext);
1389             STArray<4, SpvId> columnIDs;
1390             for (int index=0; index < type.columns(); ++index) {
1391                 STArray<4, SpvId> columnConstants(&constants[index * type.rows()],
1392                                                     type.rows());
1393                 columnIDs.push_back(this->writeOpConstantComposite(vecType, columnConstants));
1394             }
1395             // Compose the matrix from its columns.
1396             return this->writeOpConstantComposite(type, columnIDs);
1397         }
1398     }
1399 
1400     Words words;
1401     words.push_back(this->getType(type));
1402     words.push_back(Word::Result(type));
1403     for (SpvId value : values) {
1404         words.push_back(value);
1405     }
1406 
1407     return this->writeInstruction(SpvOpCompositeConstruct, words, out);
1408 }
1409 
resultTypeForInstruction(const Instruction & instr)1410 SPIRVCodeGenerator::Instruction* SPIRVCodeGenerator::resultTypeForInstruction(
1411         const Instruction& instr) {
1412     // This list should contain every op that we cache that has a result and result-type.
1413     // (If one is missing, we will not find some optimization opportunities.)
1414     // Generally, the result type of an op is in the 0th word, but I'm not sure if this is
1415     // universally true, so it's configurable on a per-op basis.
1416     int resultTypeWord;
1417     switch (instr.fOp) {
1418         case SpvOpConstant:
1419         case SpvOpConstantTrue:
1420         case SpvOpConstantFalse:
1421         case SpvOpConstantComposite:
1422         case SpvOpCompositeConstruct:
1423         case SpvOpCompositeExtract:
1424         case SpvOpLoad:
1425             resultTypeWord = 0;
1426             break;
1427 
1428         default:
1429             return nullptr;
1430     }
1431 
1432     Instruction* typeInstr = fSpvIdCache.find(instr.fWords[resultTypeWord]);
1433     SkASSERT(typeInstr);
1434     return typeInstr;
1435 }
1436 
numComponentsForVecInstruction(const Instruction & instr)1437 int SPIRVCodeGenerator::numComponentsForVecInstruction(const Instruction& instr) {
1438     // If an instruction is in the op cache, its type should be as well.
1439     Instruction* typeInstr = this->resultTypeForInstruction(instr);
1440     SkASSERT(typeInstr);
1441     SkASSERT(typeInstr->fOp == SpvOpTypeVector || typeInstr->fOp == SpvOpTypeFloat ||
1442              typeInstr->fOp == SpvOpTypeInt || typeInstr->fOp == SpvOpTypeBool);
1443 
1444     // For vectors, extract their column count. Scalars have one component by definition.
1445     //   SpvOpTypeVector ResultID ComponentType NumComponents
1446     return (typeInstr->fOp == SpvOpTypeVector) ? typeInstr->fWords[2]
1447                                                : 1;
1448 }
1449 
toComponent(SpvId id,int component)1450 SpvId SPIRVCodeGenerator::toComponent(SpvId id, int component) {
1451     Instruction* instr = fSpvIdCache.find(id);
1452     if (!instr) {
1453         return NA;
1454     }
1455     if (instr->fOp == SpvOpConstantComposite) {
1456         // SpvOpConstantComposite ResultType ResultID [components...]
1457         // Add 2 to the component index to skip past ResultType and ResultID.
1458         return instr->fWords[2 + component];
1459     }
1460     if (instr->fOp == SpvOpCompositeConstruct) {
1461         // SpvOpCompositeConstruct ResultType ResultID [components...]
1462         // Vectors have special rules; check to see if we are composing a vector.
1463         Instruction* composedType = fSpvIdCache.find(instr->fWords[0]);
1464         SkASSERT(composedType);
1465 
1466         // When composing a non-vector, each instruction word maps 1:1 to the component index.
1467         // We can just extract out the associated component directly.
1468         if (composedType->fOp != SpvOpTypeVector) {
1469             return instr->fWords[2 + component];
1470         }
1471 
1472         // When composing a vector, components can be either scalars or vectors.
1473         // This means we need to check the op type on each component. (+2 to skip ResultType/Result)
1474         for (int index = 2; index < instr->fWords.size(); ++index) {
1475             int32_t currentWord = instr->fWords[index];
1476 
1477             // Retrieve the sub-instruction pointed to by OpCompositeConstruct.
1478             Instruction* subinstr = fSpvIdCache.find(currentWord);
1479             if (!subinstr) {
1480                 return NA;
1481             }
1482             // If this subinstruction contains the component we're looking for...
1483             int numComponents = this->numComponentsForVecInstruction(*subinstr);
1484             if (component < numComponents) {
1485                 if (numComponents == 1) {
1486                     // ... it's a scalar. Return it.
1487                     SkASSERT(component == 0);
1488                     return currentWord;
1489                 } else {
1490                     // ... it's a vector. Recurse into it.
1491                     return this->toComponent(currentWord, component);
1492                 }
1493             }
1494             // This sub-instruction doesn't contain our component. Keep walking forward.
1495             component -= numComponents;
1496         }
1497         SkDEBUGFAIL("component index goes past the end of this composite value");
1498         return NA;
1499     }
1500     return NA;
1501 }
1502 
writeOpCompositeExtract(const Type & type,SpvId base,int component,OutputStream & out)1503 SpvId SPIRVCodeGenerator::writeOpCompositeExtract(const Type& type,
1504                                                   SpvId base,
1505                                                   int component,
1506                                                   OutputStream& out) {
1507     // If the base op is a composite, we can extract from it directly.
1508     SpvId result = this->toComponent(base, component);
1509     if (result != NA) {
1510         return result;
1511     }
1512     return this->writeInstruction(
1513             SpvOpCompositeExtract,
1514             {this->getType(type), Word::Result(type), base, Word::Number(component)},
1515             out);
1516 }
1517 
writeOpCompositeExtract(const Type & type,SpvId base,int componentA,int componentB,OutputStream & out)1518 SpvId SPIRVCodeGenerator::writeOpCompositeExtract(const Type& type,
1519                                                   SpvId base,
1520                                                   int componentA,
1521                                                   int componentB,
1522                                                   OutputStream& out) {
1523     // If the base op is a composite, we can extract from it directly.
1524     SpvId result = this->toComponent(base, componentA);
1525     if (result != NA) {
1526         return this->writeOpCompositeExtract(type, result, componentB, out);
1527     }
1528     return this->writeInstruction(SpvOpCompositeExtract,
1529                                   {this->getType(type),
1530                                    Word::Result(type),
1531                                    base,
1532                                    Word::Number(componentA),
1533                                    Word::Number(componentB)},
1534                                   out);
1535 }
1536 
writeCapabilities(OutputStream & out)1537 void SPIRVCodeGenerator::writeCapabilities(OutputStream& out) {
1538     for (uint64_t i = 0, bit = 1; i <= kLast_Capability; i++, bit <<= 1) {
1539         if (fCapabilities & bit) {
1540             this->writeInstruction(SpvOpCapability, (SpvId) i, out);
1541         }
1542     }
1543     this->writeInstruction(SpvOpCapability, SpvCapabilityShader, out);
1544 }
1545 
nextId(const Type * type)1546 SpvId SPIRVCodeGenerator::nextId(const Type* type) {
1547     return this->nextId(type && type->hasPrecision() && !type->highPrecision()
1548                 ? Precision::kRelaxed
1549                 : Precision::kDefault);
1550 }
1551 
nextId(Precision precision)1552 SpvId SPIRVCodeGenerator::nextId(Precision precision) {
1553     if (precision == Precision::kRelaxed && !fProgram.fConfig->fSettings.fForceHighPrecision) {
1554         this->writeInstruction(SpvOpDecorate, fIdCount, SpvDecorationRelaxedPrecision,
1555                                fDecorationBuffer);
1556     }
1557     return fIdCount++;
1558 }
1559 
writeStruct(const Type & type,const MemoryLayout & memoryLayout)1560 SpvId SPIRVCodeGenerator::writeStruct(const Type& type, const MemoryLayout& memoryLayout) {
1561     // If we've already written out this struct, return its existing SpvId.
1562     if (SpvId* cachedStructId = fStructMap.find(&type)) {
1563         return *cachedStructId;
1564     }
1565 
1566     // Write all of the field types first, so we don't inadvertently write them while we're in the
1567     // middle of writing the struct instruction.
1568     Words words;
1569     words.push_back(Word::UniqueResult());
1570     for (const auto& f : type.fields()) {
1571         words.push_back(this->getType(*f.fType, f.fLayout, memoryLayout));
1572     }
1573     SpvId resultId = this->writeInstruction(SpvOpTypeStruct, words, fConstantBuffer);
1574     this->writeInstruction(SpvOpName, resultId, type.name(), fNameBuffer);
1575     fStructMap.set(&type, resultId);
1576 
1577     size_t offset = 0;
1578     for (int32_t i = 0; i < (int32_t) type.fields().size(); i++) {
1579         const Field& field = type.fields()[i];
1580         if (!memoryLayout.isSupported(*field.fType)) {
1581             fContext.fErrors->error(type.fPosition, "type '" + field.fType->displayName() +
1582                                                     "' is not permitted here");
1583             return resultId;
1584         }
1585         size_t size = memoryLayout.size(*field.fType);
1586         size_t alignment = memoryLayout.alignment(*field.fType);
1587         const Layout& fieldLayout = field.fLayout;
1588         if (fieldLayout.fOffset >= 0) {
1589             if (fieldLayout.fOffset < (int) offset) {
1590                 fContext.fErrors->error(field.fPosition, "offset of field '" +
1591                         std::string(field.fName) + "' must be at least " + std::to_string(offset));
1592             }
1593             if (fieldLayout.fOffset % alignment) {
1594                 fContext.fErrors->error(field.fPosition,
1595                                         "offset of field '" + std::string(field.fName) +
1596                                         "' must be a multiple of " + std::to_string(alignment));
1597             }
1598             offset = fieldLayout.fOffset;
1599         } else {
1600             size_t mod = offset % alignment;
1601             if (mod) {
1602                 offset += alignment - mod;
1603             }
1604         }
1605         this->writeInstruction(SpvOpMemberName, resultId, i, field.fName, fNameBuffer);
1606         this->writeFieldLayout(fieldLayout, resultId, i);
1607         if (field.fLayout.fBuiltin < 0) {
1608             this->writeInstruction(SpvOpMemberDecorate, resultId, (SpvId) i, SpvDecorationOffset,
1609                                    (SpvId) offset, fDecorationBuffer);
1610         }
1611         if (field.fType->isMatrix()) {
1612             this->writeInstruction(SpvOpMemberDecorate, resultId, i, SpvDecorationColMajor,
1613                                    fDecorationBuffer);
1614             this->writeInstruction(SpvOpMemberDecorate, resultId, i, SpvDecorationMatrixStride,
1615                                    (SpvId) memoryLayout.stride(*field.fType),
1616                                    fDecorationBuffer);
1617         }
1618         if (!field.fType->highPrecision()) {
1619             this->writeInstruction(SpvOpMemberDecorate, resultId, (SpvId) i,
1620                                    SpvDecorationRelaxedPrecision, fDecorationBuffer);
1621         }
1622         offset += size;
1623         if ((field.fType->isArray() || field.fType->isStruct()) && offset % alignment != 0) {
1624             offset += alignment - offset % alignment;
1625         }
1626     }
1627 
1628     return resultId;
1629 }
1630 
getType(const Type & type)1631 SpvId SPIRVCodeGenerator::getType(const Type& type) {
1632     return this->getType(type, kDefaultTypeLayout, fDefaultMemoryLayout);
1633 }
1634 
layout_flags_to_image_format(LayoutFlags flags)1635 static SpvImageFormat layout_flags_to_image_format(LayoutFlags flags) {
1636     flags &= LayoutFlag::kAllPixelFormats;
1637     switch (flags.value()) {
1638         case (int)LayoutFlag::kRGBA8:
1639             return SpvImageFormatRgba8;
1640 
1641         case (int)LayoutFlag::kRGBA32F:
1642             return SpvImageFormatRgba32f;
1643 
1644         case (int)LayoutFlag::kR32F:
1645             return SpvImageFormatR32f;
1646 
1647         default:
1648             return SpvImageFormatUnknown;
1649     }
1650 
1651     SkUNREACHABLE;
1652 }
1653 
getType(const Type & rawType,const Layout & typeLayout,const MemoryLayout & memoryLayout)1654 SpvId SPIRVCodeGenerator::getType(const Type& rawType,
1655                                   const Layout& typeLayout,
1656                                   const MemoryLayout& memoryLayout) {
1657     const Type* type = &rawType;
1658 
1659     switch (type->typeKind()) {
1660         case Type::TypeKind::kVoid: {
1661             return this->writeInstruction(SpvOpTypeVoid, Words{Word::Result()}, fConstantBuffer);
1662         }
1663         case Type::TypeKind::kScalar:
1664         case Type::TypeKind::kLiteral: {
1665             if (type->isBoolean()) {
1666                 return this->writeInstruction(SpvOpTypeBool, {Word::Result()}, fConstantBuffer);
1667             }
1668             if (type->isSigned()) {
1669                 return this->writeInstruction(
1670                         SpvOpTypeInt,
1671                         Words{Word::Result(), Word::Number(32), Word::Number(1)},
1672                         fConstantBuffer);
1673             }
1674             if (type->isUnsigned()) {
1675                 return this->writeInstruction(
1676                         SpvOpTypeInt,
1677                         Words{Word::Result(), Word::Number(32), Word::Number(0)},
1678                         fConstantBuffer);
1679             }
1680             if (type->isFloat()) {
1681                 return this->writeInstruction(
1682                         SpvOpTypeFloat,
1683                         Words{Word::Result(), Word::Number(32)},
1684                         fConstantBuffer);
1685             }
1686             SkDEBUGFAILF("unrecognized scalar type '%s'", type->description().c_str());
1687             return NA;
1688         }
1689         case Type::TypeKind::kVector: {
1690             SpvId scalarTypeId = this->getType(type->componentType(), typeLayout, memoryLayout);
1691             return this->writeInstruction(
1692                     SpvOpTypeVector,
1693                     Words{Word::Result(), scalarTypeId, Word::Number(type->columns())},
1694                     fConstantBuffer);
1695         }
1696         case Type::TypeKind::kMatrix: {
1697             SpvId vectorTypeId = this->getType(IndexExpression::IndexType(fContext, *type),
1698                                                typeLayout,
1699                                                memoryLayout);
1700             return this->writeInstruction(
1701                     SpvOpTypeMatrix,
1702                     Words{Word::Result(), vectorTypeId, Word::Number(type->columns())},
1703                     fConstantBuffer);
1704         }
1705         case Type::TypeKind::kArray: {
1706             if (!memoryLayout.isSupported(*type)) {
1707                 fContext.fErrors->error(type->fPosition, "type '" + type->displayName() +
1708                                                          "' is not permitted here");
1709                 return NA;
1710             }
1711             size_t stride = memoryLayout.stride(*type);
1712             SpvId typeId = this->getType(type->componentType(), typeLayout, memoryLayout);
1713             SpvId result = NA;
1714             if (type->isUnsizedArray()) {
1715                 result = this->writeInstruction(SpvOpTypeRuntimeArray,
1716                                                 Words{Word::KeyedResult(stride), typeId},
1717                                                 fConstantBuffer);
1718             } else {
1719                 SpvId countId = this->writeLiteral(type->columns(), *fContext.fTypes.fInt);
1720                 result = this->writeInstruction(SpvOpTypeArray,
1721                                                 Words{Word::KeyedResult(stride), typeId, countId},
1722                                                 fConstantBuffer);
1723             }
1724             this->writeInstruction(SpvOpDecorate,
1725                                    {result, SpvDecorationArrayStride, Word::Number(stride)},
1726                                    fDecorationBuffer);
1727             return result;
1728         }
1729         case Type::TypeKind::kStruct: {
1730             return this->writeStruct(*type, memoryLayout);
1731         }
1732         case Type::TypeKind::kSeparateSampler: {
1733             return this->writeInstruction(SpvOpTypeSampler, Words{Word::Result()}, fConstantBuffer);
1734         }
1735         case Type::TypeKind::kSampler: {
1736             if (SpvDimBuffer == type->dimensions()) {
1737                 fCapabilities |= 1ULL << SpvCapabilitySampledBuffer;
1738             }
1739             SpvId imageTypeId = this->getType(type->textureType(), typeLayout, memoryLayout);
1740             return this->writeInstruction(SpvOpTypeSampledImage,
1741                                           Words{Word::Result(), imageTypeId},
1742                                           fConstantBuffer);
1743         }
1744         case Type::TypeKind::kTexture: {
1745             SpvId floatTypeId = this->getType(*fContext.fTypes.fFloat,
1746                                               kDefaultTypeLayout,
1747                                               memoryLayout);
1748 
1749             bool sampled = (type->textureAccess() == Type::TextureAccess::kSample);
1750             SpvImageFormat format = (!sampled && type->dimensions() != SpvDimSubpassData)
1751                                             ? layout_flags_to_image_format(typeLayout.fFlags)
1752                                             : SpvImageFormatUnknown;
1753 
1754             return this->writeInstruction(SpvOpTypeImage,
1755                                           Words{Word::Result(),
1756                                                 floatTypeId,
1757                                                 Word::Number(type->dimensions()),
1758                                                 Word::Number(type->isDepth()),
1759                                                 Word::Number(type->isArrayedTexture()),
1760                                                 Word::Number(type->isMultisampled()),
1761                                                 Word::Number(sampled ? 1 : 2),
1762                                                 format},
1763                                           fConstantBuffer);
1764         }
1765         case Type::TypeKind::kAtomic: {
1766             // SkSL currently only supports the atomicUint type.
1767             SkASSERT(type->matches(*fContext.fTypes.fAtomicUInt));
1768             // SPIR-V doesn't have atomic types. Rather, it allows atomic operations on primitive
1769             // types. The SPIR-V type of an SkSL atomic is simply the underlying type.
1770             return this->writeInstruction(SpvOpTypeInt,
1771                                           Words{Word::Result(), Word::Number(32), Word::Number(0)},
1772                                           fConstantBuffer);
1773         }
1774         default: {
1775             SkDEBUGFAILF("invalid type: %s", type->description().c_str());
1776             return NA;
1777         }
1778     }
1779 }
1780 
getFunctionType(const FunctionDeclaration & function)1781 SpvId SPIRVCodeGenerator::getFunctionType(const FunctionDeclaration& function) {
1782     Words words;
1783     words.push_back(Word::Result());
1784     words.push_back(this->getType(function.returnType()));
1785     for (const Variable* parameter : function.parameters()) {
1786         if (fUseTextureSamplerPairs && parameter->type().isSampler()) {
1787             words.push_back(this->getFunctionParameterType(parameter->type().textureType(),
1788                                                            parameter->layout()));
1789             words.push_back(this->getFunctionParameterType(*fContext.fTypes.fSampler,
1790                                                            kDefaultTypeLayout));
1791         } else {
1792             words.push_back(this->getFunctionParameterType(parameter->type(), parameter->layout()));
1793         }
1794     }
1795     return this->writeInstruction(SpvOpTypeFunction, words, fConstantBuffer);
1796 }
1797 
getFunctionParameterType(const Type & parameterType,const Layout & parameterLayout)1798 SpvId SPIRVCodeGenerator::getFunctionParameterType(const Type& parameterType,
1799                                                    const Layout& parameterLayout) {
1800     // glslang treats all function arguments as pointers whether they need to be or
1801     // not. I was initially puzzled by this until I ran bizarre failures with certain
1802     // patterns of function calls and control constructs, as exemplified by this minimal
1803     // failure case:
1804     //
1805     // void sphere(float x) {
1806     // }
1807     //
1808     // void map() {
1809     //     sphere(1.0);
1810     // }
1811     //
1812     // void main() {
1813     //     for (int i = 0; i < 1; i++) {
1814     //         map();
1815     //     }
1816     // }
1817     //
1818     // As of this writing, compiling this in the "obvious" way (with sphere taking a float)
1819     // crashes. Making it take a float* and storing the argument in a temporary variable,
1820     // as glslang does, fixes it.
1821     //
1822     // The consensus among shader compiler authors seems to be that GPU driver generally don't
1823     // handle value-based parameters consistently. It is highly likely that they fit their
1824     // implementations to conform to glslang. We take care to do so ourselves.
1825     //
1826     // Our implementation first stores every parameter value into a function storage-class pointer
1827     // before calling a function. The exception is for opaque handle types (samplers and textures)
1828     // which must be stored in a pointer with UniformConstant storage-class. This prevents
1829     // unnecessary temporaries (becuase opaque handles are always rooted in a pointer variable),
1830     // matches glslang's behavior, and translates into WGSL more easily when targeting Dawn.
1831     SpvStorageClass_ storageClass;
1832     if (parameterType.typeKind() == Type::TypeKind::kSampler ||
1833         parameterType.typeKind() == Type::TypeKind::kSeparateSampler ||
1834         parameterType.typeKind() == Type::TypeKind::kTexture) {
1835         storageClass = SpvStorageClassUniformConstant;
1836     } else {
1837         storageClass = SpvStorageClassFunction;
1838     }
1839     return this->getPointerType(parameterType,
1840                                 parameterLayout,
1841                                 this->memoryLayoutForStorageClass(storageClass),
1842                                 storageClass);
1843 }
1844 
getPointerType(const Type & type,SpvStorageClass_ storageClass)1845 SpvId SPIRVCodeGenerator::getPointerType(const Type& type, SpvStorageClass_ storageClass) {
1846     return this->getPointerType(type,
1847                                 kDefaultTypeLayout,
1848                                 this->memoryLayoutForStorageClass(storageClass),
1849                                 storageClass);
1850 }
1851 
getPointerType(const Type & type,const Layout & typeLayout,const MemoryLayout & memoryLayout,SpvStorageClass_ storageClass)1852 SpvId SPIRVCodeGenerator::getPointerType(const Type& type,
1853                                          const Layout& typeLayout,
1854                                          const MemoryLayout& memoryLayout,
1855                                          SpvStorageClass_ storageClass) {
1856     return this->writeInstruction(SpvOpTypePointer,
1857                                   Words{Word::Result(),
1858                                         Word::Number(storageClass),
1859                                         this->getType(type, typeLayout, memoryLayout)},
1860                                   fConstantBuffer);
1861 }
1862 
writeExpression(const Expression & expr,OutputStream & out)1863 SpvId SPIRVCodeGenerator::writeExpression(const Expression& expr, OutputStream& out) {
1864     switch (expr.kind()) {
1865         case Expression::Kind::kBinary:
1866             return this->writeBinaryExpression(expr.as<BinaryExpression>(), out);
1867         case Expression::Kind::kConstructorArrayCast:
1868             return this->writeExpression(*expr.as<ConstructorArrayCast>().argument(), out);
1869         case Expression::Kind::kConstructorArray:
1870         case Expression::Kind::kConstructorStruct:
1871             return this->writeCompositeConstructor(expr.asAnyConstructor(), out);
1872         case Expression::Kind::kConstructorDiagonalMatrix:
1873             return this->writeConstructorDiagonalMatrix(expr.as<ConstructorDiagonalMatrix>(), out);
1874         case Expression::Kind::kConstructorMatrixResize:
1875             return this->writeConstructorMatrixResize(expr.as<ConstructorMatrixResize>(), out);
1876         case Expression::Kind::kConstructorScalarCast:
1877             return this->writeConstructorScalarCast(expr.as<ConstructorScalarCast>(), out);
1878         case Expression::Kind::kConstructorSplat:
1879             return this->writeConstructorSplat(expr.as<ConstructorSplat>(), out);
1880         case Expression::Kind::kConstructorCompound:
1881             return this->writeConstructorCompound(expr.as<ConstructorCompound>(), out);
1882         case Expression::Kind::kConstructorCompoundCast:
1883             return this->writeConstructorCompoundCast(expr.as<ConstructorCompoundCast>(), out);
1884         case Expression::Kind::kEmpty:
1885             return NA;
1886         case Expression::Kind::kFieldAccess:
1887             return this->writeFieldAccess(expr.as<FieldAccess>(), out);
1888         case Expression::Kind::kFunctionCall:
1889             return this->writeFunctionCall(expr.as<FunctionCall>(), out);
1890         case Expression::Kind::kLiteral:
1891             return this->writeLiteral(expr.as<Literal>());
1892         case Expression::Kind::kPrefix:
1893             return this->writePrefixExpression(expr.as<PrefixExpression>(), out);
1894         case Expression::Kind::kPostfix:
1895             return this->writePostfixExpression(expr.as<PostfixExpression>(), out);
1896         case Expression::Kind::kSwizzle:
1897             return this->writeSwizzle(expr.as<Swizzle>(), out);
1898         case Expression::Kind::kVariableReference:
1899             return this->writeVariableReference(expr.as<VariableReference>(), out);
1900         case Expression::Kind::kTernary:
1901             return this->writeTernaryExpression(expr.as<TernaryExpression>(), out);
1902         case Expression::Kind::kIndex:
1903             return this->writeIndexExpression(expr.as<IndexExpression>(), out);
1904         case Expression::Kind::kSetting:
1905             return this->writeExpression(*expr.as<Setting>().toLiteral(fCaps), out);
1906         default:
1907             SkDEBUGFAILF("unsupported expression: %s", expr.description().c_str());
1908             break;
1909     }
1910     return NA;
1911 }
1912 
writeIntrinsicCall(const FunctionCall & c,OutputStream & out)1913 SpvId SPIRVCodeGenerator::writeIntrinsicCall(const FunctionCall& c, OutputStream& out) {
1914     const FunctionDeclaration& function = c.function();
1915     Intrinsic intrinsic = this->getIntrinsic(function.intrinsicKind());
1916     if (intrinsic.opKind == kInvalid_IntrinsicOpcodeKind) {
1917         fContext.fErrors->error(c.fPosition, "unsupported intrinsic '" + function.description() +
1918                 "'");
1919         return NA;
1920     }
1921     const ExpressionArray& arguments = c.arguments();
1922     int32_t intrinsicId = intrinsic.floatOp;
1923     if (!arguments.empty()) {
1924         const Type& type = arguments[0]->type();
1925         if (intrinsic.opKind == kSpecial_IntrinsicOpcodeKind) {
1926             // Keep the default float op.
1927         } else {
1928             intrinsicId = pick_by_type(type, intrinsic.floatOp, intrinsic.signedOp,
1929                                        intrinsic.unsignedOp, intrinsic.boolOp);
1930         }
1931     }
1932     switch (intrinsic.opKind) {
1933         case kGLSL_STD_450_IntrinsicOpcodeKind: {
1934             SpvId result = this->nextId(&c.type());
1935             TArray<SpvId> argumentIds;
1936             argumentIds.reserve_exact(arguments.size());
1937             std::vector<TempVar> tempVars;
1938             for (int i = 0; i < arguments.size(); i++) {
1939                 argumentIds.push_back(this->writeFunctionCallArgument(c, i, &tempVars, out));
1940             }
1941             this->writeOpCode(SpvOpExtInst, 5 + (int32_t) argumentIds.size(), out);
1942             this->writeWord(this->getType(c.type()), out);
1943             this->writeWord(result, out);
1944             this->writeWord(fGLSLExtendedInstructions, out);
1945             this->writeWord(intrinsicId, out);
1946             for (SpvId id : argumentIds) {
1947                 this->writeWord(id, out);
1948             }
1949             this->copyBackTempVars(tempVars, out);
1950             return result;
1951         }
1952         case kSPIRV_IntrinsicOpcodeKind: {
1953             // GLSL supports dot(float, float), but SPIR-V does not. Convert it to FMul
1954             if (intrinsicId == SpvOpDot && arguments[0]->type().isScalar()) {
1955                 intrinsicId = SpvOpFMul;
1956             }
1957             SpvId result = this->nextId(&c.type());
1958             TArray<SpvId> argumentIds;
1959             argumentIds.reserve_exact(arguments.size());
1960             std::vector<TempVar> tempVars;
1961             for (int i = 0; i < arguments.size(); i++) {
1962                 argumentIds.push_back(this->writeFunctionCallArgument(c, i, &tempVars, out));
1963             }
1964             if (!c.type().isVoid()) {
1965                 this->writeOpCode((SpvOp_) intrinsicId, 3 + (int32_t) arguments.size(), out);
1966                 this->writeWord(this->getType(c.type()), out);
1967                 this->writeWord(result, out);
1968             } else {
1969                 this->writeOpCode((SpvOp_) intrinsicId, 1 + (int32_t) arguments.size(), out);
1970             }
1971             for (SpvId id : argumentIds) {
1972                 this->writeWord(id, out);
1973             }
1974             this->copyBackTempVars(tempVars, out);
1975             return result;
1976         }
1977         case kSpecial_IntrinsicOpcodeKind:
1978             return this->writeSpecialIntrinsic(c, (SpecialIntrinsic) intrinsicId, out);
1979         default:
1980             fContext.fErrors->error(c.fPosition, "unsupported intrinsic '" +
1981                     function.description() + "'");
1982             return NA;
1983     }
1984 }
1985 
vectorize(const Expression & arg,int vectorSize,OutputStream & out)1986 SpvId SPIRVCodeGenerator::vectorize(const Expression& arg, int vectorSize, OutputStream& out) {
1987     SkASSERT(vectorSize >= 1 && vectorSize <= 4);
1988     const Type& argType = arg.type();
1989     if (argType.isScalar() && vectorSize > 1) {
1990         ConstructorSplat splat{arg.fPosition,
1991                                argType.toCompound(fContext, vectorSize, /*rows=*/1),
1992                                arg.clone()};
1993         return this->writeConstructorSplat(splat, out);
1994     }
1995 
1996     SkASSERT(vectorSize == argType.columns());
1997     return this->writeExpression(arg, out);
1998 }
1999 
vectorize(const ExpressionArray & args,OutputStream & out)2000 TArray<SpvId> SPIRVCodeGenerator::vectorize(const ExpressionArray& args, OutputStream& out) {
2001     int vectorSize = 1;
2002     for (const auto& a : args) {
2003         if (a->type().isVector()) {
2004             if (vectorSize > 1) {
2005                 SkASSERT(a->type().columns() == vectorSize);
2006             } else {
2007                 vectorSize = a->type().columns();
2008             }
2009         }
2010     }
2011     TArray<SpvId> result;
2012     result.reserve_exact(args.size());
2013     for (const auto& arg : args) {
2014         result.push_back(this->vectorize(*arg, vectorSize, out));
2015     }
2016     return result;
2017 }
2018 
writeGLSLExtendedInstruction(const Type & type,SpvId id,SpvId floatInst,SpvId signedInst,SpvId unsignedInst,const TArray<SpvId> & args,OutputStream & out)2019 void SPIRVCodeGenerator::writeGLSLExtendedInstruction(const Type& type, SpvId id, SpvId floatInst,
2020                                                       SpvId signedInst, SpvId unsignedInst,
2021                                                       const TArray<SpvId>& args,
2022                                                       OutputStream& out) {
2023     this->writeOpCode(SpvOpExtInst, 5 + args.size(), out);
2024     this->writeWord(this->getType(type), out);
2025     this->writeWord(id, out);
2026     this->writeWord(fGLSLExtendedInstructions, out);
2027     this->writeWord(pick_by_type(type, floatInst, signedInst, unsignedInst, NA), out);
2028     for (SpvId a : args) {
2029         this->writeWord(a, out);
2030     }
2031 }
2032 
writeSpecialIntrinsic(const FunctionCall & c,SpecialIntrinsic kind,OutputStream & out)2033 SpvId SPIRVCodeGenerator::writeSpecialIntrinsic(const FunctionCall& c, SpecialIntrinsic kind,
2034                                                 OutputStream& out) {
2035     const ExpressionArray& arguments = c.arguments();
2036     const Type& callType = c.type();
2037     SpvId result = this->nextId(nullptr);
2038     switch (kind) {
2039         case kAtan_SpecialIntrinsic: {
2040             STArray<2, SpvId> argumentIds;
2041             for (const std::unique_ptr<Expression>& arg : arguments) {
2042                 argumentIds.push_back(this->writeExpression(*arg, out));
2043             }
2044             this->writeOpCode(SpvOpExtInst, 5 + (int32_t) argumentIds.size(), out);
2045             this->writeWord(this->getType(callType), out);
2046             this->writeWord(result, out);
2047             this->writeWord(fGLSLExtendedInstructions, out);
2048             this->writeWord(argumentIds.size() == 2 ? GLSLstd450Atan2 : GLSLstd450Atan, out);
2049             for (SpvId id : argumentIds) {
2050                 this->writeWord(id, out);
2051             }
2052             break;
2053         }
2054         case kSampledImage_SpecialIntrinsic: {
2055             SkASSERT(arguments.size() == 2);
2056             SpvId img = this->writeExpression(*arguments[0], out);
2057             SpvId sampler = this->writeExpression(*arguments[1], out);
2058             this->writeInstruction(SpvOpSampledImage,
2059                                    this->getType(callType),
2060                                    result,
2061                                    img,
2062                                    sampler,
2063                                    out);
2064             break;
2065         }
2066         case kSubpassLoad_SpecialIntrinsic: {
2067             SpvId img = this->writeExpression(*arguments[0], out);
2068             ExpressionArray args;
2069             args.reserve_exact(2);
2070             args.push_back(Literal::MakeInt(fContext, Position(), /*value=*/0));
2071             args.push_back(Literal::MakeInt(fContext, Position(), /*value=*/0));
2072             ConstructorCompound ctor(Position(), *fContext.fTypes.fInt2, std::move(args));
2073             SpvId coords = this->writeExpression(ctor, out);
2074             if (arguments.size() == 1) {
2075                 this->writeInstruction(SpvOpImageRead,
2076                                        this->getType(callType),
2077                                        result,
2078                                        img,
2079                                        coords,
2080                                        out);
2081             } else {
2082                 SkASSERT(arguments.size() == 2);
2083                 SpvId sample = this->writeExpression(*arguments[1], out);
2084                 this->writeInstruction(SpvOpImageRead,
2085                                        this->getType(callType),
2086                                        result,
2087                                        img,
2088                                        coords,
2089                                        SpvImageOperandsSampleMask,
2090                                        sample,
2091                                        out);
2092             }
2093             break;
2094         }
2095         case kTexture_SpecialIntrinsic: {
2096             SpvOp_ op = SpvOpImageSampleImplicitLod;
2097             const Type& arg1Type = arguments[1]->type();
2098             switch (arguments[0]->type().dimensions()) {
2099                 case SpvDim1D:
2100                     if (arg1Type.matches(*fContext.fTypes.fFloat2)) {
2101                         op = SpvOpImageSampleProjImplicitLod;
2102                     } else {
2103                         SkASSERT(arg1Type.matches(*fContext.fTypes.fFloat));
2104                     }
2105                     break;
2106                 case SpvDim2D:
2107                     if (arg1Type.matches(*fContext.fTypes.fFloat3)) {
2108                         op = SpvOpImageSampleProjImplicitLod;
2109                     } else {
2110                         SkASSERT(arg1Type.matches(*fContext.fTypes.fFloat2));
2111                     }
2112                     break;
2113                 case SpvDim3D:
2114                     if (arg1Type.matches(*fContext.fTypes.fFloat4)) {
2115                         op = SpvOpImageSampleProjImplicitLod;
2116                     } else {
2117                         SkASSERT(arg1Type.matches(*fContext.fTypes.fFloat3));
2118                     }
2119                     break;
2120                 case SpvDimCube:   // fall through
2121                 case SpvDimRect:   // fall through
2122                 case SpvDimBuffer: // fall through
2123                 case SpvDimSubpassData:
2124                     break;
2125             }
2126             SpvId type = this->getType(callType);
2127             SpvId sampler = this->writeExpression(*arguments[0], out);
2128             SpvId uv = this->writeExpression(*arguments[1], out);
2129             if (arguments.size() == 3) {
2130                 this->writeInstruction(op, type, result, sampler, uv,
2131                                        SpvImageOperandsBiasMask,
2132                                        this->writeExpression(*arguments[2], out),
2133                                        out);
2134             } else {
2135                 SkASSERT(arguments.size() == 2);
2136                 if (fProgram.fConfig->fSettings.fSharpenTextures) {
2137                     SpvId lodBias = this->writeLiteral(kSharpenTexturesBias,
2138                                                        *fContext.fTypes.fFloat);
2139                     this->writeInstruction(op, type, result, sampler, uv,
2140                                            SpvImageOperandsBiasMask, lodBias, out);
2141                 } else {
2142                     this->writeInstruction(op, type, result, sampler, uv,
2143                                            out);
2144                 }
2145             }
2146             break;
2147         }
2148         case kTextureGrad_SpecialIntrinsic: {
2149             SpvOp_ op = SpvOpImageSampleExplicitLod;
2150             SkASSERT(arguments.size() == 4);
2151             SkASSERT(arguments[0]->type().dimensions() == SpvDim2D);
2152             SkASSERT(arguments[1]->type().matches(*fContext.fTypes.fFloat2));
2153             SkASSERT(arguments[2]->type().matches(*fContext.fTypes.fFloat2));
2154             SkASSERT(arguments[3]->type().matches(*fContext.fTypes.fFloat2));
2155             SpvId type = this->getType(callType);
2156             SpvId sampler = this->writeExpression(*arguments[0], out);
2157             SpvId uv = this->writeExpression(*arguments[1], out);
2158             SpvId dPdx = this->writeExpression(*arguments[2], out);
2159             SpvId dPdy = this->writeExpression(*arguments[3], out);
2160             this->writeInstruction(op, type, result, sampler, uv, SpvImageOperandsGradMask,
2161                                    dPdx, dPdy, out);
2162             break;
2163         }
2164         case kTextureLod_SpecialIntrinsic: {
2165             SpvOp_ op = SpvOpImageSampleExplicitLod;
2166             SkASSERT(arguments.size() == 3);
2167             SkASSERT(arguments[0]->type().dimensions() == SpvDim2D);
2168             SkASSERT(arguments[2]->type().matches(*fContext.fTypes.fFloat));
2169             const Type& arg1Type = arguments[1]->type();
2170             if (arg1Type.matches(*fContext.fTypes.fFloat3)) {
2171                 op = SpvOpImageSampleProjExplicitLod;
2172             } else {
2173                 SkASSERT(arg1Type.matches(*fContext.fTypes.fFloat2));
2174             }
2175             SpvId type = this->getType(callType);
2176             SpvId sampler = this->writeExpression(*arguments[0], out);
2177             SpvId uv = this->writeExpression(*arguments[1], out);
2178             this->writeInstruction(op, type, result, sampler, uv,
2179                                    SpvImageOperandsLodMask,
2180                                    this->writeExpression(*arguments[2], out),
2181                                    out);
2182             break;
2183         }
2184         case kTextureRead_SpecialIntrinsic: {
2185             SkASSERT(arguments[0]->type().dimensions() == SpvDim2D);
2186             SkASSERT(arguments[1]->type().matches(*fContext.fTypes.fUInt2));
2187 
2188             SpvId type = this->getType(callType);
2189             SpvId image = this->writeExpression(*arguments[0], out);
2190             SpvId coord = this->writeExpression(*arguments[1], out);
2191 
2192             const Type& arg0Type = arguments[0]->type();
2193             SkASSERT(arg0Type.typeKind() == Type::TypeKind::kTexture);
2194 
2195             switch (arg0Type.textureAccess()) {
2196                 case Type::TextureAccess::kSample:
2197                     this->writeInstruction(SpvOpImageFetch, type, result, image, coord,
2198                                            SpvImageOperandsLodMask,
2199                                            this->writeOpConstant(*fContext.fTypes.fInt, 0),
2200                                            out);
2201                     break;
2202                 case Type::TextureAccess::kRead:
2203                 case Type::TextureAccess::kReadWrite:
2204                     this->writeInstruction(SpvOpImageRead, type, result, image, coord, out);
2205                     break;
2206                 case Type::TextureAccess::kWrite:
2207                 default:
2208                     SkDEBUGFAIL("'textureRead' called on writeonly texture type");
2209                     break;
2210             }
2211 
2212             break;
2213         }
2214         case kTextureWrite_SpecialIntrinsic: {
2215             SkASSERT(arguments[0]->type().dimensions() == SpvDim2D);
2216             SkASSERT(arguments[1]->type().matches(*fContext.fTypes.fUInt2));
2217             SkASSERT(arguments[2]->type().matches(*fContext.fTypes.fHalf4));
2218 
2219             SpvId image = this->writeExpression(*arguments[0], out);
2220             SpvId coord = this->writeExpression(*arguments[1], out);
2221             SpvId texel = this->writeExpression(*arguments[2], out);
2222 
2223             this->writeInstruction(SpvOpImageWrite, image, coord, texel, out);
2224             break;
2225         }
2226         case kTextureWidth_SpecialIntrinsic:
2227         case kTextureHeight_SpecialIntrinsic: {
2228             SkASSERT(arguments[0]->type().dimensions() == SpvDim2D);
2229             fCapabilities |= 1ULL << SpvCapabilityImageQuery;
2230 
2231             SpvId dimsType = this->getType(*fContext.fTypes.fUInt2);
2232             SpvId dims = this->nextId(nullptr);
2233             SpvId image = this->writeExpression(*arguments[0], out);
2234             this->writeInstruction(SpvOpImageQuerySize, dimsType, dims, image, out);
2235 
2236             SpvId type = this->getType(callType);
2237             int32_t index = (kind == kTextureWidth_SpecialIntrinsic) ? 0 : 1;
2238             this->writeInstruction(SpvOpCompositeExtract, type, result, dims, index, out);
2239             break;
2240         }
2241         case kMod_SpecialIntrinsic: {
2242             TArray<SpvId> args = this->vectorize(arguments, out);
2243             SkASSERT(args.size() == 2);
2244             const Type& operandType = arguments[0]->type();
2245             SpvOp_ op = pick_by_type(operandType, SpvOpFMod, SpvOpSMod, SpvOpUMod, SpvOpUndef);
2246             SkASSERT(op != SpvOpUndef);
2247             this->writeOpCode(op, 5, out);
2248             this->writeWord(this->getType(operandType), out);
2249             this->writeWord(result, out);
2250             this->writeWord(args[0], out);
2251             this->writeWord(args[1], out);
2252             break;
2253         }
2254         case kDFdy_SpecialIntrinsic: {
2255             SpvId fn = this->writeExpression(*arguments[0], out);
2256             this->writeOpCode(SpvOpDPdy, 4, out);
2257             this->writeWord(this->getType(callType), out);
2258             this->writeWord(result, out);
2259             this->writeWord(fn, out);
2260             if (!fProgram.fConfig->fSettings.fForceNoRTFlip) {
2261                 this->addRTFlipUniform(c.fPosition);
2262                 ComponentArray componentArray;
2263                 for (int index = 0; index < callType.columns(); ++index) {
2264                     componentArray.push_back(SwizzleComponent::Y);
2265                 }
2266                 SpvId rtFlipY = this->writeSwizzle(*this->identifier(SKSL_RTFLIP_NAME),
2267                                                    componentArray, out);
2268                 SpvId flipped = this->nextId(&callType);
2269                 this->writeInstruction(SpvOpFMul, this->getType(callType), flipped, result,
2270                                        rtFlipY, out);
2271                 result = flipped;
2272             }
2273             break;
2274         }
2275         case kClamp_SpecialIntrinsic: {
2276             TArray<SpvId> args = this->vectorize(arguments, out);
2277             SkASSERT(args.size() == 3);
2278             this->writeGLSLExtendedInstruction(callType, result, GLSLstd450FClamp, GLSLstd450SClamp,
2279                                                GLSLstd450UClamp, args, out);
2280             break;
2281         }
2282         case kMax_SpecialIntrinsic: {
2283             TArray<SpvId> args = this->vectorize(arguments, out);
2284             SkASSERT(args.size() == 2);
2285             this->writeGLSLExtendedInstruction(callType, result, GLSLstd450FMax, GLSLstd450SMax,
2286                                                GLSLstd450UMax, args, out);
2287             break;
2288         }
2289         case kMin_SpecialIntrinsic: {
2290             TArray<SpvId> args = this->vectorize(arguments, out);
2291             SkASSERT(args.size() == 2);
2292             this->writeGLSLExtendedInstruction(callType, result, GLSLstd450FMin, GLSLstd450SMin,
2293                                                GLSLstd450UMin, args, out);
2294             break;
2295         }
2296         case kMix_SpecialIntrinsic: {
2297             TArray<SpvId> args = this->vectorize(arguments, out);
2298             SkASSERT(args.size() == 3);
2299             if (arguments[2]->type().componentType().isBoolean()) {
2300                 // Use OpSelect to implement Boolean mix().
2301                 SpvId falseId     = this->writeExpression(*arguments[0], out);
2302                 SpvId trueId      = this->writeExpression(*arguments[1], out);
2303                 SpvId conditionId = this->writeExpression(*arguments[2], out);
2304                 this->writeInstruction(SpvOpSelect, this->getType(arguments[0]->type()), result,
2305                                        conditionId, trueId, falseId, out);
2306             } else {
2307                 this->writeGLSLExtendedInstruction(callType, result, GLSLstd450FMix, SpvOpUndef,
2308                                                    SpvOpUndef, args, out);
2309             }
2310             break;
2311         }
2312         case kSaturate_SpecialIntrinsic: {
2313             SkASSERT(arguments.size() == 1);
2314             ExpressionArray finalArgs;
2315             finalArgs.reserve_exact(3);
2316             finalArgs.push_back(arguments[0]->clone());
2317             finalArgs.push_back(Literal::MakeFloat(fContext, Position(), /*value=*/0));
2318             finalArgs.push_back(Literal::MakeFloat(fContext, Position(), /*value=*/1));
2319             TArray<SpvId> spvArgs = this->vectorize(finalArgs, out);
2320             this->writeGLSLExtendedInstruction(callType, result, GLSLstd450FClamp, GLSLstd450SClamp,
2321                                                GLSLstd450UClamp, spvArgs, out);
2322             break;
2323         }
2324         case kSmoothStep_SpecialIntrinsic: {
2325             TArray<SpvId> args = this->vectorize(arguments, out);
2326             SkASSERT(args.size() == 3);
2327             this->writeGLSLExtendedInstruction(callType, result, GLSLstd450SmoothStep, SpvOpUndef,
2328                                                SpvOpUndef, args, out);
2329             break;
2330         }
2331         case kStep_SpecialIntrinsic: {
2332             TArray<SpvId> args = this->vectorize(arguments, out);
2333             SkASSERT(args.size() == 2);
2334             this->writeGLSLExtendedInstruction(callType, result, GLSLstd450Step, SpvOpUndef,
2335                                                SpvOpUndef, args, out);
2336             break;
2337         }
2338         case kMatrixCompMult_SpecialIntrinsic: {
2339             SkASSERT(arguments.size() == 2);
2340             SpvId lhs = this->writeExpression(*arguments[0], out);
2341             SpvId rhs = this->writeExpression(*arguments[1], out);
2342             result = this->writeComponentwiseMatrixBinary(callType, lhs, rhs, SpvOpFMul, out);
2343             break;
2344         }
2345         case kAtomicAdd_SpecialIntrinsic:
2346         case kAtomicLoad_SpecialIntrinsic:
2347         case kAtomicStore_SpecialIntrinsic:
2348             result = this->writeAtomicIntrinsic(c, kind, result, out);
2349             break;
2350         case kStorageBarrier_SpecialIntrinsic:
2351         case kWorkgroupBarrier_SpecialIntrinsic: {
2352             // Both barrier types operate in the workgroup execution and memory scope and differ
2353             // only in memory semantics. storageBarrier() is not a device-scope barrier.
2354             SpvId scopeId =
2355                     this->writeOpConstant(*fContext.fTypes.fUInt, (int32_t)SpvScopeWorkgroup);
2356             int32_t memSemMask = (kind == kStorageBarrier_SpecialIntrinsic)
2357                                          ? SpvMemorySemanticsAcquireReleaseMask |
2358                                                    SpvMemorySemanticsUniformMemoryMask
2359                                          : SpvMemorySemanticsAcquireReleaseMask |
2360                                                    SpvMemorySemanticsWorkgroupMemoryMask;
2361             SpvId memorySemanticsId = this->writeOpConstant(*fContext.fTypes.fUInt, memSemMask);
2362             this->writeInstruction(SpvOpControlBarrier,
2363                                    scopeId,  // execution scope
2364                                    scopeId,  // memory scope
2365                                    memorySemanticsId,
2366                                    out);
2367             break;
2368         }
2369     }
2370     return result;
2371 }
2372 
writeAtomicIntrinsic(const FunctionCall & c,SpecialIntrinsic kind,SpvId resultId,OutputStream & out)2373 SpvId SPIRVCodeGenerator::writeAtomicIntrinsic(const FunctionCall& c,
2374                                                SpecialIntrinsic kind,
2375                                                SpvId resultId,
2376                                                OutputStream& out) {
2377     const ExpressionArray& arguments = c.arguments();
2378     SkASSERT(!arguments.empty());
2379 
2380     std::unique_ptr<LValue> atomicPtr = this->getLValue(*arguments[0], out);
2381     SpvId atomicPtrId = atomicPtr->getPointer();
2382     if (atomicPtrId == NA) {
2383         SkDEBUGFAILF("atomic intrinsic expected a pointer argument: %s",
2384                      arguments[0]->description().c_str());
2385         return NA;
2386     }
2387 
2388     SpvId memoryScopeId = NA;
2389     {
2390         // In SkSL, the atomicUint type can only be declared as a workgroup variable or SSBO block
2391         // member. The two memory scopes that these map to are "workgroup" and "device",
2392         // respectively.
2393         SpvScope memoryScope;
2394         switch (atomicPtr->storageClass()) {
2395             case SpvStorageClassUniform:
2396                 // We encode storage buffers in the uniform address space (with the BufferBlock
2397                 // decorator).
2398                 memoryScope = SpvScopeDevice;
2399                 break;
2400             case SpvStorageClassWorkgroup:
2401                 memoryScope = SpvScopeWorkgroup;
2402                 break;
2403             default:
2404                 SkDEBUGFAILF("atomic argument has invalid storage class: %d",
2405                              atomicPtr->storageClass());
2406                 return NA;
2407         }
2408         memoryScopeId = this->writeOpConstant(*fContext.fTypes.fUInt, (int32_t)memoryScope);
2409     }
2410 
2411     SpvId relaxedMemoryOrderId =
2412             this->writeOpConstant(*fContext.fTypes.fUInt, (int32_t)SpvMemorySemanticsMaskNone);
2413 
2414     switch (kind) {
2415         case kAtomicAdd_SpecialIntrinsic:
2416             SkASSERT(arguments.size() == 2);
2417             this->writeInstruction(SpvOpAtomicIAdd,
2418                                    this->getType(c.type()),
2419                                    resultId,
2420                                    atomicPtrId,
2421                                    memoryScopeId,
2422                                    relaxedMemoryOrderId,
2423                                    this->writeExpression(*arguments[1], out),
2424                                    out);
2425             break;
2426         case kAtomicLoad_SpecialIntrinsic:
2427             SkASSERT(arguments.size() == 1);
2428             this->writeInstruction(SpvOpAtomicLoad,
2429                                    this->getType(c.type()),
2430                                    resultId,
2431                                    atomicPtrId,
2432                                    memoryScopeId,
2433                                    relaxedMemoryOrderId,
2434                                    out);
2435             break;
2436         case kAtomicStore_SpecialIntrinsic:
2437             SkASSERT(arguments.size() == 2);
2438             this->writeInstruction(SpvOpAtomicStore,
2439                                    atomicPtrId,
2440                                    memoryScopeId,
2441                                    relaxedMemoryOrderId,
2442                                    this->writeExpression(*arguments[1], out),
2443                                    out);
2444             break;
2445         default:
2446             SkUNREACHABLE;
2447     }
2448 
2449     return resultId;
2450 }
2451 
writeFunctionCallArgument(const FunctionCall & call,int argIndex,std::vector<TempVar> * tempVars,OutputStream & out,SpvId * outSynthesizedSamplerId)2452 SpvId SPIRVCodeGenerator::writeFunctionCallArgument(const FunctionCall& call,
2453                                                     int argIndex,
2454                                                     std::vector<TempVar>* tempVars,
2455                                                     OutputStream& out,
2456                                                     SpvId* outSynthesizedSamplerId) {
2457     const FunctionDeclaration& funcDecl = call.function();
2458     const Expression& arg = *call.arguments()[argIndex];
2459     ModifierFlags paramFlags = funcDecl.parameters()[argIndex]->modifierFlags();
2460 
2461     // ID of temporary variable that we will use to hold this argument, or 0 if it is being
2462     // passed directly
2463     SpvId tmpVar;
2464     // if we need a temporary var to store this argument, this is the value to store in the var
2465     SpvId tmpValueId = NA;
2466 
2467     if (is_out(paramFlags)) {
2468         std::unique_ptr<LValue> lv = this->getLValue(arg, out);
2469         // We handle out params with a temp var that we copy back to the original variable at the
2470         // end of the call. GLSL guarantees that the original variable will be unchanged until the
2471         // end of the call, and also that out params are written back to their original variables in
2472         // a specific order (left-to-right), so it's unsafe to pass a pointer to the original value.
2473         if (is_in(paramFlags)) {
2474             tmpValueId = lv->load(out);
2475         }
2476         tmpVar = this->nextId(&arg.type());
2477         tempVars->push_back(TempVar{tmpVar, &arg.type(), std::move(lv)});
2478     } else if (funcDecl.isIntrinsic()) {
2479         // Unlike user function calls, non-out intrinsic arguments don't need pointer parameters.
2480         return this->writeExpression(arg, out);
2481     } else if (arg.is<VariableReference>() &&
2482                (arg.type().typeKind() == Type::TypeKind::kSampler ||
2483                 arg.type().typeKind() == Type::TypeKind::kSeparateSampler ||
2484                 arg.type().typeKind() == Type::TypeKind::kTexture)) {
2485         // Opaque handle (sampler/texture) arguments are always declared as pointers but never
2486         // stored in intermediates when calling user-defined functions.
2487         //
2488         // The case for intrinsics (which take opaque arguments by value) is handled above just like
2489         // regular pointers.
2490         //
2491         // See getFunctionParameterType for further explanation.
2492         const Variable* var = arg.as<VariableReference>().variable();
2493 
2494         // In Dawn-mode the texture and sampler arguments are forwarded to the helper function.
2495         if (fUseTextureSamplerPairs && var->type().isSampler()) {
2496             if (const auto* p = fSynthesizedSamplerMap.find(var)) {
2497                 SkASSERT(outSynthesizedSamplerId);
2498 
2499                 SpvId* img = fVariableMap.find((*p)->fTexture.get());
2500                 SpvId* sampler = fVariableMap.find((*p)->fSampler.get());
2501                 SkASSERT(img);
2502                 SkASSERT(sampler);
2503 
2504                 *outSynthesizedSamplerId = *sampler;
2505                 return *img;
2506             }
2507             SkDEBUGFAIL("sampler missing from fSynthesizedSamplerMap");
2508         }
2509 
2510         SpvId* entry = fVariableMap.find(var);
2511         SkASSERTF(entry, "%s", arg.description().c_str());
2512         return *entry;
2513     } else {
2514         // We always use pointer parameters when calling user functions.
2515         // See getFunctionParameterType for further explanation.
2516         tmpValueId = this->writeExpression(arg, out);
2517         tmpVar = this->nextId(nullptr);
2518     }
2519     this->writeInstruction(SpvOpVariable,
2520                            this->getPointerType(arg.type(), SpvStorageClassFunction),
2521                            tmpVar,
2522                            SpvStorageClassFunction,
2523                            fVariableBuffer);
2524     if (tmpValueId != NA) {
2525         this->writeOpStore(SpvStorageClassFunction, tmpVar, tmpValueId, out);
2526     }
2527     return tmpVar;
2528 }
2529 
copyBackTempVars(const std::vector<TempVar> & tempVars,OutputStream & out)2530 void SPIRVCodeGenerator::copyBackTempVars(const std::vector<TempVar>& tempVars, OutputStream& out) {
2531     for (const TempVar& tempVar : tempVars) {
2532         SpvId load = this->nextId(tempVar.type);
2533         this->writeInstruction(SpvOpLoad, this->getType(*tempVar.type), load, tempVar.spvId, out);
2534         tempVar.lvalue->store(load, out);
2535     }
2536 }
2537 
writeFunctionCall(const FunctionCall & c,OutputStream & out)2538 SpvId SPIRVCodeGenerator::writeFunctionCall(const FunctionCall& c, OutputStream& out) {
2539     const FunctionDeclaration& function = c.function();
2540     if (function.isIntrinsic() && !function.definition()) {
2541         return this->writeIntrinsicCall(c, out);
2542     }
2543     const ExpressionArray& arguments = c.arguments();
2544     SpvId* entry = fFunctionMap.find(&function);
2545     if (!entry) {
2546         fContext.fErrors->error(c.fPosition, "function '" + function.description() +
2547                 "' is not defined");
2548         return NA;
2549     }
2550     // Temp variables are used to write back out-parameters after the function call is complete.
2551     std::vector<TempVar> tempVars;
2552     TArray<SpvId> argumentIds;
2553     argumentIds.reserve_exact(arguments.size());
2554     for (int i = 0; i < arguments.size(); i++) {
2555         SpvId samplerId = NA;
2556         argumentIds.push_back(this->writeFunctionCallArgument(c, i, &tempVars, out, &samplerId));
2557         if (samplerId != NA) {
2558             argumentIds.push_back(samplerId);
2559         }
2560     }
2561     SpvId result = this->nextId(nullptr);
2562     this->writeOpCode(SpvOpFunctionCall, 4 + (int32_t)argumentIds.size(), out);
2563     this->writeWord(this->getType(c.type()), out);
2564     this->writeWord(result, out);
2565     this->writeWord(*entry, out);
2566     for (SpvId id : argumentIds) {
2567         this->writeWord(id, out);
2568     }
2569     // Now that the call is complete, we copy temp out-variables back to their real lvalues.
2570     this->copyBackTempVars(tempVars, out);
2571     return result;
2572 }
2573 
castScalarToType(SpvId inputExprId,const Type & inputType,const Type & outputType,OutputStream & out)2574 SpvId SPIRVCodeGenerator::castScalarToType(SpvId inputExprId,
2575                                            const Type& inputType,
2576                                            const Type& outputType,
2577                                            OutputStream& out) {
2578     if (outputType.isFloat()) {
2579         return this->castScalarToFloat(inputExprId, inputType, outputType, out);
2580     }
2581     if (outputType.isSigned()) {
2582         return this->castScalarToSignedInt(inputExprId, inputType, outputType, out);
2583     }
2584     if (outputType.isUnsigned()) {
2585         return this->castScalarToUnsignedInt(inputExprId, inputType, outputType, out);
2586     }
2587     if (outputType.isBoolean()) {
2588         return this->castScalarToBoolean(inputExprId, inputType, outputType, out);
2589     }
2590 
2591     fContext.fErrors->error(Position(), "unsupported cast: " + inputType.description() + " to " +
2592             outputType.description());
2593     return inputExprId;
2594 }
2595 
castScalarToFloat(SpvId inputId,const Type & inputType,const Type & outputType,OutputStream & out)2596 SpvId SPIRVCodeGenerator::castScalarToFloat(SpvId inputId, const Type& inputType,
2597                                             const Type& outputType, OutputStream& out) {
2598     // Casting a float to float is a no-op.
2599     if (inputType.isFloat()) {
2600         return inputId;
2601     }
2602 
2603     // Given the input type, generate the appropriate instruction to cast to float.
2604     SpvId result = this->nextId(&outputType);
2605     if (inputType.isBoolean()) {
2606         // Use OpSelect to convert the boolean argument to a literal 1.0 or 0.0.
2607         const SpvId oneID = this->writeLiteral(1.0, *fContext.fTypes.fFloat);
2608         const SpvId zeroID = this->writeLiteral(0.0, *fContext.fTypes.fFloat);
2609         this->writeInstruction(SpvOpSelect, this->getType(outputType), result,
2610                                inputId, oneID, zeroID, out);
2611     } else if (inputType.isSigned()) {
2612         this->writeInstruction(SpvOpConvertSToF, this->getType(outputType), result, inputId, out);
2613     } else if (inputType.isUnsigned()) {
2614         this->writeInstruction(SpvOpConvertUToF, this->getType(outputType), result, inputId, out);
2615     } else {
2616         SkDEBUGFAILF("unsupported type for float typecast: %s", inputType.description().c_str());
2617         return NA;
2618     }
2619     return result;
2620 }
2621 
castScalarToSignedInt(SpvId inputId,const Type & inputType,const Type & outputType,OutputStream & out)2622 SpvId SPIRVCodeGenerator::castScalarToSignedInt(SpvId inputId, const Type& inputType,
2623                                                 const Type& outputType, OutputStream& out) {
2624     // Casting a signed int to signed int is a no-op.
2625     if (inputType.isSigned()) {
2626         return inputId;
2627     }
2628 
2629     // Given the input type, generate the appropriate instruction to cast to signed int.
2630     SpvId result = this->nextId(&outputType);
2631     if (inputType.isBoolean()) {
2632         // Use OpSelect to convert the boolean argument to a literal 1 or 0.
2633         const SpvId oneID = this->writeLiteral(1.0, *fContext.fTypes.fInt);
2634         const SpvId zeroID = this->writeLiteral(0.0, *fContext.fTypes.fInt);
2635         this->writeInstruction(SpvOpSelect, this->getType(outputType), result,
2636                                inputId, oneID, zeroID, out);
2637     } else if (inputType.isFloat()) {
2638         this->writeInstruction(SpvOpConvertFToS, this->getType(outputType), result, inputId, out);
2639     } else if (inputType.isUnsigned()) {
2640         this->writeInstruction(SpvOpBitcast, this->getType(outputType), result, inputId, out);
2641     } else {
2642         SkDEBUGFAILF("unsupported type for signed int typecast: %s",
2643                      inputType.description().c_str());
2644         return NA;
2645     }
2646     return result;
2647 }
2648 
castScalarToUnsignedInt(SpvId inputId,const Type & inputType,const Type & outputType,OutputStream & out)2649 SpvId SPIRVCodeGenerator::castScalarToUnsignedInt(SpvId inputId, const Type& inputType,
2650                                                   const Type& outputType, OutputStream& out) {
2651     // Casting an unsigned int to unsigned int is a no-op.
2652     if (inputType.isUnsigned()) {
2653         return inputId;
2654     }
2655 
2656     // Given the input type, generate the appropriate instruction to cast to unsigned int.
2657     SpvId result = this->nextId(&outputType);
2658     if (inputType.isBoolean()) {
2659         // Use OpSelect to convert the boolean argument to a literal 1u or 0u.
2660         const SpvId oneID = this->writeLiteral(1.0, *fContext.fTypes.fUInt);
2661         const SpvId zeroID = this->writeLiteral(0.0, *fContext.fTypes.fUInt);
2662         this->writeInstruction(SpvOpSelect, this->getType(outputType), result,
2663                                inputId, oneID, zeroID, out);
2664     } else if (inputType.isFloat()) {
2665         this->writeInstruction(SpvOpConvertFToU, this->getType(outputType), result, inputId, out);
2666     } else if (inputType.isSigned()) {
2667         this->writeInstruction(SpvOpBitcast, this->getType(outputType), result, inputId, out);
2668     } else {
2669         SkDEBUGFAILF("unsupported type for unsigned int typecast: %s",
2670                      inputType.description().c_str());
2671         return NA;
2672     }
2673     return result;
2674 }
2675 
castScalarToBoolean(SpvId inputId,const Type & inputType,const Type & outputType,OutputStream & out)2676 SpvId SPIRVCodeGenerator::castScalarToBoolean(SpvId inputId, const Type& inputType,
2677                                               const Type& outputType, OutputStream& out) {
2678     // Casting a bool to bool is a no-op.
2679     if (inputType.isBoolean()) {
2680         return inputId;
2681     }
2682 
2683     // Given the input type, generate the appropriate instruction to cast to bool.
2684     SpvId result = this->nextId(nullptr);
2685     if (inputType.isSigned()) {
2686         // Synthesize a boolean result by comparing the input against a signed zero literal.
2687         const SpvId zeroID = this->writeLiteral(0.0, *fContext.fTypes.fInt);
2688         this->writeInstruction(SpvOpINotEqual, this->getType(outputType), result,
2689                                inputId, zeroID, out);
2690     } else if (inputType.isUnsigned()) {
2691         // Synthesize a boolean result by comparing the input against an unsigned zero literal.
2692         const SpvId zeroID = this->writeLiteral(0.0, *fContext.fTypes.fUInt);
2693         this->writeInstruction(SpvOpINotEqual, this->getType(outputType), result,
2694                                inputId, zeroID, out);
2695     } else if (inputType.isFloat()) {
2696         // Synthesize a boolean result by comparing the input against a floating-point zero literal.
2697         const SpvId zeroID = this->writeLiteral(0.0, *fContext.fTypes.fFloat);
2698         this->writeInstruction(SpvOpFUnordNotEqual, this->getType(outputType), result,
2699                                inputId, zeroID, out);
2700     } else {
2701         SkDEBUGFAILF("unsupported type for boolean typecast: %s", inputType.description().c_str());
2702         return NA;
2703     }
2704     return result;
2705 }
2706 
writeMatrixCopy(SpvId src,const Type & srcType,const Type & dstType,OutputStream & out)2707 SpvId SPIRVCodeGenerator::writeMatrixCopy(SpvId src, const Type& srcType, const Type& dstType,
2708                                           OutputStream& out) {
2709     SkASSERT(srcType.isMatrix());
2710     SkASSERT(dstType.isMatrix());
2711     SkASSERT(srcType.componentType().matches(dstType.componentType()));
2712     const Type& srcColumnType = srcType.componentType().toCompound(fContext, srcType.rows(), 1);
2713     const Type& dstColumnType = dstType.componentType().toCompound(fContext, dstType.rows(), 1);
2714     SkASSERT(dstType.componentType().isFloat());
2715     SpvId dstColumnTypeId = this->getType(dstColumnType);
2716     const SpvId zeroId = this->writeLiteral(0.0, dstType.componentType());
2717     const SpvId oneId = this->writeLiteral(1.0, dstType.componentType());
2718 
2719     STArray<4, SpvId> columns;
2720     for (int i = 0; i < dstType.columns(); i++) {
2721         if (i < srcType.columns()) {
2722             // we're still inside the src matrix, copy the column
2723             SpvId srcColumn = this->writeOpCompositeExtract(srcColumnType, src, i, out);
2724             SpvId dstColumn;
2725             if (srcType.rows() == dstType.rows()) {
2726                 // columns are equal size, don't need to do anything
2727                 dstColumn = srcColumn;
2728             }
2729             else if (dstType.rows() > srcType.rows()) {
2730                 // dst column is bigger, need to zero-pad it
2731                 STArray<4, SpvId> values;
2732                 values.push_back(srcColumn);
2733                 for (int j = srcType.rows(); j < dstType.rows(); ++j) {
2734                     values.push_back((i == j) ? oneId : zeroId);
2735                 }
2736                 dstColumn = this->writeOpCompositeConstruct(dstColumnType, values, out);
2737             }
2738             else {
2739                 // dst column is smaller, need to swizzle the src column
2740                 dstColumn = this->nextId(&dstType);
2741                 this->writeOpCode(SpvOpVectorShuffle, 5 + dstType.rows(), out);
2742                 this->writeWord(dstColumnTypeId, out);
2743                 this->writeWord(dstColumn, out);
2744                 this->writeWord(srcColumn, out);
2745                 this->writeWord(srcColumn, out);
2746                 for (int j = 0; j < dstType.rows(); j++) {
2747                     this->writeWord(j, out);
2748                 }
2749             }
2750             columns.push_back(dstColumn);
2751         } else {
2752             // we're past the end of the src matrix, need to synthesize an identity-matrix column
2753             STArray<4, SpvId> values;
2754             for (int j = 0; j < dstType.rows(); ++j) {
2755                 values.push_back((i == j) ? oneId : zeroId);
2756             }
2757             columns.push_back(this->writeOpCompositeConstruct(dstColumnType, values, out));
2758         }
2759     }
2760 
2761     return this->writeOpCompositeConstruct(dstType, columns, out);
2762 }
2763 
addColumnEntry(const Type & columnType,TArray<SpvId> * currentColumn,TArray<SpvId> * columnIds,int rows,SpvId entry,OutputStream & out)2764 void SPIRVCodeGenerator::addColumnEntry(const Type& columnType,
2765                                         TArray<SpvId>* currentColumn,
2766                                         TArray<SpvId>* columnIds,
2767                                         int rows,
2768                                         SpvId entry,
2769                                         OutputStream& out) {
2770     SkASSERT(currentColumn->size() < rows);
2771     currentColumn->push_back(entry);
2772     if (currentColumn->size() == rows) {
2773         // Synthesize this column into a vector.
2774         SpvId columnId = this->writeOpCompositeConstruct(columnType, *currentColumn, out);
2775         columnIds->push_back(columnId);
2776         currentColumn->clear();
2777     }
2778 }
2779 
writeMatrixConstructor(const ConstructorCompound & c,OutputStream & out)2780 SpvId SPIRVCodeGenerator::writeMatrixConstructor(const ConstructorCompound& c, OutputStream& out) {
2781     const Type& type = c.type();
2782     SkASSERT(type.isMatrix());
2783     SkASSERT(!c.arguments().empty());
2784     const Type& arg0Type = c.arguments()[0]->type();
2785     // go ahead and write the arguments so we don't try to write new instructions in the middle of
2786     // an instruction
2787     STArray<16, SpvId> arguments;
2788     for (const std::unique_ptr<Expression>& arg : c.arguments()) {
2789         arguments.push_back(this->writeExpression(*arg, out));
2790     }
2791 
2792     if (arguments.size() == 1 && arg0Type.isVector()) {
2793         // Special-case handling of float4 -> mat2x2.
2794         SkASSERT(type.rows() == 2 && type.columns() == 2);
2795         SkASSERT(arg0Type.columns() == 4);
2796         SpvId v[4];
2797         for (int i = 0; i < 4; ++i) {
2798             v[i] = this->writeOpCompositeExtract(type.componentType(), arguments[0], i, out);
2799         }
2800         const Type& vecType = type.columnType(fContext);
2801         SpvId v0v1 = this->writeOpCompositeConstruct(vecType, {v[0], v[1]}, out);
2802         SpvId v2v3 = this->writeOpCompositeConstruct(vecType, {v[2], v[3]}, out);
2803         return this->writeOpCompositeConstruct(type, {v0v1, v2v3}, out);
2804     }
2805 
2806     int rows = type.rows();
2807     const Type& columnType = type.columnType(fContext);
2808     // SpvIds of completed columns of the matrix.
2809     STArray<4, SpvId> columnIds;
2810     // SpvIds of scalars we have written to the current column so far.
2811     STArray<4, SpvId> currentColumn;
2812     for (int i = 0; i < arguments.size(); i++) {
2813         const Type& argType = c.arguments()[i]->type();
2814         if (currentColumn.empty() && argType.isVector() && argType.columns() == rows) {
2815             // This vector is a complete matrix column by itself and can be used as-is.
2816             columnIds.push_back(arguments[i]);
2817         } else if (argType.columns() == 1) {
2818             // This argument is a lone scalar and can be added to the current column as-is.
2819             this->addColumnEntry(columnType, &currentColumn, &columnIds, rows, arguments[i], out);
2820         } else {
2821             // This argument needs to be decomposed into its constituent scalars.
2822             for (int j = 0; j < argType.columns(); ++j) {
2823                 SpvId swizzle = this->writeOpCompositeExtract(argType.componentType(),
2824                                                               arguments[i], j, out);
2825                 this->addColumnEntry(columnType, &currentColumn, &columnIds, rows, swizzle, out);
2826             }
2827         }
2828     }
2829     SkASSERT(columnIds.size() == type.columns());
2830     return this->writeOpCompositeConstruct(type, columnIds, out);
2831 }
2832 
writeConstructorCompound(const ConstructorCompound & c,OutputStream & out)2833 SpvId SPIRVCodeGenerator::writeConstructorCompound(const ConstructorCompound& c,
2834                                                    OutputStream& out) {
2835     return c.type().isMatrix() ? this->writeMatrixConstructor(c, out)
2836                                : this->writeVectorConstructor(c, out);
2837 }
2838 
writeVectorConstructor(const ConstructorCompound & c,OutputStream & out)2839 SpvId SPIRVCodeGenerator::writeVectorConstructor(const ConstructorCompound& c, OutputStream& out) {
2840     const Type& type = c.type();
2841     const Type& componentType = type.componentType();
2842     SkASSERT(type.isVector());
2843 
2844     STArray<4, SpvId> arguments;
2845     for (int i = 0; i < c.arguments().size(); i++) {
2846         const Type& argType = c.arguments()[i]->type();
2847         SkASSERT(componentType.numberKind() == argType.componentType().numberKind());
2848 
2849         SpvId arg = this->writeExpression(*c.arguments()[i], out);
2850         if (argType.isMatrix()) {
2851             // CompositeConstruct cannot take a 2x2 matrix as an input, so we need to extract out
2852             // each scalar separately.
2853             SkASSERT(argType.rows() == 2);
2854             SkASSERT(argType.columns() == 2);
2855             for (int j = 0; j < 4; ++j) {
2856                 arguments.push_back(this->writeOpCompositeExtract(componentType, arg,
2857                                                                   j / 2, j % 2, out));
2858             }
2859         } else if (argType.isVector()) {
2860             // There's a bug in the Intel Vulkan driver where OpCompositeConstruct doesn't handle
2861             // vector arguments at all, so we always extract each vector component and pass them
2862             // into OpCompositeConstruct individually.
2863             for (int j = 0; j < argType.columns(); j++) {
2864                 arguments.push_back(this->writeOpCompositeExtract(componentType, arg, j, out));
2865             }
2866         } else {
2867             arguments.push_back(arg);
2868         }
2869     }
2870 
2871     return this->writeOpCompositeConstruct(type, arguments, out);
2872 }
2873 
writeConstructorSplat(const ConstructorSplat & c,OutputStream & out)2874 SpvId SPIRVCodeGenerator::writeConstructorSplat(const ConstructorSplat& c, OutputStream& out) {
2875     // Write the splat argument as a scalar, then splat it.
2876     SpvId argument = this->writeExpression(*c.argument(), out);
2877     return this->splat(c.type(), argument, out);
2878 }
2879 
writeCompositeConstructor(const AnyConstructor & c,OutputStream & out)2880 SpvId SPIRVCodeGenerator::writeCompositeConstructor(const AnyConstructor& c, OutputStream& out) {
2881     SkASSERT(c.type().isArray() || c.type().isStruct());
2882     auto ctorArgs = c.argumentSpan();
2883 
2884     STArray<4, SpvId> arguments;
2885     for (const std::unique_ptr<Expression>& arg : ctorArgs) {
2886         arguments.push_back(this->writeExpression(*arg, out));
2887     }
2888 
2889     return this->writeOpCompositeConstruct(c.type(), arguments, out);
2890 }
2891 
writeConstructorScalarCast(const ConstructorScalarCast & c,OutputStream & out)2892 SpvId SPIRVCodeGenerator::writeConstructorScalarCast(const ConstructorScalarCast& c,
2893                                                      OutputStream& out) {
2894     const Type& type = c.type();
2895     if (type.componentType().numberKind() == c.argument()->type().componentType().numberKind()) {
2896         return this->writeExpression(*c.argument(), out);
2897     }
2898 
2899     const Expression& ctorExpr = *c.argument();
2900     SpvId expressionId = this->writeExpression(ctorExpr, out);
2901     return this->castScalarToType(expressionId, ctorExpr.type(), type, out);
2902 }
2903 
writeConstructorCompoundCast(const ConstructorCompoundCast & c,OutputStream & out)2904 SpvId SPIRVCodeGenerator::writeConstructorCompoundCast(const ConstructorCompoundCast& c,
2905                                                        OutputStream& out) {
2906     const Type& ctorType = c.type();
2907     const Type& argType = c.argument()->type();
2908     SkASSERT(ctorType.isVector() || ctorType.isMatrix());
2909 
2910     // Write the composite that we are casting. If the actual type matches, we are done.
2911     SpvId compositeId = this->writeExpression(*c.argument(), out);
2912     if (ctorType.componentType().numberKind() == argType.componentType().numberKind()) {
2913         return compositeId;
2914     }
2915 
2916     // writeMatrixCopy can cast matrices to a different type.
2917     if (ctorType.isMatrix()) {
2918         return this->writeMatrixCopy(compositeId, argType, ctorType, out);
2919     }
2920 
2921     // SPIR-V doesn't support vector(vector-of-different-type) directly, so we need to extract the
2922     // components and convert each one manually.
2923     const Type& srcType = argType.componentType();
2924     const Type& dstType = ctorType.componentType();
2925 
2926     STArray<4, SpvId> arguments;
2927     for (int index = 0; index < argType.columns(); ++index) {
2928         SpvId componentId = this->writeOpCompositeExtract(srcType, compositeId, index, out);
2929         arguments.push_back(this->castScalarToType(componentId, srcType, dstType, out));
2930     }
2931 
2932     return this->writeOpCompositeConstruct(ctorType, arguments, out);
2933 }
2934 
writeConstructorDiagonalMatrix(const ConstructorDiagonalMatrix & c,OutputStream & out)2935 SpvId SPIRVCodeGenerator::writeConstructorDiagonalMatrix(const ConstructorDiagonalMatrix& c,
2936                                                          OutputStream& out) {
2937     const Type& type = c.type();
2938     SkASSERT(type.isMatrix());
2939     SkASSERT(c.argument()->type().isScalar());
2940 
2941     // Write out the scalar argument.
2942     SpvId diagonal = this->writeExpression(*c.argument(), out);
2943 
2944     // Build the diagonal matrix.
2945     SpvId zeroId = this->writeLiteral(0.0, *fContext.fTypes.fFloat);
2946 
2947     const Type& vecType = type.columnType(fContext);
2948     STArray<4, SpvId> columnIds;
2949     STArray<4, SpvId> arguments;
2950     arguments.resize(type.rows());
2951     for (int column = 0; column < type.columns(); column++) {
2952         for (int row = 0; row < type.rows(); row++) {
2953             arguments[row] = (row == column) ? diagonal : zeroId;
2954         }
2955         columnIds.push_back(this->writeOpCompositeConstruct(vecType, arguments, out));
2956     }
2957     return this->writeOpCompositeConstruct(type, columnIds, out);
2958 }
2959 
writeConstructorMatrixResize(const ConstructorMatrixResize & c,OutputStream & out)2960 SpvId SPIRVCodeGenerator::writeConstructorMatrixResize(const ConstructorMatrixResize& c,
2961                                                        OutputStream& out) {
2962     // Write the input matrix.
2963     SpvId argument = this->writeExpression(*c.argument(), out);
2964 
2965     // Use matrix-copy to resize the input matrix to its new size.
2966     return this->writeMatrixCopy(argument, c.argument()->type(), c.type(), out);
2967 }
2968 
get_storage_class_for_global_variable(const Variable & var,SpvStorageClass_ fallbackStorageClass)2969 static SpvStorageClass_ get_storage_class_for_global_variable(
2970         const Variable& var, SpvStorageClass_ fallbackStorageClass) {
2971     SkASSERT(var.storage() == Variable::Storage::kGlobal);
2972 
2973     if (var.type().typeKind() == Type::TypeKind::kSampler ||
2974         var.type().typeKind() == Type::TypeKind::kSeparateSampler ||
2975         var.type().typeKind() == Type::TypeKind::kTexture) {
2976         return SpvStorageClassUniformConstant;
2977     }
2978 
2979     const Layout& layout = var.layout();
2980     ModifierFlags flags = var.modifierFlags();
2981     if (flags & ModifierFlag::kIn) {
2982         SkASSERT(!(layout.fFlags & LayoutFlag::kPushConstant));
2983         return SpvStorageClassInput;
2984     }
2985     if (flags & ModifierFlag::kOut) {
2986         SkASSERT(!(layout.fFlags & LayoutFlag::kPushConstant));
2987         return SpvStorageClassOutput;
2988     }
2989     if (flags.isUniform()) {
2990         if (layout.fFlags & LayoutFlag::kPushConstant) {
2991             return SpvStorageClassPushConstant;
2992         }
2993         return SpvStorageClassUniform;
2994     }
2995     if (flags.isBuffer()) {
2996         // Note: In SPIR-V 1.3, a storage buffer can be declared with the "StorageBuffer"
2997         // storage class and the "Block" decoration and the <1.3 approach we use here ("Uniform"
2998         // storage class and the "BufferBlock" decoration) is deprecated. Since we target SPIR-V
2999         // 1.0, we have to use the deprecated approach which is well supported in Vulkan and
3000         // addresses SkSL use cases (notably SkSL currently doesn't support pointer features that
3001         // would benefit from SPV_KHR_variable_pointers capabilities).
3002         return SpvStorageClassUniform;
3003     }
3004     if (flags.isWorkgroup()) {
3005         return SpvStorageClassWorkgroup;
3006     }
3007     return fallbackStorageClass;
3008 }
3009 
get_storage_class(const Expression & expr)3010 static SpvStorageClass_ get_storage_class(const Expression& expr) {
3011     switch (expr.kind()) {
3012         case Expression::Kind::kVariableReference: {
3013             const Variable& var = *expr.as<VariableReference>().variable();
3014             if (var.storage() != Variable::Storage::kGlobal) {
3015                 return SpvStorageClassFunction;
3016             }
3017             return get_storage_class_for_global_variable(var, SpvStorageClassPrivate);
3018         }
3019         case Expression::Kind::kFieldAccess:
3020             return get_storage_class(*expr.as<FieldAccess>().base());
3021         case Expression::Kind::kIndex:
3022             return get_storage_class(*expr.as<IndexExpression>().base());
3023         default:
3024             return SpvStorageClassFunction;
3025     }
3026 }
3027 
getAccessChain(const Expression & expr,OutputStream & out)3028 TArray<SpvId> SPIRVCodeGenerator::getAccessChain(const Expression& expr, OutputStream& out) {
3029     switch (expr.kind()) {
3030         case Expression::Kind::kIndex: {
3031             const IndexExpression& indexExpr = expr.as<IndexExpression>();
3032             if (indexExpr.base()->is<Swizzle>()) {
3033                 // Access chains don't directly support dynamically indexing into a swizzle, but we
3034                 // can rewrite them into a supported form.
3035                 return this->getAccessChain(*Transform::RewriteIndexedSwizzle(fContext, indexExpr),
3036                                             out);
3037             }
3038             // All other index-expressions can be represented as typical access chains.
3039             TArray<SpvId> chain = this->getAccessChain(*indexExpr.base(), out);
3040             chain.push_back(this->writeExpression(*indexExpr.index(), out));
3041             return chain;
3042         }
3043         case Expression::Kind::kFieldAccess: {
3044             const FieldAccess& fieldExpr = expr.as<FieldAccess>();
3045             TArray<SpvId> chain = this->getAccessChain(*fieldExpr.base(), out);
3046             chain.push_back(this->writeLiteral(fieldExpr.fieldIndex(), *fContext.fTypes.fInt));
3047             return chain;
3048         }
3049         default: {
3050             SpvId id = this->getLValue(expr, out)->getPointer();
3051             SkASSERT(id != NA);
3052             return TArray<SpvId>{id};
3053         }
3054     }
3055     SkUNREACHABLE;
3056 }
3057 
3058 class PointerLValue : public SPIRVCodeGenerator::LValue {
3059 public:
PointerLValue(SPIRVCodeGenerator & gen,SpvId pointer,bool isMemoryObject,SpvId type,SPIRVCodeGenerator::Precision precision,SpvStorageClass_ storageClass)3060     PointerLValue(SPIRVCodeGenerator& gen, SpvId pointer, bool isMemoryObject, SpvId type,
3061                   SPIRVCodeGenerator::Precision precision, SpvStorageClass_ storageClass)
3062     : fGen(gen)
3063     , fPointer(pointer)
3064     , fIsMemoryObject(isMemoryObject)
3065     , fType(type)
3066     , fPrecision(precision)
3067     , fStorageClass(storageClass) {}
3068 
getPointer()3069     SpvId getPointer() override {
3070         return fPointer;
3071     }
3072 
isMemoryObjectPointer() const3073     bool isMemoryObjectPointer() const override {
3074         return fIsMemoryObject;
3075     }
3076 
storageClass() const3077     SpvStorageClass storageClass() const override {
3078         return fStorageClass;
3079     }
3080 
load(OutputStream & out)3081     SpvId load(OutputStream& out) override {
3082         return fGen.writeOpLoad(fType, fPrecision, fPointer, out);
3083     }
3084 
store(SpvId value,OutputStream & out)3085     void store(SpvId value, OutputStream& out) override {
3086         if (!fIsMemoryObject) {
3087             // We are going to write into an access chain; this could represent one component of a
3088             // vector, or one element of an array. This has the potential to invalidate other,
3089             // *unknown* elements of our store cache. (e.g. if the store cache holds `%50 = myVec4`,
3090             // and we store `%60 = myVec4.z`, this invalidates the cached value for %50.) To avoid
3091             // relying on stale data, reset the store cache entirely when this happens.
3092             fGen.fStoreCache.reset();
3093         }
3094 
3095         fGen.writeOpStore(fStorageClass, fPointer, value, out);
3096     }
3097 
3098 private:
3099     SPIRVCodeGenerator& fGen;
3100     const SpvId fPointer;
3101     const bool fIsMemoryObject;
3102     const SpvId fType;
3103     const SPIRVCodeGenerator::Precision fPrecision;
3104     const SpvStorageClass_ fStorageClass;
3105 };
3106 
3107 class SwizzleLValue : public SPIRVCodeGenerator::LValue {
3108 public:
SwizzleLValue(SPIRVCodeGenerator & gen,SpvId vecPointer,const ComponentArray & components,const Type & baseType,const Type & swizzleType,SpvStorageClass_ storageClass)3109     SwizzleLValue(SPIRVCodeGenerator& gen, SpvId vecPointer, const ComponentArray& components,
3110                   const Type& baseType, const Type& swizzleType, SpvStorageClass_ storageClass)
3111     : fGen(gen)
3112     , fVecPointer(vecPointer)
3113     , fComponents(components)
3114     , fBaseType(&baseType)
3115     , fSwizzleType(&swizzleType)
3116     , fStorageClass(storageClass) {}
3117 
applySwizzle(const ComponentArray & components,const Type & newType)3118     bool applySwizzle(const ComponentArray& components, const Type& newType) override {
3119         ComponentArray updatedSwizzle;
3120         for (int8_t component : components) {
3121             if (component < 0 || component >= fComponents.size()) {
3122                 SkDEBUGFAILF("swizzle accessed nonexistent component %d", (int)component);
3123                 return false;
3124             }
3125             updatedSwizzle.push_back(fComponents[component]);
3126         }
3127         fComponents = updatedSwizzle;
3128         fSwizzleType = &newType;
3129         return true;
3130     }
3131 
storageClass() const3132     SpvStorageClass storageClass() const override {
3133         return fStorageClass;
3134     }
3135 
load(OutputStream & out)3136     SpvId load(OutputStream& out) override {
3137         SpvId base = fGen.nextId(fBaseType);
3138         fGen.writeInstruction(SpvOpLoad, fGen.getType(*fBaseType), base, fVecPointer, out);
3139         SpvId result = fGen.nextId(fBaseType);
3140         fGen.writeOpCode(SpvOpVectorShuffle, 5 + (int32_t) fComponents.size(), out);
3141         fGen.writeWord(fGen.getType(*fSwizzleType), out);
3142         fGen.writeWord(result, out);
3143         fGen.writeWord(base, out);
3144         fGen.writeWord(base, out);
3145         for (int component : fComponents) {
3146             fGen.writeWord(component, out);
3147         }
3148         return result;
3149     }
3150 
store(SpvId value,OutputStream & out)3151     void store(SpvId value, OutputStream& out) override {
3152         // use OpVectorShuffle to mix and match the vector components. We effectively create
3153         // a virtual vector out of the concatenation of the left and right vectors, and then
3154         // select components from this virtual vector to make the result vector. For
3155         // instance, given:
3156         // float3L = ...;
3157         // float3R = ...;
3158         // L.xz = R.xy;
3159         // we end up with the virtual vector (L.x, L.y, L.z, R.x, R.y, R.z). Then we want
3160         // our result vector to look like (R.x, L.y, R.y), so we need to select indices
3161         // (3, 1, 4).
3162         SpvId base = fGen.nextId(fBaseType);
3163         fGen.writeInstruction(SpvOpLoad, fGen.getType(*fBaseType), base, fVecPointer, out);
3164         SpvId shuffle = fGen.nextId(fBaseType);
3165         fGen.writeOpCode(SpvOpVectorShuffle, 5 + fBaseType->columns(), out);
3166         fGen.writeWord(fGen.getType(*fBaseType), out);
3167         fGen.writeWord(shuffle, out);
3168         fGen.writeWord(base, out);
3169         fGen.writeWord(value, out);
3170         for (int i = 0; i < fBaseType->columns(); i++) {
3171             // current offset into the virtual vector, defaults to pulling the unmodified
3172             // value from the left side
3173             int offset = i;
3174             // check to see if we are writing this component
3175             for (int j = 0; j < fComponents.size(); j++) {
3176                 if (fComponents[j] == i) {
3177                     // we're writing to this component, so adjust the offset to pull from
3178                     // the correct component of the right side instead of preserving the
3179                     // value from the left
3180                     offset = (int) (j + fBaseType->columns());
3181                     break;
3182                 }
3183             }
3184             fGen.writeWord(offset, out);
3185         }
3186         fGen.writeOpStore(fStorageClass, fVecPointer, shuffle, out);
3187     }
3188 
3189 private:
3190     SPIRVCodeGenerator& fGen;
3191     const SpvId fVecPointer;
3192     ComponentArray fComponents;
3193     const Type* fBaseType;
3194     const Type* fSwizzleType;
3195     const SpvStorageClass_ fStorageClass;
3196 };
3197 
findUniformFieldIndex(const Variable & var) const3198 int SPIRVCodeGenerator::findUniformFieldIndex(const Variable& var) const {
3199     int* fieldIndex = fTopLevelUniformMap.find(&var);
3200     return fieldIndex ? *fieldIndex : -1;
3201 }
3202 
getLValue(const Expression & expr,OutputStream & out)3203 std::unique_ptr<SPIRVCodeGenerator::LValue> SPIRVCodeGenerator::getLValue(const Expression& expr,
3204                                                                           OutputStream& out) {
3205     const Type& type = expr.type();
3206     Precision precision = type.highPrecision() ? Precision::kDefault : Precision::kRelaxed;
3207     switch (expr.kind()) {
3208         case Expression::Kind::kVariableReference: {
3209             const Variable& var = *expr.as<VariableReference>().variable();
3210             int uniformIdx = this->findUniformFieldIndex(var);
3211             if (uniformIdx >= 0) {
3212                 // Access uniforms via an AccessChain into the uniform-buffer struct.
3213                 SpvId memberId = this->nextId(nullptr);
3214                 SpvId typeId = this->getPointerType(type, SpvStorageClassUniform);
3215                 SpvId uniformIdxId = this->writeLiteral((double)uniformIdx, *fContext.fTypes.fInt);
3216                 this->writeInstruction(SpvOpAccessChain, typeId, memberId, fUniformBufferId,
3217                                        uniformIdxId, out);
3218                 return std::make_unique<PointerLValue>(
3219                         *this,
3220                         memberId,
3221                         /*isMemoryObjectPointer=*/true,
3222                         this->getType(type, kDefaultTypeLayout, this->memoryLayoutForVariable(var)),
3223                         precision,
3224                         SpvStorageClassUniform);
3225             }
3226 
3227             SpvId* entry = fVariableMap.find(&var);
3228             SkASSERTF(entry, "%s", expr.description().c_str());
3229 
3230             if (var.layout().fBuiltin == SK_SAMPLEMASKIN_BUILTIN ||
3231                 var.layout().fBuiltin == SK_SAMPLEMASK_BUILTIN) {
3232                 // Access sk_SampleMask and sk_SampleMaskIn via an array access, since Vulkan
3233                 // represents sample masks as an array of uints.
3234                 SpvStorageClass_ storageClass =
3235                         get_storage_class_for_global_variable(var, SpvStorageClassPrivate);
3236                 SkASSERT(storageClass != SpvStorageClassPrivate);
3237                 SkASSERT(type.matches(*fContext.fTypes.fUInt));
3238 
3239                 SpvId accessId = this->nextId(nullptr);
3240                 SpvId typeId = this->getPointerType(type, storageClass);
3241                 SpvId indexId = this->writeLiteral(0.0, *fContext.fTypes.fInt);
3242                 this->writeInstruction(SpvOpAccessChain, typeId, accessId, *entry, indexId, out);
3243                 return std::make_unique<PointerLValue>(*this,
3244                                                        accessId,
3245                                                        /*isMemoryObjectPointer=*/true,
3246                                                        this->getType(type),
3247                                                        precision,
3248                                                        storageClass);
3249             }
3250 
3251             SpvId typeId = this->getType(type, var.layout(), this->memoryLayoutForVariable(var));
3252             return std::make_unique<PointerLValue>(*this, *entry,
3253                                                    /*isMemoryObjectPointer=*/true,
3254                                                    typeId, precision, get_storage_class(expr));
3255         }
3256         case Expression::Kind::kIndex: // fall through
3257         case Expression::Kind::kFieldAccess: {
3258             TArray<SpvId> chain = this->getAccessChain(expr, out);
3259             SpvId member = this->nextId(nullptr);
3260             SpvStorageClass_ storageClass = get_storage_class(expr);
3261             this->writeOpCode(SpvOpAccessChain, (SpvId) (3 + chain.size()), out);
3262             this->writeWord(this->getPointerType(type, storageClass), out);
3263             this->writeWord(member, out);
3264             for (SpvId idx : chain) {
3265                 this->writeWord(idx, out);
3266             }
3267             return std::make_unique<PointerLValue>(
3268                     *this,
3269                     member,
3270                     /*isMemoryObjectPointer=*/false,
3271                     this->getType(type,
3272                                   kDefaultTypeLayout,
3273                                   this->memoryLayoutForStorageClass(storageClass)),
3274                     precision,
3275                     storageClass);
3276         }
3277         case Expression::Kind::kSwizzle: {
3278             const Swizzle& swizzle = expr.as<Swizzle>();
3279             std::unique_ptr<LValue> lvalue = this->getLValue(*swizzle.base(), out);
3280             if (lvalue->applySwizzle(swizzle.components(), type)) {
3281                 return lvalue;
3282             }
3283             SpvId base = lvalue->getPointer();
3284             if (base == NA) {
3285                 fContext.fErrors->error(swizzle.fPosition,
3286                         "unable to retrieve lvalue from swizzle");
3287             }
3288             SpvStorageClass_ storageClass = get_storage_class(*swizzle.base());
3289             if (swizzle.components().size() == 1) {
3290                 SpvId member = this->nextId(nullptr);
3291                 SpvId typeId = this->getPointerType(type, storageClass);
3292                 SpvId indexId = this->writeLiteral(swizzle.components()[0], *fContext.fTypes.fInt);
3293                 this->writeInstruction(SpvOpAccessChain, typeId, member, base, indexId, out);
3294                 return std::make_unique<PointerLValue>(*this, member,
3295                                                        /*isMemoryObjectPointer=*/false,
3296                                                        this->getType(type),
3297                                                        precision, storageClass);
3298             } else {
3299                 return std::make_unique<SwizzleLValue>(*this, base, swizzle.components(),
3300                                                        swizzle.base()->type(), type, storageClass);
3301             }
3302         }
3303         default: {
3304             // expr isn't actually an lvalue, create a placeholder variable for it. This case
3305             // happens due to the need to store values in temporary variables during function
3306             // calls (see comments in getFunctionParameterType); erroneous uses of rvalues as
3307             // lvalues should have been caught before code generation.
3308             //
3309             // This is with the exception of opaque handle types (textures/samplers) which are
3310             // always defined as UniformConstant pointers and don't need to be explicitly stored
3311             // into a temporary (which is handled explicitly in writeFunctionCallArgument).
3312             SpvId result = this->nextId(nullptr);
3313             SpvId pointerType = this->getPointerType(type, SpvStorageClassFunction);
3314             this->writeInstruction(SpvOpVariable, pointerType, result, SpvStorageClassFunction,
3315                                    fVariableBuffer);
3316             this->writeOpStore(SpvStorageClassFunction, result, this->writeExpression(expr, out),
3317                                out);
3318             return std::make_unique<PointerLValue>(*this, result, /*isMemoryObjectPointer=*/true,
3319                                                    this->getType(type), precision,
3320                                                    SpvStorageClassFunction);
3321         }
3322     }
3323 }
3324 
identifier(std::string_view name)3325 std::unique_ptr<Expression> SPIRVCodeGenerator::identifier(std::string_view name) {
3326     std::unique_ptr<Expression> expr =
3327             fProgram.fSymbols->instantiateSymbolRef(fContext, name, Position());
3328     return expr ? std::move(expr)
3329                 : Poison::Make(Position(), fContext);
3330 }
3331 
writeVariableReference(const VariableReference & ref,OutputStream & out)3332 SpvId SPIRVCodeGenerator::writeVariableReference(const VariableReference& ref, OutputStream& out) {
3333     const Variable* variable = ref.variable();
3334     switch (variable->layout().fBuiltin) {
3335         case DEVICE_FRAGCOORDS_BUILTIN: {
3336             // Down below, we rewrite raw references to sk_FragCoord with expressions that reference
3337             // DEVICE_FRAGCOORDS_BUILTIN. This is a fake variable that means we need to directly
3338             // access the fragcoord; do so now.
3339             return this->getLValue(*this->identifier("sk_FragCoord"), out)->load(out);
3340         }
3341         case DEVICE_CLOCKWISE_BUILTIN: {
3342             // Down below, we rewrite raw references to sk_Clockwise with expressions that reference
3343             // DEVICE_CLOCKWISE_BUILTIN. This is a fake variable that means we need to directly
3344             // access front facing; do so now.
3345             return this->getLValue(*this->identifier("sk_Clockwise"), out)->load(out);
3346         }
3347         case SK_SECONDARYFRAGCOLOR_BUILTIN: {
3348             // sk_SecondaryFragColor corresponds to gl_SecondaryFragColorEXT, which isn't supposed
3349             // to appear in a SPIR-V program (it's only valid in ES2). Report an error.
3350             fContext.fErrors->error(ref.fPosition,
3351                                     "sk_SecondaryFragColor is not allowed in SPIR-V");
3352             return NA;
3353         }
3354         case SK_FRAGCOORD_BUILTIN: {
3355             if (fProgram.fConfig->fSettings.fForceNoRTFlip) {
3356                 return this->getLValue(*this->identifier("sk_FragCoord"), out)->load(out);
3357             }
3358 
3359             // Handle inserting use of uniform to flip y when referencing sk_FragCoord.
3360             this->addRTFlipUniform(ref.fPosition);
3361             // Use sk_RTAdjust to compute the flipped coordinate
3362             // Use a uniform to flip the Y coordinate. The new expression will be written in
3363             // terms of $device_FragCoords, which is a fake variable that means "access the
3364             // underlying fragcoords directly without flipping it".
3365             static constexpr char DEVICE_COORDS_NAME[] = "$device_FragCoords";
3366             if (!fProgram.fSymbols->find(DEVICE_COORDS_NAME)) {
3367                 AutoAttachPoolToThread attach(fProgram.fPool.get());
3368                 Layout layout;
3369                 layout.fBuiltin = DEVICE_FRAGCOORDS_BUILTIN;
3370                 auto coordsVar = Variable::Make(/*pos=*/Position(),
3371                                                 /*modifiersPosition=*/Position(),
3372                                                 layout,
3373                                                 ModifierFlag::kNone,
3374                                                 fContext.fTypes.fFloat4.get(),
3375                                                 DEVICE_COORDS_NAME,
3376                                                 /*mangledName=*/"",
3377                                                 /*builtin=*/true,
3378                                                 Variable::Storage::kGlobal);
3379                 fProgram.fSymbols->add(fContext, std::move(coordsVar));
3380             }
3381             std::unique_ptr<Expression> deviceCoord = this->identifier(DEVICE_COORDS_NAME);
3382             std::unique_ptr<Expression> rtFlip = this->identifier(SKSL_RTFLIP_NAME);
3383             SpvId rtFlipX = this->writeSwizzle(*rtFlip, {SwizzleComponent::X}, out);
3384             SpvId rtFlipY = this->writeSwizzle(*rtFlip, {SwizzleComponent::Y}, out);
3385             SpvId deviceCoordX  = this->writeSwizzle(*deviceCoord, {SwizzleComponent::X}, out);
3386             SpvId deviceCoordY  = this->writeSwizzle(*deviceCoord, {SwizzleComponent::Y}, out);
3387             SpvId deviceCoordZW = this->writeSwizzle(*deviceCoord, {SwizzleComponent::Z,
3388                                                                     SwizzleComponent::W}, out);
3389             // Compute `flippedY = u_RTFlip.y * $device_FragCoords.y`.
3390             SpvId flippedY = this->writeBinaryExpression(
3391                                      *fContext.fTypes.fFloat, rtFlipY, OperatorKind::STAR,
3392                                      *fContext.fTypes.fFloat, deviceCoordY,
3393                                      *fContext.fTypes.fFloat, out);
3394 
3395             // Compute `flippedY = u_RTFlip.x + flippedY`.
3396             flippedY = this->writeBinaryExpression(
3397                                *fContext.fTypes.fFloat, rtFlipX, OperatorKind::PLUS,
3398                                *fContext.fTypes.fFloat, flippedY,
3399                                *fContext.fTypes.fFloat, out);
3400 
3401             // Return `float4(deviceCoord.x, flippedY, deviceCoord.zw)`.
3402             return this->writeOpCompositeConstruct(*fContext.fTypes.fFloat4,
3403                                                    {deviceCoordX, flippedY, deviceCoordZW},
3404                                                    out);
3405         }
3406         case SK_CLOCKWISE_BUILTIN: {
3407             if (fProgram.fConfig->fSettings.fForceNoRTFlip) {
3408                 return this->getLValue(*this->identifier("sk_Clockwise"), out)->load(out);
3409             }
3410 
3411             // Apply RTFlip to sk_Clockwise.
3412             this->addRTFlipUniform(ref.fPosition);
3413             // Use a uniform to flip the Y coordinate. The new expression will be written in
3414             // terms of $device_Clockwise, which is a fake variable that means "access the
3415             // underlying FrontFacing directly".
3416             static constexpr char DEVICE_CLOCKWISE_NAME[] = "$device_Clockwise";
3417             if (!fProgram.fSymbols->find(DEVICE_CLOCKWISE_NAME)) {
3418                 AutoAttachPoolToThread attach(fProgram.fPool.get());
3419                 Layout layout;
3420                 layout.fBuiltin = DEVICE_CLOCKWISE_BUILTIN;
3421                 auto clockwiseVar = Variable::Make(/*pos=*/Position(),
3422                                                    /*modifiersPosition=*/Position(),
3423                                                    layout,
3424                                                    ModifierFlag::kNone,
3425                                                    fContext.fTypes.fBool.get(),
3426                                                    DEVICE_CLOCKWISE_NAME,
3427                                                    /*mangledName=*/"",
3428                                                    /*builtin=*/true,
3429                                                    Variable::Storage::kGlobal);
3430                 fProgram.fSymbols->add(fContext, std::move(clockwiseVar));
3431             }
3432             // FrontFacing in Vulkan is defined in terms of a top-down render target. In Skia,
3433             // we use the default convention of "counter-clockwise face is front".
3434 
3435             // Compute `positiveRTFlip = (rtFlip.y > 0)`.
3436             std::unique_ptr<Expression> rtFlip = this->identifier(SKSL_RTFLIP_NAME);
3437             SpvId rtFlipY = this->writeSwizzle(*rtFlip, {SwizzleComponent::Y}, out);
3438             SpvId zero = this->writeLiteral(0.0, *fContext.fTypes.fFloat);
3439             SpvId positiveRTFlip = this->writeBinaryExpression(
3440                                            *fContext.fTypes.fFloat, rtFlipY, OperatorKind::GT,
3441                                            *fContext.fTypes.fFloat, zero,
3442                                            *fContext.fTypes.fBool, out);
3443 
3444             // Compute `positiveRTFlip ^^ $device_Clockwise`.
3445             std::unique_ptr<Expression> deviceClockwise = this->identifier(DEVICE_CLOCKWISE_NAME);
3446             SpvId deviceClockwiseID = this->writeExpression(*deviceClockwise, out);
3447             return this->writeBinaryExpression(
3448                            *fContext.fTypes.fBool, positiveRTFlip, OperatorKind::LOGICALXOR,
3449                            *fContext.fTypes.fBool, deviceClockwiseID,
3450                            *fContext.fTypes.fBool, out);
3451         }
3452         default: {
3453             // Constant-propagate variables that have a known compile-time value.
3454             if (const Expression* expr = ConstantFolder::GetConstantValueOrNull(ref)) {
3455                 return this->writeExpression(*expr, out);
3456             }
3457 
3458             // A reference to a sampler variable at global scope with synthesized texture/sampler
3459             // backing should construct a function-scope combined image-sampler from the synthesized
3460             // constituents. This is the case in which a sample intrinsic was invoked.
3461             //
3462             // Variable references to opaque handles (texture/sampler) that appear as the argument
3463             // of a user-defined function call are explicitly handled in writeFunctionCallArgument.
3464             if (fUseTextureSamplerPairs && variable->type().isSampler()) {
3465                 if (const auto* p = fSynthesizedSamplerMap.find(variable)) {
3466                     SpvId* imgPtr = fVariableMap.find((*p)->fTexture.get());
3467                     SpvId* samplerPtr = fVariableMap.find((*p)->fSampler.get());
3468                     SkASSERT(imgPtr);
3469                     SkASSERT(samplerPtr);
3470 
3471                     SpvId img = this->writeOpLoad(this->getType((*p)->fTexture->type()),
3472                                                   Precision::kDefault, *imgPtr, out);
3473                     SpvId sampler = this->writeOpLoad(this->getType((*p)->fSampler->type()),
3474                                                       Precision::kDefault,
3475                                                       *samplerPtr,
3476                                                       out);
3477                     SpvId result = this->nextId(nullptr);
3478                     this->writeInstruction(SpvOpSampledImage,
3479                                            this->getType(variable->type()),
3480                                            result,
3481                                            img,
3482                                            sampler,
3483                                            out);
3484                     return result;
3485                 }
3486                 SkDEBUGFAIL("sampler missing from fSynthesizedSamplerMap");
3487             }
3488             return this->getLValue(ref, out)->load(out);
3489         }
3490     }
3491 }
3492 
writeIndexExpression(const IndexExpression & expr,OutputStream & out)3493 SpvId SPIRVCodeGenerator::writeIndexExpression(const IndexExpression& expr, OutputStream& out) {
3494     if (expr.base()->type().isVector()) {
3495         SpvId base = this->writeExpression(*expr.base(), out);
3496         SpvId index = this->writeExpression(*expr.index(), out);
3497         SpvId result = this->nextId(nullptr);
3498         this->writeInstruction(SpvOpVectorExtractDynamic, this->getType(expr.type()), result, base,
3499                                index, out);
3500         return result;
3501     }
3502     return getLValue(expr, out)->load(out);
3503 }
3504 
writeFieldAccess(const FieldAccess & f,OutputStream & out)3505 SpvId SPIRVCodeGenerator::writeFieldAccess(const FieldAccess& f, OutputStream& out) {
3506     return getLValue(f, out)->load(out);
3507 }
3508 
writeSwizzle(const Expression & baseExpr,const ComponentArray & components,OutputStream & out)3509 SpvId SPIRVCodeGenerator::writeSwizzle(const Expression& baseExpr,
3510                                        const ComponentArray& components,
3511                                        OutputStream& out) {
3512     size_t count = components.size();
3513     const Type& type = baseExpr.type().componentType().toCompound(fContext, count, /*rows=*/1);
3514     SpvId base = this->writeExpression(baseExpr, out);
3515     if (count == 1) {
3516         return this->writeOpCompositeExtract(type, base, components[0], out);
3517     }
3518 
3519     SpvId result = this->nextId(&type);
3520     this->writeOpCode(SpvOpVectorShuffle, 5 + (int32_t) count, out);
3521     this->writeWord(this->getType(type), out);
3522     this->writeWord(result, out);
3523     this->writeWord(base, out);
3524     this->writeWord(base, out);
3525     for (int component : components) {
3526         this->writeWord(component, out);
3527     }
3528     return result;
3529 }
3530 
writeSwizzle(const Swizzle & swizzle,OutputStream & out)3531 SpvId SPIRVCodeGenerator::writeSwizzle(const Swizzle& swizzle, OutputStream& out) {
3532     return this->writeSwizzle(*swizzle.base(), swizzle.components(), out);
3533 }
3534 
writeBinaryOperation(const Type & resultType,const Type & operandType,SpvId lhs,SpvId rhs,bool writeComponentwiseIfMatrix,SpvOp_ ifFloat,SpvOp_ ifInt,SpvOp_ ifUInt,SpvOp_ ifBool,OutputStream & out)3535 SpvId SPIRVCodeGenerator::writeBinaryOperation(const Type& resultType, const Type& operandType,
3536                                                SpvId lhs, SpvId rhs,
3537                                                bool writeComponentwiseIfMatrix,
3538                                                SpvOp_ ifFloat, SpvOp_ ifInt, SpvOp_ ifUInt,
3539                                                SpvOp_ ifBool, OutputStream& out) {
3540     SpvOp_ op = pick_by_type(operandType, ifFloat, ifInt, ifUInt, ifBool);
3541     if (op == SpvOpUndef) {
3542         fContext.fErrors->error(operandType.fPosition,
3543                 "unsupported operand for binary expression: " + operandType.description());
3544         return NA;
3545     }
3546     if (writeComponentwiseIfMatrix && operandType.isMatrix()) {
3547         return this->writeComponentwiseMatrixBinary(resultType, lhs, rhs, op, out);
3548     }
3549     SpvId result = this->nextId(&resultType);
3550     this->writeInstruction(op, this->getType(resultType), result, lhs, rhs, out);
3551     return result;
3552 }
3553 
writeBinaryOperationComponentwiseIfMatrix(const Type & resultType,const Type & operandType,SpvId lhs,SpvId rhs,SpvOp_ ifFloat,SpvOp_ ifInt,SpvOp_ ifUInt,SpvOp_ ifBool,OutputStream & out)3554 SpvId SPIRVCodeGenerator::writeBinaryOperationComponentwiseIfMatrix(const Type& resultType,
3555                                                                     const Type& operandType,
3556                                                                     SpvId lhs, SpvId rhs,
3557                                                                     SpvOp_ ifFloat, SpvOp_ ifInt,
3558                                                                     SpvOp_ ifUInt, SpvOp_ ifBool,
3559                                                                     OutputStream& out) {
3560     return this->writeBinaryOperation(resultType, operandType, lhs, rhs,
3561                                       /*writeComponentwiseIfMatrix=*/true,
3562                                       ifFloat, ifInt, ifUInt, ifBool, out);
3563 }
3564 
writeBinaryOperation(const Type & resultType,const Type & operandType,SpvId lhs,SpvId rhs,SpvOp_ ifFloat,SpvOp_ ifInt,SpvOp_ ifUInt,SpvOp_ ifBool,OutputStream & out)3565 SpvId SPIRVCodeGenerator::writeBinaryOperation(const Type& resultType, const Type& operandType,
3566                                                SpvId lhs, SpvId rhs, SpvOp_ ifFloat, SpvOp_ ifInt,
3567                                                SpvOp_ ifUInt, SpvOp_ ifBool, OutputStream& out) {
3568     return this->writeBinaryOperation(resultType, operandType, lhs, rhs,
3569                                       /*writeComponentwiseIfMatrix=*/false,
3570                                       ifFloat, ifInt, ifUInt, ifBool, out);
3571 }
3572 
foldToBool(SpvId id,const Type & operandType,SpvOp op,OutputStream & out)3573 SpvId SPIRVCodeGenerator::foldToBool(SpvId id, const Type& operandType, SpvOp op,
3574                                      OutputStream& out) {
3575     if (operandType.isVector()) {
3576         SpvId result = this->nextId(nullptr);
3577         this->writeInstruction(op, this->getType(*fContext.fTypes.fBool), result, id, out);
3578         return result;
3579     }
3580     return id;
3581 }
3582 
writeMatrixComparison(const Type & operandType,SpvId lhs,SpvId rhs,SpvOp_ floatOperator,SpvOp_ intOperator,SpvOp_ vectorMergeOperator,SpvOp_ mergeOperator,OutputStream & out)3583 SpvId SPIRVCodeGenerator::writeMatrixComparison(const Type& operandType, SpvId lhs, SpvId rhs,
3584                                                 SpvOp_ floatOperator, SpvOp_ intOperator,
3585                                                 SpvOp_ vectorMergeOperator, SpvOp_ mergeOperator,
3586                                                 OutputStream& out) {
3587     SpvOp_ compareOp = is_float(operandType) ? floatOperator : intOperator;
3588     SkASSERT(operandType.isMatrix());
3589     const Type& columnType = operandType.componentType().toCompound(fContext,
3590                                                                     operandType.rows(),
3591                                                                     1);
3592     SpvId bvecType = this->getType(fContext.fTypes.fBool->toCompound(fContext,
3593                                                                      operandType.rows(),
3594                                                                      1));
3595     SpvId boolType = this->getType(*fContext.fTypes.fBool);
3596     SpvId result = 0;
3597     for (int i = 0; i < operandType.columns(); i++) {
3598         SpvId columnL = this->writeOpCompositeExtract(columnType, lhs, i, out);
3599         SpvId columnR = this->writeOpCompositeExtract(columnType, rhs, i, out);
3600         SpvId compare = this->nextId(&operandType);
3601         this->writeInstruction(compareOp, bvecType, compare, columnL, columnR, out);
3602         SpvId merge = this->nextId(nullptr);
3603         this->writeInstruction(vectorMergeOperator, boolType, merge, compare, out);
3604         if (result != 0) {
3605             SpvId next = this->nextId(nullptr);
3606             this->writeInstruction(mergeOperator, boolType, next, result, merge, out);
3607             result = next;
3608         } else {
3609             result = merge;
3610         }
3611     }
3612     return result;
3613 }
3614 
writeComponentwiseMatrixUnary(const Type & operandType,SpvId operand,SpvOp_ op,OutputStream & out)3615 SpvId SPIRVCodeGenerator::writeComponentwiseMatrixUnary(const Type& operandType,
3616                                                         SpvId operand,
3617                                                         SpvOp_ op,
3618                                                         OutputStream& out) {
3619     SkASSERT(operandType.isMatrix());
3620     const Type& columnType = operandType.columnType(fContext);
3621     SpvId columnTypeId = this->getType(columnType);
3622 
3623     STArray<4, SpvId> columns;
3624     for (int i = 0; i < operandType.columns(); i++) {
3625         SpvId srcColumn = this->writeOpCompositeExtract(columnType, operand, i, out);
3626         SpvId dstColumn = this->nextId(&operandType);
3627         this->writeInstruction(op, columnTypeId, dstColumn, srcColumn, out);
3628         columns.push_back(dstColumn);
3629     }
3630 
3631     return this->writeOpCompositeConstruct(operandType, columns, out);
3632 }
3633 
writeComponentwiseMatrixBinary(const Type & operandType,SpvId lhs,SpvId rhs,SpvOp_ op,OutputStream & out)3634 SpvId SPIRVCodeGenerator::writeComponentwiseMatrixBinary(const Type& operandType, SpvId lhs,
3635                                                          SpvId rhs, SpvOp_ op, OutputStream& out) {
3636     SkASSERT(operandType.isMatrix());
3637     const Type& columnType = operandType.columnType(fContext);
3638     SpvId columnTypeId = this->getType(columnType);
3639 
3640     STArray<4, SpvId> columns;
3641     for (int i = 0; i < operandType.columns(); i++) {
3642         SpvId columnL = this->writeOpCompositeExtract(columnType, lhs, i, out);
3643         SpvId columnR = this->writeOpCompositeExtract(columnType, rhs, i, out);
3644         columns.push_back(this->nextId(&operandType));
3645         this->writeInstruction(op, columnTypeId, columns[i], columnL, columnR, out);
3646     }
3647     return this->writeOpCompositeConstruct(operandType, columns, out);
3648 }
3649 
writeReciprocal(const Type & type,SpvId value,OutputStream & out)3650 SpvId SPIRVCodeGenerator::writeReciprocal(const Type& type, SpvId value, OutputStream& out) {
3651     SkASSERT(type.isFloat());
3652     SpvId one = this->writeLiteral(1.0, type);
3653     SpvId reciprocal = this->nextId(&type);
3654     this->writeInstruction(SpvOpFDiv, this->getType(type), reciprocal, one, value, out);
3655     return reciprocal;
3656 }
3657 
splat(const Type & type,SpvId id,OutputStream & out)3658 SpvId SPIRVCodeGenerator::splat(const Type& type, SpvId id, OutputStream& out) {
3659     if (type.isScalar()) {
3660         // Scalars require no additional work; we can return the passed-in ID as is.
3661     } else {
3662         SkASSERT(type.isVector() || type.isMatrix());
3663         bool isMatrix = type.isMatrix();
3664 
3665         // Splat the input scalar across a vector.
3666         int vectorSize = (isMatrix ? type.rows() : type.columns());
3667         const Type& vectorType = type.componentType().toCompound(fContext, vectorSize, /*rows=*/1);
3668 
3669         STArray<4, SpvId> values;
3670         values.push_back_n(/*n=*/vectorSize, /*t=*/id);
3671         id = this->writeOpCompositeConstruct(vectorType, values, out);
3672 
3673         if (isMatrix) {
3674             // Splat the newly-synthesized vector into a matrix.
3675             STArray<4, SpvId> matArguments;
3676             matArguments.push_back_n(/*n=*/type.columns(), /*t=*/id);
3677             id = this->writeOpCompositeConstruct(type, matArguments, out);
3678         }
3679     }
3680 
3681     return id;
3682 }
3683 
types_match(const Type & a,const Type & b)3684 static bool types_match(const Type& a, const Type& b) {
3685     if (a.matches(b)) {
3686         return true;
3687     }
3688     return (a.typeKind() == b.typeKind()) &&
3689            (a.isScalar() || a.isVector() || a.isMatrix()) &&
3690            (a.columns() == b.columns() && a.rows() == b.rows()) &&
3691            a.componentType().numberKind() == b.componentType().numberKind();
3692 }
3693 
writeDecomposedMatrixVectorMultiply(const Type & leftType,SpvId lhs,const Type & rightType,SpvId rhs,const Type & resultType,OutputStream & out)3694 SpvId SPIRVCodeGenerator::writeDecomposedMatrixVectorMultiply(const Type& leftType,
3695                                                               SpvId lhs,
3696                                                               const Type& rightType,
3697                                                               SpvId rhs,
3698                                                               const Type& resultType,
3699                                                               OutputStream& out) {
3700     SpvId sum = NA;
3701     const Type& columnType = leftType.columnType(fContext);
3702     const Type& scalarType = rightType.componentType();
3703 
3704     for (int n = 0; n < leftType.rows(); ++n) {
3705         // Extract mat[N] from the matrix.
3706         SpvId matN = this->writeOpCompositeExtract(columnType, lhs, n, out);
3707 
3708         // Extract vec[N] from the vector.
3709         SpvId vecN = this->writeOpCompositeExtract(scalarType, rhs, n, out);
3710 
3711         // Multiply them together.
3712         SpvId product = this->writeBinaryExpression(columnType, matN, OperatorKind::STAR,
3713                                                     scalarType, vecN,
3714                                                     columnType, out);
3715 
3716         // Sum all the components together.
3717         if (sum == NA) {
3718             sum = product;
3719         } else {
3720             sum = this->writeBinaryExpression(columnType, sum, OperatorKind::PLUS,
3721                                               columnType, product,
3722                                               columnType, out);
3723         }
3724     }
3725 
3726     return sum;
3727 }
3728 
writeBinaryExpression(const Type & leftType,SpvId lhs,Operator op,const Type & rightType,SpvId rhs,const Type & resultType,OutputStream & out)3729 SpvId SPIRVCodeGenerator::writeBinaryExpression(const Type& leftType, SpvId lhs, Operator op,
3730                                                 const Type& rightType, SpvId rhs,
3731                                                 const Type& resultType, OutputStream& out) {
3732     // The comma operator ignores the type of the left-hand side entirely.
3733     if (op.kind() == Operator::Kind::COMMA) {
3734         return rhs;
3735     }
3736     // overall type we are operating on: float2, int, uint4...
3737     const Type* operandType;
3738     if (types_match(leftType, rightType)) {
3739         operandType = &leftType;
3740     } else {
3741         // IR allows mismatched types in expressions (e.g. float2 * float), but they need special
3742         // handling in SPIR-V
3743         if (leftType.isVector() && rightType.isNumber()) {
3744             if (resultType.componentType().isFloat()) {
3745                 switch (op.kind()) {
3746                     case Operator::Kind::SLASH: {
3747                         rhs = this->writeReciprocal(rightType, rhs, out);
3748                         [[fallthrough]];
3749                     }
3750                     case Operator::Kind::STAR: {
3751                         SpvId result = this->nextId(&resultType);
3752                         this->writeInstruction(SpvOpVectorTimesScalar, this->getType(resultType),
3753                                                result, lhs, rhs, out);
3754                         return result;
3755                     }
3756                     default:
3757                         break;
3758                 }
3759             }
3760             // Vectorize the right-hand side.
3761             STArray<4, SpvId> arguments;
3762             arguments.push_back_n(/*n=*/leftType.columns(), /*t=*/rhs);
3763             rhs = this->writeOpCompositeConstruct(leftType, arguments, out);
3764             operandType = &leftType;
3765         } else if (rightType.isVector() && leftType.isNumber()) {
3766             if (resultType.componentType().isFloat()) {
3767                 if (op.kind() == Operator::Kind::STAR) {
3768                     SpvId result = this->nextId(&resultType);
3769                     this->writeInstruction(SpvOpVectorTimesScalar, this->getType(resultType),
3770                                            result, rhs, lhs, out);
3771                     return result;
3772                 }
3773             }
3774             // Vectorize the left-hand side.
3775             STArray<4, SpvId> arguments;
3776             arguments.push_back_n(/*n=*/rightType.columns(), /*t=*/lhs);
3777             lhs = this->writeOpCompositeConstruct(rightType, arguments, out);
3778             operandType = &rightType;
3779         } else if (leftType.isMatrix()) {
3780             if (op.kind() == Operator::Kind::STAR) {
3781                 // When the rewriteMatrixVectorMultiply bit is set, we rewrite medium-precision
3782                 // matrix * vector multiplication as (mat[0]*vec[0] + ... + mat[N]*vec[N]).
3783                 if (fCaps.fRewriteMatrixVectorMultiply &&
3784                     rightType.isVector() &&
3785                     !resultType.highPrecision()) {
3786                     return this->writeDecomposedMatrixVectorMultiply(leftType, lhs, rightType, rhs,
3787                                                                      resultType, out);
3788                 }
3789 
3790                 // Matrix-times-vector and matrix-times-scalar have dedicated ops in SPIR-V.
3791                 SpvOp_ spvop;
3792                 if (rightType.isMatrix()) {
3793                     spvop = SpvOpMatrixTimesMatrix;
3794                 } else if (rightType.isVector()) {
3795                     spvop = SpvOpMatrixTimesVector;
3796                 } else {
3797                     SkASSERT(rightType.isScalar());
3798                     spvop = SpvOpMatrixTimesScalar;
3799                 }
3800                 SpvId result = this->nextId(&resultType);
3801                 this->writeInstruction(spvop, this->getType(resultType), result, lhs, rhs, out);
3802                 return result;
3803             } else {
3804                 // Matrix-op-vector is not supported in GLSL/SkSL for non-multiplication ops; we
3805                 // expect to have a scalar here.
3806                 SkASSERT(rightType.isScalar());
3807 
3808                 // Splat rhs across an entire matrix so we can reuse the matrix-op-matrix path.
3809                 SpvId rhsMatrix = this->splat(leftType, rhs, out);
3810 
3811                 // Perform this operation as matrix-op-matrix.
3812                 return this->writeBinaryExpression(leftType, lhs, op, leftType, rhsMatrix,
3813                                                    resultType, out);
3814             }
3815         } else if (rightType.isMatrix()) {
3816             if (op.kind() == Operator::Kind::STAR) {
3817                 // Matrix-times-vector and matrix-times-scalar have dedicated ops in SPIR-V.
3818                 SpvId result = this->nextId(&resultType);
3819                 if (leftType.isVector()) {
3820                     this->writeInstruction(SpvOpVectorTimesMatrix, this->getType(resultType),
3821                                            result, lhs, rhs, out);
3822                 } else {
3823                     SkASSERT(leftType.isScalar());
3824                     this->writeInstruction(SpvOpMatrixTimesScalar, this->getType(resultType),
3825                                            result, rhs, lhs, out);
3826                 }
3827                 return result;
3828             } else {
3829                 // Vector-op-matrix is not supported in GLSL/SkSL for non-multiplication ops; we
3830                 // expect to have a scalar here.
3831                 SkASSERT(leftType.isScalar());
3832 
3833                 // Splat lhs across an entire matrix so we can reuse the matrix-op-matrix path.
3834                 SpvId lhsMatrix = this->splat(rightType, lhs, out);
3835 
3836                 // Perform this operation as matrix-op-matrix.
3837                 return this->writeBinaryExpression(rightType, lhsMatrix, op, rightType, rhs,
3838                                                    resultType, out);
3839             }
3840         } else {
3841             fContext.fErrors->error(leftType.fPosition, "unsupported mixed-type expression");
3842             return NA;
3843         }
3844     }
3845 
3846     switch (op.kind()) {
3847         case Operator::Kind::EQEQ: {
3848             if (operandType->isMatrix()) {
3849                 return this->writeMatrixComparison(*operandType, lhs, rhs, SpvOpFOrdEqual,
3850                                                    SpvOpIEqual, SpvOpAll, SpvOpLogicalAnd, out);
3851             }
3852             if (operandType->isStruct()) {
3853                 return this->writeStructComparison(*operandType, lhs, op, rhs, out);
3854             }
3855             if (operandType->isArray()) {
3856                 return this->writeArrayComparison(*operandType, lhs, op, rhs, out);
3857             }
3858             SkASSERT(resultType.isBoolean());
3859             const Type* tmpType;
3860             if (operandType->isVector()) {
3861                 tmpType = &fContext.fTypes.fBool->toCompound(fContext,
3862                                                              operandType->columns(),
3863                                                              operandType->rows());
3864             } else {
3865                 tmpType = &resultType;
3866             }
3867             if (lhs == rhs) {
3868                 // This ignores the effects of NaN.
3869                 return this->writeOpConstantTrue(*fContext.fTypes.fBool);
3870             }
3871             return this->foldToBool(this->writeBinaryOperation(*tmpType, *operandType, lhs, rhs,
3872                                                                SpvOpFOrdEqual, SpvOpIEqual,
3873                                                                SpvOpIEqual, SpvOpLogicalEqual, out),
3874                                     *operandType, SpvOpAll, out);
3875         }
3876         case Operator::Kind::NEQ:
3877             if (operandType->isMatrix()) {
3878                 return this->writeMatrixComparison(*operandType, lhs, rhs, SpvOpFUnordNotEqual,
3879                                                    SpvOpINotEqual, SpvOpAny, SpvOpLogicalOr, out);
3880             }
3881             if (operandType->isStruct()) {
3882                 return this->writeStructComparison(*operandType, lhs, op, rhs, out);
3883             }
3884             if (operandType->isArray()) {
3885                 return this->writeArrayComparison(*operandType, lhs, op, rhs, out);
3886             }
3887             [[fallthrough]];
3888         case Operator::Kind::LOGICALXOR:
3889             SkASSERT(resultType.isBoolean());
3890             const Type* tmpType;
3891             if (operandType->isVector()) {
3892                 tmpType = &fContext.fTypes.fBool->toCompound(fContext,
3893                                                              operandType->columns(),
3894                                                              operandType->rows());
3895             } else {
3896                 tmpType = &resultType;
3897             }
3898             if (lhs == rhs) {
3899                 // This ignores the effects of NaN.
3900                 return this->writeOpConstantFalse(*fContext.fTypes.fBool);
3901             }
3902             return this->foldToBool(this->writeBinaryOperation(*tmpType, *operandType, lhs, rhs,
3903                                                                SpvOpFUnordNotEqual, SpvOpINotEqual,
3904                                                                SpvOpINotEqual, SpvOpLogicalNotEqual,
3905                                                                out),
3906                                     *operandType, SpvOpAny, out);
3907         case Operator::Kind::GT:
3908             SkASSERT(resultType.isBoolean());
3909             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
3910                                               SpvOpFOrdGreaterThan, SpvOpSGreaterThan,
3911                                               SpvOpUGreaterThan, SpvOpUndef, out);
3912         case Operator::Kind::LT:
3913             SkASSERT(resultType.isBoolean());
3914             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFOrdLessThan,
3915                                               SpvOpSLessThan, SpvOpULessThan, SpvOpUndef, out);
3916         case Operator::Kind::GTEQ:
3917             SkASSERT(resultType.isBoolean());
3918             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
3919                                               SpvOpFOrdGreaterThanEqual, SpvOpSGreaterThanEqual,
3920                                               SpvOpUGreaterThanEqual, SpvOpUndef, out);
3921         case Operator::Kind::LTEQ:
3922             SkASSERT(resultType.isBoolean());
3923             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
3924                                               SpvOpFOrdLessThanEqual, SpvOpSLessThanEqual,
3925                                               SpvOpULessThanEqual, SpvOpUndef, out);
3926         case Operator::Kind::PLUS:
3927             return this->writeBinaryOperationComponentwiseIfMatrix(resultType, *operandType,
3928                                                                    lhs, rhs, SpvOpFAdd, SpvOpIAdd,
3929                                                                    SpvOpIAdd, SpvOpUndef, out);
3930         case Operator::Kind::MINUS:
3931             return this->writeBinaryOperationComponentwiseIfMatrix(resultType, *operandType,
3932                                                                    lhs, rhs, SpvOpFSub, SpvOpISub,
3933                                                                    SpvOpISub, SpvOpUndef, out);
3934         case Operator::Kind::STAR:
3935             if (leftType.isMatrix() && rightType.isMatrix()) {
3936                 // matrix multiply
3937                 SpvId result = this->nextId(&resultType);
3938                 this->writeInstruction(SpvOpMatrixTimesMatrix, this->getType(resultType), result,
3939                                        lhs, rhs, out);
3940                 return result;
3941             }
3942             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFMul,
3943                                               SpvOpIMul, SpvOpIMul, SpvOpUndef, out);
3944         case Operator::Kind::SLASH:
3945             return this->writeBinaryOperationComponentwiseIfMatrix(resultType, *operandType,
3946                                                                    lhs, rhs, SpvOpFDiv, SpvOpSDiv,
3947                                                                    SpvOpUDiv, SpvOpUndef, out);
3948         case Operator::Kind::PERCENT:
3949             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFMod,
3950                                               SpvOpSMod, SpvOpUMod, SpvOpUndef, out);
3951         case Operator::Kind::SHL:
3952             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef,
3953                                               SpvOpShiftLeftLogical, SpvOpShiftLeftLogical,
3954                                               SpvOpUndef, out);
3955         case Operator::Kind::SHR:
3956             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef,
3957                                               SpvOpShiftRightArithmetic, SpvOpShiftRightLogical,
3958                                               SpvOpUndef, out);
3959         case Operator::Kind::BITWISEAND:
3960             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef,
3961                                               SpvOpBitwiseAnd, SpvOpBitwiseAnd, SpvOpUndef, out);
3962         case Operator::Kind::BITWISEOR:
3963             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef,
3964                                               SpvOpBitwiseOr, SpvOpBitwiseOr, SpvOpUndef, out);
3965         case Operator::Kind::BITWISEXOR:
3966             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef,
3967                                               SpvOpBitwiseXor, SpvOpBitwiseXor, SpvOpUndef, out);
3968         default:
3969             fContext.fErrors->error(Position(), "unsupported token");
3970             return NA;
3971     }
3972 }
3973 
writeArrayComparison(const Type & arrayType,SpvId lhs,Operator op,SpvId rhs,OutputStream & out)3974 SpvId SPIRVCodeGenerator::writeArrayComparison(const Type& arrayType, SpvId lhs, Operator op,
3975                                                SpvId rhs, OutputStream& out) {
3976     // The inputs must be arrays, and the op must be == or !=.
3977     SkASSERT(op.kind() == Operator::Kind::EQEQ || op.kind() == Operator::Kind::NEQ);
3978     SkASSERT(arrayType.isArray());
3979     const Type& componentType = arrayType.componentType();
3980     const int arraySize = arrayType.columns();
3981     SkASSERT(arraySize > 0);
3982 
3983     // Synthesize equality checks for each item in the array.
3984     const Type& boolType = *fContext.fTypes.fBool;
3985     SpvId allComparisons = NA;
3986     for (int index = 0; index < arraySize; ++index) {
3987         // Get the left and right item in the array.
3988         SpvId itemL = this->writeOpCompositeExtract(componentType, lhs, index, out);
3989         SpvId itemR = this->writeOpCompositeExtract(componentType, rhs, index, out);
3990         // Use `writeBinaryExpression` with the requested == or != operator on these items.
3991         SpvId comparison = this->writeBinaryExpression(componentType, itemL, op,
3992                                                        componentType, itemR, boolType, out);
3993         // Merge this comparison result with all the other comparisons we've done.
3994         allComparisons = this->mergeComparisons(comparison, allComparisons, op, out);
3995     }
3996     return allComparisons;
3997 }
3998 
writeStructComparison(const Type & structType,SpvId lhs,Operator op,SpvId rhs,OutputStream & out)3999 SpvId SPIRVCodeGenerator::writeStructComparison(const Type& structType, SpvId lhs, Operator op,
4000                                                 SpvId rhs, OutputStream& out) {
4001     // The inputs must be structs containing fields, and the op must be == or !=.
4002     SkASSERT(op.kind() == Operator::Kind::EQEQ || op.kind() == Operator::Kind::NEQ);
4003     SkASSERT(structType.isStruct());
4004     SkSpan<const Field> fields = structType.fields();
4005     SkASSERT(!fields.empty());
4006 
4007     // Synthesize equality checks for each field in the struct.
4008     const Type& boolType = *fContext.fTypes.fBool;
4009     SpvId allComparisons = NA;
4010     for (int index = 0; index < (int)fields.size(); ++index) {
4011         // Get the left and right versions of this field.
4012         const Type& fieldType = *fields[index].fType;
4013 
4014         SpvId fieldL = this->writeOpCompositeExtract(fieldType, lhs, index, out);
4015         SpvId fieldR = this->writeOpCompositeExtract(fieldType, rhs, index, out);
4016         // Use `writeBinaryExpression` with the requested == or != operator on these fields.
4017         SpvId comparison = this->writeBinaryExpression(fieldType, fieldL, op, fieldType, fieldR,
4018                                                        boolType, out);
4019         // Merge this comparison result with all the other comparisons we've done.
4020         allComparisons = this->mergeComparisons(comparison, allComparisons, op, out);
4021     }
4022     return allComparisons;
4023 }
4024 
mergeComparisons(SpvId comparison,SpvId allComparisons,Operator op,OutputStream & out)4025 SpvId SPIRVCodeGenerator::mergeComparisons(SpvId comparison, SpvId allComparisons, Operator op,
4026                                            OutputStream& out) {
4027     // If this is the first entry, we don't need to merge comparison results with anything.
4028     if (allComparisons == NA) {
4029         return comparison;
4030     }
4031     // Use LogicalAnd or LogicalOr to combine the comparison with all the other comparisons.
4032     const Type& boolType = *fContext.fTypes.fBool;
4033     SpvId boolTypeId = this->getType(boolType);
4034     SpvId logicalOp = this->nextId(&boolType);
4035     switch (op.kind()) {
4036         case Operator::Kind::EQEQ:
4037             this->writeInstruction(SpvOpLogicalAnd, boolTypeId, logicalOp,
4038                                    comparison, allComparisons, out);
4039             break;
4040         case Operator::Kind::NEQ:
4041             this->writeInstruction(SpvOpLogicalOr, boolTypeId, logicalOp,
4042                                    comparison, allComparisons, out);
4043             break;
4044         default:
4045             SkDEBUGFAILF("mergeComparisons only supports == and !=, not %s", op.operatorName());
4046             return NA;
4047     }
4048     return logicalOp;
4049 }
4050 
writeBinaryExpression(const BinaryExpression & b,OutputStream & out)4051 SpvId SPIRVCodeGenerator::writeBinaryExpression(const BinaryExpression& b, OutputStream& out) {
4052     const Expression* left = b.left().get();
4053     const Expression* right = b.right().get();
4054     Operator op = b.getOperator();
4055 
4056     switch (op.kind()) {
4057         case Operator::Kind::EQ: {
4058             // Handles assignment.
4059             SpvId rhs = this->writeExpression(*right, out);
4060             this->getLValue(*left, out)->store(rhs, out);
4061             return rhs;
4062         }
4063         case Operator::Kind::LOGICALAND:
4064             // Handles short-circuiting; we don't necessarily evaluate both LHS and RHS.
4065             return this->writeLogicalAnd(*b.left(), *b.right(), out);
4066 
4067         case Operator::Kind::LOGICALOR:
4068             // Handles short-circuiting; we don't necessarily evaluate both LHS and RHS.
4069             return this->writeLogicalOr(*b.left(), *b.right(), out);
4070 
4071         default:
4072             break;
4073     }
4074 
4075     std::unique_ptr<LValue> lvalue;
4076     SpvId lhs;
4077     if (op.isAssignment()) {
4078         lvalue = this->getLValue(*left, out);
4079         lhs = lvalue->load(out);
4080     } else {
4081         lvalue = nullptr;
4082         lhs = this->writeExpression(*left, out);
4083     }
4084 
4085     SpvId rhs = this->writeExpression(*right, out);
4086     SpvId result = this->writeBinaryExpression(left->type(), lhs, op.removeAssignment(),
4087                                                right->type(), rhs, b.type(), out);
4088     if (lvalue) {
4089         lvalue->store(result, out);
4090     }
4091     return result;
4092 }
4093 
writeLogicalAnd(const Expression & left,const Expression & right,OutputStream & out)4094 SpvId SPIRVCodeGenerator::writeLogicalAnd(const Expression& left, const Expression& right,
4095                                           OutputStream& out) {
4096     SpvId falseConstant = this->writeLiteral(0.0, *fContext.fTypes.fBool);
4097     SpvId lhs = this->writeExpression(left, out);
4098 
4099     ConditionalOpCounts conditionalOps = this->getConditionalOpCounts();
4100 
4101     SpvId rhsLabel = this->nextId(nullptr);
4102     SpvId end = this->nextId(nullptr);
4103     SpvId lhsBlock = fCurrentBlock;
4104     this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
4105     this->writeInstruction(SpvOpBranchConditional, lhs, rhsLabel, end, out);
4106     this->writeLabel(rhsLabel, kBranchIsOnPreviousLine, out);
4107     SpvId rhs = this->writeExpression(right, out);
4108     SpvId rhsBlock = fCurrentBlock;
4109     this->writeInstruction(SpvOpBranch, end, out);
4110     this->writeLabel(end, kBranchIsAbove, conditionalOps, out);
4111     SpvId result = this->nextId(nullptr);
4112     this->writeInstruction(SpvOpPhi, this->getType(*fContext.fTypes.fBool), result, falseConstant,
4113                            lhsBlock, rhs, rhsBlock, out);
4114 
4115     return result;
4116 }
4117 
writeLogicalOr(const Expression & left,const Expression & right,OutputStream & out)4118 SpvId SPIRVCodeGenerator::writeLogicalOr(const Expression& left, const Expression& right,
4119                                          OutputStream& out) {
4120     SpvId trueConstant = this->writeLiteral(1.0, *fContext.fTypes.fBool);
4121     SpvId lhs = this->writeExpression(left, out);
4122 
4123     ConditionalOpCounts conditionalOps = this->getConditionalOpCounts();
4124 
4125     SpvId rhsLabel = this->nextId(nullptr);
4126     SpvId end = this->nextId(nullptr);
4127     SpvId lhsBlock = fCurrentBlock;
4128     this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
4129     this->writeInstruction(SpvOpBranchConditional, lhs, end, rhsLabel, out);
4130     this->writeLabel(rhsLabel, kBranchIsOnPreviousLine, out);
4131     SpvId rhs = this->writeExpression(right, out);
4132     SpvId rhsBlock = fCurrentBlock;
4133     this->writeInstruction(SpvOpBranch, end, out);
4134     this->writeLabel(end, kBranchIsAbove, conditionalOps, out);
4135     SpvId result = this->nextId(nullptr);
4136     this->writeInstruction(SpvOpPhi, this->getType(*fContext.fTypes.fBool), result, trueConstant,
4137                            lhsBlock, rhs, rhsBlock, out);
4138 
4139     return result;
4140 }
4141 
writeTernaryExpression(const TernaryExpression & t,OutputStream & out)4142 SpvId SPIRVCodeGenerator::writeTernaryExpression(const TernaryExpression& t, OutputStream& out) {
4143     const Type& type = t.type();
4144     SpvId test = this->writeExpression(*t.test(), out);
4145     if (t.ifTrue()->type().columns() == 1 &&
4146         Analysis::IsCompileTimeConstant(*t.ifTrue()) &&
4147         Analysis::IsCompileTimeConstant(*t.ifFalse())) {
4148         // both true and false are constants, can just use OpSelect
4149         SpvId result = this->nextId(nullptr);
4150         SpvId trueId = this->writeExpression(*t.ifTrue(), out);
4151         SpvId falseId = this->writeExpression(*t.ifFalse(), out);
4152         this->writeInstruction(SpvOpSelect, this->getType(type), result, test, trueId, falseId,
4153                                out);
4154         return result;
4155     }
4156 
4157     ConditionalOpCounts conditionalOps = this->getConditionalOpCounts();
4158 
4159     // was originally using OpPhi to choose the result, but for some reason that is crashing on
4160     // Adreno. Switched to storing the result in a temp variable as glslang does.
4161     SpvId var = this->nextId(nullptr);
4162     this->writeInstruction(SpvOpVariable, this->getPointerType(type, SpvStorageClassFunction),
4163                            var, SpvStorageClassFunction, fVariableBuffer);
4164     SpvId trueLabel = this->nextId(nullptr);
4165     SpvId falseLabel = this->nextId(nullptr);
4166     SpvId end = this->nextId(nullptr);
4167     this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
4168     this->writeInstruction(SpvOpBranchConditional, test, trueLabel, falseLabel, out);
4169     this->writeLabel(trueLabel, kBranchIsOnPreviousLine, out);
4170     this->writeOpStore(SpvStorageClassFunction, var, this->writeExpression(*t.ifTrue(), out), out);
4171     this->writeInstruction(SpvOpBranch, end, out);
4172     this->writeLabel(falseLabel, kBranchIsAbove, conditionalOps, out);
4173     this->writeOpStore(SpvStorageClassFunction, var, this->writeExpression(*t.ifFalse(), out), out);
4174     this->writeInstruction(SpvOpBranch, end, out);
4175     this->writeLabel(end, kBranchIsAbove, conditionalOps, out);
4176     SpvId result = this->nextId(&type);
4177     this->writeInstruction(SpvOpLoad, this->getType(type), result, var, out);
4178 
4179     return result;
4180 }
4181 
writePrefixExpression(const PrefixExpression & p,OutputStream & out)4182 SpvId SPIRVCodeGenerator::writePrefixExpression(const PrefixExpression& p, OutputStream& out) {
4183     const Type& type = p.type();
4184     if (p.getOperator().kind() == Operator::Kind::MINUS) {
4185         SpvOp_ negateOp = pick_by_type(type, SpvOpFNegate, SpvOpSNegate, SpvOpSNegate, SpvOpUndef);
4186         SkASSERT(negateOp != SpvOpUndef);
4187         SpvId expr = this->writeExpression(*p.operand(), out);
4188         if (type.isMatrix()) {
4189             return this->writeComponentwiseMatrixUnary(type, expr, negateOp, out);
4190         }
4191         SpvId result = this->nextId(&type);
4192         SpvId typeId = this->getType(type);
4193         this->writeInstruction(negateOp, typeId, result, expr, out);
4194         return result;
4195     }
4196     switch (p.getOperator().kind()) {
4197         case Operator::Kind::PLUS:
4198             return this->writeExpression(*p.operand(), out);
4199 
4200         case Operator::Kind::PLUSPLUS: {
4201             std::unique_ptr<LValue> lv = this->getLValue(*p.operand(), out);
4202             SpvId one = this->writeLiteral(1.0, type.componentType());
4203             one = this->splat(type, one, out);
4204             SpvId result = this->writeBinaryOperationComponentwiseIfMatrix(type, type,
4205                                                                            lv->load(out), one,
4206                                                                            SpvOpFAdd, SpvOpIAdd,
4207                                                                            SpvOpIAdd, SpvOpUndef,
4208                                                                            out);
4209             lv->store(result, out);
4210             return result;
4211         }
4212         case Operator::Kind::MINUSMINUS: {
4213             std::unique_ptr<LValue> lv = this->getLValue(*p.operand(), out);
4214             SpvId one = this->writeLiteral(1.0, type.componentType());
4215             one = this->splat(type, one, out);
4216             SpvId result = this->writeBinaryOperationComponentwiseIfMatrix(type, type,
4217                                                                            lv->load(out), one,
4218                                                                            SpvOpFSub, SpvOpISub,
4219                                                                            SpvOpISub, SpvOpUndef,
4220                                                                            out);
4221             lv->store(result, out);
4222             return result;
4223         }
4224         case Operator::Kind::LOGICALNOT: {
4225             SkASSERT(p.operand()->type().isBoolean());
4226             SpvId result = this->nextId(nullptr);
4227             this->writeInstruction(SpvOpLogicalNot, this->getType(type), result,
4228                                    this->writeExpression(*p.operand(), out), out);
4229             return result;
4230         }
4231         case Operator::Kind::BITWISENOT: {
4232             SpvId result = this->nextId(nullptr);
4233             this->writeInstruction(SpvOpNot, this->getType(type), result,
4234                                    this->writeExpression(*p.operand(), out), out);
4235             return result;
4236         }
4237         default:
4238             SkDEBUGFAILF("unsupported prefix expression: %s",
4239                          p.description(OperatorPrecedence::kExpression).c_str());
4240             return NA;
4241     }
4242 }
4243 
writePostfixExpression(const PostfixExpression & p,OutputStream & out)4244 SpvId SPIRVCodeGenerator::writePostfixExpression(const PostfixExpression& p, OutputStream& out) {
4245     const Type& type = p.type();
4246     std::unique_ptr<LValue> lv = this->getLValue(*p.operand(), out);
4247     SpvId result = lv->load(out);
4248     SpvId one = this->writeLiteral(1.0, type.componentType());
4249     one = this->splat(type, one, out);
4250     switch (p.getOperator().kind()) {
4251         case Operator::Kind::PLUSPLUS: {
4252             SpvId temp = this->writeBinaryOperationComponentwiseIfMatrix(type, type, result, one,
4253                                                                          SpvOpFAdd, SpvOpIAdd,
4254                                                                          SpvOpIAdd, SpvOpUndef,
4255                                                                          out);
4256             lv->store(temp, out);
4257             return result;
4258         }
4259         case Operator::Kind::MINUSMINUS: {
4260             SpvId temp = this->writeBinaryOperationComponentwiseIfMatrix(type, type, result, one,
4261                                                                          SpvOpFSub, SpvOpISub,
4262                                                                          SpvOpISub, SpvOpUndef,
4263                                                                          out);
4264             lv->store(temp, out);
4265             return result;
4266         }
4267         default:
4268             SkDEBUGFAILF("unsupported postfix expression %s",
4269                          p.description(OperatorPrecedence::kExpression).c_str());
4270             return NA;
4271     }
4272 }
4273 
writeLiteral(const Literal & l)4274 SpvId SPIRVCodeGenerator::writeLiteral(const Literal& l) {
4275     return this->writeLiteral(l.value(), l.type());
4276 }
4277 
writeLiteral(double value,const Type & type)4278 SpvId SPIRVCodeGenerator::writeLiteral(double value, const Type& type) {
4279     switch (type.numberKind()) {
4280         case Type::NumberKind::kFloat: {
4281             float floatVal = value;
4282             int32_t valueBits;
4283             memcpy(&valueBits, &floatVal, sizeof(valueBits));
4284             return this->writeOpConstant(type, valueBits);
4285         }
4286         case Type::NumberKind::kBoolean: {
4287             return value ? this->writeOpConstantTrue(type)
4288                          : this->writeOpConstantFalse(type);
4289         }
4290         default: {
4291             return this->writeOpConstant(type, (SKSL_INT)value);
4292         }
4293     }
4294 }
4295 
writeFunctionStart(const FunctionDeclaration & f,OutputStream & out)4296 SpvId SPIRVCodeGenerator::writeFunctionStart(const FunctionDeclaration& f, OutputStream& out) {
4297     SpvId result = fFunctionMap[&f];
4298     SpvId returnTypeId = this->getType(f.returnType());
4299     SpvId functionTypeId = this->getFunctionType(f);
4300     this->writeInstruction(SpvOpFunction, returnTypeId, result,
4301                            SpvFunctionControlMaskNone, functionTypeId, out);
4302     std::string mangledName = f.mangledName();
4303     this->writeInstruction(SpvOpName,
4304                            result,
4305                            std::string_view(mangledName.c_str(), mangledName.size()),
4306                            fNameBuffer);
4307     for (const Variable* parameter : f.parameters()) {
4308         if (fUseTextureSamplerPairs && parameter->type().isSampler()) {
4309             auto [texture, sampler] = this->synthesizeTextureAndSampler(*parameter);
4310 
4311             SpvId textureId = this->nextId(nullptr);
4312             SpvId samplerId = this->nextId(nullptr);
4313             fVariableMap.set(texture, textureId);
4314             fVariableMap.set(sampler, samplerId);
4315 
4316             SpvId textureType = this->getFunctionParameterType(texture->type(), texture->layout());
4317             SpvId samplerType = this->getFunctionParameterType(sampler->type(), kDefaultTypeLayout);
4318 
4319             this->writeInstruction(SpvOpFunctionParameter, textureType, textureId, out);
4320             this->writeInstruction(SpvOpFunctionParameter, samplerType, samplerId, out);
4321         } else {
4322             SpvId id = this->nextId(nullptr);
4323             fVariableMap.set(parameter, id);
4324 
4325             SpvId type = this->getFunctionParameterType(parameter->type(), parameter->layout());
4326             this->writeInstruction(SpvOpFunctionParameter, type, id, out);
4327         }
4328     }
4329     return result;
4330 }
4331 
writeFunction(const FunctionDefinition & f,OutputStream & out)4332 SpvId SPIRVCodeGenerator::writeFunction(const FunctionDefinition& f, OutputStream& out) {
4333     ConditionalOpCounts conditionalOps = this->getConditionalOpCounts();
4334 
4335     fVariableBuffer.reset();
4336     SpvId result = this->writeFunctionStart(f.declaration(), out);
4337     fCurrentBlock = 0;
4338     this->writeLabel(this->nextId(nullptr), kBranchlessBlock, out);
4339     StringStream bodyBuffer;
4340     this->writeBlock(f.body()->as<Block>(), bodyBuffer);
4341     write_stringstream(fVariableBuffer, out);
4342     if (f.declaration().isMain()) {
4343         write_stringstream(fGlobalInitializersBuffer, out);
4344     }
4345     write_stringstream(bodyBuffer, out);
4346     if (fCurrentBlock) {
4347         if (f.declaration().returnType().isVoid()) {
4348             this->writeInstruction(SpvOpReturn, out);
4349         } else {
4350             this->writeInstruction(SpvOpUnreachable, out);
4351         }
4352     }
4353     this->writeInstruction(SpvOpFunctionEnd, out);
4354     this->pruneConditionalOps(conditionalOps);
4355     return result;
4356 }
4357 
writeLayout(const Layout & layout,SpvId target,Position pos)4358 void SPIRVCodeGenerator::writeLayout(const Layout& layout, SpvId target, Position pos) {
4359     bool isPushConstant = SkToBool(layout.fFlags & LayoutFlag::kPushConstant);
4360     if (layout.fLocation >= 0) {
4361         this->writeInstruction(SpvOpDecorate, target, SpvDecorationLocation, layout.fLocation,
4362                                fDecorationBuffer);
4363     }
4364     if (layout.fBinding >= 0) {
4365         if (isPushConstant) {
4366             fContext.fErrors->error(pos, "Can't apply 'binding' to push constants");
4367         } else {
4368             this->writeInstruction(SpvOpDecorate, target, SpvDecorationBinding, layout.fBinding,
4369                                    fDecorationBuffer);
4370         }
4371     }
4372     if (layout.fIndex >= 0) {
4373         this->writeInstruction(SpvOpDecorate, target, SpvDecorationIndex, layout.fIndex,
4374                                fDecorationBuffer);
4375     }
4376     if (layout.fSet >= 0) {
4377         if (isPushConstant) {
4378             fContext.fErrors->error(pos, "Can't apply 'set' to push constants");
4379         } else {
4380             this->writeInstruction(SpvOpDecorate, target, SpvDecorationDescriptorSet, layout.fSet,
4381                                    fDecorationBuffer);
4382         }
4383     }
4384     if (layout.fInputAttachmentIndex >= 0) {
4385         this->writeInstruction(SpvOpDecorate, target, SpvDecorationInputAttachmentIndex,
4386                                layout.fInputAttachmentIndex, fDecorationBuffer);
4387         fCapabilities |= (((uint64_t) 1) << SpvCapabilityInputAttachment);
4388     }
4389     if (layout.fBuiltin >= 0 && layout.fBuiltin != SK_FRAGCOLOR_BUILTIN) {
4390         this->writeInstruction(SpvOpDecorate, target, SpvDecorationBuiltIn, layout.fBuiltin,
4391                                fDecorationBuffer);
4392     }
4393 }
4394 
writeFieldLayout(const Layout & layout,SpvId target,int member)4395 void SPIRVCodeGenerator::writeFieldLayout(const Layout& layout, SpvId target, int member) {
4396     // 'binding' and 'set' can not be applied to struct members
4397     SkASSERT(layout.fBinding == -1);
4398     SkASSERT(layout.fSet == -1);
4399     if (layout.fLocation >= 0) {
4400         this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationLocation,
4401                                layout.fLocation, fDecorationBuffer);
4402     }
4403     if (layout.fIndex >= 0) {
4404         this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationIndex,
4405                                layout.fIndex, fDecorationBuffer);
4406     }
4407     if (layout.fInputAttachmentIndex >= 0) {
4408         this->writeInstruction(SpvOpDecorate, target, member, SpvDecorationInputAttachmentIndex,
4409                                layout.fInputAttachmentIndex, fDecorationBuffer);
4410     }
4411     if (layout.fBuiltin >= 0) {
4412         this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationBuiltIn,
4413                                layout.fBuiltin, fDecorationBuffer);
4414     }
4415 }
4416 
memoryLayoutForStorageClass(SpvStorageClass_ storageClass)4417 MemoryLayout SPIRVCodeGenerator::memoryLayoutForStorageClass(SpvStorageClass_ storageClass) {
4418     return storageClass == SpvStorageClassPushConstant ? MemoryLayout(MemoryLayout::Standard::k430)
4419                                                        : fDefaultMemoryLayout;
4420 }
4421 
memoryLayoutForVariable(const Variable & v) const4422 MemoryLayout SPIRVCodeGenerator::memoryLayoutForVariable(const Variable& v) const {
4423     bool pushConstant = SkToBool(v.layout().fFlags & LayoutFlag::kPushConstant);
4424     return pushConstant ? MemoryLayout(MemoryLayout::Standard::k430)
4425                         : fDefaultMemoryLayout;
4426 }
4427 
writeInterfaceBlock(const InterfaceBlock & intf,bool appendRTFlip)4428 SpvId SPIRVCodeGenerator::writeInterfaceBlock(const InterfaceBlock& intf, bool appendRTFlip) {
4429     MemoryLayout memoryLayout = this->memoryLayoutForVariable(*intf.var());
4430     SpvId result = this->nextId(nullptr);
4431     const Variable& intfVar = *intf.var();
4432     const Type& type = intfVar.type();
4433     if (!memoryLayout.isSupported(type)) {
4434         fContext.fErrors->error(type.fPosition, "type '" + type.displayName() +
4435                                                 "' is not permitted here");
4436         return this->nextId(nullptr);
4437     }
4438     SpvStorageClass_ storageClass =
4439             get_storage_class_for_global_variable(intfVar, SpvStorageClassFunction);
4440     if (fProgram.fInterface.fRTFlipUniform != Program::Interface::kRTFlip_None && appendRTFlip &&
4441         !fWroteRTFlip && type.isStruct()) {
4442         // We can only have one interface block (because we use push_constant and that is limited
4443         // to one per program), so we need to append rtflip to this one rather than synthesize an
4444         // entirely new block when the variable is referenced. And we can't modify the existing
4445         // block, so we instead create a modified copy of it and write that.
4446         SkSpan<const Field> fieldSpan = type.fields();
4447         TArray<Field> fields(fieldSpan.data(), fieldSpan.size());
4448         fields.emplace_back(Position(),
4449                             Layout(LayoutFlag::kNone,
4450                                    /*location=*/-1,
4451                                    fProgram.fConfig->fSettings.fRTFlipOffset,
4452                                    /*binding=*/-1,
4453                                    /*index=*/-1,
4454                                    /*set=*/-1,
4455                                    /*builtin=*/-1,
4456                                    /*inputAttachmentIndex=*/-1),
4457                             ModifierFlag::kNone,
4458                             SKSL_RTFLIP_NAME,
4459                             fContext.fTypes.fFloat2.get());
4460         {
4461             AutoAttachPoolToThread attach(fProgram.fPool.get());
4462             const Type* rtFlipStructType = fProgram.fSymbols->takeOwnershipOfSymbol(
4463                     Type::MakeStructType(fContext,
4464                                          type.fPosition,
4465                                          type.name(),
4466                                          std::move(fields),
4467                                          /*interfaceBlock=*/true));
4468             Variable* modifiedVar = fProgram.fSymbols->takeOwnershipOfSymbol(
4469                     Variable::Make(intfVar.fPosition,
4470                                    intfVar.modifiersPosition(),
4471                                    intfVar.layout(),
4472                                    intfVar.modifierFlags(),
4473                                    rtFlipStructType,
4474                                    intfVar.name(),
4475                                    /*mangledName=*/"",
4476                                    intfVar.isBuiltin(),
4477                                    intfVar.storage()));
4478             InterfaceBlock modifiedCopy(intf.fPosition, modifiedVar);
4479             result = this->writeInterfaceBlock(modifiedCopy, /*appendRTFlip=*/false);
4480             fProgram.fSymbols->add(fContext, std::make_unique<FieldSymbol>(
4481                     Position(), modifiedVar, rtFlipStructType->fields().size() - 1));
4482         }
4483         fVariableMap.set(&intfVar, result);
4484         fWroteRTFlip = true;
4485         return result;
4486     }
4487     SpvId typeId = this->getType(type, kDefaultTypeLayout, memoryLayout);
4488     if (intfVar.layout().fBuiltin == -1) {
4489         // Note: In SPIR-V 1.3, a storage buffer can be declared with the "StorageBuffer"
4490         // storage class and the "Block" decoration and the <1.3 approach we use here ("Uniform"
4491         // storage class and the "BufferBlock" decoration) is deprecated. Since we target SPIR-V
4492         // 1.0, we have to use the deprecated approach which is well supported in Vulkan and
4493         // addresses SkSL use cases (notably SkSL currently doesn't support pointer features that
4494         // would benefit from SPV_KHR_variable_pointers capabilities).
4495         bool isStorageBuffer = intfVar.modifierFlags().isBuffer();
4496         this->writeInstruction(SpvOpDecorate,
4497                                typeId,
4498                                isStorageBuffer ? SpvDecorationBufferBlock : SpvDecorationBlock,
4499                                fDecorationBuffer);
4500     }
4501     SpvId ptrType = this->nextId(nullptr);
4502     this->writeInstruction(SpvOpTypePointer, ptrType, storageClass, typeId, fConstantBuffer);
4503     this->writeInstruction(SpvOpVariable, ptrType, result, storageClass, fConstantBuffer);
4504     Layout layout = intfVar.layout();
4505     if (storageClass == SpvStorageClassUniform && layout.fSet < 0) {
4506         layout.fSet = fProgram.fConfig->fSettings.fDefaultUniformSet;
4507     }
4508     this->writeLayout(layout, result, intfVar.fPosition);
4509     fVariableMap.set(&intfVar, result);
4510     return result;
4511 }
4512 
4513 // This function determines whether to skip an OpVariable (of pointer type) declaration for
4514 // compile-time constant scalars and vectors which we turn into OpConstant/OpConstantComposite and
4515 // always reference by value.
4516 //
4517 // Accessing a matrix or array member with a dynamic index requires the use of OpAccessChain which
4518 // requires a base operand of pointer type. However, a vector can always be accessed by value using
4519 // OpVectorExtractDynamic (see writeIndexExpression).
4520 //
4521 // This is why we always emit an OpVariable for all non-scalar and non-vector types in case they get
4522 // accessed via a dynamic index.
is_vardecl_compile_time_constant(const VarDeclaration & varDecl)4523 static bool is_vardecl_compile_time_constant(const VarDeclaration& varDecl) {
4524     return varDecl.var()->modifierFlags().isConst() &&
4525            (varDecl.var()->type().isScalar() || varDecl.var()->type().isVector()) &&
4526            (ConstantFolder::GetConstantValueOrNull(*varDecl.value()) ||
4527             Analysis::IsCompileTimeConstant(*varDecl.value()));
4528 }
4529 
writeGlobalVarDeclaration(ProgramKind kind,const VarDeclaration & varDecl)4530 bool SPIRVCodeGenerator::writeGlobalVarDeclaration(ProgramKind kind,
4531                                                    const VarDeclaration& varDecl) {
4532     const Variable* var = varDecl.var();
4533     const LayoutFlags backendFlags = var->layout().fFlags & LayoutFlag::kAllBackends;
4534     const LayoutFlags kPermittedBackendFlags =
4535             LayoutFlag::kVulkan | LayoutFlag::kWebGPU | LayoutFlag::kDirect3D;
4536     if (backendFlags & ~kPermittedBackendFlags) {
4537         fContext.fErrors->error(var->fPosition, "incompatible backend flag in SPIR-V codegen");
4538         return false;
4539     }
4540 
4541     // If this global variable is a compile-time constant then we'll emit OpConstant or
4542     // OpConstantComposite later when the variable is referenced. Avoid declaring an OpVariable now.
4543     if (is_vardecl_compile_time_constant(varDecl)) {
4544         return true;
4545     }
4546 
4547     SpvStorageClass_ storageClass =
4548             get_storage_class_for_global_variable(*var, SpvStorageClassPrivate);
4549     if (storageClass == SpvStorageClassUniform) {
4550         // Top-level uniforms are emitted in writeUniformBuffer.
4551         fTopLevelUniforms.push_back(&varDecl);
4552         return true;
4553     }
4554 
4555     if (fUseTextureSamplerPairs && var->type().isSampler()) {
4556         if (var->layout().fTexture == -1 || var->layout().fSampler == -1) {
4557             fContext.fErrors->error(var->fPosition, "selected backend requires separate texture "
4558                                                     "and sampler indices");
4559             return false;
4560         }
4561         SkASSERT(storageClass == SpvStorageClassUniformConstant);
4562 
4563         auto [texture, sampler] = this->synthesizeTextureAndSampler(*var);
4564         this->writeGlobalVar(kind, storageClass, *texture);
4565         this->writeGlobalVar(kind, storageClass, *sampler);
4566 
4567         return true;
4568     }
4569 
4570     SpvId id = this->writeGlobalVar(kind, storageClass, *var);
4571     if (id != NA && varDecl.value()) {
4572         SkASSERT(!fCurrentBlock);
4573         fCurrentBlock = NA;
4574         SpvId value = this->writeExpression(*varDecl.value(), fGlobalInitializersBuffer);
4575         this->writeOpStore(storageClass, id, value, fGlobalInitializersBuffer);
4576         fCurrentBlock = 0;
4577     }
4578     return true;
4579 }
4580 
writeGlobalVar(ProgramKind kind,SpvStorageClass_ storageClass,const Variable & var)4581 SpvId SPIRVCodeGenerator::writeGlobalVar(ProgramKind kind,
4582                                          SpvStorageClass_ storageClass,
4583                                          const Variable& var) {
4584     Layout layout = var.layout();
4585     const ModifierFlags flags = var.modifierFlags();
4586     const Type* type = &var.type();
4587     switch (layout.fBuiltin) {
4588         case SK_FRAGCOLOR_BUILTIN:
4589             if (!ProgramConfig::IsFragment(kind)) {
4590                 SkASSERT(!fProgram.fConfig->fSettings.fFragColorIsInOut);
4591                 return NA;
4592             }
4593             break;
4594 
4595         case SK_SAMPLEMASKIN_BUILTIN:
4596         case SK_SAMPLEMASK_BUILTIN:
4597             // SkSL exposes this as a `uint` but SPIR-V, like GLSL, uses an array of signed `uint`
4598             // decorated with SpvBuiltinSampleMask.
4599             type = fSynthetics.addArrayDimension(fContext, type, /*arraySize=*/1);
4600             layout.fBuiltin = SpvBuiltInSampleMask;
4601             break;
4602     }
4603 
4604     // Add this global to the variable map.
4605     SpvId id = this->nextId(type);
4606     fVariableMap.set(&var, id);
4607 
4608     if (layout.fSet < 0 && storageClass == SpvStorageClassUniformConstant) {
4609         layout.fSet = fProgram.fConfig->fSettings.fDefaultUniformSet;
4610     }
4611 
4612     SpvId typeId = this->getPointerType(*type,
4613                                         layout,
4614                                         this->memoryLayoutForStorageClass(storageClass),
4615                                         storageClass);
4616     this->writeInstruction(SpvOpVariable, typeId, id, storageClass, fConstantBuffer);
4617     this->writeInstruction(SpvOpName, id, var.name(), fNameBuffer);
4618     this->writeLayout(layout, id, var.fPosition);
4619     if (flags & ModifierFlag::kFlat) {
4620         this->writeInstruction(SpvOpDecorate, id, SpvDecorationFlat, fDecorationBuffer);
4621     }
4622     if (flags & ModifierFlag::kNoPerspective) {
4623         this->writeInstruction(SpvOpDecorate, id, SpvDecorationNoPerspective,
4624                                fDecorationBuffer);
4625     }
4626     if (flags.isWriteOnly()) {
4627         this->writeInstruction(SpvOpDecorate, id, SpvDecorationNonReadable, fDecorationBuffer);
4628     } else if (flags.isReadOnly()) {
4629         this->writeInstruction(SpvOpDecorate, id, SpvDecorationNonWritable, fDecorationBuffer);
4630     }
4631 
4632     return id;
4633 }
4634 
writeVarDeclaration(const VarDeclaration & varDecl,OutputStream & out)4635 void SPIRVCodeGenerator::writeVarDeclaration(const VarDeclaration& varDecl, OutputStream& out) {
4636     // If this variable is a compile-time constant then we'll emit OpConstant or
4637     // OpConstantComposite later when the variable is referenced. Avoid declaring an OpVariable now.
4638     if (is_vardecl_compile_time_constant(varDecl)) {
4639         return;
4640     }
4641 
4642     const Variable* var = varDecl.var();
4643     SpvId id = this->nextId(&var->type());
4644     fVariableMap.set(var, id);
4645     SpvId type = this->getPointerType(var->type(), SpvStorageClassFunction);
4646     this->writeInstruction(SpvOpVariable, type, id, SpvStorageClassFunction, fVariableBuffer);
4647     this->writeInstruction(SpvOpName, id, var->name(), fNameBuffer);
4648     if (varDecl.value()) {
4649         SpvId value = this->writeExpression(*varDecl.value(), out);
4650         this->writeOpStore(SpvStorageClassFunction, id, value, out);
4651     }
4652 }
4653 
writeStatement(const Statement & s,OutputStream & out)4654 void SPIRVCodeGenerator::writeStatement(const Statement& s, OutputStream& out) {
4655     switch (s.kind()) {
4656         case Statement::Kind::kNop:
4657             break;
4658         case Statement::Kind::kBlock:
4659             this->writeBlock(s.as<Block>(), out);
4660             break;
4661         case Statement::Kind::kExpression:
4662             this->writeExpression(*s.as<ExpressionStatement>().expression(), out);
4663             break;
4664         case Statement::Kind::kReturn:
4665             this->writeReturnStatement(s.as<ReturnStatement>(), out);
4666             break;
4667         case Statement::Kind::kVarDeclaration:
4668             this->writeVarDeclaration(s.as<VarDeclaration>(), out);
4669             break;
4670         case Statement::Kind::kIf:
4671             this->writeIfStatement(s.as<IfStatement>(), out);
4672             break;
4673         case Statement::Kind::kFor:
4674             this->writeForStatement(s.as<ForStatement>(), out);
4675             break;
4676         case Statement::Kind::kDo:
4677             this->writeDoStatement(s.as<DoStatement>(), out);
4678             break;
4679         case Statement::Kind::kSwitch:
4680             this->writeSwitchStatement(s.as<SwitchStatement>(), out);
4681             break;
4682         case Statement::Kind::kBreak:
4683             this->writeInstruction(SpvOpBranch, fBreakTarget.back(), out);
4684             break;
4685         case Statement::Kind::kContinue:
4686             this->writeInstruction(SpvOpBranch, fContinueTarget.back(), out);
4687             break;
4688         case Statement::Kind::kDiscard:
4689             this->writeInstruction(SpvOpKill, out);
4690             break;
4691         default:
4692             SkDEBUGFAILF("unsupported statement: %s", s.description().c_str());
4693             break;
4694     }
4695 }
4696 
writeBlock(const Block & b,OutputStream & out)4697 void SPIRVCodeGenerator::writeBlock(const Block& b, OutputStream& out) {
4698     for (const std::unique_ptr<Statement>& stmt : b.children()) {
4699         this->writeStatement(*stmt, out);
4700     }
4701 }
4702 
getConditionalOpCounts()4703 SPIRVCodeGenerator::ConditionalOpCounts SPIRVCodeGenerator::getConditionalOpCounts() {
4704     return {fReachableOps.size(), fStoreOps.size()};
4705 }
4706 
pruneConditionalOps(ConditionalOpCounts ops)4707 void SPIRVCodeGenerator::pruneConditionalOps(ConditionalOpCounts ops) {
4708     // Remove ops which are no longer reachable.
4709     while (fReachableOps.size() > ops.numReachableOps) {
4710         SpvId prunableSpvId = fReachableOps.back();
4711         const Instruction* prunableOp = fSpvIdCache.find(prunableSpvId);
4712 
4713         if (prunableOp) {
4714             fOpCache.remove(*prunableOp);
4715             fSpvIdCache.remove(prunableSpvId);
4716         } else {
4717             SkDEBUGFAIL("reachable-op list contains unrecognized SpvId");
4718         }
4719 
4720         fReachableOps.pop_back();
4721     }
4722 
4723     // Remove any cached stores that occurred during the conditional block.
4724     while (fStoreOps.size() > ops.numStoreOps) {
4725         if (fStoreCache.find(fStoreOps.back())) {
4726             fStoreCache.remove(fStoreOps.back());
4727         }
4728         fStoreOps.pop_back();
4729     }
4730 }
4731 
writeIfStatement(const IfStatement & stmt,OutputStream & out)4732 void SPIRVCodeGenerator::writeIfStatement(const IfStatement& stmt, OutputStream& out) {
4733     SpvId test = this->writeExpression(*stmt.test(), out);
4734     SpvId ifTrue = this->nextId(nullptr);
4735     SpvId ifFalse = this->nextId(nullptr);
4736 
4737     ConditionalOpCounts conditionalOps = this->getConditionalOpCounts();
4738 
4739     if (stmt.ifFalse()) {
4740         SpvId end = this->nextId(nullptr);
4741         this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
4742         this->writeInstruction(SpvOpBranchConditional, test, ifTrue, ifFalse, out);
4743         this->writeLabel(ifTrue, kBranchIsOnPreviousLine, out);
4744         this->writeStatement(*stmt.ifTrue(), out);
4745         if (fCurrentBlock) {
4746             this->writeInstruction(SpvOpBranch, end, out);
4747         }
4748         this->writeLabel(ifFalse, kBranchIsAbove, conditionalOps, out);
4749         this->writeStatement(*stmt.ifFalse(), out);
4750         if (fCurrentBlock) {
4751             this->writeInstruction(SpvOpBranch, end, out);
4752         }
4753         this->writeLabel(end, kBranchIsAbove, conditionalOps, out);
4754     } else {
4755         this->writeInstruction(SpvOpSelectionMerge, ifFalse, SpvSelectionControlMaskNone, out);
4756         this->writeInstruction(SpvOpBranchConditional, test, ifTrue, ifFalse, out);
4757         this->writeLabel(ifTrue, kBranchIsOnPreviousLine, out);
4758         this->writeStatement(*stmt.ifTrue(), out);
4759         if (fCurrentBlock) {
4760             this->writeInstruction(SpvOpBranch, ifFalse, out);
4761         }
4762         this->writeLabel(ifFalse, kBranchIsAbove, conditionalOps, out);
4763     }
4764 }
4765 
writeForStatement(const ForStatement & f,OutputStream & out)4766 void SPIRVCodeGenerator::writeForStatement(const ForStatement& f, OutputStream& out) {
4767     if (f.initializer()) {
4768         this->writeStatement(*f.initializer(), out);
4769     }
4770 
4771     ConditionalOpCounts conditionalOps = this->getConditionalOpCounts();
4772 
4773     // The store cache isn't trustworthy in the presence of branches; store caching only makes sense
4774     // in the context of linear straight-line execution. If we wanted to be more clever, we could
4775     // only invalidate store cache entries for variables affected by the loop body, but for now we
4776     // simply clear the entire cache whenever branching occurs.
4777     SpvId header = this->nextId(nullptr);
4778     SpvId start = this->nextId(nullptr);
4779     SpvId body = this->nextId(nullptr);
4780     SpvId next = this->nextId(nullptr);
4781     fContinueTarget.push_back(next);
4782     SpvId end = this->nextId(nullptr);
4783     fBreakTarget.push_back(end);
4784     this->writeInstruction(SpvOpBranch, header, out);
4785     this->writeLabel(header, kBranchIsBelow, conditionalOps, out);
4786     this->writeInstruction(SpvOpLoopMerge, end, next, SpvLoopControlMaskNone, out);
4787     this->writeInstruction(SpvOpBranch, start, out);
4788     this->writeLabel(start, kBranchIsOnPreviousLine, out);
4789     if (f.test()) {
4790         SpvId test = this->writeExpression(*f.test(), out);
4791         this->writeInstruction(SpvOpBranchConditional, test, body, end, out);
4792     } else {
4793         this->writeInstruction(SpvOpBranch, body, out);
4794     }
4795     this->writeLabel(body, kBranchIsOnPreviousLine, out);
4796     this->writeStatement(*f.statement(), out);
4797     if (fCurrentBlock) {
4798         this->writeInstruction(SpvOpBranch, next, out);
4799     }
4800     this->writeLabel(next, kBranchIsAbove, conditionalOps, out);
4801     if (f.next()) {
4802         this->writeExpression(*f.next(), out);
4803     }
4804     this->writeInstruction(SpvOpBranch, header, out);
4805     this->writeLabel(end, kBranchIsAbove, conditionalOps, out);
4806     fBreakTarget.pop_back();
4807     fContinueTarget.pop_back();
4808 }
4809 
writeDoStatement(const DoStatement & d,OutputStream & out)4810 void SPIRVCodeGenerator::writeDoStatement(const DoStatement& d, OutputStream& out) {
4811     ConditionalOpCounts conditionalOps = this->getConditionalOpCounts();
4812 
4813     // The store cache isn't trustworthy in the presence of branches; store caching only makes sense
4814     // in the context of linear straight-line execution. If we wanted to be more clever, we could
4815     // only invalidate store cache entries for variables affected by the loop body, but for now we
4816     // simply clear the entire cache whenever branching occurs.
4817     SpvId header = this->nextId(nullptr);
4818     SpvId start = this->nextId(nullptr);
4819     SpvId next = this->nextId(nullptr);
4820     SpvId continueTarget = this->nextId(nullptr);
4821     fContinueTarget.push_back(continueTarget);
4822     SpvId end = this->nextId(nullptr);
4823     fBreakTarget.push_back(end);
4824     this->writeInstruction(SpvOpBranch, header, out);
4825     this->writeLabel(header, kBranchIsBelow, conditionalOps, out);
4826     this->writeInstruction(SpvOpLoopMerge, end, continueTarget, SpvLoopControlMaskNone, out);
4827     this->writeInstruction(SpvOpBranch, start, out);
4828     this->writeLabel(start, kBranchIsOnPreviousLine, out);
4829     this->writeStatement(*d.statement(), out);
4830     if (fCurrentBlock) {
4831         this->writeInstruction(SpvOpBranch, next, out);
4832         this->writeLabel(next, kBranchIsOnPreviousLine, out);
4833         this->writeInstruction(SpvOpBranch, continueTarget, out);
4834     }
4835     this->writeLabel(continueTarget, kBranchIsAbove, conditionalOps, out);
4836     SpvId test = this->writeExpression(*d.test(), out);
4837     this->writeInstruction(SpvOpBranchConditional, test, header, end, out);
4838     this->writeLabel(end, kBranchIsAbove, conditionalOps, out);
4839     fBreakTarget.pop_back();
4840     fContinueTarget.pop_back();
4841 }
4842 
writeSwitchStatement(const SwitchStatement & s,OutputStream & out)4843 void SPIRVCodeGenerator::writeSwitchStatement(const SwitchStatement& s, OutputStream& out) {
4844     SpvId value = this->writeExpression(*s.value(), out);
4845 
4846     ConditionalOpCounts conditionalOps = this->getConditionalOpCounts();
4847 
4848     // The store cache isn't trustworthy in the presence of branches; store caching only makes sense
4849     // in the context of linear straight-line execution. If we wanted to be more clever, we could
4850     // only invalidate store cache entries for variables affected by the switch body, but for now we
4851     // simply clear the entire cache whenever branching occurs.
4852     TArray<SpvId> labels;
4853     SpvId end = this->nextId(nullptr);
4854     SpvId defaultLabel = end;
4855     fBreakTarget.push_back(end);
4856     int size = 3;
4857     const StatementArray& cases = s.cases();
4858     for (const std::unique_ptr<Statement>& stmt : cases) {
4859         const SwitchCase& c = stmt->as<SwitchCase>();
4860         SpvId label = this->nextId(nullptr);
4861         labels.push_back(label);
4862         if (!c.isDefault()) {
4863             size += 2;
4864         } else {
4865             defaultLabel = label;
4866         }
4867     }
4868 
4869     // We should have exactly one label for each case.
4870     SkASSERT(labels.size() == cases.size());
4871 
4872     // Collapse adjacent switch-cases into one; that is, reduce `case 1: case 2: case 3:` into a
4873     // single OpLabel. The Tint SPIR-V reader does not support switch-case fallthrough, but it
4874     // does support multiple switch-cases branching to the same label.
4875     SkBitSet caseIsCollapsed(cases.size());
4876     for (int index = cases.size() - 2; index >= 0; index--) {
4877         if (cases[index]->as<SwitchCase>().statement()->isEmpty()) {
4878             caseIsCollapsed.set(index);
4879             labels[index] = labels[index + 1];
4880         }
4881     }
4882 
4883     labels.push_back(end);
4884 
4885     this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
4886     this->writeOpCode(SpvOpSwitch, size, out);
4887     this->writeWord(value, out);
4888     this->writeWord(defaultLabel, out);
4889     for (int i = 0; i < cases.size(); ++i) {
4890         const SwitchCase& c = cases[i]->as<SwitchCase>();
4891         if (c.isDefault()) {
4892             continue;
4893         }
4894         this->writeWord(c.value(), out);
4895         this->writeWord(labels[i], out);
4896     }
4897     for (int i = 0; i < cases.size(); ++i) {
4898         if (caseIsCollapsed.test(i)) {
4899             continue;
4900         }
4901         const SwitchCase& c = cases[i]->as<SwitchCase>();
4902         if (i == 0) {
4903             this->writeLabel(labels[i], kBranchIsOnPreviousLine, out);
4904         } else {
4905             this->writeLabel(labels[i], kBranchIsAbove, conditionalOps, out);
4906         }
4907         this->writeStatement(*c.statement(), out);
4908         if (fCurrentBlock) {
4909             this->writeInstruction(SpvOpBranch, labels[i + 1], out);
4910         }
4911     }
4912     this->writeLabel(end, kBranchIsAbove, conditionalOps, out);
4913     fBreakTarget.pop_back();
4914 }
4915 
writeReturnStatement(const ReturnStatement & r,OutputStream & out)4916 void SPIRVCodeGenerator::writeReturnStatement(const ReturnStatement& r, OutputStream& out) {
4917     if (r.expression()) {
4918         this->writeInstruction(SpvOpReturnValue, this->writeExpression(*r.expression(), out),
4919                                out);
4920     } else {
4921         this->writeInstruction(SpvOpReturn, out);
4922     }
4923 }
4924 
4925 // Given any function, returns the top-level symbol table (OUTSIDE of the function's scope).
get_top_level_symbol_table(const FunctionDeclaration & anyFunc)4926 static SymbolTable* get_top_level_symbol_table(const FunctionDeclaration& anyFunc) {
4927     return anyFunc.definition()->body()->as<Block>().symbolTable()->fParent;
4928 }
4929 
writeEntrypointAdapter(const FunctionDeclaration & main)4930 SPIRVCodeGenerator::EntrypointAdapter SPIRVCodeGenerator::writeEntrypointAdapter(
4931         const FunctionDeclaration& main) {
4932     // Our goal is to synthesize a tiny helper function which looks like this:
4933     //     void _entrypoint() { sk_FragColor = main(); }
4934 
4935     // Fish a symbol table out of main().
4936     SymbolTable* symbolTable = get_top_level_symbol_table(main);
4937 
4938     // Get `sk_FragColor` as a writable reference.
4939     const Symbol* skFragColorSymbol = symbolTable->find("sk_FragColor");
4940     SkASSERT(skFragColorSymbol);
4941     const Variable& skFragColorVar = skFragColorSymbol->as<Variable>();
4942     auto skFragColorRef = std::make_unique<VariableReference>(Position(), &skFragColorVar,
4943                                                               VariableReference::RefKind::kWrite);
4944     // Synthesize a call to the `main()` function.
4945     if (!main.returnType().matches(skFragColorRef->type())) {
4946         fContext.fErrors->error(main.fPosition, "SPIR-V does not support returning '" +
4947                 main.returnType().description() + "' from main()");
4948         return {};
4949     }
4950     ExpressionArray args;
4951     if (main.parameters().size() == 1) {
4952         if (!main.parameters()[0]->type().matches(*fContext.fTypes.fFloat2)) {
4953             fContext.fErrors->error(main.fPosition,
4954                     "SPIR-V does not support parameter of type '" +
4955                     main.parameters()[0]->type().description() + "' to main()");
4956             return {};
4957         }
4958         double kZero[2] = {0.0, 0.0};
4959         args.push_back(ConstructorCompound::MakeFromConstants(fContext, Position{},
4960                                                               *fContext.fTypes.fFloat2, kZero));
4961     }
4962     auto callMainFn = std::make_unique<FunctionCall>(Position(), &main.returnType(), &main,
4963                                                      std::move(args));
4964 
4965     // Synthesize `skFragColor = main()` as a BinaryExpression.
4966     auto assignmentStmt = std::make_unique<ExpressionStatement>(std::make_unique<BinaryExpression>(
4967             Position(),
4968             std::move(skFragColorRef),
4969             Operator::Kind::EQ,
4970             std::move(callMainFn),
4971             &main.returnType()));
4972 
4973     // Function bodies are always wrapped in a Block.
4974     StatementArray entrypointStmts;
4975     entrypointStmts.push_back(std::move(assignmentStmt));
4976     auto entrypointBlock = Block::Make(Position(), std::move(entrypointStmts),
4977                                        Block::Kind::kBracedScope, /*symbols=*/nullptr);
4978     // Declare an entrypoint function.
4979     EntrypointAdapter adapter;
4980     adapter.entrypointDecl =
4981             std::make_unique<FunctionDeclaration>(fContext,
4982                                                   Position(),
4983                                                   ModifierFlag::kNone,
4984                                                   "_entrypoint",
4985                                                   /*parameters=*/TArray<Variable*>{},
4986                                                   /*returnType=*/fContext.fTypes.fVoid.get(),
4987                                                   kNotIntrinsic);
4988     // Define it.
4989     adapter.entrypointDef = FunctionDefinition::Convert(fContext,
4990                                                         Position(),
4991                                                         *adapter.entrypointDecl,
4992                                                         std::move(entrypointBlock),
4993                                                         /*builtin=*/false);
4994 
4995     adapter.entrypointDecl->setDefinition(adapter.entrypointDef.get());
4996     return adapter;
4997 }
4998 
writeUniformBuffer(SymbolTable * topLevelSymbolTable)4999 void SPIRVCodeGenerator::writeUniformBuffer(SymbolTable* topLevelSymbolTable) {
5000     SkASSERT(!fTopLevelUniforms.empty());
5001     static constexpr char kUniformBufferName[] = "_UniformBuffer";
5002 
5003     // Convert the list of top-level uniforms into a matching struct named _UniformBuffer, and build
5004     // a lookup table of variables to UniformBuffer field indices.
5005     TArray<Field> fields;
5006     fields.reserve_exact(fTopLevelUniforms.size());
5007     for (const VarDeclaration* topLevelUniform : fTopLevelUniforms) {
5008         const Variable* var = topLevelUniform->var();
5009         fTopLevelUniformMap.set(var, (int)fields.size());
5010         ModifierFlags flags = var->modifierFlags() & ~ModifierFlag::kUniform;
5011         fields.emplace_back(var->fPosition, var->layout(), flags, var->name(), &var->type());
5012     }
5013     fUniformBuffer.fStruct = Type::MakeStructType(fContext,
5014                                                   Position(),
5015                                                   kUniformBufferName,
5016                                                   std::move(fields),
5017                                                   /*interfaceBlock=*/true);
5018 
5019     // Create a global variable to contain this struct.
5020     Layout layout;
5021     layout.fBinding = fProgram.fConfig->fSettings.fDefaultUniformBinding;
5022     layout.fSet     = fProgram.fConfig->fSettings.fDefaultUniformSet;
5023 
5024     fUniformBuffer.fInnerVariable = Variable::Make(/*pos=*/Position(),
5025                                                    /*modifiersPosition=*/Position(),
5026                                                    layout,
5027                                                    ModifierFlag::kUniform,
5028                                                    fUniformBuffer.fStruct.get(),
5029                                                    kUniformBufferName,
5030                                                    /*mangledName=*/"",
5031                                                    /*builtin=*/false,
5032                                                    Variable::Storage::kGlobal);
5033 
5034     // Create an interface block object for this global variable.
5035     fUniformBuffer.fInterfaceBlock =
5036             std::make_unique<InterfaceBlock>(Position(), fUniformBuffer.fInnerVariable.get());
5037 
5038     // Generate an interface block and hold onto its ID.
5039     fUniformBufferId = this->writeInterfaceBlock(*fUniformBuffer.fInterfaceBlock);
5040 }
5041 
addRTFlipUniform(Position pos)5042 void SPIRVCodeGenerator::addRTFlipUniform(Position pos) {
5043     SkASSERT(!fProgram.fConfig->fSettings.fForceNoRTFlip);
5044 
5045     if (fWroteRTFlip) {
5046         return;
5047     }
5048     // Flip variable hasn't been written yet. This means we don't have an existing
5049     // interface block, so we're free to just synthesize one.
5050     fWroteRTFlip = true;
5051     TArray<Field> fields;
5052     if (fProgram.fConfig->fSettings.fRTFlipOffset < 0) {
5053         fContext.fErrors->error(pos, "RTFlipOffset is negative");
5054     }
5055     fields.emplace_back(pos,
5056                         Layout(LayoutFlag::kNone,
5057                                /*location=*/-1,
5058                                fProgram.fConfig->fSettings.fRTFlipOffset,
5059                                /*binding=*/-1,
5060                                /*index=*/-1,
5061                                /*set=*/-1,
5062                                /*builtin=*/-1,
5063                                /*inputAttachmentIndex=*/-1),
5064                         ModifierFlag::kNone,
5065                         SKSL_RTFLIP_NAME,
5066                         fContext.fTypes.fFloat2.get());
5067     std::string_view name = "sksl_synthetic_uniforms";
5068     const Type* intfStruct = fSynthetics.takeOwnershipOfSymbol(Type::MakeStructType(
5069             fContext, Position(), name, std::move(fields), /*interfaceBlock=*/true));
5070     bool usePushConstants = fProgram.fConfig->fSettings.fUsePushConstants;
5071     int binding = -1, set = -1;
5072     if (!usePushConstants) {
5073         binding = fProgram.fConfig->fSettings.fRTFlipBinding;
5074         if (binding == -1) {
5075             fContext.fErrors->error(pos, "layout(binding=...) is required in SPIR-V");
5076         }
5077         set = fProgram.fConfig->fSettings.fRTFlipSet;
5078         if (set == -1) {
5079             fContext.fErrors->error(pos, "layout(set=...) is required in SPIR-V");
5080         }
5081     }
5082     Layout layout(/*flags=*/usePushConstants ? LayoutFlag::kPushConstant : LayoutFlag::kNone,
5083                   /*location=*/-1,
5084                   /*offset=*/-1,
5085                   binding,
5086                   /*index=*/-1,
5087                   set,
5088                   /*builtin=*/-1,
5089                   /*inputAttachmentIndex=*/-1);
5090     Variable* intfVar =
5091             fSynthetics.takeOwnershipOfSymbol(Variable::Make(/*pos=*/Position(),
5092                                                              /*modifiersPosition=*/Position(),
5093                                                              layout,
5094                                                              ModifierFlag::kUniform,
5095                                                              intfStruct,
5096                                                              name,
5097                                                              /*mangledName=*/"",
5098                                                              /*builtin=*/false,
5099                                                              Variable::Storage::kGlobal));
5100     {
5101         AutoAttachPoolToThread attach(fProgram.fPool.get());
5102         fProgram.fSymbols->add(fContext,
5103                                std::make_unique<FieldSymbol>(Position(), intfVar, /*field=*/0));
5104     }
5105     InterfaceBlock intf(Position(), intfVar);
5106     this->writeInterfaceBlock(intf, false);
5107 }
5108 
synthesizeTextureAndSampler(const Variable & combinedSampler)5109 std::tuple<const Variable*, const Variable*> SPIRVCodeGenerator::synthesizeTextureAndSampler(
5110         const Variable& combinedSampler) {
5111     SkASSERT(fUseTextureSamplerPairs);
5112     SkASSERT(combinedSampler.type().typeKind() == Type::TypeKind::kSampler);
5113 
5114     auto data = std::make_unique<SynthesizedTextureSamplerPair>();
5115 
5116     Layout texLayout = combinedSampler.layout();
5117     texLayout.fBinding = texLayout.fTexture;
5118     data->fTextureName = std::string(combinedSampler.name()) + "_texture";
5119 
5120     auto texture = Variable::Make(/*pos=*/Position(),
5121                                   /*modifiersPosition=*/Position(),
5122                                   texLayout,
5123                                   combinedSampler.modifierFlags(),
5124                                   &combinedSampler.type().textureType(),
5125                                   data->fTextureName,
5126                                   /*mangledName=*/"",
5127                                   /*builtin=*/false,
5128                                   Variable::Storage::kGlobal);
5129 
5130     Layout samplerLayout = combinedSampler.layout();
5131     samplerLayout.fBinding = samplerLayout.fSampler;
5132     samplerLayout.fFlags &= ~LayoutFlag::kAllPixelFormats;
5133     data->fSamplerName = std::string(combinedSampler.name()) + "_sampler";
5134 
5135     auto sampler = Variable::Make(/*pos=*/Position(),
5136                                   /*modifiersPosition=*/Position(),
5137                                   samplerLayout,
5138                                   combinedSampler.modifierFlags(),
5139                                   fContext.fTypes.fSampler.get(),
5140                                   data->fSamplerName,
5141                                   /*mangledName=*/"",
5142                                   /*builtin=*/false,
5143                                   Variable::Storage::kGlobal);
5144 
5145     const Variable* t = texture.get();
5146     const Variable* s = sampler.get();
5147     data->fTexture = std::move(texture);
5148     data->fSampler = std::move(sampler);
5149     fSynthesizedSamplerMap.set(&combinedSampler, std::move(data));
5150 
5151     return {t, s};
5152 }
5153 
writeInstructions(const Program & program,OutputStream & out)5154 void SPIRVCodeGenerator::writeInstructions(const Program& program, OutputStream& out) {
5155     fGLSLExtendedInstructions = this->nextId(nullptr);
5156     StringStream body;
5157 
5158     // Do an initial pass over the program elements to establish some baseline info.
5159     const FunctionDeclaration* main = nullptr;
5160     int localSizeX = 1, localSizeY = 1, localSizeZ = 1;
5161     Position combinedSamplerPos;
5162     Position separateSamplerPos;
5163     for (const ProgramElement* e : program.elements()) {
5164         switch (e->kind()) {
5165             case ProgramElement::Kind::kFunction: {
5166                 // Assign SpvIds to functions.
5167                 const FunctionDefinition& funcDef = e->as<FunctionDefinition>();
5168                 const FunctionDeclaration& funcDecl = funcDef.declaration();
5169                 fFunctionMap.set(&funcDecl, this->nextId(nullptr));
5170                 if (funcDecl.isMain()) {
5171                     main = &funcDecl;
5172                 }
5173                 break;
5174             }
5175             case ProgramElement::Kind::kGlobalVar: {
5176                 // Look for sampler variables and determine whether or not this program uses
5177                 // combined samplers or separate samplers. The layout backend will be marked as
5178                 // WebGPU for separate samplers, or Vulkan for combined samplers.
5179                 const GlobalVarDeclaration& decl = e->as<GlobalVarDeclaration>();
5180                 const Variable& var = *decl.varDeclaration().var();
5181                 if (var.type().isSampler()) {
5182                     if (var.layout().fFlags & LayoutFlag::kVulkan) {
5183                         combinedSamplerPos = decl.position();
5184                     }
5185                     if (var.layout().fFlags & (LayoutFlag::kWebGPU | LayoutFlag::kDirect3D)) {
5186                         separateSamplerPos = decl.position();
5187                     }
5188                 }
5189                 break;
5190             }
5191             case ProgramElement::Kind::kModifiers: {
5192                 // If this is a compute program, collect the local-size values. Dimensions that are
5193                 // not present will be assigned a value of 1.
5194                 if (ProgramConfig::IsCompute(program.fConfig->fKind)) {
5195                     const ModifiersDeclaration& modifiers = e->as<ModifiersDeclaration>();
5196                     if (modifiers.layout().fLocalSizeX >= 0) {
5197                         localSizeX = modifiers.layout().fLocalSizeX;
5198                     }
5199                     if (modifiers.layout().fLocalSizeY >= 0) {
5200                         localSizeY = modifiers.layout().fLocalSizeY;
5201                     }
5202                     if (modifiers.layout().fLocalSizeZ >= 0) {
5203                         localSizeZ = modifiers.layout().fLocalSizeZ;
5204                     }
5205                 }
5206                 break;
5207             }
5208             default:
5209                 break;
5210         }
5211     }
5212 
5213     // Make sure we have a main() function.
5214     if (!main) {
5215         fContext.fErrors->error(Position(), "program does not contain a main() function");
5216         return;
5217     }
5218     // Make sure our program's sampler usage is consistent.
5219     if (combinedSamplerPos.valid() && separateSamplerPos.valid()) {
5220         fContext.fErrors->error(Position(), "programs cannot contain a mixture of sampler types");
5221         fContext.fErrors->error(combinedSamplerPos, "combined sampler found here:");
5222         fContext.fErrors->error(separateSamplerPos, "separate sampler found here:");
5223         return;
5224     }
5225     fUseTextureSamplerPairs = separateSamplerPos.valid();
5226 
5227     // Emit interface blocks.
5228     std::set<SpvId> interfaceVars;
5229     for (const ProgramElement* e : program.elements()) {
5230         if (e->is<InterfaceBlock>()) {
5231             const InterfaceBlock& intf = e->as<InterfaceBlock>();
5232             SpvId id = this->writeInterfaceBlock(intf);
5233 
5234             if ((intf.var()->modifierFlags() & (ModifierFlag::kIn | ModifierFlag::kOut)) &&
5235                 intf.var()->layout().fBuiltin == -1) {
5236                 interfaceVars.insert(id);
5237             }
5238         }
5239     }
5240     // If MustDeclareFragmentFrontFacing is set, the front-facing flag (sk_Clockwise) needs to be
5241     // explicitly declared in the output, whether or not the program explicitly references it.
5242     // However, if the program naturally declares it, we don't want to include it a second time;
5243     // we keep track of the real global variable declarations to see if sk_Clockwise is emitted.
5244     const VarDeclaration* missingClockwiseDecl = nullptr;
5245     if (fCaps.fMustDeclareFragmentFrontFacing) {
5246         if (const Symbol* clockwise = program.fSymbols->findBuiltinSymbol("sk_Clockwise")) {
5247             missingClockwiseDecl = clockwise->as<Variable>().varDeclaration();
5248         }
5249     }
5250     // Emit global variable declarations.
5251     for (const ProgramElement* e : program.elements()) {
5252         if (e->is<GlobalVarDeclaration>()) {
5253             const VarDeclaration& decl = e->as<GlobalVarDeclaration>().varDeclaration();
5254             if (!this->writeGlobalVarDeclaration(program.fConfig->fKind, decl)) {
5255                 return;
5256             }
5257             if (missingClockwiseDecl == &decl) {
5258                 // We emitted an sk_Clockwise declaration naturally, so we don't need a workaround.
5259                 missingClockwiseDecl = nullptr;
5260             }
5261         }
5262     }
5263     // All the global variables have been declared. If sk_Clockwise was not naturally included in
5264     // the output, but MustDeclareFragmentFrontFacing was set, we need to bodge it in ourselves.
5265     if (missingClockwiseDecl) {
5266         if (!this->writeGlobalVarDeclaration(program.fConfig->fKind, *missingClockwiseDecl)) {
5267             return;
5268         }
5269         missingClockwiseDecl = nullptr;
5270     }
5271     // Emit top-level uniforms into a dedicated uniform buffer.
5272     if (!fTopLevelUniforms.empty()) {
5273         this->writeUniformBuffer(get_top_level_symbol_table(*main));
5274     }
5275     // If main() returns a half4, synthesize a tiny entrypoint function which invokes the real
5276     // main() and stores the result into sk_FragColor.
5277     EntrypointAdapter adapter;
5278     if (main->returnType().matches(*fContext.fTypes.fHalf4)) {
5279         adapter = this->writeEntrypointAdapter(*main);
5280         if (adapter.entrypointDecl) {
5281             fFunctionMap.set(adapter.entrypointDecl.get(), this->nextId(nullptr));
5282             this->writeFunction(*adapter.entrypointDef, body);
5283             main = adapter.entrypointDecl.get();
5284         }
5285     }
5286     // Emit all the functions.
5287     for (const ProgramElement* e : program.elements()) {
5288         if (e->is<FunctionDefinition>()) {
5289             this->writeFunction(e->as<FunctionDefinition>(), body);
5290         }
5291     }
5292     // Add global in/out variables to the list of interface variables.
5293     for (const auto& [var, spvId] : fVariableMap) {
5294         if (var->storage() == Variable::Storage::kGlobal &&
5295             (var->modifierFlags() & (ModifierFlag::kIn | ModifierFlag::kOut))) {
5296             interfaceVars.insert(spvId);
5297         }
5298     }
5299     this->writeCapabilities(out);
5300     this->writeInstruction(SpvOpExtInstImport, fGLSLExtendedInstructions, "GLSL.std.450", out);
5301     this->writeInstruction(SpvOpMemoryModel, SpvAddressingModelLogical, SpvMemoryModelGLSL450, out);
5302     this->writeOpCode(SpvOpEntryPoint,
5303                       (SpvId)(3 + (main->name().length() + 4) / 4) + (int32_t)interfaceVars.size(),
5304                       out);
5305     if (ProgramConfig::IsVertex(program.fConfig->fKind)) {
5306         this->writeWord(SpvExecutionModelVertex, out);
5307     } else if (ProgramConfig::IsFragment(program.fConfig->fKind)) {
5308         this->writeWord(SpvExecutionModelFragment, out);
5309     } else if (ProgramConfig::IsCompute(program.fConfig->fKind)) {
5310         this->writeWord(SpvExecutionModelGLCompute, out);
5311     } else {
5312         SK_ABORT("cannot write this kind of program to SPIR-V\n");
5313     }
5314     SpvId entryPoint = fFunctionMap[main];
5315     this->writeWord(entryPoint, out);
5316     this->writeString(main->name(), out);
5317     for (int var : interfaceVars) {
5318         this->writeWord(var, out);
5319     }
5320     if (ProgramConfig::IsFragment(program.fConfig->fKind)) {
5321         this->writeInstruction(SpvOpExecutionMode,
5322                                fFunctionMap[main],
5323                                SpvExecutionModeOriginUpperLeft,
5324                                out);
5325     } else if (ProgramConfig::IsCompute(program.fConfig->fKind)) {
5326         this->writeInstruction(SpvOpExecutionMode,
5327                                fFunctionMap[main],
5328                                SpvExecutionModeLocalSize,
5329                                localSizeX, localSizeY, localSizeZ,
5330                                out);
5331     }
5332     for (const ProgramElement* e : program.elements()) {
5333         if (e->is<Extension>()) {
5334             this->writeInstruction(SpvOpSourceExtension, e->as<Extension>().name(), out);
5335         }
5336     }
5337 
5338     write_stringstream(fNameBuffer, out);
5339     write_stringstream(fDecorationBuffer, out);
5340     write_stringstream(fConstantBuffer, out);
5341     write_stringstream(body, out);
5342 }
5343 
generateCode()5344 bool SPIRVCodeGenerator::generateCode() {
5345     SkASSERT(!fContext.fErrors->errorCount());
5346     this->writeWord(SpvMagicNumber, *fOut);
5347     this->writeWord(SpvVersion, *fOut);
5348     this->writeWord(SKSL_MAGIC, *fOut);
5349     StringStream buffer;
5350     this->writeInstructions(fProgram, buffer);
5351     this->writeWord(fIdCount, *fOut);
5352     this->writeWord(0, *fOut); // reserved, always zero
5353     write_stringstream(buffer, *fOut);
5354     return fContext.fErrors->errorCount() == 0;
5355 }
5356 
5357 #if defined(SK_ENABLE_SPIRV_VALIDATION)
validate_spirv(ErrorReporter & reporter,std::string_view program)5358 static bool validate_spirv(ErrorReporter& reporter, std::string_view program) {
5359     SkASSERT(0 == program.size() % 4);
5360     const uint32_t* programData = reinterpret_cast<const uint32_t*>(program.data());
5361     size_t programSize = program.size() / 4;
5362 
5363     spvtools::SpirvTools tools(SPV_ENV_VULKAN_1_0);
5364     std::string errors;
5365     auto msgFn = [&errors](spv_message_level_t, const char*, const spv_position_t&, const char* m) {
5366         errors += "SPIR-V validation error: ";
5367         errors += m;
5368         errors += '\n';
5369     };
5370     tools.SetMessageConsumer(msgFn);
5371 
5372     // Verify that the SPIR-V we produced is valid. At runtime, we will abort() with a message
5373     // explaining the error. In standalone mode (skslc), we will send the message, plus the
5374     // entire disassembled SPIR-V (for easier context & debugging) as *our* error message.
5375     bool result = tools.Validate(programData, programSize);
5376     if (!result) {
5377 #if defined(SKSL_STANDALONE)
5378         // Convert the string-stream to a SPIR-V disassembly.
5379         std::string disassembly;
5380         uint32_t options = spvtools::SpirvTools::kDefaultDisassembleOption;
5381         options |= SPV_BINARY_TO_TEXT_OPTION_INDENT;
5382         if (tools.Disassemble(programData, programSize, &disassembly, options)) {
5383             errors.append(disassembly);
5384         }
5385         reporter.error(Position(), errors);
5386 #else
5387         SkDEBUGFAILF("%s", errors.c_str());
5388 #endif
5389     }
5390     return result;
5391 }
5392 #endif
5393 
ToSPIRV(Program & program,const ShaderCaps * caps,OutputStream & out)5394 bool ToSPIRV(Program& program, const ShaderCaps* caps, OutputStream& out) {
5395     TRACE_EVENT0("skia.shaders", "SkSL::ToSPIRV");
5396     SkASSERT(caps != nullptr);
5397 
5398     program.fContext->fErrors->setSource(*program.fSource);
5399 #ifdef SK_ENABLE_SPIRV_VALIDATION
5400     StringStream buffer;
5401     SPIRVCodeGenerator cg(program.fContext.get(), caps, &program, &buffer);
5402     bool result = cg.generateCode();
5403 
5404     if (result && program.fConfig->fSettings.fValidateSPIRV) {
5405         std::string_view binary = buffer.str();
5406         result = validate_spirv(*program.fContext->fErrors, binary);
5407         out.write(binary.data(), binary.size());
5408     }
5409 #else
5410     SPIRVCodeGenerator cg(program.fContext.get(), caps, &program, &out);
5411     bool result = cg.generateCode();
5412 #endif
5413     program.fContext->fErrors->setSource(std::string_view());
5414 
5415     return result;
5416 }
5417 
ToSPIRV(Program & program,const ShaderCaps * caps,std::string * out)5418 bool ToSPIRV(Program& program, const ShaderCaps* caps, std::string* out) {
5419     StringStream buffer;
5420     if (!ToSPIRV(program, caps, buffer)) {
5421         return false;
5422     }
5423     *out = buffer.str();
5424     return true;
5425 }
5426 
5427 }  // namespace SkSL
5428