1 //===- TranslateRegistration.cpp - hooks to mlir-translate ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a translation from SPIR-V binary module to MLIR SPIR-V
10 // ModuleOp.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "mlir/Dialect/SPIRV/SPIRVDialect.h"
15 #include "mlir/Dialect/SPIRV/SPIRVModule.h"
16 #include "mlir/Dialect/SPIRV/SPIRVOps.h"
17 #include "mlir/Dialect/SPIRV/Serialization.h"
18 #include "mlir/IR/Builders.h"
19 #include "mlir/IR/BuiltinOps.h"
20 #include "mlir/IR/Dialect.h"
21 #include "mlir/Parser.h"
22 #include "mlir/Support/FileUtilities.h"
23 #include "mlir/Translation.h"
24 #include "llvm/ADT/StringRef.h"
25 #include "llvm/Support/MemoryBuffer.h"
26 #include "llvm/Support/SMLoc.h"
27 #include "llvm/Support/SourceMgr.h"
28 #include "llvm/Support/ToolOutputFile.h"
29
30 using namespace mlir;
31
32 //===----------------------------------------------------------------------===//
33 // Deserialization registration
34 //===----------------------------------------------------------------------===//
35
36 // Deserializes the SPIR-V binary module stored in the file named as
37 // `inputFilename` and returns a module containing the SPIR-V module.
deserializeModule(const llvm::MemoryBuffer * input,MLIRContext * context)38 static OwningModuleRef deserializeModule(const llvm::MemoryBuffer *input,
39 MLIRContext *context) {
40 context->loadDialect<spirv::SPIRVDialect>();
41
42 // Make sure the input stream can be treated as a stream of SPIR-V words
43 auto start = input->getBufferStart();
44 auto size = input->getBufferSize();
45 if (size % sizeof(uint32_t) != 0) {
46 emitError(UnknownLoc::get(context))
47 << "SPIR-V binary module must contain integral number of 32-bit words";
48 return {};
49 }
50
51 auto binary = llvm::makeArrayRef(reinterpret_cast<const uint32_t *>(start),
52 size / sizeof(uint32_t));
53
54 spirv::OwningSPIRVModuleRef spirvModule = spirv::deserialize(binary, context);
55 if (!spirvModule)
56 return {};
57
58 OwningModuleRef module(ModuleOp::create(FileLineColLoc::get(
59 input->getBufferIdentifier(), /*line=*/0, /*column=*/0, context)));
60 module->getBody()->push_front(spirvModule.release());
61
62 return module;
63 }
64
65 namespace mlir {
registerFromSPIRVTranslation()66 void registerFromSPIRVTranslation() {
67 TranslateToMLIRRegistration fromBinary(
68 "deserialize-spirv",
69 [](llvm::SourceMgr &sourceMgr, MLIRContext *context) {
70 assert(sourceMgr.getNumBuffers() == 1 && "expected one buffer");
71 return deserializeModule(
72 sourceMgr.getMemoryBuffer(sourceMgr.getMainFileID()), context);
73 });
74 }
75 } // namespace mlir
76
77 //===----------------------------------------------------------------------===//
78 // Serialization registration
79 //===----------------------------------------------------------------------===//
80
serializeModule(ModuleOp module,raw_ostream & output)81 static LogicalResult serializeModule(ModuleOp module, raw_ostream &output) {
82 if (!module)
83 return failure();
84
85 SmallVector<uint32_t, 0> binary;
86
87 SmallVector<spirv::ModuleOp, 1> spirvModules;
88 module.walk([&](spirv::ModuleOp op) { spirvModules.push_back(op); });
89
90 if (spirvModules.empty())
91 return module.emitError("found no 'spv.module' op");
92
93 if (spirvModules.size() != 1)
94 return module.emitError("found more than one 'spv.module' op");
95
96 if (failed(
97 spirv::serialize(spirvModules[0], binary, /*emitDebuginfo=*/false)))
98 return failure();
99
100 output.write(reinterpret_cast<char *>(binary.data()),
101 binary.size() * sizeof(uint32_t));
102
103 return mlir::success();
104 }
105
106 namespace mlir {
registerToSPIRVTranslation()107 void registerToSPIRVTranslation() {
108 TranslateFromMLIRRegistration toBinary(
109 "serialize-spirv",
110 [](ModuleOp module, raw_ostream &output) {
111 return serializeModule(module, output);
112 },
113 [](DialectRegistry ®istry) {
114 registry.insert<spirv::SPIRVDialect>();
115 });
116 }
117 } // namespace mlir
118
119 //===----------------------------------------------------------------------===//
120 // Round-trip registration
121 //===----------------------------------------------------------------------===//
122
roundTripModule(ModuleOp srcModule,bool emitDebugInfo,raw_ostream & output)123 static LogicalResult roundTripModule(ModuleOp srcModule, bool emitDebugInfo,
124 raw_ostream &output) {
125 SmallVector<uint32_t, 0> binary;
126 MLIRContext *context = srcModule.getContext();
127 auto spirvModules = srcModule.getOps<spirv::ModuleOp>();
128
129 if (spirvModules.begin() == spirvModules.end())
130 return srcModule.emitError("found no 'spv.module' op");
131
132 if (std::next(spirvModules.begin()) != spirvModules.end())
133 return srcModule.emitError("found more than one 'spv.module' op");
134
135 if (failed(spirv::serialize(*spirvModules.begin(), binary, emitDebugInfo)))
136 return failure();
137
138 MLIRContext deserializationContext;
139 context->getDialectRegistry().loadAll(&deserializationContext);
140 // Then deserialize to get back a SPIR-V module.
141 spirv::OwningSPIRVModuleRef spirvModule =
142 spirv::deserialize(binary, &deserializationContext);
143 if (!spirvModule)
144 return failure();
145
146 // Wrap around in a new MLIR module.
147 OwningModuleRef dstModule(ModuleOp::create(FileLineColLoc::get(
148 /*filename=*/"", /*line=*/0, /*column=*/0, &deserializationContext)));
149 dstModule->getBody()->push_front(spirvModule.release());
150 dstModule->print(output);
151
152 return mlir::success();
153 }
154
155 namespace mlir {
registerTestRoundtripSPIRV()156 void registerTestRoundtripSPIRV() {
157 TranslateFromMLIRRegistration roundtrip(
158 "test-spirv-roundtrip",
159 [](ModuleOp module, raw_ostream &output) {
160 return roundTripModule(module, /*emitDebugInfo=*/false, output);
161 },
162 [](DialectRegistry ®istry) {
163 registry.insert<spirv::SPIRVDialect>();
164 });
165 }
166
registerTestRoundtripDebugSPIRV()167 void registerTestRoundtripDebugSPIRV() {
168 TranslateFromMLIRRegistration roundtrip(
169 "test-spirv-roundtrip-debug",
170 [](ModuleOp module, raw_ostream &output) {
171 return roundTripModule(module, /*emitDebugInfo=*/true, output);
172 },
173 [](DialectRegistry ®istry) {
174 registry.insert<spirv::SPIRVDialect>();
175 });
176 }
177 } // namespace mlir
178