• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //===- Serializer.cpp - MLIR SPIR-V Serialization -------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the MLIR SPIR-V module to SPIR-V binary serialization.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/SPIRV/Serialization.h"
14 
15 #include "mlir/Dialect/SPIRV/SPIRVAttributes.h"
16 #include "mlir/Dialect/SPIRV/SPIRVBinaryUtils.h"
17 #include "mlir/Dialect/SPIRV/SPIRVDialect.h"
18 #include "mlir/Dialect/SPIRV/SPIRVOps.h"
19 #include "mlir/Dialect/SPIRV/SPIRVTypes.h"
20 #include "mlir/IR/Builders.h"
21 #include "mlir/IR/RegionGraphTraits.h"
22 #include "mlir/Support/LogicalResult.h"
23 #include "llvm/ADT/DepthFirstIterator.h"
24 #include "llvm/ADT/Sequence.h"
25 #include "llvm/ADT/SetVector.h"
26 #include "llvm/ADT/SmallPtrSet.h"
27 #include "llvm/ADT/SmallVector.h"
28 #include "llvm/ADT/StringExtras.h"
29 #include "llvm/ADT/TypeSwitch.h"
30 #include "llvm/ADT/bit.h"
31 #include "llvm/Support/Debug.h"
32 #include "llvm/Support/raw_ostream.h"
33 
34 #define DEBUG_TYPE "spirv-serialization"
35 
36 using namespace mlir;
37 
38 /// Encodes an SPIR-V instruction with the given `opcode` and `operands` into
39 /// the given `binary` vector.
encodeInstructionInto(SmallVectorImpl<uint32_t> & binary,spirv::Opcode op,ArrayRef<uint32_t> operands)40 static LogicalResult encodeInstructionInto(SmallVectorImpl<uint32_t> &binary,
41                                            spirv::Opcode op,
42                                            ArrayRef<uint32_t> operands) {
43   uint32_t wordCount = 1 + operands.size();
44   binary.push_back(spirv::getPrefixedOpcode(wordCount, op));
45   binary.append(operands.begin(), operands.end());
46   return success();
47 }
48 
49 /// A pre-order depth-first visitor function for processing basic blocks.
50 ///
51 /// Visits the basic blocks starting from the given `headerBlock` in pre-order
52 /// depth-first manner and calls `blockHandler` on each block. Skips handling
53 /// blocks in the `skipBlocks` list. If `skipHeader` is true, `blockHandler`
54 /// will not be invoked in `headerBlock` but still handles all `headerBlock`'s
55 /// successors.
56 ///
57 /// SPIR-V spec "2.16.1. Universal Validation Rules" requires that "the order
58 /// of blocks in a function must satisfy the rule that blocks appear before
59 /// all blocks they dominate." This can be achieved by a pre-order CFG
60 /// traversal algorithm. To make the serialization output more logical and
61 /// readable to human, we perform depth-first CFG traversal and delay the
62 /// serialization of the merge block and the continue block, if exists, until
63 /// after all other blocks have been processed.
64 static LogicalResult
visitInPrettyBlockOrder(Block * headerBlock,function_ref<LogicalResult (Block *)> blockHandler,bool skipHeader=false,BlockRange skipBlocks={})65 visitInPrettyBlockOrder(Block *headerBlock,
66                         function_ref<LogicalResult(Block *)> blockHandler,
67                         bool skipHeader = false, BlockRange skipBlocks = {}) {
68   llvm::df_iterator_default_set<Block *, 4> doneBlocks;
69   doneBlocks.insert(skipBlocks.begin(), skipBlocks.end());
70 
71   for (Block *block : llvm::depth_first_ext(headerBlock, doneBlocks)) {
72     if (skipHeader && block == headerBlock)
73       continue;
74     if (failed(blockHandler(block)))
75       return failure();
76   }
77   return success();
78 }
79 
80 /// Returns the merge block if the given `op` is a structured control flow op.
81 /// Otherwise returns nullptr.
getStructuredControlFlowOpMergeBlock(Operation * op)82 static Block *getStructuredControlFlowOpMergeBlock(Operation *op) {
83   if (auto selectionOp = dyn_cast<spirv::SelectionOp>(op))
84     return selectionOp.getMergeBlock();
85   if (auto loopOp = dyn_cast<spirv::LoopOp>(op))
86     return loopOp.getMergeBlock();
87   return nullptr;
88 }
89 
90 /// Given a predecessor `block` for a block with arguments, returns the block
91 /// that should be used as the parent block for SPIR-V OpPhi instructions
92 /// corresponding to the block arguments.
getPhiIncomingBlock(Block * block)93 static Block *getPhiIncomingBlock(Block *block) {
94   // If the predecessor block in question is the entry block for a spv.loop,
95   // we jump to this spv.loop from its enclosing block.
96   if (block->isEntryBlock()) {
97     if (auto loopOp = dyn_cast<spirv::LoopOp>(block->getParentOp())) {
98       // Then the incoming parent block for OpPhi should be the merge block of
99       // the structured control flow op before this loop.
100       Operation *op = loopOp.getOperation();
101       while ((op = op->getPrevNode()) != nullptr)
102         if (Block *incomingBlock = getStructuredControlFlowOpMergeBlock(op))
103           return incomingBlock;
104       // Or the enclosing block itself if no structured control flow ops
105       // exists before this loop.
106       return loopOp->getBlock();
107     }
108   }
109 
110   // Otherwise, we jump from the given predecessor block. Try to see if there is
111   // a structured control flow op inside it.
112   for (Operation &op : llvm::reverse(block->getOperations())) {
113     if (Block *incomingBlock = getStructuredControlFlowOpMergeBlock(&op))
114       return incomingBlock;
115   }
116   return block;
117 }
118 
119 namespace {
120 
121 /// A SPIR-V module serializer.
122 ///
123 /// A SPIR-V binary module is a single linear stream of instructions; each
124 /// instruction is composed of 32-bit words with the layout:
125 ///
126 ///   | <word-count>|<opcode> |  <operand>   |  <operand>   | ... |
127 ///   | <------ word -------> | <-- word --> | <-- word --> | ... |
128 ///
129 /// For the first word, the 16 high-order bits are the word count of the
130 /// instruction, the 16 low-order bits are the opcode enumerant. The
131 /// instructions then belong to different sections, which must be laid out in
132 /// the particular order as specified in "2.4 Logical Layout of a Module" of
133 /// the SPIR-V spec.
134 class Serializer {
135 public:
136   /// Creates a serializer for the given SPIR-V `module`.
137   explicit Serializer(spirv::ModuleOp module, bool emitDebugInfo = false);
138 
139   /// Serializes the remembered SPIR-V module.
140   LogicalResult serialize();
141 
142   /// Collects the final SPIR-V `binary`.
143   void collect(SmallVectorImpl<uint32_t> &binary);
144 
145 #ifndef NDEBUG
146   /// (For debugging) prints each value and its corresponding result <id>.
147   void printValueIDMap(raw_ostream &os);
148 #endif
149 
150 private:
151   // Note that there are two main categories of methods in this class:
152   // * process*() methods are meant to fully serialize a SPIR-V module entity
153   //   (header, type, op, etc.). They update internal vectors containing
154   //   different binary sections. They are not meant to be called except the
155   //   top-level serialization loop.
156   // * prepare*() methods are meant to be helpers that prepare for serializing
157   //   certain entity. They may or may not update internal vectors containing
158   //   different binary sections. They are meant to be called among themselves
159   //   or by other process*() methods for subtasks.
160 
161   //===--------------------------------------------------------------------===//
162   // <id>
163   //===--------------------------------------------------------------------===//
164 
165   // Note that it is illegal to use id <0> in SPIR-V binary module. Various
166   // methods in this class, if using SPIR-V word (uint32_t) as interface,
167   // check or return id <0> to indicate error in processing.
168 
169   /// Consumes the next unused <id>. This method will never return 0.
getNextID()170   uint32_t getNextID() { return nextID++; }
171 
172   //===--------------------------------------------------------------------===//
173   // Module structure
174   //===--------------------------------------------------------------------===//
175 
getSpecConstID(StringRef constName) const176   uint32_t getSpecConstID(StringRef constName) const {
177     return specConstIDMap.lookup(constName);
178   }
179 
getVariableID(StringRef varName) const180   uint32_t getVariableID(StringRef varName) const {
181     return globalVarIDMap.lookup(varName);
182   }
183 
getFunctionID(StringRef fnName) const184   uint32_t getFunctionID(StringRef fnName) const {
185     return funcIDMap.lookup(fnName);
186   }
187 
188   /// Gets the <id> for the function with the given name. Assigns the next
189   /// available <id> if the function haven't been deserialized.
190   uint32_t getOrCreateFunctionID(StringRef fnName);
191 
192   void processCapability();
193 
194   void processDebugInfo();
195 
196   void processExtension();
197 
198   void processMemoryModel();
199 
200   LogicalResult processConstantOp(spirv::ConstantOp op);
201 
202   LogicalResult processSpecConstantOp(spirv::SpecConstantOp op);
203 
204   LogicalResult
205   processSpecConstantCompositeOp(spirv::SpecConstantCompositeOp op);
206 
207   /// SPIR-V dialect supports OpUndef using spv.UndefOp that produces a SSA
208   /// value to use with other operations. The SPIR-V spec recommends that
209   /// OpUndef be generated at module level. The serialization generates an
210   /// OpUndef for each type needed at module level.
211   LogicalResult processUndefOp(spirv::UndefOp op);
212 
213   /// Emit OpName for the given `resultID`.
214   LogicalResult processName(uint32_t resultID, StringRef name);
215 
216   /// Processes a SPIR-V function op.
217   LogicalResult processFuncOp(spirv::FuncOp op);
218 
219   LogicalResult processVariableOp(spirv::VariableOp op);
220 
221   /// Process a SPIR-V GlobalVariableOp
222   LogicalResult processGlobalVariableOp(spirv::GlobalVariableOp varOp);
223 
224   /// Process attributes that translate to decorations on the result <id>
225   LogicalResult processDecoration(Location loc, uint32_t resultID,
226                                   NamedAttribute attr);
227 
228   template <typename DType>
processTypeDecoration(Location loc,DType type,uint32_t resultId)229   LogicalResult processTypeDecoration(Location loc, DType type,
230                                       uint32_t resultId) {
231     return emitError(loc, "unhandled decoration for type:") << type;
232   }
233 
234   /// Process member decoration
235   LogicalResult processMemberDecoration(
236       uint32_t structID,
237       const spirv::StructType::MemberDecorationInfo &memberDecorationInfo);
238 
239   //===--------------------------------------------------------------------===//
240   // Types
241   //===--------------------------------------------------------------------===//
242 
getTypeID(Type type) const243   uint32_t getTypeID(Type type) const { return typeIDMap.lookup(type); }
244 
getVoidType()245   Type getVoidType() { return mlirBuilder.getNoneType(); }
246 
isVoidType(Type type) const247   bool isVoidType(Type type) const { return type.isa<NoneType>(); }
248 
249   /// Returns true if the given type is a pointer type to a struct in some
250   /// interface storage class.
251   bool isInterfaceStructPtrType(Type type) const;
252 
253   /// Main dispatch method for serializing a type. The result <id> of the
254   /// serialized type will be returned as `typeID`.
255   LogicalResult processType(Location loc, Type type, uint32_t &typeID);
256   LogicalResult processTypeImpl(Location loc, Type type, uint32_t &typeID,
257                                 llvm::SetVector<StringRef> &serializationCtx);
258 
259   /// Method for preparing basic SPIR-V type serialization. Returns the type's
260   /// opcode and operands for the instruction via `typeEnum` and `operands`.
261   LogicalResult prepareBasicType(Location loc, Type type, uint32_t resultID,
262                                  spirv::Opcode &typeEnum,
263                                  SmallVectorImpl<uint32_t> &operands,
264                                  bool &deferSerialization,
265                                  llvm::SetVector<StringRef> &serializationCtx);
266 
267   LogicalResult prepareFunctionType(Location loc, FunctionType type,
268                                     spirv::Opcode &typeEnum,
269                                     SmallVectorImpl<uint32_t> &operands);
270 
271   //===--------------------------------------------------------------------===//
272   // Constant
273   //===--------------------------------------------------------------------===//
274 
getConstantID(Attribute value) const275   uint32_t getConstantID(Attribute value) const {
276     return constIDMap.lookup(value);
277   }
278 
279   /// Main dispatch method for processing a constant with the given `constType`
280   /// and `valueAttr`. `constType` is needed here because we can interpret the
281   /// `valueAttr` as a different type than the type of `valueAttr` itself; for
282   /// example, ArrayAttr, whose type is NoneType, is used for spirv::ArrayType
283   /// constants.
284   uint32_t prepareConstant(Location loc, Type constType, Attribute valueAttr);
285 
286   /// Prepares array attribute serialization. This method emits corresponding
287   /// OpConstant* and returns the result <id> associated with it. Returns 0 if
288   /// failed.
289   uint32_t prepareArrayConstant(Location loc, Type constType, ArrayAttr attr);
290 
291   /// Prepares bool/int/float DenseElementsAttr serialization. This method
292   /// iterates the DenseElementsAttr to construct the constant array, and
293   /// returns the result <id>  associated with it. Returns 0 if failed. Note
294   /// that the size of `index` must match the rank.
295   /// TODO: Consider to enhance splat elements cases. For splat cases,
296   /// we don't need to loop over all elements, especially when the splat value
297   /// is zero. We can use OpConstantNull when the value is zero.
298   uint32_t prepareDenseElementsConstant(Location loc, Type constType,
299                                         DenseElementsAttr valueAttr, int dim,
300                                         MutableArrayRef<uint64_t> index);
301 
302   /// Prepares scalar attribute serialization. This method emits corresponding
303   /// OpConstant* and returns the result <id> associated with it. Returns 0 if
304   /// the attribute is not for a scalar bool/integer/float value. If `isSpec` is
305   /// true, then the constant will be serialized as a specialization constant.
306   uint32_t prepareConstantScalar(Location loc, Attribute valueAttr,
307                                  bool isSpec = false);
308 
309   uint32_t prepareConstantBool(Location loc, BoolAttr boolAttr,
310                                bool isSpec = false);
311 
312   uint32_t prepareConstantInt(Location loc, IntegerAttr intAttr,
313                               bool isSpec = false);
314 
315   uint32_t prepareConstantFp(Location loc, FloatAttr floatAttr,
316                              bool isSpec = false);
317 
318   //===--------------------------------------------------------------------===//
319   // Control flow
320   //===--------------------------------------------------------------------===//
321 
322   /// Returns the result <id> for the given block.
getBlockID(Block * block) const323   uint32_t getBlockID(Block *block) const { return blockIDMap.lookup(block); }
324 
325   /// Returns the result <id> for the given block. If no <id> has been assigned,
326   /// assigns the next available <id>
327   uint32_t getOrCreateBlockID(Block *block);
328 
329   /// Processes the given `block` and emits SPIR-V instructions for all ops
330   /// inside. Does not emit OpLabel for this block if `omitLabel` is true.
331   /// `actionBeforeTerminator` is a callback that will be invoked before
332   /// handling the terminator op. It can be used to inject the Op*Merge
333   /// instruction if this is a SPIR-V selection/loop header block.
334   LogicalResult
335   processBlock(Block *block, bool omitLabel = false,
336                function_ref<void()> actionBeforeTerminator = nullptr);
337 
338   /// Emits OpPhi instructions for the given block if it has block arguments.
339   LogicalResult emitPhiForBlockArguments(Block *block);
340 
341   LogicalResult processSelectionOp(spirv::SelectionOp selectionOp);
342 
343   LogicalResult processLoopOp(spirv::LoopOp loopOp);
344 
345   LogicalResult processBranchConditionalOp(spirv::BranchConditionalOp);
346 
347   LogicalResult processBranchOp(spirv::BranchOp branchOp);
348 
349   //===--------------------------------------------------------------------===//
350   // Operations
351   //===--------------------------------------------------------------------===//
352 
353   LogicalResult encodeExtensionInstruction(Operation *op,
354                                            StringRef extensionSetName,
355                                            uint32_t opcode,
356                                            ArrayRef<uint32_t> operands);
357 
getValueID(Value val) const358   uint32_t getValueID(Value val) const { return valueIDMap.lookup(val); }
359 
360   LogicalResult processAddressOfOp(spirv::AddressOfOp addressOfOp);
361 
362   LogicalResult processReferenceOfOp(spirv::ReferenceOfOp referenceOfOp);
363 
364   /// Main dispatch method for serializing an operation.
365   LogicalResult processOperation(Operation *op);
366 
367   /// Method to dispatch to the serialization function for an operation in
368   /// SPIR-V dialect that is a mirror of an instruction in the SPIR-V spec.
369   /// This is auto-generated from ODS. Dispatch is handled for all operations
370   /// in SPIR-V dialect that have hasOpcode == 1.
371   LogicalResult dispatchToAutogenSerialization(Operation *op);
372 
373   /// Method to serialize an operation in the SPIR-V dialect that is a mirror of
374   /// an instruction in the SPIR-V spec. This is auto generated if hasOpcode ==
375   /// 1 and autogenSerialization == 1 in ODS.
376   template <typename OpTy>
processOp(OpTy op)377   LogicalResult processOp(OpTy op) {
378     return op.emitError("unsupported op serialization");
379   }
380 
381   //===--------------------------------------------------------------------===//
382   // Utilities
383   //===--------------------------------------------------------------------===//
384 
385   /// Emits an OpDecorate instruction to decorate the given `target` with the
386   /// given `decoration`.
387   LogicalResult emitDecoration(uint32_t target, spirv::Decoration decoration,
388                                ArrayRef<uint32_t> params = {});
389 
390   /// Emits an OpLine instruction with the given `loc` location information into
391   /// the given `binary` vector.
392   LogicalResult emitDebugLine(SmallVectorImpl<uint32_t> &binary, Location loc);
393 
394 private:
395   /// The SPIR-V module to be serialized.
396   spirv::ModuleOp module;
397 
398   /// An MLIR builder for getting MLIR constructs.
399   mlir::Builder mlirBuilder;
400 
401   /// A flag which indicates if the debuginfo should be emitted.
402   bool emitDebugInfo = false;
403 
404   /// A flag which indicates if the last processed instruction was a merge
405   /// instruction.
406   /// According to SPIR-V spec: "If a branch merge instruction is used, the last
407   /// OpLine in the block must be before its merge instruction".
408   bool lastProcessedWasMergeInst = false;
409 
410   /// The <id> of the OpString instruction, which specifies a file name, for
411   /// use by other debug instructions.
412   uint32_t fileID = 0;
413 
414   /// The next available result <id>.
415   uint32_t nextID = 1;
416 
417   // The following are for different SPIR-V instruction sections. They follow
418   // the logical layout of a SPIR-V module.
419 
420   SmallVector<uint32_t, 4> capabilities;
421   SmallVector<uint32_t, 0> extensions;
422   SmallVector<uint32_t, 0> extendedSets;
423   SmallVector<uint32_t, 3> memoryModel;
424   SmallVector<uint32_t, 0> entryPoints;
425   SmallVector<uint32_t, 4> executionModes;
426   SmallVector<uint32_t, 0> debug;
427   SmallVector<uint32_t, 0> names;
428   SmallVector<uint32_t, 0> decorations;
429   SmallVector<uint32_t, 0> typesGlobalValues;
430   SmallVector<uint32_t, 0> functions;
431 
432   /// Recursive struct references are serialized as OpTypePointer instructions
433   /// to the recursive struct type. However, the OpTypePointer instruction
434   /// cannot be emitted before the recursive struct's OpTypeStruct.
435   /// RecursiveStructPointerInfo stores the data needed to emit such
436   /// OpTypePointer instructions after forward references to such types.
437   struct RecursiveStructPointerInfo {
438     uint32_t pointerTypeID;
439     spirv::StorageClass storageClass;
440   };
441 
442   // Maps spirv::StructType to its recursive reference member info.
443   DenseMap<Type, SmallVector<RecursiveStructPointerInfo, 0>>
444       recursiveStructInfos;
445 
446   /// `functionHeader` contains all the instructions that must be in the first
447   /// block in the function, and `functionBody` contains the rest. After
448   /// processing FuncOp, the encoded instructions of a function are appended to
449   /// `functions`. An example of instructions in `functionHeader` in order:
450   /// OpFunction ...
451   /// OpFunctionParameter ...
452   /// OpFunctionParameter ...
453   /// OpLabel ...
454   /// OpVariable ...
455   /// OpVariable ...
456   SmallVector<uint32_t, 0> functionHeader;
457   SmallVector<uint32_t, 0> functionBody;
458 
459   /// Map from type used in SPIR-V module to their <id>s.
460   DenseMap<Type, uint32_t> typeIDMap;
461 
462   /// Map from constant values to their <id>s.
463   DenseMap<Attribute, uint32_t> constIDMap;
464 
465   /// Map from specialization constant names to their <id>s.
466   llvm::StringMap<uint32_t> specConstIDMap;
467 
468   /// Map from GlobalVariableOps name to <id>s.
469   llvm::StringMap<uint32_t> globalVarIDMap;
470 
471   /// Map from FuncOps name to <id>s.
472   llvm::StringMap<uint32_t> funcIDMap;
473 
474   /// Map from blocks to their <id>s.
475   DenseMap<Block *, uint32_t> blockIDMap;
476 
477   /// Map from the Type to the <id> that represents undef value of that type.
478   DenseMap<Type, uint32_t> undefValIDMap;
479 
480   /// Map from results of normal operations to their <id>s.
481   DenseMap<Value, uint32_t> valueIDMap;
482 
483   /// Map from extended instruction set name to <id>s.
484   llvm::StringMap<uint32_t> extendedInstSetIDMap;
485 
486   /// Map from values used in OpPhi instructions to their offset in the
487   /// `functions` section.
488   ///
489   /// When processing a block with arguments, we need to emit OpPhi
490   /// instructions to record the predecessor block <id>s and the values they
491   /// send to the block in question. But it's not guaranteed all values are
492   /// visited and thus assigned result <id>s. So we need this list to capture
493   /// the offsets into `functions` where a value is used so that we can fix it
494   /// up later after processing all the blocks in a function.
495   ///
496   /// More concretely, say if we are visiting the following blocks:
497   ///
498   /// ```mlir
499   /// ^phi(%arg0: i32):
500   ///   ...
501   /// ^parent1:
502   ///   ...
503   ///   spv.Branch ^phi(%val0: i32)
504   /// ^parent2:
505   ///   ...
506   ///   spv.Branch ^phi(%val1: i32)
507   /// ```
508   ///
509   /// When we are serializing the `^phi` block, we need to emit at the beginning
510   /// of the block OpPhi instructions which has the following parameters:
511   ///
512   /// OpPhi id-for-i32 id-for-%arg0 id-for-%val0 id-for-^parent1
513   ///                               id-for-%val1 id-for-^parent2
514   ///
515   /// But we don't know the <id> for %val0 and %val1 yet. One way is to visit
516   /// all the blocks twice and use the first visit to assign an <id> to each
517   /// value. But it's paying the overheads just for OpPhi emission. Instead,
518   /// we still visit the blocks once for emission. When we emit the OpPhi
519   /// instructions, we use 0 as a placeholder for the <id>s for %val0 and %val1.
520   /// At the same time, we record their offsets in the emitted binary (which is
521   /// placed inside `functions`) here. And then after emitting all blocks, we
522   /// replace the dummy <id> 0 with the real result <id> by overwriting
523   /// `functions[offset]`.
524   DenseMap<Value, SmallVector<size_t, 1>> deferredPhiValues;
525 };
526 } // namespace
527 
Serializer(spirv::ModuleOp module,bool emitDebugInfo)528 Serializer::Serializer(spirv::ModuleOp module, bool emitDebugInfo)
529     : module(module), mlirBuilder(module.getContext()),
530       emitDebugInfo(emitDebugInfo) {}
531 
serialize()532 LogicalResult Serializer::serialize() {
533   LLVM_DEBUG(llvm::dbgs() << "+++ starting serialization +++\n");
534 
535   if (failed(module.verify()))
536     return failure();
537 
538   // TODO: handle the other sections
539   processCapability();
540   processExtension();
541   processMemoryModel();
542   processDebugInfo();
543 
544   // Iterate over the module body to serialize it. Assumptions are that there is
545   // only one basic block in the moduleOp
546   for (auto &op : module.getBlock()) {
547     if (failed(processOperation(&op))) {
548       return failure();
549     }
550   }
551 
552   LLVM_DEBUG(llvm::dbgs() << "+++ completed serialization +++\n");
553   return success();
554 }
555 
collect(SmallVectorImpl<uint32_t> & binary)556 void Serializer::collect(SmallVectorImpl<uint32_t> &binary) {
557   auto moduleSize = spirv::kHeaderWordCount + capabilities.size() +
558                     extensions.size() + extendedSets.size() +
559                     memoryModel.size() + entryPoints.size() +
560                     executionModes.size() + decorations.size() +
561                     typesGlobalValues.size() + functions.size();
562 
563   binary.clear();
564   binary.reserve(moduleSize);
565 
566   spirv::appendModuleHeader(binary, module.vce_triple()->getVersion(), nextID);
567   binary.append(capabilities.begin(), capabilities.end());
568   binary.append(extensions.begin(), extensions.end());
569   binary.append(extendedSets.begin(), extendedSets.end());
570   binary.append(memoryModel.begin(), memoryModel.end());
571   binary.append(entryPoints.begin(), entryPoints.end());
572   binary.append(executionModes.begin(), executionModes.end());
573   binary.append(debug.begin(), debug.end());
574   binary.append(names.begin(), names.end());
575   binary.append(decorations.begin(), decorations.end());
576   binary.append(typesGlobalValues.begin(), typesGlobalValues.end());
577   binary.append(functions.begin(), functions.end());
578 }
579 
580 #ifndef NDEBUG
printValueIDMap(raw_ostream & os)581 void Serializer::printValueIDMap(raw_ostream &os) {
582   os << "\n= Value <id> Map =\n\n";
583   for (auto valueIDPair : valueIDMap) {
584     Value val = valueIDPair.first;
585     os << "  " << val << " "
586        << "id = " << valueIDPair.second << ' ';
587     if (auto *op = val.getDefiningOp()) {
588       os << "from op '" << op->getName() << "'";
589     } else if (auto arg = val.dyn_cast<BlockArgument>()) {
590       Block *block = arg.getOwner();
591       os << "from argument of block " << block << ' ';
592       os << " in op '" << block->getParentOp()->getName() << "'";
593     }
594     os << '\n';
595   }
596 }
597 #endif
598 
599 //===----------------------------------------------------------------------===//
600 // Module structure
601 //===----------------------------------------------------------------------===//
602 
getOrCreateFunctionID(StringRef fnName)603 uint32_t Serializer::getOrCreateFunctionID(StringRef fnName) {
604   auto funcID = funcIDMap.lookup(fnName);
605   if (!funcID) {
606     funcID = getNextID();
607     funcIDMap[fnName] = funcID;
608   }
609   return funcID;
610 }
611 
processCapability()612 void Serializer::processCapability() {
613   for (auto cap : module.vce_triple()->getCapabilities())
614     encodeInstructionInto(capabilities, spirv::Opcode::OpCapability,
615                           {static_cast<uint32_t>(cap)});
616 }
617 
processDebugInfo()618 void Serializer::processDebugInfo() {
619   if (!emitDebugInfo)
620     return;
621   auto fileLoc = module.getLoc().dyn_cast<FileLineColLoc>();
622   auto fileName = fileLoc ? fileLoc.getFilename() : "<unknown>";
623   fileID = getNextID();
624   SmallVector<uint32_t, 16> operands;
625   operands.push_back(fileID);
626   spirv::encodeStringLiteralInto(operands, fileName);
627   encodeInstructionInto(debug, spirv::Opcode::OpString, operands);
628   // TODO: Encode more debug instructions.
629 }
630 
processExtension()631 void Serializer::processExtension() {
632   llvm::SmallVector<uint32_t, 16> extName;
633   for (spirv::Extension ext : module.vce_triple()->getExtensions()) {
634     extName.clear();
635     spirv::encodeStringLiteralInto(extName, spirv::stringifyExtension(ext));
636     encodeInstructionInto(extensions, spirv::Opcode::OpExtension, extName);
637   }
638 }
639 
processMemoryModel()640 void Serializer::processMemoryModel() {
641   uint32_t mm = module->getAttrOfType<IntegerAttr>("memory_model").getInt();
642   uint32_t am = module->getAttrOfType<IntegerAttr>("addressing_model").getInt();
643 
644   encodeInstructionInto(memoryModel, spirv::Opcode::OpMemoryModel, {am, mm});
645 }
646 
processConstantOp(spirv::ConstantOp op)647 LogicalResult Serializer::processConstantOp(spirv::ConstantOp op) {
648   if (auto resultID = prepareConstant(op.getLoc(), op.getType(), op.value())) {
649     valueIDMap[op.getResult()] = resultID;
650     return success();
651   }
652   return failure();
653 }
654 
processSpecConstantOp(spirv::SpecConstantOp op)655 LogicalResult Serializer::processSpecConstantOp(spirv::SpecConstantOp op) {
656   if (auto resultID = prepareConstantScalar(op.getLoc(), op.default_value(),
657                                             /*isSpec=*/true)) {
658     // Emit the OpDecorate instruction for SpecId.
659     if (auto specID = op->getAttrOfType<IntegerAttr>("spec_id")) {
660       auto val = static_cast<uint32_t>(specID.getInt());
661       emitDecoration(resultID, spirv::Decoration::SpecId, {val});
662     }
663 
664     specConstIDMap[op.sym_name()] = resultID;
665     return processName(resultID, op.sym_name());
666   }
667   return failure();
668 }
669 
670 LogicalResult
processSpecConstantCompositeOp(spirv::SpecConstantCompositeOp op)671 Serializer::processSpecConstantCompositeOp(spirv::SpecConstantCompositeOp op) {
672   uint32_t typeID = 0;
673   if (failed(processType(op.getLoc(), op.type(), typeID))) {
674     return failure();
675   }
676 
677   auto resultID = getNextID();
678 
679   SmallVector<uint32_t, 8> operands;
680   operands.push_back(typeID);
681   operands.push_back(resultID);
682 
683   auto constituents = op.constituents();
684 
685   for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
686     auto constituent = constituents[index].dyn_cast<FlatSymbolRefAttr>();
687 
688     auto constituentName = constituent.getValue();
689     auto constituentID = getSpecConstID(constituentName);
690 
691     if (!constituentID) {
692       return op.emitError("unknown result <id> for specialization constant ")
693              << constituentName;
694     }
695 
696     operands.push_back(constituentID);
697   }
698 
699   encodeInstructionInto(typesGlobalValues,
700                         spirv::Opcode::OpSpecConstantComposite, operands);
701   specConstIDMap[op.sym_name()] = resultID;
702 
703   return processName(resultID, op.sym_name());
704 }
705 
processUndefOp(spirv::UndefOp op)706 LogicalResult Serializer::processUndefOp(spirv::UndefOp op) {
707   auto undefType = op.getType();
708   auto &id = undefValIDMap[undefType];
709   if (!id) {
710     id = getNextID();
711     uint32_t typeID = 0;
712     if (failed(processType(op.getLoc(), undefType, typeID)) ||
713         failed(encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpUndef,
714                                      {typeID, id}))) {
715       return failure();
716     }
717   }
718   valueIDMap[op.getResult()] = id;
719   return success();
720 }
721 
processDecoration(Location loc,uint32_t resultID,NamedAttribute attr)722 LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID,
723                                             NamedAttribute attr) {
724   auto attrName = attr.first.strref();
725   auto decorationName = llvm::convertToCamelFromSnakeCase(attrName, true);
726   auto decoration = spirv::symbolizeDecoration(decorationName);
727   if (!decoration) {
728     return emitError(
729                loc, "non-argument attributes expected to have snake-case-ified "
730                     "decoration name, unhandled attribute with name : ")
731            << attrName;
732   }
733   SmallVector<uint32_t, 1> args;
734   switch (decoration.getValue()) {
735   case spirv::Decoration::Binding:
736   case spirv::Decoration::DescriptorSet:
737   case spirv::Decoration::Location:
738     if (auto intAttr = attr.second.dyn_cast<IntegerAttr>()) {
739       args.push_back(intAttr.getValue().getZExtValue());
740       break;
741     }
742     return emitError(loc, "expected integer attribute for ") << attrName;
743   case spirv::Decoration::BuiltIn:
744     if (auto strAttr = attr.second.dyn_cast<StringAttr>()) {
745       auto enumVal = spirv::symbolizeBuiltIn(strAttr.getValue());
746       if (enumVal) {
747         args.push_back(static_cast<uint32_t>(enumVal.getValue()));
748         break;
749       }
750       return emitError(loc, "invalid ")
751              << attrName << " attribute " << strAttr.getValue();
752     }
753     return emitError(loc, "expected string attribute for ") << attrName;
754   case spirv::Decoration::Aliased:
755   case spirv::Decoration::Flat:
756   case spirv::Decoration::NonReadable:
757   case spirv::Decoration::NonWritable:
758   case spirv::Decoration::NoPerspective:
759   case spirv::Decoration::Restrict:
760     // For unit attributes, the args list has no values so we do nothing
761     if (auto unitAttr = attr.second.dyn_cast<UnitAttr>())
762       break;
763     return emitError(loc, "expected unit attribute for ") << attrName;
764   default:
765     return emitError(loc, "unhandled decoration ") << decorationName;
766   }
767   return emitDecoration(resultID, decoration.getValue(), args);
768 }
769 
processName(uint32_t resultID,StringRef name)770 LogicalResult Serializer::processName(uint32_t resultID, StringRef name) {
771   assert(!name.empty() && "unexpected empty string for OpName");
772 
773   SmallVector<uint32_t, 4> nameOperands;
774   nameOperands.push_back(resultID);
775   if (failed(spirv::encodeStringLiteralInto(nameOperands, name))) {
776     return failure();
777   }
778   return encodeInstructionInto(names, spirv::Opcode::OpName, nameOperands);
779 }
780 
781 namespace {
782 template <>
processTypeDecoration(Location loc,spirv::ArrayType type,uint32_t resultID)783 LogicalResult Serializer::processTypeDecoration<spirv::ArrayType>(
784     Location loc, spirv::ArrayType type, uint32_t resultID) {
785   if (unsigned stride = type.getArrayStride()) {
786     // OpDecorate %arrayTypeSSA ArrayStride strideLiteral
787     return emitDecoration(resultID, spirv::Decoration::ArrayStride, {stride});
788   }
789   return success();
790 }
791 
792 template <>
processTypeDecoration(Location Loc,spirv::RuntimeArrayType type,uint32_t resultID)793 LogicalResult Serializer::processTypeDecoration<spirv::RuntimeArrayType>(
794     Location Loc, spirv::RuntimeArrayType type, uint32_t resultID) {
795   if (unsigned stride = type.getArrayStride()) {
796     // OpDecorate %arrayTypeSSA ArrayStride strideLiteral
797     return emitDecoration(resultID, spirv::Decoration::ArrayStride, {stride});
798   }
799   return success();
800 }
801 
processMemberDecoration(uint32_t structID,const spirv::StructType::MemberDecorationInfo & memberDecoration)802 LogicalResult Serializer::processMemberDecoration(
803     uint32_t structID,
804     const spirv::StructType::MemberDecorationInfo &memberDecoration) {
805   SmallVector<uint32_t, 4> args(
806       {structID, memberDecoration.memberIndex,
807        static_cast<uint32_t>(memberDecoration.decoration)});
808   if (memberDecoration.hasValue) {
809     args.push_back(memberDecoration.decorationValue);
810   }
811   return encodeInstructionInto(decorations, spirv::Opcode::OpMemberDecorate,
812                                args);
813 }
814 } // namespace
815 
processFuncOp(spirv::FuncOp op)816 LogicalResult Serializer::processFuncOp(spirv::FuncOp op) {
817   LLVM_DEBUG(llvm::dbgs() << "-- start function '" << op.getName() << "' --\n");
818   assert(functionHeader.empty() && functionBody.empty());
819 
820   uint32_t fnTypeID = 0;
821   // Generate type of the function.
822   processType(op.getLoc(), op.getType(), fnTypeID);
823 
824   // Add the function definition.
825   SmallVector<uint32_t, 4> operands;
826   uint32_t resTypeID = 0;
827   auto resultTypes = op.getType().getResults();
828   if (resultTypes.size() > 1) {
829     return op.emitError("cannot serialize function with multiple return types");
830   }
831   if (failed(processType(op.getLoc(),
832                          (resultTypes.empty() ? getVoidType() : resultTypes[0]),
833                          resTypeID))) {
834     return failure();
835   }
836   operands.push_back(resTypeID);
837   auto funcID = getOrCreateFunctionID(op.getName());
838   operands.push_back(funcID);
839   operands.push_back(static_cast<uint32_t>(op.function_control()));
840   operands.push_back(fnTypeID);
841   encodeInstructionInto(functionHeader, spirv::Opcode::OpFunction, operands);
842 
843   // Add function name.
844   if (failed(processName(funcID, op.getName()))) {
845     return failure();
846   }
847 
848   // Declare the parameters.
849   for (auto arg : op.getArguments()) {
850     uint32_t argTypeID = 0;
851     if (failed(processType(op.getLoc(), arg.getType(), argTypeID))) {
852       return failure();
853     }
854     auto argValueID = getNextID();
855     valueIDMap[arg] = argValueID;
856     encodeInstructionInto(functionHeader, spirv::Opcode::OpFunctionParameter,
857                           {argTypeID, argValueID});
858   }
859 
860   // Process the body.
861   if (op.isExternal()) {
862     return op.emitError("external function is unhandled");
863   }
864 
865   // Some instructions (e.g., OpVariable) in a function must be in the first
866   // block in the function. These instructions will be put in functionHeader.
867   // Thus, we put the label in functionHeader first, and omit it from the first
868   // block.
869   encodeInstructionInto(functionHeader, spirv::Opcode::OpLabel,
870                         {getOrCreateBlockID(&op.front())});
871   processBlock(&op.front(), /*omitLabel=*/true);
872   if (failed(visitInPrettyBlockOrder(
873           &op.front(), [&](Block *block) { return processBlock(block); },
874           /*skipHeader=*/true))) {
875     return failure();
876   }
877 
878   // There might be OpPhi instructions who have value references needing to fix.
879   for (auto deferredValue : deferredPhiValues) {
880     Value value = deferredValue.first;
881     uint32_t id = getValueID(value);
882     LLVM_DEBUG(llvm::dbgs() << "[phi] fix reference of value " << value
883                             << " to id = " << id << '\n');
884     assert(id && "OpPhi references undefined value!");
885     for (size_t offset : deferredValue.second)
886       functionBody[offset] = id;
887   }
888   deferredPhiValues.clear();
889 
890   LLVM_DEBUG(llvm::dbgs() << "-- completed function '" << op.getName()
891                           << "' --\n");
892   // Insert OpFunctionEnd.
893   if (failed(encodeInstructionInto(functionBody, spirv::Opcode::OpFunctionEnd,
894                                    {}))) {
895     return failure();
896   }
897 
898   functions.append(functionHeader.begin(), functionHeader.end());
899   functions.append(functionBody.begin(), functionBody.end());
900   functionHeader.clear();
901   functionBody.clear();
902 
903   return success();
904 }
905 
processVariableOp(spirv::VariableOp op)906 LogicalResult Serializer::processVariableOp(spirv::VariableOp op) {
907   SmallVector<uint32_t, 4> operands;
908   SmallVector<StringRef, 2> elidedAttrs;
909   uint32_t resultID = 0;
910   uint32_t resultTypeID = 0;
911   if (failed(processType(op.getLoc(), op.getType(), resultTypeID))) {
912     return failure();
913   }
914   operands.push_back(resultTypeID);
915   resultID = getNextID();
916   valueIDMap[op.getResult()] = resultID;
917   operands.push_back(resultID);
918   auto attr = op.getAttr(spirv::attributeName<spirv::StorageClass>());
919   if (attr) {
920     operands.push_back(static_cast<uint32_t>(
921         attr.cast<IntegerAttr>().getValue().getZExtValue()));
922   }
923   elidedAttrs.push_back(spirv::attributeName<spirv::StorageClass>());
924   for (auto arg : op.getODSOperands(0)) {
925     auto argID = getValueID(arg);
926     if (!argID) {
927       return emitError(op.getLoc(), "operand 0 has a use before def");
928     }
929     operands.push_back(argID);
930   }
931   emitDebugLine(functionHeader, op.getLoc());
932   encodeInstructionInto(functionHeader, spirv::Opcode::OpVariable, operands);
933   for (auto attr : op.getAttrs()) {
934     if (llvm::any_of(elidedAttrs,
935                      [&](StringRef elided) { return attr.first == elided; })) {
936       continue;
937     }
938     if (failed(processDecoration(op.getLoc(), resultID, attr))) {
939       return failure();
940     }
941   }
942   return success();
943 }
944 
945 LogicalResult
processGlobalVariableOp(spirv::GlobalVariableOp varOp)946 Serializer::processGlobalVariableOp(spirv::GlobalVariableOp varOp) {
947   // Get TypeID.
948   uint32_t resultTypeID = 0;
949   SmallVector<StringRef, 4> elidedAttrs;
950   if (failed(processType(varOp.getLoc(), varOp.type(), resultTypeID))) {
951     return failure();
952   }
953 
954   if (isInterfaceStructPtrType(varOp.type())) {
955     auto structType = varOp.type()
956                           .cast<spirv::PointerType>()
957                           .getPointeeType()
958                           .cast<spirv::StructType>();
959     if (failed(
960             emitDecoration(getTypeID(structType), spirv::Decoration::Block))) {
961       return varOp.emitError("cannot decorate ")
962              << structType << " with Block decoration";
963     }
964   }
965 
966   elidedAttrs.push_back("type");
967   SmallVector<uint32_t, 4> operands;
968   operands.push_back(resultTypeID);
969   auto resultID = getNextID();
970 
971   // Encode the name.
972   auto varName = varOp.sym_name();
973   elidedAttrs.push_back(SymbolTable::getSymbolAttrName());
974   if (failed(processName(resultID, varName))) {
975     return failure();
976   }
977   globalVarIDMap[varName] = resultID;
978   operands.push_back(resultID);
979 
980   // Encode StorageClass.
981   operands.push_back(static_cast<uint32_t>(varOp.storageClass()));
982 
983   // Encode initialization.
984   if (auto initializer = varOp.initializer()) {
985     auto initializerID = getVariableID(initializer.getValue());
986     if (!initializerID) {
987       return emitError(varOp.getLoc(),
988                        "invalid usage of undefined variable as initializer");
989     }
990     operands.push_back(initializerID);
991     elidedAttrs.push_back("initializer");
992   }
993 
994   emitDebugLine(typesGlobalValues, varOp.getLoc());
995   if (failed(encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpVariable,
996                                    operands))) {
997     elidedAttrs.push_back("initializer");
998     return failure();
999   }
1000 
1001   // Encode decorations.
1002   for (auto attr : varOp.getAttrs()) {
1003     if (llvm::any_of(elidedAttrs,
1004                      [&](StringRef elided) { return attr.first == elided; })) {
1005       continue;
1006     }
1007     if (failed(processDecoration(varOp.getLoc(), resultID, attr))) {
1008       return failure();
1009     }
1010   }
1011   return success();
1012 }
1013 
1014 //===----------------------------------------------------------------------===//
1015 // Type
1016 //===----------------------------------------------------------------------===//
1017 
1018 // According to the SPIR-V spec "Validation Rules for Shader Capabilities":
1019 // "Composite objects in the StorageBuffer, PhysicalStorageBuffer, Uniform, and
1020 // PushConstant Storage Classes must be explicitly laid out."
isInterfaceStructPtrType(Type type) const1021 bool Serializer::isInterfaceStructPtrType(Type type) const {
1022   if (auto ptrType = type.dyn_cast<spirv::PointerType>()) {
1023     switch (ptrType.getStorageClass()) {
1024     case spirv::StorageClass::PhysicalStorageBuffer:
1025     case spirv::StorageClass::PushConstant:
1026     case spirv::StorageClass::StorageBuffer:
1027     case spirv::StorageClass::Uniform:
1028       return ptrType.getPointeeType().isa<spirv::StructType>();
1029     default:
1030       break;
1031     }
1032   }
1033   return false;
1034 }
1035 
processType(Location loc,Type type,uint32_t & typeID)1036 LogicalResult Serializer::processType(Location loc, Type type,
1037                                       uint32_t &typeID) {
1038   // Maintains a set of names for nested identified struct types. This is used
1039   // to properly serialize resursive references.
1040   llvm::SetVector<StringRef> serializationCtx;
1041   return processTypeImpl(loc, type, typeID, serializationCtx);
1042 }
1043 
1044 LogicalResult
processTypeImpl(Location loc,Type type,uint32_t & typeID,llvm::SetVector<StringRef> & serializationCtx)1045 Serializer::processTypeImpl(Location loc, Type type, uint32_t &typeID,
1046                             llvm::SetVector<StringRef> &serializationCtx) {
1047   typeID = getTypeID(type);
1048   if (typeID) {
1049     return success();
1050   }
1051   typeID = getNextID();
1052   SmallVector<uint32_t, 4> operands;
1053 
1054   operands.push_back(typeID);
1055   auto typeEnum = spirv::Opcode::OpTypeVoid;
1056   bool deferSerialization = false;
1057 
1058   if ((type.isa<FunctionType>() &&
1059        succeeded(prepareFunctionType(loc, type.cast<FunctionType>(), typeEnum,
1060                                      operands))) ||
1061       succeeded(prepareBasicType(loc, type, typeID, typeEnum, operands,
1062                                  deferSerialization, serializationCtx))) {
1063     if (deferSerialization)
1064       return success();
1065 
1066     typeIDMap[type] = typeID;
1067 
1068     if (failed(encodeInstructionInto(typesGlobalValues, typeEnum, operands)))
1069       return failure();
1070 
1071     if (recursiveStructInfos.count(type) != 0) {
1072       // This recursive struct type is emitted already, now the OpTypePointer
1073       // instructions referring to recursive references are emitted as well.
1074       for (auto &ptrInfo : recursiveStructInfos[type]) {
1075         // TODO: This might not work if more than 1 recursive reference is
1076         // present in the struct.
1077         SmallVector<uint32_t, 4> ptrOperands;
1078         ptrOperands.push_back(ptrInfo.pointerTypeID);
1079         ptrOperands.push_back(static_cast<uint32_t>(ptrInfo.storageClass));
1080         ptrOperands.push_back(typeIDMap[type]);
1081 
1082         if (failed(encodeInstructionInto(
1083                 typesGlobalValues, spirv::Opcode::OpTypePointer, ptrOperands)))
1084           return failure();
1085       }
1086 
1087       recursiveStructInfos[type].clear();
1088     }
1089 
1090     return success();
1091   }
1092 
1093   return failure();
1094 }
1095 
prepareBasicType(Location loc,Type type,uint32_t resultID,spirv::Opcode & typeEnum,SmallVectorImpl<uint32_t> & operands,bool & deferSerialization,llvm::SetVector<StringRef> & serializationCtx)1096 LogicalResult Serializer::prepareBasicType(
1097     Location loc, Type type, uint32_t resultID, spirv::Opcode &typeEnum,
1098     SmallVectorImpl<uint32_t> &operands, bool &deferSerialization,
1099     llvm::SetVector<StringRef> &serializationCtx) {
1100   deferSerialization = false;
1101 
1102   if (isVoidType(type)) {
1103     typeEnum = spirv::Opcode::OpTypeVoid;
1104     return success();
1105   }
1106 
1107   if (auto intType = type.dyn_cast<IntegerType>()) {
1108     if (intType.getWidth() == 1) {
1109       typeEnum = spirv::Opcode::OpTypeBool;
1110       return success();
1111     }
1112 
1113     typeEnum = spirv::Opcode::OpTypeInt;
1114     operands.push_back(intType.getWidth());
1115     // SPIR-V OpTypeInt "Signedness specifies whether there are signed semantics
1116     // to preserve or validate.
1117     // 0 indicates unsigned, or no signedness semantics
1118     // 1 indicates signed semantics."
1119     operands.push_back(intType.isSigned() ? 1 : 0);
1120     return success();
1121   }
1122 
1123   if (auto floatType = type.dyn_cast<FloatType>()) {
1124     typeEnum = spirv::Opcode::OpTypeFloat;
1125     operands.push_back(floatType.getWidth());
1126     return success();
1127   }
1128 
1129   if (auto vectorType = type.dyn_cast<VectorType>()) {
1130     uint32_t elementTypeID = 0;
1131     if (failed(processTypeImpl(loc, vectorType.getElementType(), elementTypeID,
1132                                serializationCtx))) {
1133       return failure();
1134     }
1135     typeEnum = spirv::Opcode::OpTypeVector;
1136     operands.push_back(elementTypeID);
1137     operands.push_back(vectorType.getNumElements());
1138     return success();
1139   }
1140 
1141   if (auto arrayType = type.dyn_cast<spirv::ArrayType>()) {
1142     typeEnum = spirv::Opcode::OpTypeArray;
1143     uint32_t elementTypeID = 0;
1144     if (failed(processTypeImpl(loc, arrayType.getElementType(), elementTypeID,
1145                                serializationCtx))) {
1146       return failure();
1147     }
1148     operands.push_back(elementTypeID);
1149     if (auto elementCountID = prepareConstantInt(
1150             loc, mlirBuilder.getI32IntegerAttr(arrayType.getNumElements()))) {
1151       operands.push_back(elementCountID);
1152     }
1153     return processTypeDecoration(loc, arrayType, resultID);
1154   }
1155 
1156   if (auto ptrType = type.dyn_cast<spirv::PointerType>()) {
1157     uint32_t pointeeTypeID = 0;
1158     spirv::StructType pointeeStruct =
1159         ptrType.getPointeeType().dyn_cast<spirv::StructType>();
1160 
1161     if (pointeeStruct && pointeeStruct.isIdentified() &&
1162         serializationCtx.count(pointeeStruct.getIdentifier()) != 0) {
1163       // A recursive reference to an enclosing struct is found.
1164       //
1165       // 1. Prepare an OpTypeForwardPointer with resultID and the ptr storage
1166       // class as operands.
1167       SmallVector<uint32_t, 2> forwardPtrOperands;
1168       forwardPtrOperands.push_back(resultID);
1169       forwardPtrOperands.push_back(
1170           static_cast<uint32_t>(ptrType.getStorageClass()));
1171 
1172       encodeInstructionInto(typesGlobalValues,
1173                             spirv::Opcode::OpTypeForwardPointer,
1174                             forwardPtrOperands);
1175 
1176       // 2. Find the pointee (enclosing) struct.
1177       auto structType = spirv::StructType::getIdentified(
1178           module.getContext(), pointeeStruct.getIdentifier());
1179 
1180       if (!structType)
1181         return failure();
1182 
1183       // 3. Mark the OpTypePointer that is supposed to be emitted by this call
1184       // as deferred.
1185       deferSerialization = true;
1186 
1187       // 4. Record the info needed to emit the deferred OpTypePointer
1188       // instruction when the enclosing struct is completely serialized.
1189       recursiveStructInfos[structType].push_back(
1190           {resultID, ptrType.getStorageClass()});
1191     } else {
1192       if (failed(processTypeImpl(loc, ptrType.getPointeeType(), pointeeTypeID,
1193                                  serializationCtx)))
1194         return failure();
1195     }
1196 
1197     typeEnum = spirv::Opcode::OpTypePointer;
1198     operands.push_back(static_cast<uint32_t>(ptrType.getStorageClass()));
1199     operands.push_back(pointeeTypeID);
1200     return success();
1201   }
1202 
1203   if (auto runtimeArrayType = type.dyn_cast<spirv::RuntimeArrayType>()) {
1204     uint32_t elementTypeID = 0;
1205     if (failed(processTypeImpl(loc, runtimeArrayType.getElementType(),
1206                                elementTypeID, serializationCtx))) {
1207       return failure();
1208     }
1209     typeEnum = spirv::Opcode::OpTypeRuntimeArray;
1210     operands.push_back(elementTypeID);
1211     return processTypeDecoration(loc, runtimeArrayType, resultID);
1212   }
1213 
1214   if (auto structType = type.dyn_cast<spirv::StructType>()) {
1215     if (structType.isIdentified()) {
1216       processName(resultID, structType.getIdentifier());
1217       serializationCtx.insert(structType.getIdentifier());
1218     }
1219 
1220     bool hasOffset = structType.hasOffset();
1221     for (auto elementIndex :
1222          llvm::seq<uint32_t>(0, structType.getNumElements())) {
1223       uint32_t elementTypeID = 0;
1224       if (failed(processTypeImpl(loc, structType.getElementType(elementIndex),
1225                                  elementTypeID, serializationCtx))) {
1226         return failure();
1227       }
1228       operands.push_back(elementTypeID);
1229       if (hasOffset) {
1230         // Decorate each struct member with an offset
1231         spirv::StructType::MemberDecorationInfo offsetDecoration{
1232             elementIndex, /*hasValue=*/1, spirv::Decoration::Offset,
1233             static_cast<uint32_t>(structType.getMemberOffset(elementIndex))};
1234         if (failed(processMemberDecoration(resultID, offsetDecoration))) {
1235           return emitError(loc, "cannot decorate ")
1236                  << elementIndex << "-th member of " << structType
1237                  << " with its offset";
1238         }
1239       }
1240     }
1241     SmallVector<spirv::StructType::MemberDecorationInfo, 4> memberDecorations;
1242     structType.getMemberDecorations(memberDecorations);
1243 
1244     for (auto &memberDecoration : memberDecorations) {
1245       if (failed(processMemberDecoration(resultID, memberDecoration))) {
1246         return emitError(loc, "cannot decorate ")
1247                << static_cast<uint32_t>(memberDecoration.memberIndex)
1248                << "-th member of " << structType << " with "
1249                << stringifyDecoration(memberDecoration.decoration);
1250       }
1251     }
1252 
1253     typeEnum = spirv::Opcode::OpTypeStruct;
1254 
1255     if (structType.isIdentified())
1256       serializationCtx.remove(structType.getIdentifier());
1257 
1258     return success();
1259   }
1260 
1261   if (auto cooperativeMatrixType =
1262           type.dyn_cast<spirv::CooperativeMatrixNVType>()) {
1263     uint32_t elementTypeID = 0;
1264     if (failed(processTypeImpl(loc, cooperativeMatrixType.getElementType(),
1265                                elementTypeID, serializationCtx))) {
1266       return failure();
1267     }
1268     typeEnum = spirv::Opcode::OpTypeCooperativeMatrixNV;
1269     auto getConstantOp = [&](uint32_t id) {
1270       auto attr = IntegerAttr::get(IntegerType::get(32, type.getContext()), id);
1271       return prepareConstantInt(loc, attr);
1272     };
1273     operands.push_back(elementTypeID);
1274     operands.push_back(
1275         getConstantOp(static_cast<uint32_t>(cooperativeMatrixType.getScope())));
1276     operands.push_back(getConstantOp(cooperativeMatrixType.getRows()));
1277     operands.push_back(getConstantOp(cooperativeMatrixType.getColumns()));
1278     return success();
1279   }
1280 
1281   if (auto matrixType = type.dyn_cast<spirv::MatrixType>()) {
1282     uint32_t elementTypeID = 0;
1283     if (failed(processTypeImpl(loc, matrixType.getColumnType(), elementTypeID,
1284                                serializationCtx))) {
1285       return failure();
1286     }
1287     typeEnum = spirv::Opcode::OpTypeMatrix;
1288     operands.push_back(elementTypeID);
1289     operands.push_back(matrixType.getNumColumns());
1290     return success();
1291   }
1292 
1293   // TODO: Handle other types.
1294   return emitError(loc, "unhandled type in serialization: ") << type;
1295 }
1296 
1297 LogicalResult
prepareFunctionType(Location loc,FunctionType type,spirv::Opcode & typeEnum,SmallVectorImpl<uint32_t> & operands)1298 Serializer::prepareFunctionType(Location loc, FunctionType type,
1299                                 spirv::Opcode &typeEnum,
1300                                 SmallVectorImpl<uint32_t> &operands) {
1301   typeEnum = spirv::Opcode::OpTypeFunction;
1302   assert(type.getNumResults() <= 1 &&
1303          "serialization supports only a single return value");
1304   uint32_t resultID = 0;
1305   if (failed(processType(
1306           loc, type.getNumResults() == 1 ? type.getResult(0) : getVoidType(),
1307           resultID))) {
1308     return failure();
1309   }
1310   operands.push_back(resultID);
1311   for (auto &res : type.getInputs()) {
1312     uint32_t argTypeID = 0;
1313     if (failed(processType(loc, res, argTypeID))) {
1314       return failure();
1315     }
1316     operands.push_back(argTypeID);
1317   }
1318   return success();
1319 }
1320 
1321 //===----------------------------------------------------------------------===//
1322 // Constant
1323 //===----------------------------------------------------------------------===//
1324 
prepareConstant(Location loc,Type constType,Attribute valueAttr)1325 uint32_t Serializer::prepareConstant(Location loc, Type constType,
1326                                      Attribute valueAttr) {
1327   if (auto id = prepareConstantScalar(loc, valueAttr)) {
1328     return id;
1329   }
1330 
1331   // This is a composite literal. We need to handle each component separately
1332   // and then emit an OpConstantComposite for the whole.
1333 
1334   if (auto id = getConstantID(valueAttr)) {
1335     return id;
1336   }
1337 
1338   uint32_t typeID = 0;
1339   if (failed(processType(loc, constType, typeID))) {
1340     return 0;
1341   }
1342 
1343   uint32_t resultID = 0;
1344   if (auto attr = valueAttr.dyn_cast<DenseElementsAttr>()) {
1345     int rank = attr.getType().dyn_cast<ShapedType>().getRank();
1346     SmallVector<uint64_t, 4> index(rank);
1347     resultID = prepareDenseElementsConstant(loc, constType, attr,
1348                                             /*dim=*/0, index);
1349   } else if (auto arrayAttr = valueAttr.dyn_cast<ArrayAttr>()) {
1350     resultID = prepareArrayConstant(loc, constType, arrayAttr);
1351   }
1352 
1353   if (resultID == 0) {
1354     emitError(loc, "cannot serialize attribute: ") << valueAttr;
1355     return 0;
1356   }
1357 
1358   constIDMap[valueAttr] = resultID;
1359   return resultID;
1360 }
1361 
prepareArrayConstant(Location loc,Type constType,ArrayAttr attr)1362 uint32_t Serializer::prepareArrayConstant(Location loc, Type constType,
1363                                           ArrayAttr attr) {
1364   uint32_t typeID = 0;
1365   if (failed(processType(loc, constType, typeID))) {
1366     return 0;
1367   }
1368 
1369   uint32_t resultID = getNextID();
1370   SmallVector<uint32_t, 4> operands = {typeID, resultID};
1371   operands.reserve(attr.size() + 2);
1372   auto elementType = constType.cast<spirv::ArrayType>().getElementType();
1373   for (Attribute elementAttr : attr) {
1374     if (auto elementID = prepareConstant(loc, elementType, elementAttr)) {
1375       operands.push_back(elementID);
1376     } else {
1377       return 0;
1378     }
1379   }
1380   spirv::Opcode opcode = spirv::Opcode::OpConstantComposite;
1381   encodeInstructionInto(typesGlobalValues, opcode, operands);
1382 
1383   return resultID;
1384 }
1385 
1386 // TODO: Turn the below function into iterative function, instead of
1387 // recursive function.
1388 uint32_t
prepareDenseElementsConstant(Location loc,Type constType,DenseElementsAttr valueAttr,int dim,MutableArrayRef<uint64_t> index)1389 Serializer::prepareDenseElementsConstant(Location loc, Type constType,
1390                                          DenseElementsAttr valueAttr, int dim,
1391                                          MutableArrayRef<uint64_t> index) {
1392   auto shapedType = valueAttr.getType().dyn_cast<ShapedType>();
1393   assert(dim <= shapedType.getRank());
1394   if (shapedType.getRank() == dim) {
1395     if (auto attr = valueAttr.dyn_cast<DenseIntElementsAttr>()) {
1396       return attr.getType().getElementType().isInteger(1)
1397                  ? prepareConstantBool(loc, attr.getValue<BoolAttr>(index))
1398                  : prepareConstantInt(loc, attr.getValue<IntegerAttr>(index));
1399     }
1400     if (auto attr = valueAttr.dyn_cast<DenseFPElementsAttr>()) {
1401       return prepareConstantFp(loc, attr.getValue<FloatAttr>(index));
1402     }
1403     return 0;
1404   }
1405 
1406   uint32_t typeID = 0;
1407   if (failed(processType(loc, constType, typeID))) {
1408     return 0;
1409   }
1410 
1411   uint32_t resultID = getNextID();
1412   SmallVector<uint32_t, 4> operands = {typeID, resultID};
1413   operands.reserve(shapedType.getDimSize(dim) + 2);
1414   auto elementType = constType.cast<spirv::CompositeType>().getElementType(0);
1415   for (int i = 0; i < shapedType.getDimSize(dim); ++i) {
1416     index[dim] = i;
1417     if (auto elementID = prepareDenseElementsConstant(
1418             loc, elementType, valueAttr, dim + 1, index)) {
1419       operands.push_back(elementID);
1420     } else {
1421       return 0;
1422     }
1423   }
1424   spirv::Opcode opcode = spirv::Opcode::OpConstantComposite;
1425   encodeInstructionInto(typesGlobalValues, opcode, operands);
1426 
1427   return resultID;
1428 }
1429 
prepareConstantScalar(Location loc,Attribute valueAttr,bool isSpec)1430 uint32_t Serializer::prepareConstantScalar(Location loc, Attribute valueAttr,
1431                                            bool isSpec) {
1432   if (auto floatAttr = valueAttr.dyn_cast<FloatAttr>()) {
1433     return prepareConstantFp(loc, floatAttr, isSpec);
1434   }
1435   if (auto boolAttr = valueAttr.dyn_cast<BoolAttr>()) {
1436     return prepareConstantBool(loc, boolAttr, isSpec);
1437   }
1438   if (auto intAttr = valueAttr.dyn_cast<IntegerAttr>()) {
1439     return prepareConstantInt(loc, intAttr, isSpec);
1440   }
1441 
1442   return 0;
1443 }
1444 
prepareConstantBool(Location loc,BoolAttr boolAttr,bool isSpec)1445 uint32_t Serializer::prepareConstantBool(Location loc, BoolAttr boolAttr,
1446                                          bool isSpec) {
1447   if (!isSpec) {
1448     // We can de-duplicate normal constants, but not specialization constants.
1449     if (auto id = getConstantID(boolAttr)) {
1450       return id;
1451     }
1452   }
1453 
1454   // Process the type for this bool literal
1455   uint32_t typeID = 0;
1456   if (failed(processType(loc, boolAttr.getType(), typeID))) {
1457     return 0;
1458   }
1459 
1460   auto resultID = getNextID();
1461   auto opcode = boolAttr.getValue()
1462                     ? (isSpec ? spirv::Opcode::OpSpecConstantTrue
1463                               : spirv::Opcode::OpConstantTrue)
1464                     : (isSpec ? spirv::Opcode::OpSpecConstantFalse
1465                               : spirv::Opcode::OpConstantFalse);
1466   encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID});
1467 
1468   if (!isSpec) {
1469     constIDMap[boolAttr] = resultID;
1470   }
1471   return resultID;
1472 }
1473 
prepareConstantInt(Location loc,IntegerAttr intAttr,bool isSpec)1474 uint32_t Serializer::prepareConstantInt(Location loc, IntegerAttr intAttr,
1475                                         bool isSpec) {
1476   if (!isSpec) {
1477     // We can de-duplicate normal constants, but not specialization constants.
1478     if (auto id = getConstantID(intAttr)) {
1479       return id;
1480     }
1481   }
1482 
1483   // Process the type for this integer literal
1484   uint32_t typeID = 0;
1485   if (failed(processType(loc, intAttr.getType(), typeID))) {
1486     return 0;
1487   }
1488 
1489   auto resultID = getNextID();
1490   APInt value = intAttr.getValue();
1491   unsigned bitwidth = value.getBitWidth();
1492   bool isSigned = value.isSignedIntN(bitwidth);
1493 
1494   auto opcode =
1495       isSpec ? spirv::Opcode::OpSpecConstant : spirv::Opcode::OpConstant;
1496 
1497   // According to SPIR-V spec, "When the type's bit width is less than 32-bits,
1498   // the literal's value appears in the low-order bits of the word, and the
1499   // high-order bits must be 0 for a floating-point type, or 0 for an integer
1500   // type with Signedness of 0, or sign extended when Signedness is 1."
1501   if (bitwidth == 32 || bitwidth == 16) {
1502     uint32_t word = 0;
1503     if (isSigned) {
1504       word = static_cast<int32_t>(value.getSExtValue());
1505     } else {
1506       word = static_cast<uint32_t>(value.getZExtValue());
1507     }
1508     encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, word});
1509   }
1510   // According to SPIR-V spec: "When the type's bit width is larger than one
1511   // word, the literal’s low-order words appear first."
1512   else if (bitwidth == 64) {
1513     struct DoubleWord {
1514       uint32_t word1;
1515       uint32_t word2;
1516     } words;
1517     if (isSigned) {
1518       words = llvm::bit_cast<DoubleWord>(value.getSExtValue());
1519     } else {
1520       words = llvm::bit_cast<DoubleWord>(value.getZExtValue());
1521     }
1522     encodeInstructionInto(typesGlobalValues, opcode,
1523                           {typeID, resultID, words.word1, words.word2});
1524   } else {
1525     std::string valueStr;
1526     llvm::raw_string_ostream rss(valueStr);
1527     value.print(rss, /*isSigned=*/false);
1528 
1529     emitError(loc, "cannot serialize ")
1530         << bitwidth << "-bit integer literal: " << rss.str();
1531     return 0;
1532   }
1533 
1534   if (!isSpec) {
1535     constIDMap[intAttr] = resultID;
1536   }
1537   return resultID;
1538 }
1539 
prepareConstantFp(Location loc,FloatAttr floatAttr,bool isSpec)1540 uint32_t Serializer::prepareConstantFp(Location loc, FloatAttr floatAttr,
1541                                        bool isSpec) {
1542   if (!isSpec) {
1543     // We can de-duplicate normal constants, but not specialization constants.
1544     if (auto id = getConstantID(floatAttr)) {
1545       return id;
1546     }
1547   }
1548 
1549   // Process the type for this float literal
1550   uint32_t typeID = 0;
1551   if (failed(processType(loc, floatAttr.getType(), typeID))) {
1552     return 0;
1553   }
1554 
1555   auto resultID = getNextID();
1556   APFloat value = floatAttr.getValue();
1557   APInt intValue = value.bitcastToAPInt();
1558 
1559   auto opcode =
1560       isSpec ? spirv::Opcode::OpSpecConstant : spirv::Opcode::OpConstant;
1561 
1562   if (&value.getSemantics() == &APFloat::IEEEsingle()) {
1563     uint32_t word = llvm::bit_cast<uint32_t>(value.convertToFloat());
1564     encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, word});
1565   } else if (&value.getSemantics() == &APFloat::IEEEdouble()) {
1566     struct DoubleWord {
1567       uint32_t word1;
1568       uint32_t word2;
1569     } words = llvm::bit_cast<DoubleWord>(value.convertToDouble());
1570     encodeInstructionInto(typesGlobalValues, opcode,
1571                           {typeID, resultID, words.word1, words.word2});
1572   } else if (&value.getSemantics() == &APFloat::IEEEhalf()) {
1573     uint32_t word =
1574         static_cast<uint32_t>(value.bitcastToAPInt().getZExtValue());
1575     encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, word});
1576   } else {
1577     std::string valueStr;
1578     llvm::raw_string_ostream rss(valueStr);
1579     value.print(rss);
1580 
1581     emitError(loc, "cannot serialize ")
1582         << floatAttr.getType() << "-typed float literal: " << rss.str();
1583     return 0;
1584   }
1585 
1586   if (!isSpec) {
1587     constIDMap[floatAttr] = resultID;
1588   }
1589   return resultID;
1590 }
1591 
1592 //===----------------------------------------------------------------------===//
1593 // Control flow
1594 //===----------------------------------------------------------------------===//
1595 
getOrCreateBlockID(Block * block)1596 uint32_t Serializer::getOrCreateBlockID(Block *block) {
1597   if (uint32_t id = getBlockID(block))
1598     return id;
1599   return blockIDMap[block] = getNextID();
1600 }
1601 
1602 LogicalResult
processBlock(Block * block,bool omitLabel,function_ref<void ()> actionBeforeTerminator)1603 Serializer::processBlock(Block *block, bool omitLabel,
1604                          function_ref<void()> actionBeforeTerminator) {
1605   LLVM_DEBUG(llvm::dbgs() << "processing block " << block << ":\n");
1606   LLVM_DEBUG(block->print(llvm::dbgs()));
1607   LLVM_DEBUG(llvm::dbgs() << '\n');
1608   if (!omitLabel) {
1609     uint32_t blockID = getOrCreateBlockID(block);
1610     LLVM_DEBUG(llvm::dbgs()
1611                << "[block] " << block << " (id = " << blockID << ")\n");
1612 
1613     // Emit OpLabel for this block.
1614     encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {blockID});
1615   }
1616 
1617   // Emit OpPhi instructions for block arguments, if any.
1618   if (failed(emitPhiForBlockArguments(block)))
1619     return failure();
1620 
1621   // Process each op in this block except the terminator.
1622   for (auto &op : llvm::make_range(block->begin(), std::prev(block->end()))) {
1623     if (failed(processOperation(&op)))
1624       return failure();
1625   }
1626 
1627   // Process the terminator.
1628   if (actionBeforeTerminator)
1629     actionBeforeTerminator();
1630   if (failed(processOperation(&block->back())))
1631     return failure();
1632 
1633   return success();
1634 }
1635 
emitPhiForBlockArguments(Block * block)1636 LogicalResult Serializer::emitPhiForBlockArguments(Block *block) {
1637   // Nothing to do if this block has no arguments or it's the entry block, which
1638   // always has the same arguments as the function signature.
1639   if (block->args_empty() || block->isEntryBlock())
1640     return success();
1641 
1642   // If the block has arguments, we need to create SPIR-V OpPhi instructions.
1643   // A SPIR-V OpPhi instruction is of the syntax:
1644   //   OpPhi | result type | result <id> | (value <id>, parent block <id>) pair
1645   // So we need to collect all predecessor blocks and the arguments they send
1646   // to this block.
1647   SmallVector<std::pair<Block *, Operation::operand_iterator>, 4> predecessors;
1648   for (Block *predecessor : block->getPredecessors()) {
1649     auto *terminator = predecessor->getTerminator();
1650     // The predecessor here is the immediate one according to MLIR's IR
1651     // structure. It does not directly map to the incoming parent block for the
1652     // OpPhi instructions at SPIR-V binary level. This is because structured
1653     // control flow ops are serialized to multiple SPIR-V blocks. If there is a
1654     // spv.selection/spv.loop op in the MLIR predecessor block, the branch op
1655     // jumping to the OpPhi's block then resides in the previous structured
1656     // control flow op's merge block.
1657     predecessor = getPhiIncomingBlock(predecessor);
1658     if (auto branchOp = dyn_cast<spirv::BranchOp>(terminator)) {
1659       predecessors.emplace_back(predecessor, branchOp.operand_begin());
1660     } else {
1661       return terminator->emitError("unimplemented terminator for Phi creation");
1662     }
1663   }
1664 
1665   // Then create OpPhi instruction for each of the block argument.
1666   for (auto argIndex : llvm::seq<unsigned>(0, block->getNumArguments())) {
1667     BlockArgument arg = block->getArgument(argIndex);
1668 
1669     // Get the type <id> and result <id> for this OpPhi instruction.
1670     uint32_t phiTypeID = 0;
1671     if (failed(processType(arg.getLoc(), arg.getType(), phiTypeID)))
1672       return failure();
1673     uint32_t phiID = getNextID();
1674 
1675     LLVM_DEBUG(llvm::dbgs() << "[phi] for block argument #" << argIndex << ' '
1676                             << arg << " (id = " << phiID << ")\n");
1677 
1678     // Prepare the (value <id>, parent block <id>) pairs.
1679     SmallVector<uint32_t, 8> phiArgs;
1680     phiArgs.push_back(phiTypeID);
1681     phiArgs.push_back(phiID);
1682 
1683     for (auto predIndex : llvm::seq<unsigned>(0, predecessors.size())) {
1684       Value value = *(predecessors[predIndex].second + argIndex);
1685       uint32_t predBlockId = getOrCreateBlockID(predecessors[predIndex].first);
1686       LLVM_DEBUG(llvm::dbgs() << "[phi] use predecessor (id = " << predBlockId
1687                               << ") value " << value << ' ');
1688       // Each pair is a value <id> ...
1689       uint32_t valueId = getValueID(value);
1690       if (valueId == 0) {
1691         // The op generating this value hasn't been visited yet so we don't have
1692         // an <id> assigned yet. Record this to fix up later.
1693         LLVM_DEBUG(llvm::dbgs() << "(need to fix)\n");
1694         deferredPhiValues[value].push_back(functionBody.size() + 1 +
1695                                            phiArgs.size());
1696       } else {
1697         LLVM_DEBUG(llvm::dbgs() << "(id = " << valueId << ")\n");
1698       }
1699       phiArgs.push_back(valueId);
1700       // ... and a parent block <id>.
1701       phiArgs.push_back(predBlockId);
1702     }
1703 
1704     encodeInstructionInto(functionBody, spirv::Opcode::OpPhi, phiArgs);
1705     valueIDMap[arg] = phiID;
1706   }
1707 
1708   return success();
1709 }
1710 
processSelectionOp(spirv::SelectionOp selectionOp)1711 LogicalResult Serializer::processSelectionOp(spirv::SelectionOp selectionOp) {
1712   // Assign <id>s to all blocks so that branches inside the SelectionOp can
1713   // resolve properly.
1714   auto &body = selectionOp.body();
1715   for (Block &block : body)
1716     getOrCreateBlockID(&block);
1717 
1718   auto *headerBlock = selectionOp.getHeaderBlock();
1719   auto *mergeBlock = selectionOp.getMergeBlock();
1720   auto mergeID = getBlockID(mergeBlock);
1721   auto loc = selectionOp.getLoc();
1722 
1723   // Emit the selection header block, which dominates all other blocks, first.
1724   // We need to emit an OpSelectionMerge instruction before the selection header
1725   // block's terminator.
1726   auto emitSelectionMerge = [&]() {
1727     emitDebugLine(functionBody, loc);
1728     lastProcessedWasMergeInst = true;
1729     encodeInstructionInto(
1730         functionBody, spirv::Opcode::OpSelectionMerge,
1731         {mergeID, static_cast<uint32_t>(selectionOp.selection_control())});
1732   };
1733   // For structured selection, we cannot have blocks in the selection construct
1734   // branching to the selection header block. Entering the selection (and
1735   // reaching the selection header) must be from the block containing the
1736   // spv.selection op. If there are ops ahead of the spv.selection op in the
1737   // block, we can "merge" them into the selection header. So here we don't need
1738   // to emit a separate block; just continue with the existing block.
1739   if (failed(processBlock(headerBlock, /*omitLabel=*/true, emitSelectionMerge)))
1740     return failure();
1741 
1742   // Process all blocks with a depth-first visitor starting from the header
1743   // block. The selection header block and merge block are skipped by this
1744   // visitor.
1745   if (failed(visitInPrettyBlockOrder(
1746           headerBlock, [&](Block *block) { return processBlock(block); },
1747           /*skipHeader=*/true, /*skipBlocks=*/{mergeBlock})))
1748     return failure();
1749 
1750   // There is nothing to do for the merge block in the selection, which just
1751   // contains a spv.mlir.merge op, itself. But we need to have an OpLabel
1752   // instruction to start a new SPIR-V block for ops following this SelectionOp.
1753   // The block should use the <id> for the merge block.
1754   return encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {mergeID});
1755 }
1756 
processLoopOp(spirv::LoopOp loopOp)1757 LogicalResult Serializer::processLoopOp(spirv::LoopOp loopOp) {
1758   // Assign <id>s to all blocks so that branches inside the LoopOp can resolve
1759   // properly. We don't need to assign for the entry block, which is just for
1760   // satisfying MLIR region's structural requirement.
1761   auto &body = loopOp.body();
1762   for (Block &block :
1763        llvm::make_range(std::next(body.begin(), 1), body.end())) {
1764     getOrCreateBlockID(&block);
1765   }
1766   auto *headerBlock = loopOp.getHeaderBlock();
1767   auto *continueBlock = loopOp.getContinueBlock();
1768   auto *mergeBlock = loopOp.getMergeBlock();
1769   auto headerID = getBlockID(headerBlock);
1770   auto continueID = getBlockID(continueBlock);
1771   auto mergeID = getBlockID(mergeBlock);
1772   auto loc = loopOp.getLoc();
1773 
1774   // This LoopOp is in some MLIR block with preceding and following ops. In the
1775   // binary format, it should reside in separate SPIR-V blocks from its
1776   // preceding and following ops. So we need to emit unconditional branches to
1777   // jump to this LoopOp's SPIR-V blocks and jumping back to the normal flow
1778   // afterwards.
1779   encodeInstructionInto(functionBody, spirv::Opcode::OpBranch, {headerID});
1780 
1781   // LoopOp's entry block is just there for satisfying MLIR's structural
1782   // requirements so we omit it and start serialization from the loop header
1783   // block.
1784 
1785   // Emit the loop header block, which dominates all other blocks, first. We
1786   // need to emit an OpLoopMerge instruction before the loop header block's
1787   // terminator.
1788   auto emitLoopMerge = [&]() {
1789     emitDebugLine(functionBody, loc);
1790     lastProcessedWasMergeInst = true;
1791     encodeInstructionInto(
1792         functionBody, spirv::Opcode::OpLoopMerge,
1793         {mergeID, continueID, static_cast<uint32_t>(loopOp.loop_control())});
1794   };
1795   if (failed(processBlock(headerBlock, /*omitLabel=*/false, emitLoopMerge)))
1796     return failure();
1797 
1798   // Process all blocks with a depth-first visitor starting from the header
1799   // block. The loop header block, loop continue block, and loop merge block are
1800   // skipped by this visitor and handled later in this function.
1801   if (failed(visitInPrettyBlockOrder(
1802           headerBlock, [&](Block *block) { return processBlock(block); },
1803           /*skipHeader=*/true, /*skipBlocks=*/{continueBlock, mergeBlock})))
1804     return failure();
1805 
1806   // We have handled all other blocks. Now get to the loop continue block.
1807   if (failed(processBlock(continueBlock)))
1808     return failure();
1809 
1810   // There is nothing to do for the merge block in the loop, which just contains
1811   // a spv.mlir.merge op, itself. But we need to have an OpLabel instruction to
1812   // start a new SPIR-V block for ops following this LoopOp. The block should
1813   // use the <id> for the merge block.
1814   return encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {mergeID});
1815 }
1816 
processBranchConditionalOp(spirv::BranchConditionalOp condBranchOp)1817 LogicalResult Serializer::processBranchConditionalOp(
1818     spirv::BranchConditionalOp condBranchOp) {
1819   auto conditionID = getValueID(condBranchOp.condition());
1820   auto trueLabelID = getOrCreateBlockID(condBranchOp.getTrueBlock());
1821   auto falseLabelID = getOrCreateBlockID(condBranchOp.getFalseBlock());
1822   SmallVector<uint32_t, 5> arguments{conditionID, trueLabelID, falseLabelID};
1823 
1824   if (auto weights = condBranchOp.branch_weights()) {
1825     for (auto val : weights->getValue())
1826       arguments.push_back(val.cast<IntegerAttr>().getInt());
1827   }
1828 
1829   emitDebugLine(functionBody, condBranchOp.getLoc());
1830   return encodeInstructionInto(functionBody, spirv::Opcode::OpBranchConditional,
1831                                arguments);
1832 }
1833 
processBranchOp(spirv::BranchOp branchOp)1834 LogicalResult Serializer::processBranchOp(spirv::BranchOp branchOp) {
1835   emitDebugLine(functionBody, branchOp.getLoc());
1836   return encodeInstructionInto(functionBody, spirv::Opcode::OpBranch,
1837                                {getOrCreateBlockID(branchOp.getTarget())});
1838 }
1839 
1840 //===----------------------------------------------------------------------===//
1841 // Operation
1842 //===----------------------------------------------------------------------===//
1843 
encodeExtensionInstruction(Operation * op,StringRef extensionSetName,uint32_t extensionOpcode,ArrayRef<uint32_t> operands)1844 LogicalResult Serializer::encodeExtensionInstruction(
1845     Operation *op, StringRef extensionSetName, uint32_t extensionOpcode,
1846     ArrayRef<uint32_t> operands) {
1847   // Check if the extension has been imported.
1848   auto &setID = extendedInstSetIDMap[extensionSetName];
1849   if (!setID) {
1850     setID = getNextID();
1851     SmallVector<uint32_t, 16> importOperands;
1852     importOperands.push_back(setID);
1853     if (failed(
1854             spirv::encodeStringLiteralInto(importOperands, extensionSetName)) ||
1855         failed(encodeInstructionInto(
1856             extendedSets, spirv::Opcode::OpExtInstImport, importOperands))) {
1857       return failure();
1858     }
1859   }
1860 
1861   // The first two operands are the result type <id> and result <id>. The set
1862   // <id> and the opcode need to be insert after this.
1863   if (operands.size() < 2) {
1864     return op->emitError("extended instructions must have a result encoding");
1865   }
1866   SmallVector<uint32_t, 8> extInstOperands;
1867   extInstOperands.reserve(operands.size() + 2);
1868   extInstOperands.append(operands.begin(), std::next(operands.begin(), 2));
1869   extInstOperands.push_back(setID);
1870   extInstOperands.push_back(extensionOpcode);
1871   extInstOperands.append(std::next(operands.begin(), 2), operands.end());
1872   return encodeInstructionInto(functionBody, spirv::Opcode::OpExtInst,
1873                                extInstOperands);
1874 }
1875 
processAddressOfOp(spirv::AddressOfOp addressOfOp)1876 LogicalResult Serializer::processAddressOfOp(spirv::AddressOfOp addressOfOp) {
1877   auto varName = addressOfOp.variable();
1878   auto variableID = getVariableID(varName);
1879   if (!variableID) {
1880     return addressOfOp.emitError("unknown result <id> for variable ")
1881            << varName;
1882   }
1883   valueIDMap[addressOfOp.pointer()] = variableID;
1884   return success();
1885 }
1886 
1887 LogicalResult
processReferenceOfOp(spirv::ReferenceOfOp referenceOfOp)1888 Serializer::processReferenceOfOp(spirv::ReferenceOfOp referenceOfOp) {
1889   auto constName = referenceOfOp.spec_const();
1890   auto constID = getSpecConstID(constName);
1891   if (!constID) {
1892     return referenceOfOp.emitError(
1893                "unknown result <id> for specialization constant ")
1894            << constName;
1895   }
1896   valueIDMap[referenceOfOp.reference()] = constID;
1897   return success();
1898 }
1899 
processOperation(Operation * opInst)1900 LogicalResult Serializer::processOperation(Operation *opInst) {
1901   LLVM_DEBUG(llvm::dbgs() << "[op] '" << opInst->getName() << "'\n");
1902 
1903   // First dispatch the ops that do not directly mirror an instruction from
1904   // the SPIR-V spec.
1905   return TypeSwitch<Operation *, LogicalResult>(opInst)
1906       .Case([&](spirv::AddressOfOp op) { return processAddressOfOp(op); })
1907       .Case([&](spirv::BranchOp op) { return processBranchOp(op); })
1908       .Case([&](spirv::BranchConditionalOp op) {
1909         return processBranchConditionalOp(op);
1910       })
1911       .Case([&](spirv::ConstantOp op) { return processConstantOp(op); })
1912       .Case([&](spirv::FuncOp op) { return processFuncOp(op); })
1913       .Case([&](spirv::GlobalVariableOp op) {
1914         return processGlobalVariableOp(op);
1915       })
1916       .Case([&](spirv::LoopOp op) { return processLoopOp(op); })
1917       .Case([&](spirv::ModuleEndOp) { return success(); })
1918       .Case([&](spirv::ReferenceOfOp op) { return processReferenceOfOp(op); })
1919       .Case([&](spirv::SelectionOp op) { return processSelectionOp(op); })
1920       .Case([&](spirv::SpecConstantOp op) { return processSpecConstantOp(op); })
1921       .Case([&](spirv::SpecConstantCompositeOp op) {
1922         return processSpecConstantCompositeOp(op);
1923       })
1924       .Case([&](spirv::UndefOp op) { return processUndefOp(op); })
1925       .Case([&](spirv::VariableOp op) { return processVariableOp(op); })
1926 
1927       // Then handle all the ops that directly mirror SPIR-V instructions with
1928       // auto-generated methods.
1929       .Default(
1930           [&](Operation *op) { return dispatchToAutogenSerialization(op); });
1931 }
1932 
1933 namespace {
1934 template <>
1935 LogicalResult
processOp(spirv::EntryPointOp op)1936 Serializer::processOp<spirv::EntryPointOp>(spirv::EntryPointOp op) {
1937   SmallVector<uint32_t, 4> operands;
1938   // Add the ExecutionModel.
1939   operands.push_back(static_cast<uint32_t>(op.execution_model()));
1940   // Add the function <id>.
1941   auto funcID = getFunctionID(op.fn());
1942   if (!funcID) {
1943     return op.emitError("missing <id> for function ")
1944            << op.fn()
1945            << "; function needs to be defined before spv.EntryPoint is "
1946               "serialized";
1947   }
1948   operands.push_back(funcID);
1949   // Add the name of the function.
1950   spirv::encodeStringLiteralInto(operands, op.fn());
1951 
1952   // Add the interface values.
1953   if (auto interface = op.interface()) {
1954     for (auto var : interface.getValue()) {
1955       auto id = getVariableID(var.cast<FlatSymbolRefAttr>().getValue());
1956       if (!id) {
1957         return op.emitError("referencing undefined global variable."
1958                             "spv.EntryPoint is at the end of spv.module. All "
1959                             "referenced variables should already be defined");
1960       }
1961       operands.push_back(id);
1962     }
1963   }
1964   return encodeInstructionInto(entryPoints, spirv::Opcode::OpEntryPoint,
1965                                operands);
1966 }
1967 
1968 template <>
1969 LogicalResult
processOp(spirv::ControlBarrierOp op)1970 Serializer::processOp<spirv::ControlBarrierOp>(spirv::ControlBarrierOp op) {
1971   StringRef argNames[] = {"execution_scope", "memory_scope",
1972                           "memory_semantics"};
1973   SmallVector<uint32_t, 3> operands;
1974 
1975   for (auto argName : argNames) {
1976     auto argIntAttr = op->getAttrOfType<IntegerAttr>(argName);
1977     auto operand = prepareConstantInt(op.getLoc(), argIntAttr);
1978     if (!operand) {
1979       return failure();
1980     }
1981     operands.push_back(operand);
1982   }
1983 
1984   return encodeInstructionInto(functionBody, spirv::Opcode::OpControlBarrier,
1985                                operands);
1986 }
1987 
1988 template <>
1989 LogicalResult
processOp(spirv::ExecutionModeOp op)1990 Serializer::processOp<spirv::ExecutionModeOp>(spirv::ExecutionModeOp op) {
1991   SmallVector<uint32_t, 4> operands;
1992   // Add the function <id>.
1993   auto funcID = getFunctionID(op.fn());
1994   if (!funcID) {
1995     return op.emitError("missing <id> for function ")
1996            << op.fn()
1997            << "; function needs to be serialized before ExecutionModeOp is "
1998               "serialized";
1999   }
2000   operands.push_back(funcID);
2001   // Add the ExecutionMode.
2002   operands.push_back(static_cast<uint32_t>(op.execution_mode()));
2003 
2004   // Serialize values if any.
2005   auto values = op.values();
2006   if (values) {
2007     for (auto &intVal : values.getValue()) {
2008       operands.push_back(static_cast<uint32_t>(
2009           intVal.cast<IntegerAttr>().getValue().getZExtValue()));
2010     }
2011   }
2012   return encodeInstructionInto(executionModes, spirv::Opcode::OpExecutionMode,
2013                                operands);
2014 }
2015 
2016 template <>
2017 LogicalResult
processOp(spirv::MemoryBarrierOp op)2018 Serializer::processOp<spirv::MemoryBarrierOp>(spirv::MemoryBarrierOp op) {
2019   StringRef argNames[] = {"memory_scope", "memory_semantics"};
2020   SmallVector<uint32_t, 2> operands;
2021 
2022   for (auto argName : argNames) {
2023     auto argIntAttr = op->getAttrOfType<IntegerAttr>(argName);
2024     auto operand = prepareConstantInt(op.getLoc(), argIntAttr);
2025     if (!operand) {
2026       return failure();
2027     }
2028     operands.push_back(operand);
2029   }
2030 
2031   return encodeInstructionInto(functionBody, spirv::Opcode::OpMemoryBarrier,
2032                                operands);
2033 }
2034 
2035 template <>
2036 LogicalResult
processOp(spirv::FunctionCallOp op)2037 Serializer::processOp<spirv::FunctionCallOp>(spirv::FunctionCallOp op) {
2038   auto funcName = op.callee();
2039   uint32_t resTypeID = 0;
2040 
2041   Type resultTy = op.getNumResults() ? *op.result_type_begin() : getVoidType();
2042   if (failed(processType(op.getLoc(), resultTy, resTypeID)))
2043     return failure();
2044 
2045   auto funcID = getOrCreateFunctionID(funcName);
2046   auto funcCallID = getNextID();
2047   SmallVector<uint32_t, 8> operands{resTypeID, funcCallID, funcID};
2048 
2049   for (auto value : op.arguments()) {
2050     auto valueID = getValueID(value);
2051     assert(valueID && "cannot find a value for spv.FunctionCall");
2052     operands.push_back(valueID);
2053   }
2054 
2055   if (!resultTy.isa<NoneType>())
2056     valueIDMap[op.getResult(0)] = funcCallID;
2057 
2058   return encodeInstructionInto(functionBody, spirv::Opcode::OpFunctionCall,
2059                                operands);
2060 }
2061 
2062 template <>
2063 LogicalResult
processOp(spirv::CopyMemoryOp op)2064 Serializer::processOp<spirv::CopyMemoryOp>(spirv::CopyMemoryOp op) {
2065   SmallVector<uint32_t, 4> operands;
2066   SmallVector<StringRef, 2> elidedAttrs;
2067 
2068   for (Value operand : op->getOperands()) {
2069     auto id = getValueID(operand);
2070     assert(id && "use before def!");
2071     operands.push_back(id);
2072   }
2073 
2074   if (auto attr = op.getAttr("memory_access")) {
2075     operands.push_back(static_cast<uint32_t>(
2076         attr.cast<IntegerAttr>().getValue().getZExtValue()));
2077   }
2078 
2079   elidedAttrs.push_back("memory_access");
2080 
2081   if (auto attr = op.getAttr("alignment")) {
2082     operands.push_back(static_cast<uint32_t>(
2083         attr.cast<IntegerAttr>().getValue().getZExtValue()));
2084   }
2085 
2086   elidedAttrs.push_back("alignment");
2087 
2088   if (auto attr = op.getAttr("source_memory_access")) {
2089     operands.push_back(static_cast<uint32_t>(
2090         attr.cast<IntegerAttr>().getValue().getZExtValue()));
2091   }
2092 
2093   elidedAttrs.push_back("source_memory_access");
2094 
2095   if (auto attr = op.getAttr("source_alignment")) {
2096     operands.push_back(static_cast<uint32_t>(
2097         attr.cast<IntegerAttr>().getValue().getZExtValue()));
2098   }
2099 
2100   elidedAttrs.push_back("source_alignment");
2101   emitDebugLine(functionBody, op.getLoc());
2102   encodeInstructionInto(functionBody, spirv::Opcode::OpCopyMemory, operands);
2103 
2104   return success();
2105 }
2106 
2107 // Pull in auto-generated Serializer::dispatchToAutogenSerialization() and
2108 // various Serializer::processOp<...>() specializations.
2109 #define GET_SERIALIZATION_FNS
2110 #include "mlir/Dialect/SPIRV/SPIRVSerialization.inc"
2111 } // namespace
2112 
emitDecoration(uint32_t target,spirv::Decoration decoration,ArrayRef<uint32_t> params)2113 LogicalResult Serializer::emitDecoration(uint32_t target,
2114                                          spirv::Decoration decoration,
2115                                          ArrayRef<uint32_t> params) {
2116   uint32_t wordCount = 3 + params.size();
2117   decorations.push_back(
2118       spirv::getPrefixedOpcode(wordCount, spirv::Opcode::OpDecorate));
2119   decorations.push_back(target);
2120   decorations.push_back(static_cast<uint32_t>(decoration));
2121   decorations.append(params.begin(), params.end());
2122   return success();
2123 }
2124 
emitDebugLine(SmallVectorImpl<uint32_t> & binary,Location loc)2125 LogicalResult Serializer::emitDebugLine(SmallVectorImpl<uint32_t> &binary,
2126                                         Location loc) {
2127   if (!emitDebugInfo)
2128     return success();
2129 
2130   if (lastProcessedWasMergeInst) {
2131     lastProcessedWasMergeInst = false;
2132     return success();
2133   }
2134 
2135   auto fileLoc = loc.dyn_cast<FileLineColLoc>();
2136   if (fileLoc)
2137     encodeInstructionInto(binary, spirv::Opcode::OpLine,
2138                           {fileID, fileLoc.getLine(), fileLoc.getColumn()});
2139   return success();
2140 }
2141 
serialize(spirv::ModuleOp module,SmallVectorImpl<uint32_t> & binary,bool emitDebugInfo)2142 LogicalResult spirv::serialize(spirv::ModuleOp module,
2143                                SmallVectorImpl<uint32_t> &binary,
2144                                bool emitDebugInfo) {
2145   if (!module.vce_triple().hasValue())
2146     return module.emitError(
2147         "module must have 'vce_triple' attribute to be serializeable");
2148 
2149   Serializer serializer(module, emitDebugInfo);
2150 
2151   if (failed(serializer.serialize()))
2152     return failure();
2153 
2154   LLVM_DEBUG(serializer.printValueIDMap(llvm::dbgs()));
2155 
2156   serializer.collect(binary);
2157   return success();
2158 }
2159