1 /* 2 * Copyright 2016 Google Inc. 3 * 4 * Use of this source code is governed by a BSD-style license that can be 5 * found in the LICENSE file. 6 */ 7 8 #ifndef SKSL_SPIRVCODEGENERATOR 9 #define SKSL_SPIRVCODEGENERATOR 10 11 #include <stack> 12 #include <tuple> 13 #include <unordered_map> 14 15 #include "include/private/SkSLModifiers.h" 16 #include "include/private/SkSLProgramElement.h" 17 #include "include/private/SkSLStatement.h" 18 #include "src/core/SkOpts.h" 19 #include "src/sksl/SkSLMemoryLayout.h" 20 #include "src/sksl/SkSLStringStream.h" 21 #include "src/sksl/codegen/SkSLCodeGenerator.h" 22 #include "src/sksl/ir/SkSLBinaryExpression.h" 23 #include "src/sksl/ir/SkSLBoolLiteral.h" 24 #include "src/sksl/ir/SkSLConstructor.h" 25 #include "src/sksl/ir/SkSLConstructorArray.h" 26 #include "src/sksl/ir/SkSLConstructorCompound.h" 27 #include "src/sksl/ir/SkSLConstructorCompoundCast.h" 28 #include "src/sksl/ir/SkSLConstructorDiagonalMatrix.h" 29 #include "src/sksl/ir/SkSLConstructorMatrixResize.h" 30 #include "src/sksl/ir/SkSLConstructorScalarCast.h" 31 #include "src/sksl/ir/SkSLConstructorSplat.h" 32 #include "src/sksl/ir/SkSLConstructorStruct.h" 33 #include "src/sksl/ir/SkSLDoStatement.h" 34 #include "src/sksl/ir/SkSLFieldAccess.h" 35 #include "src/sksl/ir/SkSLFloatLiteral.h" 36 #include "src/sksl/ir/SkSLForStatement.h" 37 #include "src/sksl/ir/SkSLFunctionCall.h" 38 #include "src/sksl/ir/SkSLFunctionDeclaration.h" 39 #include "src/sksl/ir/SkSLFunctionDefinition.h" 40 #include "src/sksl/ir/SkSLIfStatement.h" 41 #include "src/sksl/ir/SkSLIndexExpression.h" 42 #include "src/sksl/ir/SkSLIntLiteral.h" 43 #include "src/sksl/ir/SkSLInterfaceBlock.h" 44 #include "src/sksl/ir/SkSLPostfixExpression.h" 45 #include "src/sksl/ir/SkSLPrefixExpression.h" 46 #include "src/sksl/ir/SkSLReturnStatement.h" 47 #include "src/sksl/ir/SkSLSwitchStatement.h" 48 #include "src/sksl/ir/SkSLSwizzle.h" 49 #include "src/sksl/ir/SkSLTernaryExpression.h" 50 #include "src/sksl/ir/SkSLVarDeclarations.h" 51 #include "src/sksl/ir/SkSLVariableReference.h" 52 #include "src/sksl/spirv.h" 53 54 namespace SkSL { 55 56 struct SPIRVNumberConstant { 57 bool operator==(const SPIRVNumberConstant& that) const { 58 return fValueBits == that.fValueBits && 59 fKind == that.fKind; 60 } 61 int64_t fValueBits; // contains either an SKSL_INT or zero-padded bits from an SKSL_FLOAT 62 SkSL::Type::NumberKind fKind; 63 }; 64 65 struct SPIRVVectorConstant { 66 bool operator==(const SPIRVVectorConstant& that) const { 67 return fTypeId == that.fTypeId && 68 fValueId[0] == that.fValueId[0] && 69 fValueId[1] == that.fValueId[1] && 70 fValueId[2] == that.fValueId[2] && 71 fValueId[3] == that.fValueId[3]; 72 } 73 SpvId fTypeId; 74 SpvId fValueId[4]; 75 }; 76 77 } // namespace SkSL 78 79 namespace std { 80 81 template <> 82 struct hash<SkSL::SPIRVNumberConstant> { 83 size_t operator()(const SkSL::SPIRVNumberConstant& key) const { 84 return key.fValueBits ^ (int)key.fKind; 85 } 86 }; 87 88 template <> 89 struct hash<SkSL::SPIRVVectorConstant> { 90 size_t operator()(const SkSL::SPIRVVectorConstant& key) const { 91 return SkOpts::hash(&key, sizeof(key)); 92 } 93 }; 94 95 } // namespace std 96 97 namespace SkSL { 98 99 /** 100 * Converts a Program into a SPIR-V binary. 101 */ 102 class SPIRVCodeGenerator : public CodeGenerator { 103 public: 104 class LValue { 105 public: 106 virtual ~LValue() {} 107 108 // returns a pointer to the lvalue, if possible. If the lvalue cannot be directly referenced 109 // by a pointer (e.g. vector swizzles), returns -1. 110 virtual SpvId getPointer() { return -1; } 111 112 // Returns true if a valid pointer returned by getPointer represents a memory object 113 // (see https://github.com/KhronosGroup/SPIRV-Tools/issues/2892). Has no meaning if 114 // getPointer() returns -1. 115 virtual bool isMemoryObjectPointer() const { return true; } 116 117 // Applies a swizzle to the components of the LValue, if possible. This is used to create 118 // LValues that are swizzes-of-swizzles. Non-swizzle LValues can just return false. 119 virtual bool applySwizzle(const ComponentArray& components, const Type& newType) { 120 return false; 121 } 122 123 virtual SpvId load(OutputStream& out) = 0; 124 125 virtual void store(SpvId value, OutputStream& out) = 0; 126 }; 127 128 SPIRVCodeGenerator(const Context* context, 129 const Program* program, 130 ErrorReporter* errors, 131 OutputStream* out) 132 : INHERITED(program, errors, out) 133 , fContext(*context) 134 , fDefaultLayout(MemoryLayout::k140_Standard) 135 , fCapabilities(0) 136 , fIdCount(1) 137 , fBoolTrue(0) 138 , fBoolFalse(0) 139 , fSetupFragPosition(false) 140 , fCurrentBlock(0) 141 , fSynthetics(errors, /*builtin=*/true) { 142 this->setupIntrinsics(); 143 } 144 145 bool generateCode() override; 146 147 private: 148 enum IntrinsicOpcodeKind { 149 kGLSL_STD_450_IntrinsicOpcodeKind, 150 kSPIRV_IntrinsicOpcodeKind, 151 kSpecial_IntrinsicOpcodeKind 152 }; 153 154 enum SpecialIntrinsic { 155 kAtan_SpecialIntrinsic, 156 kClamp_SpecialIntrinsic, 157 kMatrixCompMult_SpecialIntrinsic, 158 kMax_SpecialIntrinsic, 159 kMin_SpecialIntrinsic, 160 kMix_SpecialIntrinsic, 161 kMod_SpecialIntrinsic, 162 kDFdy_SpecialIntrinsic, 163 kSaturate_SpecialIntrinsic, 164 kSampledImage_SpecialIntrinsic, 165 kSmoothStep_SpecialIntrinsic, 166 kStep_SpecialIntrinsic, 167 kSubpassLoad_SpecialIntrinsic, 168 kTexture_SpecialIntrinsic, 169 }; 170 171 enum class Precision { 172 kDefault, 173 kRelaxed, 174 }; 175 176 void setupIntrinsics(); 177 178 /** 179 * Pass in the type to automatically add a RelaxedPrecision decoration for the id when 180 * appropriate, or null to never add one. 181 */ 182 SpvId nextId(const Type* type); 183 184 SpvId nextId(Precision precision); 185 186 const Type& getActualType(const Type& type); 187 188 SpvId getType(const Type& type); 189 190 SpvId getType(const Type& type, const MemoryLayout& layout); 191 192 SpvId getImageType(const Type& type); 193 194 SpvId getFunctionType(const FunctionDeclaration& function); 195 196 SpvId getPointerType(const Type& type, SpvStorageClass_ storageClass); 197 198 SpvId getPointerType(const Type& type, const MemoryLayout& layout, 199 SpvStorageClass_ storageClass); 200 201 std::vector<SpvId> getAccessChain(const Expression& expr, OutputStream& out); 202 203 void writeLayout(const Layout& layout, SpvId target); 204 205 void writeLayout(const Layout& layout, SpvId target, int member); 206 207 void writeStruct(const Type& type, const MemoryLayout& layout, SpvId resultId); 208 209 void writeProgramElement(const ProgramElement& pe, OutputStream& out); 210 211 SpvId writeInterfaceBlock(const InterfaceBlock& intf, bool appendRTHeight = true); 212 213 SpvId writeFunctionStart(const FunctionDeclaration& f, OutputStream& out); 214 215 SpvId writeFunctionDeclaration(const FunctionDeclaration& f, OutputStream& out); 216 217 SpvId writeFunction(const FunctionDefinition& f, OutputStream& out); 218 219 void writeGlobalVar(ProgramKind kind, const VarDeclaration& v); 220 221 void writeVarDeclaration(const VarDeclaration& var, OutputStream& out); 222 223 SpvId writeVariableReference(const VariableReference& ref, OutputStream& out); 224 225 int findUniformFieldIndex(const Variable& var) const; 226 227 std::unique_ptr<LValue> getLValue(const Expression& value, OutputStream& out); 228 229 SpvId writeExpression(const Expression& expr, OutputStream& out); 230 231 SpvId writeIntrinsicCall(const FunctionCall& c, OutputStream& out); 232 233 SpvId writeFunctionCall(const FunctionCall& c, OutputStream& out); 234 235 236 void writeGLSLExtendedInstruction(const Type& type, SpvId id, SpvId floatInst, 237 SpvId signedInst, SpvId unsignedInst, 238 const std::vector<SpvId>& args, OutputStream& out); 239 240 /** 241 * Given a list of potentially mixed scalars and vectors, promotes the scalars to match the 242 * size of the vectors and returns the ids of the written expressions. e.g. given (float, vec2), 243 * returns (vec2(float), vec2). It is an error to use mismatched vector sizes, e.g. (float, 244 * vec2, vec3). 245 */ 246 std::vector<SpvId> vectorize(const ExpressionArray& args, OutputStream& out); 247 248 SpvId writeSpecialIntrinsic(const FunctionCall& c, SpecialIntrinsic kind, OutputStream& out); 249 250 SpvId writeConstantVector(const AnyConstructor& c); 251 252 SpvId writeFloatConstructor(const AnyConstructor& c, OutputStream& out); 253 254 SpvId castScalarToFloat(SpvId inputId, const Type& inputType, const Type& outputType, 255 OutputStream& out); 256 257 SpvId writeIntConstructor(const AnyConstructor& c, OutputStream& out); 258 259 SpvId castScalarToSignedInt(SpvId inputId, const Type& inputType, const Type& outputType, 260 OutputStream& out); 261 262 SpvId writeUIntConstructor(const AnyConstructor& c, OutputStream& out); 263 264 SpvId castScalarToUnsignedInt(SpvId inputId, const Type& inputType, const Type& outputType, 265 OutputStream& out); 266 267 SpvId writeBooleanConstructor(const AnyConstructor& c, OutputStream& out); 268 269 SpvId castScalarToBoolean(SpvId inputId, const Type& inputType, const Type& outputType, 270 OutputStream& out); 271 272 SpvId castScalarToType(SpvId inputExprId, const Type& inputType, const Type& outputType, 273 OutputStream& out); 274 275 /** 276 * Writes a matrix with the diagonal entries all equal to the provided expression, and all other 277 * entries equal to zero. 278 */ 279 void writeUniformScaleMatrix(SpvId id, SpvId diagonal, const Type& type, OutputStream& out); 280 281 /** 282 * Writes a potentially-different-sized copy of a matrix. Entries which do not exist in the 283 * source matrix are filled with zero; entries which do not exist in the destination matrix are 284 * ignored. 285 */ 286 SpvId writeMatrixCopy(SpvId src, const Type& srcType, const Type& dstType, OutputStream& out); 287 288 void addColumnEntry(SpvId columnType, Precision precision, std::vector<SpvId>* currentColumn, 289 std::vector<SpvId>* columnIds, int* currentCount, int rows, SpvId entry, 290 OutputStream& out); 291 292 SpvId writeConstructorCompound(const ConstructorCompound& c, OutputStream& out); 293 294 SpvId writeMatrixConstructor(const ConstructorCompound& c, OutputStream& out); 295 296 SpvId writeVectorConstructor(const ConstructorCompound& c, OutputStream& out); 297 298 SpvId writeCompositeConstructor(const AnyConstructor& c, OutputStream& out); 299 300 SpvId writeConstructorDiagonalMatrix(const ConstructorDiagonalMatrix& c, OutputStream& out); 301 302 SpvId writeConstructorMatrixResize(const ConstructorMatrixResize& c, OutputStream& out); 303 304 SpvId writeConstructorScalarCast(const ConstructorScalarCast& c, OutputStream& out); 305 306 SpvId writeConstructorSplat(const ConstructorSplat& c, OutputStream& out); 307 308 SpvId writeConstructorCompoundCast(const ConstructorCompoundCast& c, OutputStream& out); 309 310 SpvId writeComposite(const std::vector<SpvId>& arguments, const Type& type, OutputStream& out); 311 312 SpvId writeFieldAccess(const FieldAccess& f, OutputStream& out); 313 314 SpvId writeSwizzle(const Swizzle& swizzle, OutputStream& out); 315 316 /** 317 * Folds the potentially-vector result of a logical operation down to a single bool. If 318 * operandType is a vector type, assumes that the intermediate result in id is a bvec of the 319 * same dimensions, and applys all() to it to fold it down to a single bool value. Otherwise, 320 * returns the original id value. 321 */ 322 SpvId foldToBool(SpvId id, const Type& operandType, SpvOp op, OutputStream& out); 323 324 SpvId writeMatrixComparison(const Type& operandType, SpvId lhs, SpvId rhs, SpvOp_ floatOperator, 325 SpvOp_ intOperator, SpvOp_ vectorMergeOperator, 326 SpvOp_ mergeOperator, OutputStream& out); 327 328 SpvId writeStructComparison(const Type& structType, SpvId lhs, Operator op, SpvId rhs, 329 OutputStream& out); 330 331 SpvId writeArrayComparison(const Type& structType, SpvId lhs, Operator op, SpvId rhs, 332 OutputStream& out); 333 334 // Used by writeStructComparison and writeArrayComparison to logically combine field-by-field 335 // comparisons into an overall comparison result. 336 // - `a.x == b.x` merged with `a.y == b.y` generates `(a.x == b.x) && (a.y == b.y)` 337 // - `a.x != b.x` merged with `a.y != b.y` generates `(a.x != b.x) || (a.y != b.y)` 338 SpvId mergeComparisons(SpvId comparison, SpvId allComparisons, Operator op, OutputStream& out); 339 340 SpvId writeComponentwiseMatrixBinary(const Type& operandType, SpvId lhs, SpvId rhs, 341 SpvOp_ floatOperator, SpvOp_ intOperator, 342 OutputStream& out); 343 344 SpvId writeBinaryOperation(const Type& resultType, const Type& operandType, SpvId lhs, 345 SpvId rhs, SpvOp_ ifFloat, SpvOp_ ifInt, SpvOp_ ifUInt, 346 SpvOp_ ifBool, OutputStream& out); 347 348 SpvId writeBinaryOperation(const BinaryExpression& expr, SpvOp_ ifFloat, SpvOp_ ifInt, 349 SpvOp_ ifUInt, OutputStream& out); 350 351 SpvId writeReciprocal(const Type& type, SpvId value, OutputStream& out); 352 353 SpvId writeBinaryExpression(const Type& leftType, SpvId lhs, Operator op, 354 const Type& rightType, SpvId rhs, const Type& resultType, 355 OutputStream& out); 356 357 SpvId writeBinaryExpression(const BinaryExpression& b, OutputStream& out); 358 359 SpvId writeTernaryExpression(const TernaryExpression& t, OutputStream& out); 360 361 SpvId writeIndexExpression(const IndexExpression& expr, OutputStream& out); 362 363 SpvId writeLogicalAnd(const Expression& left, const Expression& right, OutputStream& out); 364 365 SpvId writeLogicalOr(const Expression& left, const Expression& right, OutputStream& out); 366 367 SpvId writePrefixExpression(const PrefixExpression& p, OutputStream& out); 368 369 SpvId writePostfixExpression(const PostfixExpression& p, OutputStream& out); 370 371 SpvId writeBoolLiteral(const BoolLiteral& b); 372 373 SpvId writeIntLiteral(const IntLiteral& i); 374 375 SpvId writeFloatLiteral(const FloatLiteral& f); 376 377 void writeStatement(const Statement& s, OutputStream& out); 378 379 void writeBlock(const Block& b, OutputStream& out); 380 381 void writeIfStatement(const IfStatement& stmt, OutputStream& out); 382 383 void writeForStatement(const ForStatement& f, OutputStream& out); 384 385 void writeDoStatement(const DoStatement& d, OutputStream& out); 386 387 void writeSwitchStatement(const SwitchStatement& s, OutputStream& out); 388 389 void writeReturnStatement(const ReturnStatement& r, OutputStream& out); 390 391 void writeCapabilities(OutputStream& out); 392 393 void writeInstructions(const Program& program, OutputStream& out); 394 395 void writeOpCode(SpvOp_ opCode, int length, OutputStream& out); 396 397 void writeWord(int32_t word, OutputStream& out); 398 399 void writeString(const char* string, size_t length, OutputStream& out); 400 401 void writeLabel(SpvId id, OutputStream& out); 402 403 void writeInstruction(SpvOp_ opCode, OutputStream& out); 404 405 void writeInstruction(SpvOp_ opCode, StringFragment string, OutputStream& out); 406 407 void writeInstruction(SpvOp_ opCode, int32_t word1, OutputStream& out); 408 409 void writeInstruction(SpvOp_ opCode, int32_t word1, StringFragment string, OutputStream& out); 410 411 void writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, StringFragment string, 412 OutputStream& out); 413 414 void writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, OutputStream& out); 415 416 void writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, int32_t word3, 417 OutputStream& out); 418 419 void writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, int32_t word3, int32_t word4, 420 OutputStream& out); 421 422 void writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, int32_t word3, int32_t word4, 423 int32_t word5, OutputStream& out); 424 425 void writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, int32_t word3, int32_t word4, 426 int32_t word5, int32_t word6, OutputStream& out); 427 428 void writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, int32_t word3, int32_t word4, 429 int32_t word5, int32_t word6, int32_t word7, OutputStream& out); 430 431 void writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, int32_t word3, int32_t word4, 432 int32_t word5, int32_t word6, int32_t word7, int32_t word8, 433 OutputStream& out); 434 435 void writeGeometryShaderExecutionMode(SpvId entryPoint, OutputStream& out); 436 437 MemoryLayout memoryLayoutForVariable(const Variable&) const; 438 439 struct EntrypointAdapter { 440 std::unique_ptr<FunctionDefinition> entrypointDef; 441 std::unique_ptr<FunctionDeclaration> entrypointDecl; 442 Layout fLayout; 443 Modifiers fModifiers; 444 }; 445 446 EntrypointAdapter writeEntrypointAdapter(const FunctionDeclaration& main); 447 448 struct UniformBuffer { 449 std::unique_ptr<InterfaceBlock> fInterfaceBlock; 450 std::unique_ptr<Variable> fInnerVariable; 451 std::unique_ptr<Type> fStruct; 452 }; 453 454 void writeUniformBuffer(std::shared_ptr<SymbolTable> topLevelSymbolTable); 455 456 const Context& fContext; 457 const MemoryLayout fDefaultLayout; 458 459 uint64_t fCapabilities; 460 SpvId fIdCount; 461 SpvId fGLSLExtendedInstructions; 462 typedef std::tuple<IntrinsicOpcodeKind, int32_t, int32_t, int32_t, int32_t> Intrinsic; 463 std::unordered_map<IntrinsicKind, Intrinsic> fIntrinsicMap; 464 std::unordered_map<const FunctionDeclaration*, SpvId> fFunctionMap; 465 std::unordered_map<const Variable*, SpvId> fVariableMap; 466 std::unordered_map<const Variable*, int32_t> fInterfaceBlockMap; 467 std::unordered_map<String, SpvId> fImageTypeMap; 468 std::unordered_map<String, SpvId> fTypeMap; 469 StringStream fCapabilitiesBuffer; 470 StringStream fGlobalInitializersBuffer; 471 StringStream fConstantBuffer; 472 StringStream fExtraGlobalsBuffer; 473 StringStream fExternalFunctionsBuffer; 474 StringStream fVariableBuffer; 475 StringStream fNameBuffer; 476 StringStream fDecorationBuffer; 477 478 SpvId fBoolTrue; 479 SpvId fBoolFalse; 480 std::unordered_map<SPIRVNumberConstant, SpvId> fNumberConstants; 481 std::unordered_map<SPIRVVectorConstant, SpvId> fVectorConstants; 482 bool fSetupFragPosition; 483 // label of the current block, or 0 if we are not in a block 484 SpvId fCurrentBlock; 485 std::stack<SpvId> fBreakTarget; 486 std::stack<SpvId> fContinueTarget; 487 SpvId fRTHeightStructId = (SpvId) -1; 488 SpvId fRTHeightFieldIndex = (SpvId) -1; 489 SpvStorageClass_ fRTHeightStorageClass; 490 // holds variables synthesized during output, for lifetime purposes 491 SymbolTable fSynthetics; 492 int fSkInCount = 1; 493 // Holds a list of uniforms that were declared as globals at the top-level instead of in an 494 // interface block. 495 UniformBuffer fUniformBuffer; 496 std::vector<const VarDeclaration*> fTopLevelUniforms; 497 std::unordered_map<const Variable*, int> fTopLevelUniformMap; //<var, UniformBuffer field index> 498 SpvId fUniformBufferId = -1; 499 500 friend class PointerLValue; 501 friend class SwizzleLValue; 502 503 using INHERITED = CodeGenerator; 504 }; 505 506 } // namespace SkSL 507 508 #endif 509