• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2016 Google Inc.
3  *
4  * Use of this source code is governed by a BSD-style license that can be
5  * found in the LICENSE file.
6  */
7 
8 #ifndef SKSL_SPIRVCODEGENERATOR
9 #define SKSL_SPIRVCODEGENERATOR
10 
11 #include <stack>
12 #include <unordered_map>
13 #include <unordered_set>
14 
15 #include "src/core/SkOpts.h"
16 #include "src/sksl/SkSLMemoryLayout.h"
17 #include "src/sksl/SkSLStringStream.h"
18 #include "src/sksl/codegen/SkSLCodeGenerator.h"
19 
20 namespace SkSL {
21 
22 class BinaryExpression;
23 class Block;
24 class ConstructorCompound;
25 class ConstructorCompoundCast;
26 class ConstructorDiagonalMatrix;
27 class ConstructorMatrixResize;
28 class ConstructorScalarCast;
29 class ConstructorSplat;
30 class DoStatement;
31 class FieldAccess;
32 class ForStatement;
33 class FunctionCall;
34 class FunctionDeclaration;
35 class FunctionDefinition;
36 class FunctionPrototype;
37 class IfStatement;
38 struct IndexExpression;
39 class InterfaceBlock;
40 enum IntrinsicKind : int8_t;
41 class Literal;
42 class Operator;
43 class PostfixExpression;
44 class PrefixExpression;
45 class ReturnStatement;
46 class Setting;
47 class StructDefinition;
48 class SwitchStatement;
49 struct Swizzle;
50 class TernaryExpression;
51 class VarDeclaration;
52 class VariableReference;
53 
54 struct SPIRVNumberConstant {
55     bool operator==(const SPIRVNumberConstant& that) const {
56         return fValueBits == that.fValueBits &&
57                fKind      == that.fKind;
58     }
59     int32_t                fValueBits;
60     SkSL::Type::NumberKind fKind;
61 };
62 
63 struct SPIRVVectorConstant {
64     bool operator==(const SPIRVVectorConstant& that) const {
65         return fTypeId     == that.fTypeId &&
66                fValueId[0] == that.fValueId[0] &&
67                fValueId[1] == that.fValueId[1] &&
68                fValueId[2] == that.fValueId[2] &&
69                fValueId[3] == that.fValueId[3];
70     }
71     SpvId fTypeId;
72     SpvId fValueId[4];
73 };
74 
75 }  // namespace SkSL
76 
77 namespace std {
78 
79 template <>
80 struct hash<SkSL::SPIRVNumberConstant> {
81     size_t operator()(const SkSL::SPIRVNumberConstant& key) const {
82         return key.fValueBits ^ (int)key.fKind;
83     }
84 };
85 
86 template <>
87 struct hash<SkSL::SPIRVVectorConstant> {
88     size_t operator()(const SkSL::SPIRVVectorConstant& key) const {
89         return SkOpts::hash(&key, sizeof(key));
90     }
91 };
92 
93 }  // namespace std
94 
95 namespace SkSL {
96 
97 /**
98  * Converts a Program into a SPIR-V binary.
99  */
100 class SPIRVCodeGenerator : public CodeGenerator {
101 public:
102     class LValue {
103     public:
104         virtual ~LValue() {}
105 
106         // returns a pointer to the lvalue, if possible. If the lvalue cannot be directly referenced
107         // by a pointer (e.g. vector swizzles), returns -1.
108         virtual SpvId getPointer() { return -1; }
109 
110         // Returns true if a valid pointer returned by getPointer represents a memory object
111         // (see https://github.com/KhronosGroup/SPIRV-Tools/issues/2892). Has no meaning if
112         // getPointer() returns -1.
113         virtual bool isMemoryObjectPointer() const { return true; }
114 
115         // Applies a swizzle to the components of the LValue, if possible. This is used to create
116         // LValues that are swizzes-of-swizzles. Non-swizzle LValues can just return false.
117         virtual bool applySwizzle(const ComponentArray& components, const Type& newType) {
118             return false;
119         }
120 
121         virtual SpvId load(OutputStream& out) = 0;
122 
123         virtual void store(SpvId value, OutputStream& out) = 0;
124     };
125 
126     SPIRVCodeGenerator(const Context* context,
127                        const Program* program,
128                        OutputStream* out)
129             : INHERITED(context, program, out)
130             , fDefaultLayout(MemoryLayout::k140_Standard)
131             , fCapabilities(0)
132             , fIdCount(1)
133             , fSetupFragPosition(false)
134             , fCurrentBlock(0)
135             , fSynthetics(fContext, /*builtin=*/true) {
136         this->setupIntrinsics();
137     }
138 
139     bool generateCode() override;
140 
141 private:
142     enum IntrinsicOpcodeKind {
143         kGLSL_STD_450_IntrinsicOpcodeKind,
144         kSPIRV_IntrinsicOpcodeKind,
145         kSpecial_IntrinsicOpcodeKind
146     };
147 
148     enum SpecialIntrinsic {
149         kAtan_SpecialIntrinsic,
150         kClamp_SpecialIntrinsic,
151         kMatrixCompMult_SpecialIntrinsic,
152         kMax_SpecialIntrinsic,
153         kMin_SpecialIntrinsic,
154         kMix_SpecialIntrinsic,
155         kMod_SpecialIntrinsic,
156         kDFdy_SpecialIntrinsic,
157         kSaturate_SpecialIntrinsic,
158         kSampledImage_SpecialIntrinsic,
159         kSmoothStep_SpecialIntrinsic,
160         kStep_SpecialIntrinsic,
161         kSubpassLoad_SpecialIntrinsic,
162         kTexture_SpecialIntrinsic,
163     };
164 
165     enum class Precision {
166         kDefault,
167         kRelaxed,
168     };
169 
170     struct TempVar {
171         SpvId spvId;
172         const Type* type;
173         std::unique_ptr<SPIRVCodeGenerator::LValue> lvalue;
174     };
175 
176     void setupIntrinsics();
177 
178     /**
179      * Pass in the type to automatically add a RelaxedPrecision decoration for the id when
180      * appropriate, or null to never add one.
181      */
182     SpvId nextId(const Type* type);
183 
184     SpvId nextId(Precision precision);
185 
186     const Type& getActualType(const Type& type);
187 
188     SpvId getType(const Type& type);
189 
190     SpvId getType(const Type& type, const MemoryLayout& layout);
191 
192     SpvId getImageType(const Type& type);
193 
194     SpvId getFunctionType(const FunctionDeclaration& function);
195 
196     SpvId getPointerType(const Type& type, SpvStorageClass_ storageClass);
197 
198     SpvId getPointerType(const Type& type, const MemoryLayout& layout,
199                          SpvStorageClass_ storageClass);
200 
201     std::vector<SpvId> getAccessChain(const Expression& expr, OutputStream& out);
202 
203     void writeLayout(const Layout& layout, SpvId target, int line);
204 
205     void writeFieldLayout(const Layout& layout, SpvId target, int member);
206 
207     void writeStruct(const Type& type, const MemoryLayout& layout, SpvId resultId);
208 
209     void writeProgramElement(const ProgramElement& pe, OutputStream& out);
210 
211     SpvId writeInterfaceBlock(const InterfaceBlock& intf, bool appendRTFlip = true);
212 
213     SpvId writeFunctionStart(const FunctionDeclaration& f, OutputStream& out);
214 
215     SpvId writeFunctionDeclaration(const FunctionDeclaration& f, OutputStream& out);
216 
217     SpvId writeFunction(const FunctionDefinition& f, OutputStream& out);
218 
219     void writeGlobalVar(ProgramKind kind, const VarDeclaration& v);
220 
221     void writeVarDeclaration(const VarDeclaration& var, OutputStream& out);
222 
223     SpvId writeVariableReference(const VariableReference& ref, OutputStream& out);
224 
225     int findUniformFieldIndex(const Variable& var) const;
226 
227     std::unique_ptr<LValue> getLValue(const Expression& value, OutputStream& out);
228 
229     SpvId writeExpression(const Expression& expr, OutputStream& out);
230 
231     SpvId writeIntrinsicCall(const FunctionCall& c, OutputStream& out);
232 
233     SpvId writeFunctionCallArgument(const FunctionCall& call,
234                                     int argIndex,
235                                     std::vector<TempVar>* tempVars,
236                                     OutputStream& out);
237 
238     void copyBackTempVars(const std::vector<TempVar>& tempVars, OutputStream& out);
239 
240     SpvId writeFunctionCall(const FunctionCall& c, OutputStream& out);
241 
242 
243     void writeGLSLExtendedInstruction(const Type& type, SpvId id, SpvId floatInst,
244                                       SpvId signedInst, SpvId unsignedInst,
245                                       const std::vector<SpvId>& args, OutputStream& out);
246 
247     /**
248      * Promotes an expression to a vector. If the expression is already a vector with vectorSize
249      * columns, returns it unmodified. If the expression is a scalar, either promotes it to a
250      * vector (if vectorSize > 1) or returns it unmodified (if vectorSize == 1). Asserts if the
251      * expression is already a vector and it does not have vectorSize columns.
252      */
253     SpvId vectorize(const Expression& expr, int vectorSize, OutputStream& out);
254 
255     /**
256      * Given a list of potentially mixed scalars and vectors, promotes the scalars to match the
257      * size of the vectors and returns the ids of the written expressions. e.g. given (float, vec2),
258      * returns (vec2(float), vec2). It is an error to use mismatched vector sizes, e.g. (float,
259      * vec2, vec3).
260      */
261     std::vector<SpvId> vectorize(const ExpressionArray& args, OutputStream& out);
262 
263     SpvId writeSpecialIntrinsic(const FunctionCall& c, SpecialIntrinsic kind, OutputStream& out);
264 
265     SpvId writeConstantVector(const AnyConstructor& c);
266 
267     SpvId writeScalarToMatrixSplat(const Type& matrixType, SpvId scalarId, OutputStream& out);
268 
269     SpvId writeFloatConstructor(const AnyConstructor& c, OutputStream& out);
270 
271     SpvId castScalarToFloat(SpvId inputId, const Type& inputType, const Type& outputType,
272                             OutputStream& out);
273 
274     SpvId writeIntConstructor(const AnyConstructor& c, OutputStream& out);
275 
276     SpvId castScalarToSignedInt(SpvId inputId, const Type& inputType, const Type& outputType,
277                                 OutputStream& out);
278 
279     SpvId writeUIntConstructor(const AnyConstructor& c, OutputStream& out);
280 
281     SpvId castScalarToUnsignedInt(SpvId inputId, const Type& inputType, const Type& outputType,
282                                   OutputStream& out);
283 
284     SpvId writeBooleanConstructor(const AnyConstructor& c, OutputStream& out);
285 
286     SpvId castScalarToBoolean(SpvId inputId, const Type& inputType, const Type& outputType,
287                               OutputStream& out);
288 
289     SpvId castScalarToType(SpvId inputExprId, const Type& inputType, const Type& outputType,
290                            OutputStream& out);
291 
292     /**
293      * Writes a matrix with the diagonal entries all equal to the provided expression, and all other
294      * entries equal to zero.
295      */
296     void writeUniformScaleMatrix(SpvId id, SpvId diagonal, const Type& type, OutputStream& out);
297 
298     /**
299      * Writes a potentially-different-sized copy of a matrix. Entries which do not exist in the
300      * source matrix are filled with zero; entries which do not exist in the destination matrix are
301      * ignored.
302      */
303     SpvId writeMatrixCopy(SpvId src, const Type& srcType, const Type& dstType, OutputStream& out);
304 
305     void addColumnEntry(const Type& columnType, std::vector<SpvId>* currentColumn,
306                         std::vector<SpvId>* columnIds, int rows, SpvId entry, OutputStream& out);
307 
308     SpvId writeConstructorCompound(const ConstructorCompound& c, OutputStream& out);
309 
310     SpvId writeMatrixConstructor(const ConstructorCompound& c, OutputStream& out);
311 
312     SpvId writeVectorConstructor(const ConstructorCompound& c, OutputStream& out);
313 
314     SpvId writeCompositeConstructor(const AnyConstructor& c, OutputStream& out);
315 
316     SpvId writeConstructorDiagonalMatrix(const ConstructorDiagonalMatrix& c, OutputStream& out);
317 
318     SpvId writeConstructorMatrixResize(const ConstructorMatrixResize& c, OutputStream& out);
319 
320     SpvId writeConstructorScalarCast(const ConstructorScalarCast& c, OutputStream& out);
321 
322     SpvId writeConstructorSplat(const ConstructorSplat& c, OutputStream& out);
323 
324     SpvId writeConstructorCompoundCast(const ConstructorCompoundCast& c, OutputStream& out);
325 
326     SpvId writeComposite(const std::vector<SpvId>& arguments, const Type& type, OutputStream& out);
327 
328     SpvId writeFieldAccess(const FieldAccess& f, OutputStream& out);
329 
330     SpvId writeSwizzle(const Swizzle& swizzle, OutputStream& out);
331 
332     /**
333      * Folds the potentially-vector result of a logical operation down to a single bool. If
334      * operandType is a vector type, assumes that the intermediate result in id is a bvec of the
335      * same dimensions, and applys all() to it to fold it down to a single bool value. Otherwise,
336      * returns the original id value.
337      */
338     SpvId foldToBool(SpvId id, const Type& operandType, SpvOp op, OutputStream& out);
339 
340     SpvId writeMatrixComparison(const Type& operandType, SpvId lhs, SpvId rhs, SpvOp_ floatOperator,
341                                 SpvOp_ intOperator, SpvOp_ vectorMergeOperator,
342                                 SpvOp_ mergeOperator, OutputStream& out);
343 
344     SpvId writeStructComparison(const Type& structType, SpvId lhs, Operator op, SpvId rhs,
345                                 OutputStream& out);
346 
347     SpvId writeArrayComparison(const Type& structType, SpvId lhs, Operator op, SpvId rhs,
348                                OutputStream& out);
349 
350     // Used by writeStructComparison and writeArrayComparison to logically combine field-by-field
351     // comparisons into an overall comparison result.
352     // - `a.x == b.x` merged with `a.y == b.y` generates `(a.x == b.x) && (a.y == b.y)`
353     // - `a.x != b.x` merged with `a.y != b.y` generates `(a.x != b.x) || (a.y != b.y)`
354     SpvId mergeComparisons(SpvId comparison, SpvId allComparisons, Operator op, OutputStream& out);
355 
356     SpvId writeComponentwiseMatrixUnary(const Type& operandType,
357                                         SpvId operand,
358                                         SpvOp_ op,
359                                         OutputStream& out);
360 
361     SpvId writeComponentwiseMatrixBinary(const Type& operandType, SpvId lhs, SpvId rhs,
362                                          SpvOp_ op, OutputStream& out);
363 
364     SpvId writeBinaryOperation(const Type& resultType, const Type& operandType, SpvId lhs,
365                                SpvId rhs, SpvOp_ ifFloat, SpvOp_ ifInt, SpvOp_ ifUInt,
366                                SpvOp_ ifBool, OutputStream& out);
367 
368     SpvId writeBinaryOperation(const BinaryExpression& expr, SpvOp_ ifFloat, SpvOp_ ifInt,
369                                SpvOp_ ifUInt, OutputStream& out);
370 
371     SpvId writeReciprocal(const Type& type, SpvId value, OutputStream& out);
372 
373     SpvId writeBinaryExpression(const Type& leftType, SpvId lhs, Operator op,
374                                 const Type& rightType, SpvId rhs, const Type& resultType,
375                                 OutputStream& out);
376 
377     SpvId writeBinaryExpression(const BinaryExpression& b, OutputStream& out);
378 
379     SpvId writeTernaryExpression(const TernaryExpression& t, OutputStream& out);
380 
381     SpvId writeIndexExpression(const IndexExpression& expr, OutputStream& out);
382 
383     SpvId writeLogicalAnd(const Expression& left, const Expression& right, OutputStream& out);
384 
385     SpvId writeLogicalOr(const Expression& left, const Expression& right, OutputStream& out);
386 
387     SpvId writePrefixExpression(const PrefixExpression& p, OutputStream& out);
388 
389     SpvId writePostfixExpression(const PostfixExpression& p, OutputStream& out);
390 
391     SpvId writeLiteral(const Literal& f);
392 
393     SpvId writeLiteral(double value, const Type& type);
394 
395     void writeStatement(const Statement& s, OutputStream& out);
396 
397     void writeBlock(const Block& b, OutputStream& out);
398 
399     void writeIfStatement(const IfStatement& stmt, OutputStream& out);
400 
401     void writeForStatement(const ForStatement& f, OutputStream& out);
402 
403     void writeDoStatement(const DoStatement& d, OutputStream& out);
404 
405     void writeSwitchStatement(const SwitchStatement& s, OutputStream& out);
406 
407     void writeReturnStatement(const ReturnStatement& r, OutputStream& out);
408 
409     void writeCapabilities(OutputStream& out);
410 
411     void writeInstructions(const Program& program, OutputStream& out);
412 
413     void writeOpCode(SpvOp_ opCode, int length, OutputStream& out);
414 
415     void writeWord(int32_t word, OutputStream& out);
416 
417     void writeString(std::string_view s, OutputStream& out);
418 
419     void writeLabel(SpvId id, OutputStream& out);
420 
421     void writeInstruction(SpvOp_ opCode, OutputStream& out);
422 
423     void writeInstruction(SpvOp_ opCode, std::string_view string, OutputStream& out);
424 
425     void writeInstruction(SpvOp_ opCode, int32_t word1, OutputStream& out);
426 
427     void writeInstruction(SpvOp_ opCode, int32_t word1, std::string_view string,
428                           OutputStream& out);
429 
430     void writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, std::string_view string,
431                           OutputStream& out);
432 
433     void writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, OutputStream& out);
434 
435     void writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, int32_t word3,
436                           OutputStream& out);
437 
438     void writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, int32_t word3, int32_t word4,
439                           OutputStream& out);
440 
441     void writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, int32_t word3, int32_t word4,
442                           int32_t word5, OutputStream& out);
443 
444     void writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, int32_t word3, int32_t word4,
445                           int32_t word5, int32_t word6, OutputStream& out);
446 
447     void writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, int32_t word3, int32_t word4,
448                           int32_t word5, int32_t word6, int32_t word7, OutputStream& out);
449 
450     void writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, int32_t word3, int32_t word4,
451                           int32_t word5, int32_t word6, int32_t word7, int32_t word8,
452                           OutputStream& out);
453 
454     bool isDead(const Variable& var) const;
455 
456     MemoryLayout memoryLayoutForVariable(const Variable&) const;
457 
458     struct EntrypointAdapter {
459         std::unique_ptr<FunctionDefinition> entrypointDef;
460         std::unique_ptr<FunctionDeclaration> entrypointDecl;
461         Layout fLayout;
462         Modifiers fModifiers;
463     };
464 
465     EntrypointAdapter writeEntrypointAdapter(const FunctionDeclaration& main);
466 
467     struct UniformBuffer {
468         std::unique_ptr<InterfaceBlock> fInterfaceBlock;
469         std::unique_ptr<Variable> fInnerVariable;
470         std::unique_ptr<Type> fStruct;
471     };
472 
473     void writeUniformBuffer(std::shared_ptr<SymbolTable> topLevelSymbolTable);
474 
475     void addRTFlipUniform(int line);
476 
477     const MemoryLayout fDefaultLayout;
478 
479     uint64_t fCapabilities;
480     SpvId fIdCount;
481     SpvId fGLSLExtendedInstructions;
482     struct Intrinsic {
483         IntrinsicOpcodeKind opKind;
484         int32_t floatOp;
485         int32_t signedOp;
486         int32_t unsignedOp;
487         int32_t boolOp;
488     };
489     std::unordered_map<IntrinsicKind, Intrinsic> fIntrinsicMap;
490     std::unordered_map<const FunctionDeclaration*, SpvId> fFunctionMap;
491     std::unordered_map<const Variable*, SpvId> fVariableMap;
492     std::unordered_map<const Variable*, int32_t> fInterfaceBlockMap;
493     std::unordered_map<std::string, SpvId> fImageTypeMap;
494     std::unordered_map<std::string, SpvId> fTypeMap;
495     StringStream fCapabilitiesBuffer;
496     StringStream fGlobalInitializersBuffer;
497     StringStream fConstantBuffer;
498     StringStream fExtraGlobalsBuffer;
499     StringStream fVariableBuffer;
500     StringStream fNameBuffer;
501     StringStream fDecorationBuffer;
502 
503     std::unordered_map<SPIRVNumberConstant, SpvId> fNumberConstants;
504     std::unordered_map<SPIRVVectorConstant, SpvId> fVectorConstants;
505     bool fSetupFragPosition;
506     // label of the current block, or 0 if we are not in a block
507     SpvId fCurrentBlock;
508     std::stack<SpvId> fBreakTarget;
509     std::stack<SpvId> fContinueTarget;
510     bool fWroteRTFlip = false;
511     // holds variables synthesized during output, for lifetime purposes
512     SymbolTable fSynthetics;
513     int fSkInCount = 1;
514     // Holds a list of uniforms that were declared as globals at the top-level instead of in an
515     // interface block.
516     UniformBuffer fUniformBuffer;
517     std::vector<const VarDeclaration*> fTopLevelUniforms;
518     std::unordered_map<const Variable*, int> fTopLevelUniformMap; //<var, UniformBuffer field index>
519     std::unordered_set<const Variable*> fSPIRVBonusVariables;
520     SpvId fUniformBufferId = -1;
521 
522     friend class PointerLValue;
523     friend class SwizzleLValue;
524 
525     using INHERITED = CodeGenerator;
526 };
527 
528 }  // namespace SkSL
529 
530 #endif
531