//===- Deserializer.cpp - MLIR SPIR-V Deserialization ---------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file defines the SPIR-V binary to MLIR SPIR-V module deserialization. // //===----------------------------------------------------------------------===// #include "mlir/Dialect/SPIRV/Serialization.h" #include "mlir/Dialect/SPIRV/SPIRVAttributes.h" #include "mlir/Dialect/SPIRV/SPIRVBinaryUtils.h" #include "mlir/Dialect/SPIRV/SPIRVModule.h" #include "mlir/Dialect/SPIRV/SPIRVOps.h" #include "mlir/Dialect/SPIRV/SPIRVTypes.h" #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Location.h" #include "mlir/Support/LogicalResult.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/Sequence.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/bit.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" using namespace mlir; #define DEBUG_TYPE "spirv-deserialization" /// Decodes a string literal in `words` starting at `wordIndex`. Update the /// latter to point to the position in words after the string literal. static inline StringRef decodeStringLiteral(ArrayRef words, unsigned &wordIndex) { StringRef str(reinterpret_cast(words.data() + wordIndex)); wordIndex += str.size() / 4 + 1; return str; } /// Extracts the opcode from the given first word of a SPIR-V instruction. static inline spirv::Opcode extractOpcode(uint32_t word) { return static_cast(word & 0xffff); } /// Returns true if the given `block` is a function entry block. static inline bool isFnEntryBlock(Block *block) { return block->isEntryBlock() && isa_and_nonnull(block->getParentOp()); } namespace { /// A struct for containing a header block's merge and continue targets. /// /// This struct is used to track original structured control flow info from /// SPIR-V blob. This info will be used to create spv.selection/spv.loop /// later. struct BlockMergeInfo { Block *mergeBlock; Block *continueBlock; // nullptr for spv.selection Location loc; uint32_t control; BlockMergeInfo(Location location, uint32_t control) : mergeBlock(nullptr), continueBlock(nullptr), loc(location), control(control) {} BlockMergeInfo(Location location, uint32_t control, Block *m, Block *c = nullptr) : mergeBlock(m), continueBlock(c), loc(location), control(control) {} }; /// A struct for containing OpLine instruction information. struct DebugLine { uint32_t fileID; uint32_t line; uint32_t col; DebugLine(uint32_t fileIDNum, uint32_t lineNum, uint32_t colNum) : fileID(fileIDNum), line(lineNum), col(colNum) {} }; /// Map from a selection/loop's header block to its merge (and continue) target. using BlockMergeInfoMap = DenseMap; /// A "deferred struct type" is a struct type with one or more member types not /// known when the Deserializer first encounters the struct. This happens, for /// example, with recursive structs where a pointer to the struct type is /// forward declared through OpTypeForwardPointer in the SPIR-V module before /// the struct declaration; the actual pointer to struct type should be defined /// later through an OpTypePointer. For example, the following C struct: /// /// struct A { /// A* next; /// }; /// /// would be represented in the SPIR-V module as: /// /// OpName %A "A" /// OpTypeForwardPointer %APtr Generic /// %A = OpTypeStruct %APtr /// %APtr = OpTypePointer Generic %A /// /// This means that the spirv::StructType cannot be fully constructed directly /// when the Deserializer encounters it. Instead we create a /// DeferredStructTypeInfo that contains all the information we know about the /// spirv::StructType. Once all forward references for the struct are resolved, /// the struct's body is set with all member info. struct DeferredStructTypeInfo { spirv::StructType deferredStructType; // A list of all unresolved member types for the struct. First element of each // item is operand ID, second element is member index in the struct. SmallVector, 0> unresolvedMemberTypes; // The list of member types. For unresolved members, this list contains // place-holder empty types that will be updated later. SmallVector memberTypes; SmallVector offsetInfo; SmallVector memberDecorationsInfo; }; /// A SPIR-V module serializer. /// /// A SPIR-V binary module is a single linear stream of instructions; each /// instruction is composed of 32-bit words. The first word of an instruction /// records the total number of words of that instruction using the 16 /// higher-order bits. So this deserializer uses that to get instruction /// boundary and parse instructions and build a SPIR-V ModuleOp gradually. /// // TODO: clean up created ops on errors class Deserializer { public: /// Creates a deserializer for the given SPIR-V `binary` module. /// The SPIR-V ModuleOp will be created into `context. explicit Deserializer(ArrayRef binary, MLIRContext *context); /// Deserializes the remembered SPIR-V binary module. LogicalResult deserialize(); /// Collects the final SPIR-V ModuleOp. spirv::OwningSPIRVModuleRef collect(); private: //===--------------------------------------------------------------------===// // Module structure //===--------------------------------------------------------------------===// /// Initializes the `module` ModuleOp in this deserializer instance. spirv::OwningSPIRVModuleRef createModuleOp(); /// Processes SPIR-V module header in `binary`. LogicalResult processHeader(); /// Processes the SPIR-V OpCapability with `operands` and updates bookkeeping /// in the deserializer. LogicalResult processCapability(ArrayRef operands); /// Processes the SPIR-V OpExtension with `operands` and updates bookkeeping /// in the deserializer. LogicalResult processExtension(ArrayRef words); /// Processes the SPIR-V OpExtInstImport with `operands` and updates /// bookkeeping in the deserializer. LogicalResult processExtInstImport(ArrayRef words); /// Attaches (version, capabilities, extensions) triple to `module` as an /// attribute. void attachVCETriple(); /// Processes the SPIR-V OpMemoryModel with `operands` and updates `module`. LogicalResult processMemoryModel(ArrayRef operands); /// Process SPIR-V OpName with `operands`. LogicalResult processName(ArrayRef operands); /// Processes an OpDecorate instruction. LogicalResult processDecoration(ArrayRef words); // Processes an OpMemberDecorate instruction. LogicalResult processMemberDecoration(ArrayRef words); /// Processes an OpMemberName instruction. LogicalResult processMemberName(ArrayRef words); /// Gets the function op associated with a result of OpFunction. spirv::FuncOp getFunction(uint32_t id) { return funcMap.lookup(id); } /// Processes the SPIR-V function at the current `offset` into `binary`. /// The operands to the OpFunction instruction is passed in as ``operands`. /// This method processes each instruction inside the function and dispatches /// them to their handler method accordingly. LogicalResult processFunction(ArrayRef operands); /// Processes OpFunctionEnd and finalizes function. This wires up block /// argument created from OpPhi instructions and also structurizes control /// flow. LogicalResult processFunctionEnd(ArrayRef operands); /// Gets the constant's attribute and type associated with the given . Optional> getConstant(uint32_t id); /// Gets the constant's integer attribute with the given . Returns a null /// IntegerAttr if the given is not registered or does not correspond to an /// integer constant. IntegerAttr getConstantInt(uint32_t id); /// Returns a symbol to be used for the function name with the given /// result . This tries to use the function's OpName if /// exists; otherwise creates one based on the . std::string getFunctionSymbol(uint32_t id); /// Returns a symbol to be used for the specialization constant with the given /// result . This tries to use the specialization constant's OpName if /// exists; otherwise creates one based on the . std::string getSpecConstantSymbol(uint32_t id); /// Gets the specialization constant with the given result . spirv::SpecConstantOp getSpecConstant(uint32_t id) { return specConstMap.lookup(id); } /// Gets the composite specialization constant with the given result . spirv::SpecConstantCompositeOp getSpecConstantComposite(uint32_t id) { return specConstCompositeMap.lookup(id); } /// Creates a spirv::SpecConstantOp. spirv::SpecConstantOp createSpecConstant(Location loc, uint32_t resultID, Attribute defaultValue); /// Processes the OpVariable instructions at current `offset` into `binary`. /// It is expected that this method is used for variables that are to be /// defined at module scope and will be deserialized into a spv.globalVariable /// instruction. LogicalResult processGlobalVariable(ArrayRef operands); /// Gets the global variable associated with a result of OpVariable. spirv::GlobalVariableOp getGlobalVariable(uint32_t id) { return globalVariableMap.lookup(id); } //===--------------------------------------------------------------------===// // Type //===--------------------------------------------------------------------===// /// Gets type for a given result . Type getType(uint32_t id) { return typeMap.lookup(id); } /// Get the type associated with the result of an OpUndef. Type getUndefType(uint32_t id) { return undefMap.lookup(id); } /// Returns true if the given `type` is for SPIR-V void type. bool isVoidType(Type type) const { return type.isa(); } /// Processes a SPIR-V type instruction with given `opcode` and `operands` and /// registers the type into `module`. LogicalResult processType(spirv::Opcode opcode, ArrayRef operands); LogicalResult processOpTypePointer(ArrayRef operands); LogicalResult processArrayType(ArrayRef operands); LogicalResult processCooperativeMatrixType(ArrayRef operands); LogicalResult processFunctionType(ArrayRef operands); LogicalResult processRuntimeArrayType(ArrayRef operands); LogicalResult processStructType(ArrayRef operands); LogicalResult processMatrixType(ArrayRef operands); //===--------------------------------------------------------------------===// // Constant //===--------------------------------------------------------------------===// /// Processes a SPIR-V Op{|Spec}Constant instruction with the given /// `operands`. `isSpec` indicates whether this is a specialization constant. LogicalResult processConstant(ArrayRef operands, bool isSpec); /// Processes a SPIR-V Op{|Spec}Constant{True|False} instruction with the /// given `operands`. `isSpec` indicates whether this is a specialization /// constant. LogicalResult processConstantBool(bool isTrue, ArrayRef operands, bool isSpec); /// Processes a SPIR-V OpConstantComposite instruction with the given /// `operands`. LogicalResult processConstantComposite(ArrayRef operands); LogicalResult processSpecConstantComposite(ArrayRef operands); /// Processes a SPIR-V OpConstantNull instruction with the given `operands`. LogicalResult processConstantNull(ArrayRef operands); //===--------------------------------------------------------------------===// // Debug //===--------------------------------------------------------------------===// /// Discontinues any source-level location information that might be active /// from a previous OpLine instruction. LogicalResult clearDebugLine(); /// Creates a FileLineColLoc with the OpLine location information. Location createFileLineColLoc(OpBuilder opBuilder); /// Processes a SPIR-V OpLine instruction with the given `operands`. LogicalResult processDebugLine(ArrayRef operands); /// Processes a SPIR-V OpString instruction with the given `operands`. LogicalResult processDebugString(ArrayRef operands); //===--------------------------------------------------------------------===// // Control flow //===--------------------------------------------------------------------===// /// Returns the block for the given label . Block *getBlock(uint32_t id) const { return blockMap.lookup(id); } // In SPIR-V, structured control flow is explicitly declared using merge // instructions (OpSelectionMerge and OpLoopMerge). In the SPIR-V dialect, // we use spv.selection and spv.loop to group structured control flow. // The deserializer need to turn structured control flow marked with merge // instructions into using spv.selection/spv.loop ops. // // Because structured control flow can nest and the basic block order have // flexibility, we cannot isolate a structured selection/loop without // deserializing all the blocks. So we use the following approach: // // 1. Deserialize all basic blocks in a function and create MLIR blocks for // them into the function's region. In the meanwhile, keep a map between // selection/loop header blocks to their corresponding merge (and continue) // target blocks. // 2. For each selection/loop header block, recursively get all basic blocks // reachable (except the merge block) and put them in a newly created // spv.selection/spv.loop's region. Structured control flow guarantees // that we enter and exit in structured ways and the construct is nestable. // 3. Put the new spv.selection/spv.loop op at the beginning of the old merge // block and redirect all branches to the old header block to the old // merge block (which contains the spv.selection/spv.loop op now). /// For OpPhi instructions, we use block arguments to represent them. OpPhi /// encodes a list of (value, predecessor) pairs. At the time of handling the /// block containing an OpPhi instruction, the predecessor block might not be /// processed yet, also the value sent by it. So we need to defer handling /// the block argument from the predecessors. We use the following approach: /// /// 1. For each OpPhi instruction, add a block argument to the current block /// in construction. Record the block argument in `valueMap` so its uses /// can be resolved. For the list of (value, predecessor) pairs, update /// `blockPhiInfo` for bookkeeping. /// 2. After processing all blocks, loop over `blockPhiInfo` to fix up each /// block recorded there to create the proper block arguments on their /// terminators. /// A data structure for containing a SPIR-V block's phi info. It will be /// represented as block argument in SPIR-V dialect. using BlockPhiInfo = SmallVector; // The result of the values sent /// Gets or creates the block corresponding to the given label . The newly /// created block will always be placed at the end of the current function. Block *getOrCreateBlock(uint32_t id); LogicalResult processBranch(ArrayRef operands); LogicalResult processBranchConditional(ArrayRef operands); /// Processes a SPIR-V OpLabel instruction with the given `operands`. LogicalResult processLabel(ArrayRef operands); /// Processes a SPIR-V OpSelectionMerge instruction with the given `operands`. LogicalResult processSelectionMerge(ArrayRef operands); /// Processes a SPIR-V OpLoopMerge instruction with the given `operands`. LogicalResult processLoopMerge(ArrayRef operands); /// Processes a SPIR-V OpPhi instruction with the given `operands`. LogicalResult processPhi(ArrayRef operands); /// Creates block arguments on predecessors previously recorded when handling /// OpPhi instructions. LogicalResult wireUpBlockArgument(); /// Extracts blocks belonging to a structured selection/loop into a /// spv.selection/spv.loop op. This method iterates until all blocks /// declared as selection/loop headers are handled. LogicalResult structurizeControlFlow(); //===--------------------------------------------------------------------===// // Instruction //===--------------------------------------------------------------------===// /// Get the Value associated with a result . /// /// This method materializes normal constants and inserts "casting" ops /// (`spv.mlir.addressof` and `spv.mlir.referenceof`) to turn an symbol into a /// SSA value for handling uses of module scope constants/variables in /// functions. Value getValue(uint32_t id); /// Slices the first instruction out of `binary` and returns its opcode and /// operands via `opcode` and `operands` respectively. Returns failure if /// there is no more remaining instructions (`expectedOpcode` will be used to /// compose the error message) or the next instruction is malformed. LogicalResult sliceInstruction(spirv::Opcode &opcode, ArrayRef &operands, Optional expectedOpcode = llvm::None); /// Processes a SPIR-V instruction with the given `opcode` and `operands`. /// This method is the main entrance for handling SPIR-V instruction; it /// checks the instruction opcode and dispatches to the corresponding handler. /// Processing of Some instructions (like OpEntryPoint and OpExecutionMode) /// might need to be deferred, since they contain forward references to s /// in the deserialized binary, but module in SPIR-V dialect expects these to /// be ssa-uses. LogicalResult processInstruction(spirv::Opcode opcode, ArrayRef operands, bool deferInstructions = true); /// Processes a OpUndef instruction. Adds a spv.Undef operation at the current /// insertion point. LogicalResult processUndef(ArrayRef operands); LogicalResult processTypeForwardPointer(ArrayRef operands); /// Method to dispatch to the specialized deserialization function for an /// operation in SPIR-V dialect that is a mirror of an instruction in the /// SPIR-V spec. This is auto-generated from ODS. Dispatch is handled for /// all operations in SPIR-V dialect that have hasOpcode == 1. LogicalResult dispatchToAutogenDeserialization(spirv::Opcode opcode, ArrayRef words); /// Processes a SPIR-V OpExtInst with given `operands`. This slices the /// entries of `operands` that specify the extended instruction set and /// the instruction opcode. The op deserializer is then invoked using the /// other entries. LogicalResult processExtInst(ArrayRef operands); /// Dispatches the deserialization of extended instruction set operation based /// on the extended instruction set name, and instruction opcode. This is /// autogenerated from ODS. LogicalResult dispatchToExtensionSetAutogenDeserialization(StringRef extensionSetName, uint32_t instructionID, ArrayRef words); /// Method to deserialize an operation in the SPIR-V dialect that is a mirror /// of an instruction in the SPIR-V spec. This is auto generated if hasOpcode /// == 1 and autogenSerialization == 1 in ODS. template LogicalResult processOp(ArrayRef words) { return emitError(unknownLoc, "unsupported deserialization for ") << OpTy::getOperationName() << " op"; } private: /// The SPIR-V binary module. ArrayRef binary; /// Contains the data of the OpLine instruction which precedes the current /// processing instruction. llvm::Optional debugLine; /// The current word offset into the binary module. unsigned curOffset = 0; /// MLIRContext to create SPIR-V ModuleOp into. MLIRContext *context; // TODO: create Location subclass for binary blob Location unknownLoc; /// The SPIR-V ModuleOp. spirv::OwningSPIRVModuleRef module; /// The current function under construction. Optional curFunction; /// The current block under construction. Block *curBlock = nullptr; OpBuilder opBuilder; spirv::Version version; /// The list of capabilities used by the module. llvm::SmallSetVector capabilities; /// The list of extensions used by the module. llvm::SmallSetVector extensions; // Result to type mapping. DenseMap typeMap; // Result to constant attribute and type mapping. /// /// In the SPIR-V binary format, all constants are placed in the module and /// shared by instructions at module level and in subsequent functions. But in /// the SPIR-V dialect, we materialize the constant to where it's used in the /// function. So when seeing a constant instruction in the binary format, we /// don't immediately emit a constant op into the module, we keep its value /// (and type) here. Later when it's used, we materialize the constant. DenseMap> constantMap; // Result to spec constant mapping. DenseMap specConstMap; // Result to composite spec constant mapping. DenseMap specConstCompositeMap; // Result to variable mapping. DenseMap globalVariableMap; // Result to function mapping. DenseMap funcMap; // Result to block mapping. DenseMap blockMap; // Header block to its merge (and continue) target mapping. BlockMergeInfoMap blockMergeInfo; // Block to its phi (block argument) mapping. DenseMap blockPhiInfo; // Result to value mapping. DenseMap valueMap; // Mapping from result to undef value of a type. DenseMap undefMap; // Result to name mapping. DenseMap nameMap; // Result to debug info mapping. DenseMap debugInfoMap; // Result to decorations mapping. DenseMap decorations; // Result to type decorations. DenseMap typeDecorations; // Result to member decorations. // decorated-struct-type- -> // (struct-member-index -> (decoration -> decoration-operands)) DenseMap>>> memberDecorationMap; // Result to member name. // struct-type- -> (struct-member-index -> name) DenseMap> memberNameMap; // Result to extended instruction set name. DenseMap extendedInstSets; // List of instructions that are processed in a deferred fashion (after an // initial processing of the entire binary). Some operations like // OpEntryPoint, and OpExecutionMode use forward references to function // s. In SPIR-V dialect the corresponding operations (spv.EntryPoint and // spv.ExecutionMode) need these references resolved. So these instructions // are deserialized and stored for processing once the entire binary is // processed. SmallVector>, 4> deferredInstructions; /// A list of IDs for all types forward-declared through OpTypeForwardPointer /// instructions. llvm::SetVector typeForwardPointerIDs; /// A list of all structs which have unresolved member types. SmallVector deferredStructTypesInfos; }; } // namespace Deserializer::Deserializer(ArrayRef binary, MLIRContext *context) : binary(binary), context(context), unknownLoc(UnknownLoc::get(context)), module(createModuleOp()), opBuilder(module->body()) {} LogicalResult Deserializer::deserialize() { LLVM_DEBUG(llvm::dbgs() << "+++ starting deserialization +++\n"); if (failed(processHeader())) return failure(); spirv::Opcode opcode = spirv::Opcode::OpNop; ArrayRef operands; auto binarySize = binary.size(); while (curOffset < binarySize) { // Slice the next instruction out and populate `opcode` and `operands`. // Internally this also updates `curOffset`. if (failed(sliceInstruction(opcode, operands))) return failure(); if (failed(processInstruction(opcode, operands))) return failure(); } assert(curOffset == binarySize && "deserializer should never index beyond the binary end"); for (auto &deferred : deferredInstructions) { if (failed(processInstruction(deferred.first, deferred.second, false))) { return failure(); } } attachVCETriple(); LLVM_DEBUG(llvm::dbgs() << "+++ completed deserialization +++\n"); return success(); } spirv::OwningSPIRVModuleRef Deserializer::collect() { return std::move(module); } //===----------------------------------------------------------------------===// // Module structure //===----------------------------------------------------------------------===// spirv::OwningSPIRVModuleRef Deserializer::createModuleOp() { OpBuilder builder(context); OperationState state(unknownLoc, spirv::ModuleOp::getOperationName()); spirv::ModuleOp::build(builder, state); return cast(Operation::create(state)); } LogicalResult Deserializer::processHeader() { if (binary.size() < spirv::kHeaderWordCount) return emitError(unknownLoc, "SPIR-V binary module must have a 5-word header"); if (binary[0] != spirv::kMagicNumber) return emitError(unknownLoc, "incorrect magic number"); // Version number bytes: 0 | major number | minor number | 0 uint32_t majorVersion = (binary[1] << 8) >> 24; uint32_t minorVersion = (binary[1] << 16) >> 24; if (majorVersion == 1) { switch (minorVersion) { #define MIN_VERSION_CASE(v) \ case v: \ version = spirv::Version::V_1_##v; \ break MIN_VERSION_CASE(0); MIN_VERSION_CASE(1); MIN_VERSION_CASE(2); MIN_VERSION_CASE(3); MIN_VERSION_CASE(4); MIN_VERSION_CASE(5); #undef MIN_VERSION_CASE default: return emitError(unknownLoc, "unsupported SPIR-V minor version: ") << minorVersion; } } else { return emitError(unknownLoc, "unsupported SPIR-V major version: ") << majorVersion; } // TODO: generator number, bound, schema curOffset = spirv::kHeaderWordCount; return success(); } LogicalResult Deserializer::processCapability(ArrayRef operands) { if (operands.size() != 1) return emitError(unknownLoc, "OpMemoryModel must have one parameter"); auto cap = spirv::symbolizeCapability(operands[0]); if (!cap) return emitError(unknownLoc, "unknown capability: ") << operands[0]; capabilities.insert(*cap); return success(); } LogicalResult Deserializer::processExtension(ArrayRef words) { if (words.empty()) { return emitError( unknownLoc, "OpExtension must have a literal string for the extension name"); } unsigned wordIndex = 0; StringRef extName = decodeStringLiteral(words, wordIndex); if (wordIndex != words.size()) return emitError(unknownLoc, "unexpected trailing words in OpExtension instruction"); auto ext = spirv::symbolizeExtension(extName); if (!ext) return emitError(unknownLoc, "unknown extension: ") << extName; extensions.insert(*ext); return success(); } LogicalResult Deserializer::processExtInstImport(ArrayRef words) { if (words.size() < 2) { return emitError(unknownLoc, "OpExtInstImport must have a result and a literal " "string for the extended instruction set name"); } unsigned wordIndex = 1; extendedInstSets[words[0]] = decodeStringLiteral(words, wordIndex); if (wordIndex != words.size()) { return emitError(unknownLoc, "unexpected trailing words in OpExtInstImport"); } return success(); } void Deserializer::attachVCETriple() { module->setAttr(spirv::ModuleOp::getVCETripleAttrName(), spirv::VerCapExtAttr::get(version, capabilities.getArrayRef(), extensions.getArrayRef(), context)); } LogicalResult Deserializer::processMemoryModel(ArrayRef operands) { if (operands.size() != 2) return emitError(unknownLoc, "OpMemoryModel must have two operands"); module->setAttr( "addressing_model", opBuilder.getI32IntegerAttr(llvm::bit_cast(operands.front()))); module->setAttr( "memory_model", opBuilder.getI32IntegerAttr(llvm::bit_cast(operands.back()))); return success(); } LogicalResult Deserializer::processDecoration(ArrayRef words) { // TODO: This function should also be auto-generated. For now, since only a // few decorations are processed/handled in a meaningful manner, going with a // manual implementation. if (words.size() < 2) { return emitError( unknownLoc, "OpDecorate must have at least result and Decoration"); } auto decorationName = stringifyDecoration(static_cast(words[1])); if (decorationName.empty()) { return emitError(unknownLoc, "invalid Decoration code : ") << words[1]; } auto attrName = llvm::convertToSnakeFromCamelCase(decorationName); auto symbol = opBuilder.getIdentifier(attrName); switch (static_cast(words[1])) { case spirv::Decoration::DescriptorSet: case spirv::Decoration::Binding: if (words.size() != 3) { return emitError(unknownLoc, "OpDecorate with ") << decorationName << " needs a single integer literal"; } decorations[words[0]].set( symbol, opBuilder.getI32IntegerAttr(static_cast(words[2]))); break; case spirv::Decoration::BuiltIn: if (words.size() != 3) { return emitError(unknownLoc, "OpDecorate with ") << decorationName << " needs a single integer literal"; } decorations[words[0]].set( symbol, opBuilder.getStringAttr( stringifyBuiltIn(static_cast(words[2])))); break; case spirv::Decoration::ArrayStride: if (words.size() != 3) { return emitError(unknownLoc, "OpDecorate with ") << decorationName << " needs a single integer literal"; } typeDecorations[words[0]] = words[2]; break; case spirv::Decoration::Aliased: case spirv::Decoration::Block: case spirv::Decoration::BufferBlock: case spirv::Decoration::Flat: case spirv::Decoration::NonReadable: case spirv::Decoration::NonWritable: case spirv::Decoration::NoPerspective: case spirv::Decoration::Restrict: if (words.size() != 2) { return emitError(unknownLoc, "OpDecoration with ") << decorationName << "needs a single target "; } // Block decoration does not affect spv.struct type, but is still stored for // verification. // TODO: Update StructType to contain this information since // it is needed for many validation rules. decorations[words[0]].set(symbol, opBuilder.getUnitAttr()); break; case spirv::Decoration::Location: case spirv::Decoration::SpecId: if (words.size() != 3) { return emitError(unknownLoc, "OpDecoration with ") << decorationName << "needs a single integer literal"; } decorations[words[0]].set( symbol, opBuilder.getI32IntegerAttr(static_cast(words[2]))); break; default: return emitError(unknownLoc, "unhandled Decoration : '") << decorationName; } return success(); } LogicalResult Deserializer::processMemberDecoration(ArrayRef words) { // The binary layout of OpMemberDecorate is different comparing to OpDecorate if (words.size() < 3) { return emitError(unknownLoc, "OpMemberDecorate must have at least 3 operands"); } auto decoration = static_cast(words[2]); if (decoration == spirv::Decoration::Offset && words.size() != 4) { return emitError(unknownLoc, " missing offset specification in OpMemberDecorate with " "Offset decoration"); } ArrayRef decorationOperands; if (words.size() > 3) { decorationOperands = words.slice(3); } memberDecorationMap[words[0]][words[1]][decoration] = decorationOperands; return success(); } LogicalResult Deserializer::processMemberName(ArrayRef words) { if (words.size() < 3) { return emitError(unknownLoc, "OpMemberName must have at least 3 operands"); } unsigned wordIndex = 2; auto name = decodeStringLiteral(words, wordIndex); if (wordIndex != words.size()) { return emitError(unknownLoc, "unexpected trailing words in OpMemberName instruction"); } memberNameMap[words[0]][words[1]] = name; return success(); } LogicalResult Deserializer::processFunction(ArrayRef operands) { if (curFunction) { return emitError(unknownLoc, "found function inside function"); } // Get the result type if (operands.size() != 4) { return emitError(unknownLoc, "OpFunction must have 4 parameters"); } Type resultType = getType(operands[0]); if (!resultType) { return emitError(unknownLoc, "undefined result type from ") << operands[0]; } if (funcMap.count(operands[1])) { return emitError(unknownLoc, "duplicate function definition/declaration"); } auto fnControl = spirv::symbolizeFunctionControl(operands[2]); if (!fnControl) { return emitError(unknownLoc, "unknown Function Control: ") << operands[2]; } Type fnType = getType(operands[3]); if (!fnType || !fnType.isa()) { return emitError(unknownLoc, "unknown function type from ") << operands[3]; } auto functionType = fnType.cast(); if ((isVoidType(resultType) && functionType.getNumResults() != 0) || (functionType.getNumResults() == 1 && functionType.getResult(0) != resultType)) { return emitError(unknownLoc, "mismatch in function type ") << functionType << " and return type " << resultType << " specified"; } std::string fnName = getFunctionSymbol(operands[1]); auto funcOp = opBuilder.create( unknownLoc, fnName, functionType, fnControl.getValue()); curFunction = funcMap[operands[1]] = funcOp; LLVM_DEBUG(llvm::dbgs() << "-- start function " << fnName << " (type = " << fnType << ", id = " << operands[1] << ") --\n"); auto *entryBlock = funcOp.addEntryBlock(); LLVM_DEBUG(llvm::dbgs() << "[block] created entry block " << entryBlock << "\n"); // Parse the op argument instructions if (functionType.getNumInputs()) { for (size_t i = 0, e = functionType.getNumInputs(); i != e; ++i) { auto argType = functionType.getInput(i); spirv::Opcode opcode = spirv::Opcode::OpNop; ArrayRef operands; if (failed(sliceInstruction(opcode, operands, spirv::Opcode::OpFunctionParameter))) { return failure(); } if (opcode != spirv::Opcode::OpFunctionParameter) { return emitError( unknownLoc, "missing OpFunctionParameter instruction for argument ") << i; } if (operands.size() != 2) { return emitError( unknownLoc, "expected result type and result for OpFunctionParameter"); } auto argDefinedType = getType(operands[0]); if (!argDefinedType || argDefinedType != argType) { return emitError(unknownLoc, "mismatch in argument type between function type " "definition ") << functionType << " and argument type definition " << argDefinedType << " at argument " << i; } if (getValue(operands[1])) { return emitError(unknownLoc, "duplicate definition of result '") << operands[1]; } auto argValue = funcOp.getArgument(i); valueMap[operands[1]] = argValue; } } // RAII guard to reset the insertion point to the module's region after // deserializing the body of this function. OpBuilder::InsertionGuard moduleInsertionGuard(opBuilder); spirv::Opcode opcode = spirv::Opcode::OpNop; ArrayRef instOperands; // Special handling for the entry block. We need to make sure it starts with // an OpLabel instruction. The entry block takes the same parameters as the // function. All other blocks do not take any parameter. We have already // created the entry block, here we need to register it to the correct label // . if (failed(sliceInstruction(opcode, instOperands, spirv::Opcode::OpFunctionEnd))) { return failure(); } if (opcode == spirv::Opcode::OpFunctionEnd) { LLVM_DEBUG(llvm::dbgs() << "-- completed function '" << fnName << "' (type = " << fnType << ", id = " << operands[1] << ") --\n"); return processFunctionEnd(instOperands); } if (opcode != spirv::Opcode::OpLabel) { return emitError(unknownLoc, "a basic block must start with OpLabel"); } if (instOperands.size() != 1) { return emitError(unknownLoc, "OpLabel should only have result "); } blockMap[instOperands[0]] = entryBlock; if (failed(processLabel(instOperands))) { return failure(); } // Then process all the other instructions in the function until we hit // OpFunctionEnd. while (succeeded(sliceInstruction(opcode, instOperands, spirv::Opcode::OpFunctionEnd)) && opcode != spirv::Opcode::OpFunctionEnd) { if (failed(processInstruction(opcode, instOperands))) { return failure(); } } if (opcode != spirv::Opcode::OpFunctionEnd) { return failure(); } LLVM_DEBUG(llvm::dbgs() << "-- completed function '" << fnName << "' (type = " << fnType << ", id = " << operands[1] << ") --\n"); return processFunctionEnd(instOperands); } LogicalResult Deserializer::processFunctionEnd(ArrayRef operands) { // Process OpFunctionEnd. if (!operands.empty()) { return emitError(unknownLoc, "unexpected operands for OpFunctionEnd"); } // Wire up block arguments from OpPhi instructions. // Put all structured control flow in spv.selection/spv.loop ops. if (failed(wireUpBlockArgument()) || failed(structurizeControlFlow())) { return failure(); } curBlock = nullptr; curFunction = llvm::None; return success(); } Optional> Deserializer::getConstant(uint32_t id) { auto constIt = constantMap.find(id); if (constIt == constantMap.end()) return llvm::None; return constIt->getSecond(); } std::string Deserializer::getFunctionSymbol(uint32_t id) { auto funcName = nameMap.lookup(id).str(); if (funcName.empty()) { funcName = "spirv_fn_" + std::to_string(id); } return funcName; } std::string Deserializer::getSpecConstantSymbol(uint32_t id) { auto constName = nameMap.lookup(id).str(); if (constName.empty()) { constName = "spirv_spec_const_" + std::to_string(id); } return constName; } spirv::SpecConstantOp Deserializer::createSpecConstant(Location loc, uint32_t resultID, Attribute defaultValue) { auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(resultID)); auto op = opBuilder.create(unknownLoc, symName, defaultValue); if (decorations.count(resultID)) { for (auto attr : decorations[resultID].getAttrs()) op.setAttr(attr.first, attr.second); } specConstMap[resultID] = op; return op; } LogicalResult Deserializer::processGlobalVariable(ArrayRef operands) { unsigned wordIndex = 0; if (operands.size() < 3) { return emitError( unknownLoc, "OpVariable needs at least 3 operands, type, and storage class"); } // Result Type. auto type = getType(operands[wordIndex]); if (!type) { return emitError(unknownLoc, "unknown result type : ") << operands[wordIndex]; } auto ptrType = type.dyn_cast(); if (!ptrType) { return emitError(unknownLoc, "expected a result type to be a spv.ptr, found : ") << type; } wordIndex++; // Result . auto variableID = operands[wordIndex]; auto variableName = nameMap.lookup(variableID).str(); if (variableName.empty()) { variableName = "spirv_var_" + std::to_string(variableID); } wordIndex++; // Storage class. auto storageClass = static_cast(operands[wordIndex]); if (ptrType.getStorageClass() != storageClass) { return emitError(unknownLoc, "mismatch in storage class of pointer type ") << type << " and that specified in OpVariable instruction : " << stringifyStorageClass(storageClass); } wordIndex++; // Initializer. FlatSymbolRefAttr initializer = nullptr; if (wordIndex < operands.size()) { auto initializerOp = getGlobalVariable(operands[wordIndex]); if (!initializerOp) { return emitError(unknownLoc, "unknown ") << operands[wordIndex] << "used as initializer"; } wordIndex++; initializer = opBuilder.getSymbolRefAttr(initializerOp.getOperation()); } if (wordIndex != operands.size()) { return emitError(unknownLoc, "found more operands than expected when deserializing " "OpVariable instruction, only ") << wordIndex << " of " << operands.size() << " processed"; } auto loc = createFileLineColLoc(opBuilder); auto varOp = opBuilder.create( loc, TypeAttr::get(type), opBuilder.getStringAttr(variableName), initializer); // Decorations. if (decorations.count(variableID)) { for (auto attr : decorations[variableID].getAttrs()) { varOp.setAttr(attr.first, attr.second); } } globalVariableMap[variableID] = varOp; return success(); } IntegerAttr Deserializer::getConstantInt(uint32_t id) { auto constInfo = getConstant(id); if (!constInfo) { return nullptr; } return constInfo->first.dyn_cast(); } LogicalResult Deserializer::processName(ArrayRef operands) { if (operands.size() < 2) { return emitError(unknownLoc, "OpName needs at least 2 operands"); } if (!nameMap.lookup(operands[0]).empty()) { return emitError(unknownLoc, "duplicate name found for result ") << operands[0]; } unsigned wordIndex = 1; StringRef name = decodeStringLiteral(operands, wordIndex); if (wordIndex != operands.size()) { return emitError(unknownLoc, "unexpected trailing words in OpName instruction"); } nameMap[operands[0]] = name; return success(); } //===----------------------------------------------------------------------===// // Type //===----------------------------------------------------------------------===// LogicalResult Deserializer::processType(spirv::Opcode opcode, ArrayRef operands) { if (operands.empty()) { return emitError(unknownLoc, "type instruction with opcode ") << spirv::stringifyOpcode(opcode) << " needs at least one "; } /// TODO: Types might be forward declared in some instructions and need to be /// handled appropriately. if (typeMap.count(operands[0])) { return emitError(unknownLoc, "duplicate definition for result ") << operands[0]; } switch (opcode) { case spirv::Opcode::OpTypeVoid: if (operands.size() != 1) return emitError(unknownLoc, "OpTypeVoid must have no parameters"); typeMap[operands[0]] = opBuilder.getNoneType(); break; case spirv::Opcode::OpTypeBool: if (operands.size() != 1) return emitError(unknownLoc, "OpTypeBool must have no parameters"); typeMap[operands[0]] = opBuilder.getI1Type(); break; case spirv::Opcode::OpTypeInt: { if (operands.size() != 3) return emitError( unknownLoc, "OpTypeInt must have bitwidth and signedness parameters"); // SPIR-V OpTypeInt "Signedness specifies whether there are signed semantics // to preserve or validate. // 0 indicates unsigned, or no signedness semantics // 1 indicates signed semantics." // // So we cannot differentiate signless and unsigned integers; always use // signless semantics for such cases. auto sign = operands[2] == 1 ? IntegerType::SignednessSemantics::Signed : IntegerType::SignednessSemantics::Signless; typeMap[operands[0]] = IntegerType::get(operands[1], sign, context); } break; case spirv::Opcode::OpTypeFloat: { if (operands.size() != 2) return emitError(unknownLoc, "OpTypeFloat must have bitwidth parameter"); Type floatTy; switch (operands[1]) { case 16: floatTy = opBuilder.getF16Type(); break; case 32: floatTy = opBuilder.getF32Type(); break; case 64: floatTy = opBuilder.getF64Type(); break; default: return emitError(unknownLoc, "unsupported OpTypeFloat bitwidth: ") << operands[1]; } typeMap[operands[0]] = floatTy; } break; case spirv::Opcode::OpTypeVector: { if (operands.size() != 3) { return emitError( unknownLoc, "OpTypeVector must have element type and count parameters"); } Type elementTy = getType(operands[1]); if (!elementTy) { return emitError(unknownLoc, "OpTypeVector references undefined ") << operands[1]; } typeMap[operands[0]] = VectorType::get({operands[2]}, elementTy); } break; case spirv::Opcode::OpTypePointer: { return processOpTypePointer(operands); } break; case spirv::Opcode::OpTypeArray: return processArrayType(operands); case spirv::Opcode::OpTypeCooperativeMatrixNV: return processCooperativeMatrixType(operands); case spirv::Opcode::OpTypeFunction: return processFunctionType(operands); case spirv::Opcode::OpTypeRuntimeArray: return processRuntimeArrayType(operands); case spirv::Opcode::OpTypeStruct: return processStructType(operands); case spirv::Opcode::OpTypeMatrix: return processMatrixType(operands); default: return emitError(unknownLoc, "unhandled type instruction"); } return success(); } LogicalResult Deserializer::processOpTypePointer(ArrayRef operands) { if (operands.size() != 3) return emitError(unknownLoc, "OpTypePointer must have two parameters"); auto pointeeType = getType(operands[2]); if (!pointeeType) return emitError(unknownLoc, "unknown OpTypePointer pointee type ") << operands[2]; uint32_t typePointerID = operands[0]; auto storageClass = static_cast(operands[1]); typeMap[typePointerID] = spirv::PointerType::get(pointeeType, storageClass); for (auto *deferredStructIt = std::begin(deferredStructTypesInfos); deferredStructIt != std::end(deferredStructTypesInfos);) { for (auto *unresolvedMemberIt = std::begin(deferredStructIt->unresolvedMemberTypes); unresolvedMemberIt != std::end(deferredStructIt->unresolvedMemberTypes);) { if (unresolvedMemberIt->first == typePointerID) { // The newly constructed pointer type can resolve one of the // deferred struct type members; update the memberTypes list and // clean the unresolvedMemberTypes list accordingly. deferredStructIt->memberTypes[unresolvedMemberIt->second] = typeMap[typePointerID]; unresolvedMemberIt = deferredStructIt->unresolvedMemberTypes.erase(unresolvedMemberIt); } else { ++unresolvedMemberIt; } } if (deferredStructIt->unresolvedMemberTypes.empty()) { // All deferred struct type members are now resolved, set the struct body. auto structType = deferredStructIt->deferredStructType; assert(structType && "expected a spirv::StructType"); assert(structType.isIdentified() && "expected an indentified struct"); if (failed(structType.trySetBody( deferredStructIt->memberTypes, deferredStructIt->offsetInfo, deferredStructIt->memberDecorationsInfo))) return failure(); deferredStructIt = deferredStructTypesInfos.erase(deferredStructIt); } else { ++deferredStructIt; } } return success(); } LogicalResult Deserializer::processArrayType(ArrayRef operands) { if (operands.size() != 3) { return emitError(unknownLoc, "OpTypeArray must have element type and count parameters"); } Type elementTy = getType(operands[1]); if (!elementTy) { return emitError(unknownLoc, "OpTypeArray references undefined ") << operands[1]; } unsigned count = 0; // TODO: The count can also come frome a specialization constant. auto countInfo = getConstant(operands[2]); if (!countInfo) { return emitError(unknownLoc, "OpTypeArray count ") << operands[2] << "can only come from normal constant right now"; } if (auto intVal = countInfo->first.dyn_cast()) { count = intVal.getValue().getZExtValue(); } else { return emitError(unknownLoc, "OpTypeArray count must come from a " "scalar integer constant instruction"); } typeMap[operands[0]] = spirv::ArrayType::get( elementTy, count, typeDecorations.lookup(operands[0])); return success(); } LogicalResult Deserializer::processFunctionType(ArrayRef operands) { assert(!operands.empty() && "No operands for processing function type"); if (operands.size() == 1) { return emitError(unknownLoc, "missing return type for OpTypeFunction"); } auto returnType = getType(operands[1]); if (!returnType) { return emitError(unknownLoc, "unknown return type in OpTypeFunction"); } SmallVector argTypes; for (size_t i = 2, e = operands.size(); i < e; ++i) { auto ty = getType(operands[i]); if (!ty) { return emitError(unknownLoc, "unknown argument type in OpTypeFunction"); } argTypes.push_back(ty); } ArrayRef returnTypes; if (!isVoidType(returnType)) { returnTypes = llvm::makeArrayRef(returnType); } typeMap[operands[0]] = FunctionType::get(argTypes, returnTypes, context); return success(); } LogicalResult Deserializer::processCooperativeMatrixType(ArrayRef operands) { if (operands.size() != 5) { return emitError(unknownLoc, "OpTypeCooperativeMatrix must have element " "type and row x column parameters"); } Type elementTy = getType(operands[1]); if (!elementTy) { return emitError(unknownLoc, "OpTypeCooperativeMatrix references undefined ") << operands[1]; } auto scope = spirv::symbolizeScope(getConstantInt(operands[2]).getInt()); if (!scope) { return emitError(unknownLoc, "OpTypeCooperativeMatrix references undefined scope ") << operands[2]; } unsigned rows = getConstantInt(operands[3]).getInt(); unsigned columns = getConstantInt(operands[4]).getInt(); typeMap[operands[0]] = spirv::CooperativeMatrixNVType::get( elementTy, scope.getValue(), rows, columns); return success(); } LogicalResult Deserializer::processRuntimeArrayType(ArrayRef operands) { if (operands.size() != 2) { return emitError(unknownLoc, "OpTypeRuntimeArray must have two operands"); } Type memberType = getType(operands[1]); if (!memberType) { return emitError(unknownLoc, "OpTypeRuntimeArray references undefined ") << operands[1]; } typeMap[operands[0]] = spirv::RuntimeArrayType::get( memberType, typeDecorations.lookup(operands[0])); return success(); } LogicalResult Deserializer::processStructType(ArrayRef operands) { // TODO: Find a way to handle identified structs when debug info is stripped. if (operands.empty()) { return emitError(unknownLoc, "OpTypeStruct must have at least result "); } if (operands.size() == 1) { // Handle empty struct. typeMap[operands[0]] = spirv::StructType::getEmpty(context, nameMap.lookup(operands[0]).str()); return success(); } // First element is operand ID, second element is member index in the struct. SmallVector, 0> unresolvedMemberTypes; SmallVector memberTypes; for (auto op : llvm::drop_begin(operands, 1)) { Type memberType = getType(op); bool typeForwardPtr = (typeForwardPointerIDs.count(op) != 0); if (!memberType && !typeForwardPtr) return emitError(unknownLoc, "OpTypeStruct references undefined ") << op; if (!memberType) unresolvedMemberTypes.emplace_back(op, memberTypes.size()); memberTypes.push_back(memberType); } SmallVector offsetInfo; SmallVector memberDecorationsInfo; if (memberDecorationMap.count(operands[0])) { auto &allMemberDecorations = memberDecorationMap[operands[0]]; for (auto memberIndex : llvm::seq(0, memberTypes.size())) { if (allMemberDecorations.count(memberIndex)) { for (auto &memberDecoration : allMemberDecorations[memberIndex]) { // Check for offset. if (memberDecoration.first == spirv::Decoration::Offset) { // If offset info is empty, resize to the number of members; if (offsetInfo.empty()) { offsetInfo.resize(memberTypes.size()); } offsetInfo[memberIndex] = memberDecoration.second[0]; } else { if (!memberDecoration.second.empty()) { memberDecorationsInfo.emplace_back(memberIndex, /*hasValue=*/1, memberDecoration.first, memberDecoration.second[0]); } else { memberDecorationsInfo.emplace_back(memberIndex, /*hasValue=*/0, memberDecoration.first, 0); } } } } } } uint32_t structID = operands[0]; std::string structIdentifier = nameMap.lookup(structID).str(); if (structIdentifier.empty()) { assert(unresolvedMemberTypes.empty() && "didn't expect unresolved member types"); typeMap[structID] = spirv::StructType::get(memberTypes, offsetInfo, memberDecorationsInfo); } else { auto structTy = spirv::StructType::getIdentified(context, structIdentifier); typeMap[structID] = structTy; if (!unresolvedMemberTypes.empty()) deferredStructTypesInfos.push_back({structTy, unresolvedMemberTypes, memberTypes, offsetInfo, memberDecorationsInfo}); else if (failed(structTy.trySetBody(memberTypes, offsetInfo, memberDecorationsInfo))) return failure(); } // TODO: Update StructType to have member name as attribute as // well. return success(); } LogicalResult Deserializer::processMatrixType(ArrayRef operands) { if (operands.size() != 3) { // Three operands are needed: result_id, column_type, and column_count return emitError(unknownLoc, "OpTypeMatrix must have 3 operands" " (result_id, column_type, and column_count)"); } // Matrix columns must be of vector type Type elementTy = getType(operands[1]); if (!elementTy) { return emitError(unknownLoc, "OpTypeMatrix references undefined column type.") << operands[1]; } uint32_t colsCount = operands[2]; typeMap[operands[0]] = spirv::MatrixType::get(elementTy, colsCount); return success(); } //===----------------------------------------------------------------------===// // Constant //===----------------------------------------------------------------------===// LogicalResult Deserializer::processConstant(ArrayRef operands, bool isSpec) { StringRef opname = isSpec ? "OpSpecConstant" : "OpConstant"; if (operands.size() < 2) { return emitError(unknownLoc) << opname << " must have type and result "; } if (operands.size() < 3) { return emitError(unknownLoc) << opname << " must have at least 1 more parameter"; } Type resultType = getType(operands[0]); if (!resultType) { return emitError(unknownLoc, "undefined result type from ") << operands[0]; } auto checkOperandSizeForBitwidth = [&](unsigned bitwidth) -> LogicalResult { if (bitwidth == 64) { if (operands.size() == 4) { return success(); } return emitError(unknownLoc) << opname << " should have 2 parameters for 64-bit values"; } if (bitwidth <= 32) { if (operands.size() == 3) { return success(); } return emitError(unknownLoc) << opname << " should have 1 parameter for values with no more than 32 bits"; } return emitError(unknownLoc, "unsupported OpConstant bitwidth: ") << bitwidth; }; auto resultID = operands[1]; if (auto intType = resultType.dyn_cast()) { auto bitwidth = intType.getWidth(); if (failed(checkOperandSizeForBitwidth(bitwidth))) { return failure(); } APInt value; if (bitwidth == 64) { // 64-bit integers are represented with two SPIR-V words. According to // SPIR-V spec: "When the type’s bit width is larger than one word, the // literal’s low-order words appear first." struct DoubleWord { uint32_t word1; uint32_t word2; } words = {operands[2], operands[3]}; value = APInt(64, llvm::bit_cast(words), /*isSigned=*/true); } else if (bitwidth <= 32) { value = APInt(bitwidth, operands[2], /*isSigned=*/true); } auto attr = opBuilder.getIntegerAttr(intType, value); if (isSpec) { createSpecConstant(unknownLoc, resultID, attr); } else { // For normal constants, we just record the attribute (and its type) for // later materialization at use sites. constantMap.try_emplace(resultID, attr, intType); } return success(); } if (auto floatType = resultType.dyn_cast()) { auto bitwidth = floatType.getWidth(); if (failed(checkOperandSizeForBitwidth(bitwidth))) { return failure(); } APFloat value(0.f); if (floatType.isF64()) { // Double values are represented with two SPIR-V words. According to // SPIR-V spec: "When the type’s bit width is larger than one word, the // literal’s low-order words appear first." struct DoubleWord { uint32_t word1; uint32_t word2; } words = {operands[2], operands[3]}; value = APFloat(llvm::bit_cast(words)); } else if (floatType.isF32()) { value = APFloat(llvm::bit_cast(operands[2])); } else if (floatType.isF16()) { APInt data(16, operands[2]); value = APFloat(APFloat::IEEEhalf(), data); } auto attr = opBuilder.getFloatAttr(floatType, value); if (isSpec) { createSpecConstant(unknownLoc, resultID, attr); } else { // For normal constants, we just record the attribute (and its type) for // later materialization at use sites. constantMap.try_emplace(resultID, attr, floatType); } return success(); } return emitError(unknownLoc, "OpConstant can only generate values of " "scalar integer or floating-point type"); } LogicalResult Deserializer::processConstantBool(bool isTrue, ArrayRef operands, bool isSpec) { if (operands.size() != 2) { return emitError(unknownLoc, "Op") << (isSpec ? "Spec" : "") << "Constant" << (isTrue ? "True" : "False") << " must have type and result "; } auto attr = opBuilder.getBoolAttr(isTrue); auto resultID = operands[1]; if (isSpec) { createSpecConstant(unknownLoc, resultID, attr); } else { // For normal constants, we just record the attribute (and its type) for // later materialization at use sites. constantMap.try_emplace(resultID, attr, opBuilder.getI1Type()); } return success(); } LogicalResult Deserializer::processConstantComposite(ArrayRef operands) { if (operands.size() < 2) { return emitError(unknownLoc, "OpConstantComposite must have type and result "); } if (operands.size() < 3) { return emitError(unknownLoc, "OpConstantComposite must have at least 1 parameter"); } Type resultType = getType(operands[0]); if (!resultType) { return emitError(unknownLoc, "undefined result type from ") << operands[0]; } SmallVector elements; elements.reserve(operands.size() - 2); for (unsigned i = 2, e = operands.size(); i < e; ++i) { auto elementInfo = getConstant(operands[i]); if (!elementInfo) { return emitError(unknownLoc, "OpConstantComposite component ") << operands[i] << " must come from a normal constant"; } elements.push_back(elementInfo->first); } auto resultID = operands[1]; if (auto vectorType = resultType.dyn_cast()) { auto attr = DenseElementsAttr::get(vectorType, elements); // For normal constants, we just record the attribute (and its type) for // later materialization at use sites. constantMap.try_emplace(resultID, attr, resultType); } else if (auto arrayType = resultType.dyn_cast()) { auto attr = opBuilder.getArrayAttr(elements); constantMap.try_emplace(resultID, attr, resultType); } else { return emitError(unknownLoc, "unsupported OpConstantComposite type: ") << resultType; } return success(); } LogicalResult Deserializer::processSpecConstantComposite(ArrayRef operands) { if (operands.size() < 2) { return emitError(unknownLoc, "OpConstantComposite must have type and result "); } if (operands.size() < 3) { return emitError(unknownLoc, "OpConstantComposite must have at least 1 parameter"); } Type resultType = getType(operands[0]); if (!resultType) { return emitError(unknownLoc, "undefined result type from ") << operands[0]; } auto resultID = operands[1]; auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(resultID)); SmallVector elements; elements.reserve(operands.size() - 2); for (unsigned i = 2, e = operands.size(); i < e; ++i) { auto elementInfo = getSpecConstant(operands[i]); elements.push_back(opBuilder.getSymbolRefAttr(elementInfo)); } auto op = opBuilder.create( unknownLoc, TypeAttr::get(resultType), symName, opBuilder.getArrayAttr(elements)); specConstCompositeMap[resultID] = op; return success(); } LogicalResult Deserializer::processConstantNull(ArrayRef operands) { if (operands.size() != 2) { return emitError(unknownLoc, "OpConstantNull must have type and result "); } Type resultType = getType(operands[0]); if (!resultType) { return emitError(unknownLoc, "undefined result type from ") << operands[0]; } auto resultID = operands[1]; if (resultType.isIntOrFloat() || resultType.isa()) { auto attr = opBuilder.getZeroAttr(resultType); // For normal constants, we just record the attribute (and its type) for // later materialization at use sites. constantMap.try_emplace(resultID, attr, resultType); return success(); } return emitError(unknownLoc, "unsupported OpConstantNull type: ") << resultType; } //===----------------------------------------------------------------------===// // Control flow //===----------------------------------------------------------------------===// Block *Deserializer::getOrCreateBlock(uint32_t id) { if (auto *block = getBlock(id)) { LLVM_DEBUG(llvm::dbgs() << "[block] got exiting block for id = " << id << " @ " << block << "\n"); return block; } // We don't know where this block will be placed finally (in a spv.selection // or spv.loop or function). Create it into the function for now and sort // out the proper place later. auto *block = curFunction->addBlock(); LLVM_DEBUG(llvm::dbgs() << "[block] created block for id = " << id << " @ " << block << "\n"); return blockMap[id] = block; } LogicalResult Deserializer::processBranch(ArrayRef operands) { if (!curBlock) { return emitError(unknownLoc, "OpBranch must appear inside a block"); } if (operands.size() != 1) { return emitError(unknownLoc, "OpBranch must take exactly one target label"); } auto *target = getOrCreateBlock(operands[0]); auto loc = createFileLineColLoc(opBuilder); // The preceding instruction for the OpBranch instruction could be an // OpLoopMerge or an OpSelectionMerge instruction, in this case they will have // the same OpLine information. opBuilder.create(loc, target); clearDebugLine(); return success(); } LogicalResult Deserializer::processBranchConditional(ArrayRef operands) { if (!curBlock) { return emitError(unknownLoc, "OpBranchConditional must appear inside a block"); } if (operands.size() != 3 && operands.size() != 5) { return emitError(unknownLoc, "OpBranchConditional must have condition, true label, " "false label, and optionally two branch weights"); } auto condition = getValue(operands[0]); auto *trueBlock = getOrCreateBlock(operands[1]); auto *falseBlock = getOrCreateBlock(operands[2]); Optional> weights; if (operands.size() == 5) { weights = std::make_pair(operands[3], operands[4]); } // The preceding instruction for the OpBranchConditional instruction could be // an OpSelectionMerge instruction, in this case they will have the same // OpLine information. auto loc = createFileLineColLoc(opBuilder); opBuilder.create( loc, condition, trueBlock, /*trueArguments=*/ArrayRef(), falseBlock, /*falseArguments=*/ArrayRef(), weights); clearDebugLine(); return success(); } LogicalResult Deserializer::processLabel(ArrayRef operands) { if (!curFunction) { return emitError(unknownLoc, "OpLabel must appear inside a function"); } if (operands.size() != 1) { return emitError(unknownLoc, "OpLabel should only have result "); } auto labelID = operands[0]; // We may have forward declared this block. auto *block = getOrCreateBlock(labelID); LLVM_DEBUG(llvm::dbgs() << "[block] populating block " << block << "\n"); // If we have seen this block, make sure it was just a forward declaration. assert(block->empty() && "re-deserialize the same block!"); opBuilder.setInsertionPointToStart(block); blockMap[labelID] = curBlock = block; return success(); } LogicalResult Deserializer::processSelectionMerge(ArrayRef operands) { if (!curBlock) { return emitError(unknownLoc, "OpSelectionMerge must appear in a block"); } if (operands.size() < 2) { return emitError( unknownLoc, "OpSelectionMerge must specify merge target and selection control"); } auto *mergeBlock = getOrCreateBlock(operands[0]); auto loc = createFileLineColLoc(opBuilder); auto selectionControl = operands[1]; if (!blockMergeInfo.try_emplace(curBlock, loc, selectionControl, mergeBlock) .second) { return emitError( unknownLoc, "a block cannot have more than one OpSelectionMerge instruction"); } return success(); } LogicalResult Deserializer::processLoopMerge(ArrayRef operands) { if (!curBlock) { return emitError(unknownLoc, "OpLoopMerge must appear in a block"); } if (operands.size() < 3) { return emitError(unknownLoc, "OpLoopMerge must specify merge target, " "continue target and loop control"); } auto *mergeBlock = getOrCreateBlock(operands[0]); auto *continueBlock = getOrCreateBlock(operands[1]); auto loc = createFileLineColLoc(opBuilder); uint32_t loopControl = operands[2]; if (!blockMergeInfo .try_emplace(curBlock, loc, loopControl, mergeBlock, continueBlock) .second) { return emitError( unknownLoc, "a block cannot have more than one OpLoopMerge instruction"); } return success(); } LogicalResult Deserializer::processPhi(ArrayRef operands) { if (!curBlock) { return emitError(unknownLoc, "OpPhi must appear in a block"); } if (operands.size() < 4) { return emitError(unknownLoc, "OpPhi must specify result type, result , " "and variable-parent pairs"); } // Create a block argument for this OpPhi instruction. Type blockArgType = getType(operands[0]); BlockArgument blockArg = curBlock->addArgument(blockArgType); valueMap[operands[1]] = blockArg; LLVM_DEBUG(llvm::dbgs() << "[phi] created block argument " << blockArg << " id = " << operands[1] << " of type " << blockArgType << '\n'); // For each (value, predecessor) pair, insert the value to the predecessor's // blockPhiInfo entry so later we can fix the block argument there. for (unsigned i = 2, e = operands.size(); i < e; i += 2) { uint32_t value = operands[i]; Block *predecessor = getOrCreateBlock(operands[i + 1]); blockPhiInfo[predecessor].push_back(value); LLVM_DEBUG(llvm::dbgs() << "[phi] predecessor @ " << predecessor << " with arg id = " << value << '\n'); } return success(); } namespace { /// A class for putting all blocks in a structured selection/loop in a /// spv.selection/spv.loop op. class ControlFlowStructurizer { public: /// Structurizes the loop at the given `headerBlock`. /// /// This method will create an spv.loop op in the `mergeBlock` and move all /// blocks in the structured loop into the spv.loop's region. All branches to /// the `headerBlock` will be redirected to the `mergeBlock`. /// This method will also update `mergeInfo` by remapping all blocks inside to /// the newly cloned ones inside structured control flow op's regions. static LogicalResult structurize(Location loc, uint32_t control, BlockMergeInfoMap &mergeInfo, Block *headerBlock, Block *mergeBlock, Block *continueBlock) { return ControlFlowStructurizer(loc, control, mergeInfo, headerBlock, mergeBlock, continueBlock) .structurizeImpl(); } private: ControlFlowStructurizer(Location loc, uint32_t control, BlockMergeInfoMap &mergeInfo, Block *header, Block *merge, Block *cont) : location(loc), control(control), blockMergeInfo(mergeInfo), headerBlock(header), mergeBlock(merge), continueBlock(cont) {} /// Creates a new spv.selection op at the beginning of the `mergeBlock`. spirv::SelectionOp createSelectionOp(uint32_t selectionControl); /// Creates a new spv.loop op at the beginning of the `mergeBlock`. spirv::LoopOp createLoopOp(uint32_t loopControl); /// Collects all blocks reachable from `headerBlock` except `mergeBlock`. void collectBlocksInConstruct(); LogicalResult structurizeImpl(); Location location; uint32_t control; BlockMergeInfoMap &blockMergeInfo; Block *headerBlock; Block *mergeBlock; Block *continueBlock; // nullptr for spv.selection llvm::SetVector constructBlocks; }; } // namespace spirv::SelectionOp ControlFlowStructurizer::createSelectionOp(uint32_t selectionControl) { // Create a builder and set the insertion point to the beginning of the // merge block so that the newly created SelectionOp will be inserted there. OpBuilder builder(&mergeBlock->front()); auto control = builder.getI32IntegerAttr(selectionControl); auto selectionOp = builder.create(location, control); selectionOp.addMergeBlock(); return selectionOp; } spirv::LoopOp ControlFlowStructurizer::createLoopOp(uint32_t loopControl) { // Create a builder and set the insertion point to the beginning of the // merge block so that the newly created LoopOp will be inserted there. OpBuilder builder(&mergeBlock->front()); auto control = builder.getI32IntegerAttr(loopControl); auto loopOp = builder.create(location, control); loopOp.addEntryAndMergeBlock(); return loopOp; } void ControlFlowStructurizer::collectBlocksInConstruct() { assert(constructBlocks.empty() && "expected empty constructBlocks"); // Put the header block in the work list first. constructBlocks.insert(headerBlock); // For each item in the work list, add its successors excluding the merge // block. for (unsigned i = 0; i < constructBlocks.size(); ++i) { for (auto *successor : constructBlocks[i]->getSuccessors()) if (successor != mergeBlock) constructBlocks.insert(successor); } } LogicalResult ControlFlowStructurizer::structurizeImpl() { Operation *op = nullptr; bool isLoop = continueBlock != nullptr; if (isLoop) { if (auto loopOp = createLoopOp(control)) op = loopOp.getOperation(); } else { if (auto selectionOp = createSelectionOp(control)) op = selectionOp.getOperation(); } if (!op) return failure(); Region &body = op->getRegion(0); BlockAndValueMapping mapper; // All references to the old merge block should be directed to the // selection/loop merge block in the SelectionOp/LoopOp's region. mapper.map(mergeBlock, &body.back()); collectBlocksInConstruct(); // We've identified all blocks belonging to the selection/loop's region. Now // need to "move" them into the selection/loop. Instead of really moving the // blocks, in the following we copy them and remap all values and branches. // This is because: // * Inserting a block into a region requires the block not in any region // before. But selections/loops can nest so we can create selection/loop ops // in a nested manner, which means some blocks may already be in a // selection/loop region when to be moved again. // * It's much trickier to fix up the branches into and out of the loop's // region: we need to treat not-moved blocks and moved blocks differently: // Not-moved blocks jumping to the loop header block need to jump to the // merge point containing the new loop op but not the loop continue block's // back edge. Moved blocks jumping out of the loop need to jump to the // merge block inside the loop region but not other not-moved blocks. // We cannot use replaceAllUsesWith clearly and it's harder to follow the // logic. // Create a corresponding block in the SelectionOp/LoopOp's region for each // block in this loop construct. OpBuilder builder(body); for (auto *block : constructBlocks) { // Create a block and insert it before the selection/loop merge block in the // SelectionOp/LoopOp's region. auto *newBlock = builder.createBlock(&body.back()); mapper.map(block, newBlock); LLVM_DEBUG(llvm::dbgs() << "[cf] cloned block " << newBlock << " from block " << block << "\n"); if (!isFnEntryBlock(block)) { for (BlockArgument blockArg : block->getArguments()) { auto newArg = newBlock->addArgument(blockArg.getType()); mapper.map(blockArg, newArg); LLVM_DEBUG(llvm::dbgs() << "[cf] remapped block argument " << blockArg << " to " << newArg << '\n'); } } else { LLVM_DEBUG(llvm::dbgs() << "[cf] block " << block << " is a function entry block\n"); } for (auto &op : *block) newBlock->push_back(op.clone(mapper)); } // Go through all ops and remap the operands. auto remapOperands = [&](Operation *op) { for (auto &operand : op->getOpOperands()) if (Value mappedOp = mapper.lookupOrNull(operand.get())) operand.set(mappedOp); for (auto &succOp : op->getBlockOperands()) if (Block *mappedOp = mapper.lookupOrNull(succOp.get())) succOp.set(mappedOp); }; for (auto &block : body) { block.walk(remapOperands); } // We have created the SelectionOp/LoopOp and "moved" all blocks belonging to // the selection/loop construct into its region. Next we need to fix the // connections between this new SelectionOp/LoopOp with existing blocks. // All existing incoming branches should go to the merge block, where the // SelectionOp/LoopOp resides right now. headerBlock->replaceAllUsesWith(mergeBlock); if (isLoop) { // The loop selection/loop header block may have block arguments. Since now // we place the selection/loop op inside the old merge block, we need to // make sure the old merge block has the same block argument list. assert(mergeBlock->args_empty() && "OpPhi in loop merge block unsupported"); for (BlockArgument blockArg : headerBlock->getArguments()) { mergeBlock->addArgument(blockArg.getType()); } // If the loop header block has block arguments, make sure the spv.branch op // matches. SmallVector blockArgs; if (!headerBlock->args_empty()) blockArgs = {mergeBlock->args_begin(), mergeBlock->args_end()}; // The loop entry block should have a unconditional branch jumping to the // loop header block. builder.setInsertionPointToEnd(&body.front()); builder.create(location, mapper.lookupOrNull(headerBlock), ArrayRef(blockArgs)); } // All the blocks cloned into the SelectionOp/LoopOp's region can now be // cleaned up. LLVM_DEBUG(llvm::dbgs() << "[cf] cleaning up blocks after clone\n"); // First we need to drop all operands' references inside all blocks. This is // needed because we can have blocks referencing SSA values from one another. for (auto *block : constructBlocks) block->dropAllReferences(); // Then erase all old blocks. for (auto *block : constructBlocks) { // We've cloned all blocks belonging to this construct into the structured // control flow op's region. Among these blocks, some may compose another // selection/loop. If so, they will be recorded within blockMergeInfo. // We need to update the pointers there to the newly remapped ones so we can // continue structurizing them later. // TODO: The asserts in the following assumes input SPIR-V blob // forms correctly nested selection/loop constructs. We should relax this // and support error cases better. auto it = blockMergeInfo.find(block); if (it != blockMergeInfo.end()) { Block *newHeader = mapper.lookupOrNull(block); assert(newHeader && "nested loop header block should be remapped!"); Block *newContinue = it->second.continueBlock; if (newContinue) { newContinue = mapper.lookupOrNull(newContinue); assert(newContinue && "nested loop continue block should be remapped!"); } Block *newMerge = it->second.mergeBlock; if (Block *mappedTo = mapper.lookupOrNull(newMerge)) newMerge = mappedTo; // Keep original location for nested selection/loop ops. Location loc = it->second.loc; // The iterator should be erased before adding a new entry into // blockMergeInfo to avoid iterator invalidation. blockMergeInfo.erase(it); blockMergeInfo.try_emplace(newHeader, loc, it->second.control, newMerge, newContinue); } // The structured selection/loop's entry block does not have arguments. // If the function's header block is also part of the structured control // flow, we cannot just simply erase it because it may contain arguments // matching the function signature and used by the cloned blocks. if (isFnEntryBlock(block)) { LLVM_DEBUG(llvm::dbgs() << "[cf] changing entry block " << block << " to only contain a spv.Branch op\n"); // Still keep the function entry block for the potential block arguments, // but replace all ops inside with a branch to the merge block. block->clear(); builder.setInsertionPointToEnd(block); builder.create(location, mergeBlock); } else { LLVM_DEBUG(llvm::dbgs() << "[cf] erasing block " << block << "\n"); block->erase(); } } LLVM_DEBUG( llvm::dbgs() << "[cf] after structurizing construct with header block " << headerBlock << ":\n" << *op << '\n'); return success(); } LogicalResult Deserializer::wireUpBlockArgument() { LLVM_DEBUG(llvm::dbgs() << "[phi] start wiring up block arguments\n"); OpBuilder::InsertionGuard guard(opBuilder); for (const auto &info : blockPhiInfo) { Block *block = info.first; const BlockPhiInfo &phiInfo = info.second; LLVM_DEBUG(llvm::dbgs() << "[phi] block " << block << "\n"); LLVM_DEBUG(llvm::dbgs() << "[phi] before creating block argument:\n"); LLVM_DEBUG(block->getParentOp()->print(llvm::dbgs())); LLVM_DEBUG(llvm::dbgs() << '\n'); // Set insertion point to before this block's terminator early because we // may materialize ops via getValue() call. auto *op = block->getTerminator(); opBuilder.setInsertionPoint(op); SmallVector blockArgs; blockArgs.reserve(phiInfo.size()); for (uint32_t valueId : phiInfo) { if (Value value = getValue(valueId)) { blockArgs.push_back(value); LLVM_DEBUG(llvm::dbgs() << "[phi] block argument " << value << " id = " << valueId << '\n'); } else { return emitError(unknownLoc, "OpPhi references undefined value!"); } } if (auto branchOp = dyn_cast(op)) { // Replace the previous branch op with a new one with block arguments. opBuilder.create(branchOp.getLoc(), branchOp.getTarget(), blockArgs); branchOp.erase(); } else { return emitError(unknownLoc, "unimplemented terminator for Phi creation"); } LLVM_DEBUG(llvm::dbgs() << "[phi] after creating block argument:\n"); LLVM_DEBUG(block->getParentOp()->print(llvm::dbgs())); LLVM_DEBUG(llvm::dbgs() << '\n'); } blockPhiInfo.clear(); LLVM_DEBUG(llvm::dbgs() << "[phi] completed wiring up block arguments\n"); return success(); } LogicalResult Deserializer::structurizeControlFlow() { LLVM_DEBUG(llvm::dbgs() << "[cf] start structurizing control flow\n"); while (!blockMergeInfo.empty()) { Block *headerBlock = blockMergeInfo.begin()->first; BlockMergeInfo mergeInfo = blockMergeInfo.begin()->second; LLVM_DEBUG(llvm::dbgs() << "[cf] header block " << headerBlock << ":\n"); LLVM_DEBUG(headerBlock->print(llvm::dbgs())); auto *mergeBlock = mergeInfo.mergeBlock; assert(mergeBlock && "merge block cannot be nullptr"); if (!mergeBlock->args_empty()) return emitError(unknownLoc, "OpPhi in loop merge block unimplemented"); LLVM_DEBUG(llvm::dbgs() << "[cf] merge block " << mergeBlock << ":\n"); LLVM_DEBUG(mergeBlock->print(llvm::dbgs())); auto *continueBlock = mergeInfo.continueBlock; if (continueBlock) { LLVM_DEBUG(llvm::dbgs() << "[cf] continue block " << continueBlock << ":\n"); LLVM_DEBUG(continueBlock->print(llvm::dbgs())); } // Erase this case before calling into structurizer, who will update // blockMergeInfo. blockMergeInfo.erase(blockMergeInfo.begin()); if (failed(ControlFlowStructurizer::structurize( mergeInfo.loc, mergeInfo.control, blockMergeInfo, headerBlock, mergeBlock, continueBlock))) return failure(); } LLVM_DEBUG(llvm::dbgs() << "[cf] completed structurizing control flow\n"); return success(); } //===----------------------------------------------------------------------===// // Debug //===----------------------------------------------------------------------===// Location Deserializer::createFileLineColLoc(OpBuilder opBuilder) { if (!debugLine) return unknownLoc; auto fileName = debugInfoMap.lookup(debugLine->fileID).str(); if (fileName.empty()) fileName = ""; return opBuilder.getFileLineColLoc(opBuilder.getIdentifier(fileName), debugLine->line, debugLine->col); } LogicalResult Deserializer::processDebugLine(ArrayRef operands) { // According to SPIR-V spec: // "This location information applies to the instructions physically // following this instruction, up to the first occurrence of any of the // following: the next end of block, the next OpLine instruction, or the next // OpNoLine instruction." if (operands.size() != 3) return emitError(unknownLoc, "OpLine must have 3 operands"); debugLine = DebugLine(operands[0], operands[1], operands[2]); return success(); } LogicalResult Deserializer::clearDebugLine() { debugLine = llvm::None; return success(); } LogicalResult Deserializer::processDebugString(ArrayRef operands) { if (operands.size() < 2) return emitError(unknownLoc, "OpString needs at least 2 operands"); if (!debugInfoMap.lookup(operands[0]).empty()) return emitError(unknownLoc, "duplicate debug string found for result ") << operands[0]; unsigned wordIndex = 1; StringRef debugString = decodeStringLiteral(operands, wordIndex); if (wordIndex != operands.size()) return emitError(unknownLoc, "unexpected trailing words in OpString instruction"); debugInfoMap[operands[0]] = debugString; return success(); } //===----------------------------------------------------------------------===// // Instruction //===----------------------------------------------------------------------===// Value Deserializer::getValue(uint32_t id) { if (auto constInfo = getConstant(id)) { // Materialize a `spv.constant` op at every use site. return opBuilder.create(unknownLoc, constInfo->second, constInfo->first); } if (auto varOp = getGlobalVariable(id)) { auto addressOfOp = opBuilder.create( unknownLoc, varOp.type(), opBuilder.getSymbolRefAttr(varOp.getOperation())); return addressOfOp.pointer(); } if (auto constOp = getSpecConstant(id)) { auto referenceOfOp = opBuilder.create( unknownLoc, constOp.default_value().getType(), opBuilder.getSymbolRefAttr(constOp.getOperation())); return referenceOfOp.reference(); } if (auto constCompositeOp = getSpecConstantComposite(id)) { auto referenceOfOp = opBuilder.create( unknownLoc, constCompositeOp.type(), opBuilder.getSymbolRefAttr(constCompositeOp.getOperation())); return referenceOfOp.reference(); } if (auto undef = getUndefType(id)) { return opBuilder.create(unknownLoc, undef); } return valueMap.lookup(id); } LogicalResult Deserializer::sliceInstruction(spirv::Opcode &opcode, ArrayRef &operands, Optional expectedOpcode) { auto binarySize = binary.size(); if (curOffset >= binarySize) { return emitError(unknownLoc, "expected ") << (expectedOpcode ? spirv::stringifyOpcode(*expectedOpcode) : "more") << " instruction"; } // For each instruction, get its word count from the first word to slice it // from the stream properly, and then dispatch to the instruction handler. uint32_t wordCount = binary[curOffset] >> 16; if (wordCount == 0) return emitError(unknownLoc, "word count cannot be zero"); uint32_t nextOffset = curOffset + wordCount; if (nextOffset > binarySize) return emitError(unknownLoc, "insufficient words for the last instruction"); opcode = extractOpcode(binary[curOffset]); operands = binary.slice(curOffset + 1, wordCount - 1); curOffset = nextOffset; return success(); } LogicalResult Deserializer::processInstruction(spirv::Opcode opcode, ArrayRef operands, bool deferInstructions) { LLVM_DEBUG(llvm::dbgs() << "[inst] processing instruction " << spirv::stringifyOpcode(opcode) << "\n"); // First dispatch all the instructions whose opcode does not correspond to // those that have a direct mirror in the SPIR-V dialect switch (opcode) { case spirv::Opcode::OpCapability: return processCapability(operands); case spirv::Opcode::OpExtension: return processExtension(operands); case spirv::Opcode::OpExtInst: return processExtInst(operands); case spirv::Opcode::OpExtInstImport: return processExtInstImport(operands); case spirv::Opcode::OpMemberName: return processMemberName(operands); case spirv::Opcode::OpMemoryModel: return processMemoryModel(operands); case spirv::Opcode::OpEntryPoint: case spirv::Opcode::OpExecutionMode: if (deferInstructions) { deferredInstructions.emplace_back(opcode, operands); return success(); } break; case spirv::Opcode::OpVariable: if (isa(opBuilder.getBlock()->getParentOp())) { return processGlobalVariable(operands); } break; case spirv::Opcode::OpLine: return processDebugLine(operands); case spirv::Opcode::OpNoLine: return clearDebugLine(); case spirv::Opcode::OpName: return processName(operands); case spirv::Opcode::OpString: return processDebugString(operands); case spirv::Opcode::OpModuleProcessed: case spirv::Opcode::OpSource: case spirv::Opcode::OpSourceContinued: case spirv::Opcode::OpSourceExtension: // TODO: This is debug information embedded in the binary which should be // translated into the spv.module. return success(); case spirv::Opcode::OpTypeVoid: case spirv::Opcode::OpTypeBool: case spirv::Opcode::OpTypeInt: case spirv::Opcode::OpTypeFloat: case spirv::Opcode::OpTypeVector: case spirv::Opcode::OpTypeMatrix: case spirv::Opcode::OpTypeArray: case spirv::Opcode::OpTypeFunction: case spirv::Opcode::OpTypeRuntimeArray: case spirv::Opcode::OpTypeStruct: case spirv::Opcode::OpTypePointer: case spirv::Opcode::OpTypeCooperativeMatrixNV: return processType(opcode, operands); case spirv::Opcode::OpConstant: return processConstant(operands, /*isSpec=*/false); case spirv::Opcode::OpSpecConstant: return processConstant(operands, /*isSpec=*/true); case spirv::Opcode::OpConstantComposite: return processConstantComposite(operands); case spirv::Opcode::OpSpecConstantComposite: return processSpecConstantComposite(operands); case spirv::Opcode::OpConstantTrue: return processConstantBool(/*isTrue=*/true, operands, /*isSpec=*/false); case spirv::Opcode::OpSpecConstantTrue: return processConstantBool(/*isTrue=*/true, operands, /*isSpec=*/true); case spirv::Opcode::OpConstantFalse: return processConstantBool(/*isTrue=*/false, operands, /*isSpec=*/false); case spirv::Opcode::OpSpecConstantFalse: return processConstantBool(/*isTrue=*/false, operands, /*isSpec=*/true); case spirv::Opcode::OpConstantNull: return processConstantNull(operands); case spirv::Opcode::OpDecorate: return processDecoration(operands); case spirv::Opcode::OpMemberDecorate: return processMemberDecoration(operands); case spirv::Opcode::OpFunction: return processFunction(operands); case spirv::Opcode::OpLabel: return processLabel(operands); case spirv::Opcode::OpBranch: return processBranch(operands); case spirv::Opcode::OpBranchConditional: return processBranchConditional(operands); case spirv::Opcode::OpSelectionMerge: return processSelectionMerge(operands); case spirv::Opcode::OpLoopMerge: return processLoopMerge(operands); case spirv::Opcode::OpPhi: return processPhi(operands); case spirv::Opcode::OpUndef: return processUndef(operands); case spirv::Opcode::OpTypeForwardPointer: return processTypeForwardPointer(operands); default: break; } return dispatchToAutogenDeserialization(opcode, operands); } LogicalResult Deserializer::processUndef(ArrayRef operands) { if (operands.size() != 2) { return emitError(unknownLoc, "OpUndef instruction must have two operands"); } auto type = getType(operands[0]); if (!type) { return emitError(unknownLoc, "unknown type with OpUndef instruction"); } undefMap[operands[1]] = type; return success(); } LogicalResult Deserializer::processTypeForwardPointer(ArrayRef operands) { if (operands.size() != 2) return emitError(unknownLoc, "OpTypeForwardPointer instruction must have two operands"); typeForwardPointerIDs.insert(operands[0]); // TODO: Use the 2nd operand (Storage Class) to validate the OpTypePointer // instruction that defines the actual type. return success(); } LogicalResult Deserializer::processExtInst(ArrayRef operands) { if (operands.size() < 4) { return emitError(unknownLoc, "OpExtInst must have at least 4 operands, result type " ", result , set and instruction opcode"); } if (!extendedInstSets.count(operands[2])) { return emitError(unknownLoc, "undefined set in OpExtInst"); } SmallVector slicedOperands; slicedOperands.append(operands.begin(), std::next(operands.begin(), 2)); slicedOperands.append(std::next(operands.begin(), 4), operands.end()); return dispatchToExtensionSetAutogenDeserialization( extendedInstSets[operands[2]], operands[3], slicedOperands); } namespace { template <> LogicalResult Deserializer::processOp(ArrayRef words) { unsigned wordIndex = 0; if (wordIndex >= words.size()) { return emitError(unknownLoc, "missing Execution Model specification in OpEntryPoint"); } auto execModel = opBuilder.getI32IntegerAttr(words[wordIndex++]); if (wordIndex >= words.size()) { return emitError(unknownLoc, "missing in OpEntryPoint"); } // Get the function auto fnID = words[wordIndex++]; // Get the function name auto fnName = decodeStringLiteral(words, wordIndex); // Verify that the function matches the fnName auto parsedFunc = getFunction(fnID); if (!parsedFunc) { return emitError(unknownLoc, "no function matching ") << fnID; } if (parsedFunc.getName() != fnName) { return emitError(unknownLoc, "function name mismatch between OpEntryPoint " "and OpFunction with ") << fnID << ": " << fnName << " vs. " << parsedFunc.getName(); } SmallVector interface; while (wordIndex < words.size()) { auto arg = getGlobalVariable(words[wordIndex]); if (!arg) { return emitError(unknownLoc, "undefined result ") << words[wordIndex] << " while decoding OpEntryPoint"; } interface.push_back(opBuilder.getSymbolRefAttr(arg.getOperation())); wordIndex++; } opBuilder.create(unknownLoc, execModel, opBuilder.getSymbolRefAttr(fnName), opBuilder.getArrayAttr(interface)); return success(); } template <> LogicalResult Deserializer::processOp(ArrayRef words) { unsigned wordIndex = 0; if (wordIndex >= words.size()) { return emitError(unknownLoc, "missing function result in OpExecutionMode"); } // Get the function to get the name of the function auto fnID = words[wordIndex++]; auto fn = getFunction(fnID); if (!fn) { return emitError(unknownLoc, "no function matching ") << fnID; } // Get the Execution mode if (wordIndex >= words.size()) { return emitError(unknownLoc, "missing Execution Mode in OpExecutionMode"); } auto execMode = opBuilder.getI32IntegerAttr(words[wordIndex++]); // Get the values SmallVector attrListElems; while (wordIndex < words.size()) { attrListElems.push_back(opBuilder.getI32IntegerAttr(words[wordIndex++])); } auto values = opBuilder.getArrayAttr(attrListElems); opBuilder.create( unknownLoc, opBuilder.getSymbolRefAttr(fn.getName()), execMode, values); return success(); } template <> LogicalResult Deserializer::processOp(ArrayRef operands) { if (operands.size() != 3) { return emitError( unknownLoc, "OpControlBarrier must have execution scope , memory scope " "and memory semantics "); } SmallVector argAttrs; for (auto operand : operands) { auto argAttr = getConstantInt(operand); if (!argAttr) { return emitError(unknownLoc, "expected 32-bit integer constant from ") << operand << " for OpControlBarrier"; } argAttrs.push_back(argAttr); } opBuilder.create(unknownLoc, argAttrs[0], argAttrs[1], argAttrs[2]); return success(); } template <> LogicalResult Deserializer::processOp(ArrayRef operands) { if (operands.size() < 3) { return emitError(unknownLoc, "OpFunctionCall must have at least 3 operands"); } Type resultType = getType(operands[0]); if (!resultType) { return emitError(unknownLoc, "undefined result type from ") << operands[0]; } // Use null type to mean no result type. if (isVoidType(resultType)) resultType = nullptr; auto resultID = operands[1]; auto functionID = operands[2]; auto functionName = getFunctionSymbol(functionID); SmallVector arguments; for (auto operand : llvm::drop_begin(operands, 3)) { auto value = getValue(operand); if (!value) { return emitError(unknownLoc, "unknown ") << operand << " used by OpFunctionCall"; } arguments.push_back(value); } auto opFunctionCall = opBuilder.create( unknownLoc, resultType, opBuilder.getSymbolRefAttr(functionName), arguments); if (resultType) valueMap[resultID] = opFunctionCall.getResult(0); return success(); } template <> LogicalResult Deserializer::processOp(ArrayRef operands) { if (operands.size() != 2) { return emitError(unknownLoc, "OpMemoryBarrier must have memory scope " "and memory semantics "); } SmallVector argAttrs; for (auto operand : operands) { auto argAttr = getConstantInt(operand); if (!argAttr) { return emitError(unknownLoc, "expected 32-bit integer constant from ") << operand << " for OpMemoryBarrier"; } argAttrs.push_back(argAttr); } opBuilder.create(unknownLoc, argAttrs[0], argAttrs[1]); return success(); } template <> LogicalResult Deserializer::processOp(ArrayRef words) { SmallVector resultTypes; size_t wordIndex = 0; SmallVector operands; SmallVector attributes; if (wordIndex < words.size()) { auto arg = getValue(words[wordIndex]); if (!arg) { return emitError(unknownLoc, "unknown result : ") << words[wordIndex]; } operands.push_back(arg); wordIndex++; } if (wordIndex < words.size()) { auto arg = getValue(words[wordIndex]); if (!arg) { return emitError(unknownLoc, "unknown result : ") << words[wordIndex]; } operands.push_back(arg); wordIndex++; } bool isAlignedAttr = false; if (wordIndex < words.size()) { auto attrValue = words[wordIndex++]; attributes.push_back(opBuilder.getNamedAttr( "memory_access", opBuilder.getI32IntegerAttr(attrValue))); isAlignedAttr = (attrValue == 2); } if (isAlignedAttr && wordIndex < words.size()) { attributes.push_back(opBuilder.getNamedAttr( "alignment", opBuilder.getI32IntegerAttr(words[wordIndex++]))); } if (wordIndex < words.size()) { attributes.push_back(opBuilder.getNamedAttr( "source_memory_access", opBuilder.getI32IntegerAttr(words[wordIndex++]))); } if (wordIndex < words.size()) { attributes.push_back(opBuilder.getNamedAttr( "source_alignment", opBuilder.getI32IntegerAttr(words[wordIndex++]))); } if (wordIndex != words.size()) { return emitError(unknownLoc, "found more operands than expected when deserializing " "spirv::CopyMemoryOp, only ") << wordIndex << " of " << words.size() << " processed"; } Location loc = createFileLineColLoc(opBuilder); opBuilder.create(loc, resultTypes, operands, attributes); return success(); } // Pull in auto-generated Deserializer::dispatchToAutogenDeserialization() and // various Deserializer::processOp<...>() specializations. #define GET_DESERIALIZATION_FNS #include "mlir/Dialect/SPIRV/SPIRVSerialization.inc" } // namespace spirv::OwningSPIRVModuleRef spirv::deserialize(ArrayRef binary, MLIRContext *context) { Deserializer deserializer(binary, context); if (failed(deserializer.deserialize())) return nullptr; return deserializer.collect(); }