• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2016 Google Inc.
3  *
4  * Use of this source code is governed by a BSD-style license that can be
5  * found in the LICENSE file.
6  */
7 
8 #ifndef SKSL_SPIRVCODEGENERATOR
9 #define SKSL_SPIRVCODEGENERATOR
10 
11 #include <stack>
12 #include <tuple>
13 #include <unordered_map>
14 #include <unordered_set>
15 
16 #include "src/core/SkOpts.h"
17 #include "src/sksl/SkSLMemoryLayout.h"
18 #include "src/sksl/SkSLStringStream.h"
19 #include "src/sksl/codegen/SkSLCodeGenerator.h"
20 
21 namespace SkSL {
22 
23 class BinaryExpression;
24 class Block;
25 class ConstructorCompound;
26 class ConstructorCompoundCast;
27 class ConstructorDiagonalMatrix;
28 class ConstructorMatrixResize;
29 class ConstructorScalarCast;
30 class ConstructorSplat;
31 class DoStatement;
32 class FieldAccess;
33 class ForStatement;
34 class FunctionCall;
35 class FunctionDeclaration;
36 class FunctionDefinition;
37 class FunctionPrototype;
38 class IfStatement;
39 struct IndexExpression;
40 class InterfaceBlock;
41 enum IntrinsicKind : int8_t;
42 class Literal;
43 class Operator;
44 class PostfixExpression;
45 class PrefixExpression;
46 class ReturnStatement;
47 class Setting;
48 class StructDefinition;
49 class SwitchStatement;
50 struct Swizzle;
51 class TernaryExpression;
52 class VarDeclaration;
53 class VariableReference;
54 
55 struct SPIRVNumberConstant {
56     bool operator==(const SPIRVNumberConstant& that) const {
57         return fValueBits == that.fValueBits &&
58                fKind      == that.fKind;
59     }
60     int32_t                fValueBits;
61     SkSL::Type::NumberKind fKind;
62 };
63 
64 struct SPIRVVectorConstant {
65     bool operator==(const SPIRVVectorConstant& that) const {
66         return fTypeId     == that.fTypeId &&
67                fValueId[0] == that.fValueId[0] &&
68                fValueId[1] == that.fValueId[1] &&
69                fValueId[2] == that.fValueId[2] &&
70                fValueId[3] == that.fValueId[3];
71     }
72     SpvId fTypeId;
73     SpvId fValueId[4];
74 };
75 
76 }  // namespace SkSL
77 
78 namespace std {
79 
80 template <>
81 struct hash<SkSL::SPIRVNumberConstant> {
82     size_t operator()(const SkSL::SPIRVNumberConstant& key) const {
83         return key.fValueBits ^ (int)key.fKind;
84     }
85 };
86 
87 template <>
88 struct hash<SkSL::SPIRVVectorConstant> {
89     size_t operator()(const SkSL::SPIRVVectorConstant& key) const {
90         return SkOpts::hash(&key, sizeof(key));
91     }
92 };
93 
94 }  // namespace std
95 
96 namespace SkSL {
97 
98 /**
99  * Converts a Program into a SPIR-V binary.
100  */
101 class SPIRVCodeGenerator : public CodeGenerator {
102 public:
103     class LValue {
104     public:
105         virtual ~LValue() {}
106 
107         // returns a pointer to the lvalue, if possible. If the lvalue cannot be directly referenced
108         // by a pointer (e.g. vector swizzles), returns -1.
109         virtual SpvId getPointer() { return -1; }
110 
111         // Returns true if a valid pointer returned by getPointer represents a memory object
112         // (see https://github.com/KhronosGroup/SPIRV-Tools/issues/2892). Has no meaning if
113         // getPointer() returns -1.
114         virtual bool isMemoryObjectPointer() const { return true; }
115 
116         // Applies a swizzle to the components of the LValue, if possible. This is used to create
117         // LValues that are swizzes-of-swizzles. Non-swizzle LValues can just return false.
118         virtual bool applySwizzle(const ComponentArray& components, const Type& newType) {
119             return false;
120         }
121 
122         virtual SpvId load(OutputStream& out) = 0;
123 
124         virtual void store(SpvId value, OutputStream& out) = 0;
125     };
126 
127     SPIRVCodeGenerator(const Context* context,
128                        const Program* program,
129                        OutputStream* out)
130             : INHERITED(context, program, out)
131             , fDefaultLayout(MemoryLayout::k140_Standard)
132             , fCapabilities(0)
133             , fIdCount(1)
134             , fSetupFragPosition(false)
135             , fCurrentBlock(0)
136             , fSynthetics(fContext, /*builtin=*/true) {
137         this->setupIntrinsics();
138     }
139 
140     bool generateCode() override;
141 
142 private:
143     enum IntrinsicOpcodeKind {
144         kGLSL_STD_450_IntrinsicOpcodeKind,
145         kSPIRV_IntrinsicOpcodeKind,
146         kSpecial_IntrinsicOpcodeKind
147     };
148 
149     enum SpecialIntrinsic {
150         kAtan_SpecialIntrinsic,
151         kClamp_SpecialIntrinsic,
152         kMatrixCompMult_SpecialIntrinsic,
153         kMax_SpecialIntrinsic,
154         kMin_SpecialIntrinsic,
155         kMix_SpecialIntrinsic,
156         kMod_SpecialIntrinsic,
157         kDFdy_SpecialIntrinsic,
158         kSaturate_SpecialIntrinsic,
159         kSampledImage_SpecialIntrinsic,
160         kSmoothStep_SpecialIntrinsic,
161         kStep_SpecialIntrinsic,
162         kSubpassLoad_SpecialIntrinsic,
163         kTexture_SpecialIntrinsic,
164     };
165 
166     enum class Precision {
167         kDefault,
168         kRelaxed,
169     };
170 
171     struct TempVar {
172         SpvId spvId;
173         const Type* type;
174         std::unique_ptr<SPIRVCodeGenerator::LValue> lvalue;
175     };
176 
177     void setupIntrinsics();
178 
179     /**
180      * Pass in the type to automatically add a RelaxedPrecision decoration for the id when
181      * appropriate, or null to never add one.
182      */
183     SpvId nextId(const Type* type);
184 
185     SpvId nextId(Precision precision);
186 
187     const Type& getActualType(const Type& type);
188 
189     SpvId getType(const Type& type);
190 
191     SpvId getType(const Type& type, const MemoryLayout& layout);
192 
193     SpvId getImageType(const Type& type);
194 
195     SpvId getFunctionType(const FunctionDeclaration& function);
196 
197     SpvId getPointerType(const Type& type, SpvStorageClass_ storageClass);
198 
199     SpvId getPointerType(const Type& type, const MemoryLayout& layout,
200                          SpvStorageClass_ storageClass);
201 
202     std::vector<SpvId> getAccessChain(const Expression& expr, OutputStream& out);
203 
204     void writeLayout(const Layout& layout, SpvId target);
205 
206     void writeLayout(const Layout& layout, SpvId target, int member);
207 
208     void writeStruct(const Type& type, const MemoryLayout& layout, SpvId resultId);
209 
210     void writeProgramElement(const ProgramElement& pe, OutputStream& out);
211 
212     SpvId writeInterfaceBlock(const InterfaceBlock& intf, bool appendRTFlip = true);
213 
214     SpvId writeFunctionStart(const FunctionDeclaration& f, OutputStream& out);
215 
216     SpvId writeFunctionDeclaration(const FunctionDeclaration& f, OutputStream& out);
217 
218     SpvId writeFunction(const FunctionDefinition& f, OutputStream& out);
219 
220     void writeGlobalVar(ProgramKind kind, const VarDeclaration& v);
221 
222     void writeVarDeclaration(const VarDeclaration& var, OutputStream& out);
223 
224     SpvId writeVariableReference(const VariableReference& ref, OutputStream& out);
225 
226     int findUniformFieldIndex(const Variable& var) const;
227 
228     std::unique_ptr<LValue> getLValue(const Expression& value, OutputStream& out);
229 
230     SpvId writeExpression(const Expression& expr, OutputStream& out);
231 
232     SpvId writeIntrinsicCall(const FunctionCall& c, OutputStream& out);
233 
234     SpvId writeFunctionCallArgument(const Expression& arg,
235                                     const Modifiers& paramModifiers,
236                                     std::vector<TempVar>* tempVars,
237                                     OutputStream& out);
238 
239     void copyBackTempVars(const std::vector<TempVar>& tempVars, OutputStream& out);
240 
241     SpvId writeFunctionCall(const FunctionCall& c, OutputStream& out);
242 
243 
244     void writeGLSLExtendedInstruction(const Type& type, SpvId id, SpvId floatInst,
245                                       SpvId signedInst, SpvId unsignedInst,
246                                       const std::vector<SpvId>& args, OutputStream& out);
247 
248     /**
249      * Promotes an expression to a vector. If the expression is already a vector with vectorSize
250      * columns, returns it unmodified. If the expression is a scalar, either promotes it to a
251      * vector (if vectorSize > 1) or returns it unmodified (if vectorSize == 1). Asserts if the
252      * expression is already a vector and it does not have vectorSize columns.
253      */
254     SpvId vectorize(const Expression& expr, int vectorSize, OutputStream& out);
255 
256     /**
257      * Given a list of potentially mixed scalars and vectors, promotes the scalars to match the
258      * size of the vectors and returns the ids of the written expressions. e.g. given (float, vec2),
259      * returns (vec2(float), vec2). It is an error to use mismatched vector sizes, e.g. (float,
260      * vec2, vec3).
261      */
262     std::vector<SpvId> vectorize(const ExpressionArray& args, OutputStream& out);
263 
264     SpvId writeSpecialIntrinsic(const FunctionCall& c, SpecialIntrinsic kind, OutputStream& out);
265 
266     SpvId writeConstantVector(const AnyConstructor& c);
267 
268     SpvId writeScalarToMatrixSplat(const Type& matrixType, SpvId scalarId, OutputStream& out);
269 
270     SpvId writeFloatConstructor(const AnyConstructor& c, OutputStream& out);
271 
272     SpvId castScalarToFloat(SpvId inputId, const Type& inputType, const Type& outputType,
273                             OutputStream& out);
274 
275     SpvId writeIntConstructor(const AnyConstructor& c, OutputStream& out);
276 
277     SpvId castScalarToSignedInt(SpvId inputId, const Type& inputType, const Type& outputType,
278                                 OutputStream& out);
279 
280     SpvId writeUIntConstructor(const AnyConstructor& c, OutputStream& out);
281 
282     SpvId castScalarToUnsignedInt(SpvId inputId, const Type& inputType, const Type& outputType,
283                                   OutputStream& out);
284 
285     SpvId writeBooleanConstructor(const AnyConstructor& c, OutputStream& out);
286 
287     SpvId castScalarToBoolean(SpvId inputId, const Type& inputType, const Type& outputType,
288                               OutputStream& out);
289 
290     SpvId castScalarToType(SpvId inputExprId, const Type& inputType, const Type& outputType,
291                            OutputStream& out);
292 
293     /**
294      * Writes a matrix with the diagonal entries all equal to the provided expression, and all other
295      * entries equal to zero.
296      */
297     void writeUniformScaleMatrix(SpvId id, SpvId diagonal, const Type& type, OutputStream& out);
298 
299     /**
300      * Writes a potentially-different-sized copy of a matrix. Entries which do not exist in the
301      * source matrix are filled with zero; entries which do not exist in the destination matrix are
302      * ignored.
303      */
304     SpvId writeMatrixCopy(SpvId src, const Type& srcType, const Type& dstType, OutputStream& out);
305 
306     void addColumnEntry(const Type& columnType, std::vector<SpvId>* currentColumn,
307                         std::vector<SpvId>* columnIds, int rows, SpvId entry, OutputStream& out);
308 
309     SpvId writeConstructorCompound(const ConstructorCompound& c, OutputStream& out);
310 
311     SpvId writeMatrixConstructor(const ConstructorCompound& c, OutputStream& out);
312 
313     SpvId writeVectorConstructor(const ConstructorCompound& c, OutputStream& out);
314 
315     SpvId writeCompositeConstructor(const AnyConstructor& c, OutputStream& out);
316 
317     SpvId writeConstructorDiagonalMatrix(const ConstructorDiagonalMatrix& c, OutputStream& out);
318 
319     SpvId writeConstructorMatrixResize(const ConstructorMatrixResize& c, OutputStream& out);
320 
321     SpvId writeConstructorScalarCast(const ConstructorScalarCast& c, OutputStream& out);
322 
323     SpvId writeConstructorSplat(const ConstructorSplat& c, OutputStream& out);
324 
325     SpvId writeConstructorCompoundCast(const ConstructorCompoundCast& c, OutputStream& out);
326 
327     SpvId writeComposite(const std::vector<SpvId>& arguments, const Type& type, OutputStream& out);
328 
329     SpvId writeFieldAccess(const FieldAccess& f, OutputStream& out);
330 
331     SpvId writeSwizzle(const Swizzle& swizzle, OutputStream& out);
332 
333     /**
334      * Folds the potentially-vector result of a logical operation down to a single bool. If
335      * operandType is a vector type, assumes that the intermediate result in id is a bvec of the
336      * same dimensions, and applys all() to it to fold it down to a single bool value. Otherwise,
337      * returns the original id value.
338      */
339     SpvId foldToBool(SpvId id, const Type& operandType, SpvOp op, OutputStream& out);
340 
341     SpvId writeMatrixComparison(const Type& operandType, SpvId lhs, SpvId rhs, SpvOp_ floatOperator,
342                                 SpvOp_ intOperator, SpvOp_ vectorMergeOperator,
343                                 SpvOp_ mergeOperator, OutputStream& out);
344 
345     SpvId writeStructComparison(const Type& structType, SpvId lhs, Operator op, SpvId rhs,
346                                 OutputStream& out);
347 
348     SpvId writeArrayComparison(const Type& structType, SpvId lhs, Operator op, SpvId rhs,
349                                OutputStream& out);
350 
351     // Used by writeStructComparison and writeArrayComparison to logically combine field-by-field
352     // comparisons into an overall comparison result.
353     // - `a.x == b.x` merged with `a.y == b.y` generates `(a.x == b.x) && (a.y == b.y)`
354     // - `a.x != b.x` merged with `a.y != b.y` generates `(a.x != b.x) || (a.y != b.y)`
355     SpvId mergeComparisons(SpvId comparison, SpvId allComparisons, Operator op, OutputStream& out);
356 
357     SpvId writeComponentwiseMatrixBinary(const Type& operandType, SpvId lhs, SpvId rhs,
358                                          SpvOp_ op, OutputStream& out);
359 
360     SpvId writeBinaryOperation(const Type& resultType, const Type& operandType, SpvId lhs,
361                                SpvId rhs, SpvOp_ ifFloat, SpvOp_ ifInt, SpvOp_ ifUInt,
362                                SpvOp_ ifBool, OutputStream& out);
363 
364     SpvId writeBinaryOperation(const BinaryExpression& expr, SpvOp_ ifFloat, SpvOp_ ifInt,
365                                SpvOp_ ifUInt, OutputStream& out);
366 
367     SpvId writeReciprocal(const Type& type, SpvId value, OutputStream& out);
368 
369     SpvId writeBinaryExpression(const Type& leftType, SpvId lhs, Operator op,
370                                 const Type& rightType, SpvId rhs, const Type& resultType,
371                                 OutputStream& out);
372 
373     SpvId writeBinaryExpression(const BinaryExpression& b, OutputStream& out);
374 
375     SpvId writeTernaryExpression(const TernaryExpression& t, OutputStream& out);
376 
377     SpvId writeIndexExpression(const IndexExpression& expr, OutputStream& out);
378 
379     SpvId writeLogicalAnd(const Expression& left, const Expression& right, OutputStream& out);
380 
381     SpvId writeLogicalOr(const Expression& left, const Expression& right, OutputStream& out);
382 
383     SpvId writePrefixExpression(const PrefixExpression& p, OutputStream& out);
384 
385     SpvId writePostfixExpression(const PostfixExpression& p, OutputStream& out);
386 
387     SpvId writeLiteral(const Literal& f);
388 
389     SpvId writeLiteral(double value, const Type& type);
390 
391     void writeStatement(const Statement& s, OutputStream& out);
392 
393     void writeBlock(const Block& b, OutputStream& out);
394 
395     void writeIfStatement(const IfStatement& stmt, OutputStream& out);
396 
397     void writeForStatement(const ForStatement& f, OutputStream& out);
398 
399     void writeDoStatement(const DoStatement& d, OutputStream& out);
400 
401     void writeSwitchStatement(const SwitchStatement& s, OutputStream& out);
402 
403     void writeReturnStatement(const ReturnStatement& r, OutputStream& out);
404 
405     void writeCapabilities(OutputStream& out);
406 
407     void writeInstructions(const Program& program, OutputStream& out);
408 
409     void writeOpCode(SpvOp_ opCode, int length, OutputStream& out);
410 
411     void writeWord(int32_t word, OutputStream& out);
412 
413     void writeString(skstd::string_view s, OutputStream& out);
414 
415     void writeLabel(SpvId id, OutputStream& out);
416 
417     void writeInstruction(SpvOp_ opCode, OutputStream& out);
418 
419     void writeInstruction(SpvOp_ opCode, skstd::string_view string, OutputStream& out);
420 
421     void writeInstruction(SpvOp_ opCode, int32_t word1, OutputStream& out);
422 
423     void writeInstruction(SpvOp_ opCode, int32_t word1, skstd::string_view string,
424                           OutputStream& out);
425 
426     void writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, skstd::string_view string,
427                           OutputStream& out);
428 
429     void writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, OutputStream& out);
430 
431     void writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, int32_t word3,
432                           OutputStream& out);
433 
434     void writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, int32_t word3, int32_t word4,
435                           OutputStream& out);
436 
437     void writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, int32_t word3, int32_t word4,
438                           int32_t word5, OutputStream& out);
439 
440     void writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, int32_t word3, int32_t word4,
441                           int32_t word5, int32_t word6, OutputStream& out);
442 
443     void writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, int32_t word3, int32_t word4,
444                           int32_t word5, int32_t word6, int32_t word7, OutputStream& out);
445 
446     void writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, int32_t word3, int32_t word4,
447                           int32_t word5, int32_t word6, int32_t word7, int32_t word8,
448                           OutputStream& out);
449 
450     bool isDead(const Variable& var) const;
451 
452     MemoryLayout memoryLayoutForVariable(const Variable&) const;
453 
454     struct EntrypointAdapter {
455         std::unique_ptr<FunctionDefinition> entrypointDef;
456         std::unique_ptr<FunctionDeclaration> entrypointDecl;
457         Layout fLayout;
458         Modifiers fModifiers;
459     };
460 
461     EntrypointAdapter writeEntrypointAdapter(const FunctionDeclaration& main);
462 
463     struct UniformBuffer {
464         std::unique_ptr<InterfaceBlock> fInterfaceBlock;
465         std::unique_ptr<Variable> fInnerVariable;
466         std::unique_ptr<Type> fStruct;
467     };
468 
469     void writeUniformBuffer(std::shared_ptr<SymbolTable> topLevelSymbolTable);
470 
471     void addRTFlipUniform(int line);
472 
473     const MemoryLayout fDefaultLayout;
474 
475     uint64_t fCapabilities;
476     SpvId fIdCount;
477     SpvId fGLSLExtendedInstructions;
478     typedef std::tuple<IntrinsicOpcodeKind, int32_t, int32_t, int32_t, int32_t> Intrinsic;
479     std::unordered_map<IntrinsicKind, Intrinsic> fIntrinsicMap;
480     std::unordered_map<const FunctionDeclaration*, SpvId> fFunctionMap;
481     std::unordered_map<const Variable*, SpvId> fVariableMap;
482     std::unordered_map<const Variable*, int32_t> fInterfaceBlockMap;
483     std::unordered_map<String, SpvId> fImageTypeMap;
484     std::unordered_map<String, SpvId> fTypeMap;
485     StringStream fCapabilitiesBuffer;
486     StringStream fGlobalInitializersBuffer;
487     StringStream fConstantBuffer;
488     StringStream fExtraGlobalsBuffer;
489     StringStream fVariableBuffer;
490     StringStream fNameBuffer;
491     StringStream fDecorationBuffer;
492 
493     std::unordered_map<SPIRVNumberConstant, SpvId> fNumberConstants;
494     std::unordered_map<SPIRVVectorConstant, SpvId> fVectorConstants;
495     bool fSetupFragPosition;
496     // label of the current block, or 0 if we are not in a block
497     SpvId fCurrentBlock;
498     std::stack<SpvId> fBreakTarget;
499     std::stack<SpvId> fContinueTarget;
500     bool fWroteRTFlip = false;
501     // holds variables synthesized during output, for lifetime purposes
502     SymbolTable fSynthetics;
503     int fSkInCount = 1;
504     // Holds a list of uniforms that were declared as globals at the top-level instead of in an
505     // interface block.
506     UniformBuffer fUniformBuffer;
507     std::vector<const VarDeclaration*> fTopLevelUniforms;
508     std::unordered_map<const Variable*, int> fTopLevelUniformMap; //<var, UniformBuffer field index>
509     std::unordered_set<const Variable*> fSPIRVBonusVariables;
510     SpvId fUniformBufferId = -1;
511 
512     friend class PointerLValue;
513     friend class SwizzleLValue;
514 
515     using INHERITED = CodeGenerator;
516 };
517 
518 }  // namespace SkSL
519 
520 #endif
521