1 //===- Translation.cpp - Translation registry -----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Definitions of the translation registry.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "mlir/Translation.h"
14 #include "mlir/IR/AsmState.h"
15 #include "mlir/IR/BuiltinOps.h"
16 #include "mlir/IR/Verifier.h"
17 #include "mlir/Parser.h"
18 #include "mlir/Support/FileUtilities.h"
19 #include "mlir/Support/ToolUtilities.h"
20 #include "llvm/Support/InitLLVM.h"
21 #include "llvm/Support/SourceMgr.h"
22 #include "llvm/Support/ToolOutputFile.h"
23
24 using namespace mlir;
25
26 //===----------------------------------------------------------------------===//
27 // Translation Registry
28 //===----------------------------------------------------------------------===//
29
30 /// Get the mutable static map between registered file-to-file MLIR translations
31 /// and the TranslateFunctions that perform those translations.
getTranslationRegistry()32 static llvm::StringMap<TranslateFunction> &getTranslationRegistry() {
33 static llvm::StringMap<TranslateFunction> translationRegistry;
34 return translationRegistry;
35 }
36
37 /// Register the given translation.
registerTranslation(StringRef name,const TranslateFunction & function)38 static void registerTranslation(StringRef name,
39 const TranslateFunction &function) {
40 auto &translationRegistry = getTranslationRegistry();
41 if (translationRegistry.find(name) != translationRegistry.end())
42 llvm::report_fatal_error(
43 "Attempting to overwrite an existing <file-to-file> function");
44 assert(function &&
45 "Attempting to register an empty translate <file-to-file> function");
46 translationRegistry[name] = function;
47 }
48
TranslateRegistration(StringRef name,const TranslateFunction & function)49 TranslateRegistration::TranslateRegistration(
50 StringRef name, const TranslateFunction &function) {
51 registerTranslation(name, function);
52 }
53
54 //===----------------------------------------------------------------------===//
55 // Translation to MLIR
56 //===----------------------------------------------------------------------===//
57
58 // Puts `function` into the to-MLIR translation registry unless there is already
59 // a function registered for the same name.
registerTranslateToMLIRFunction(StringRef name,const TranslateSourceMgrToMLIRFunction & function)60 static void registerTranslateToMLIRFunction(
61 StringRef name, const TranslateSourceMgrToMLIRFunction &function) {
62 auto wrappedFn = [function](llvm::SourceMgr &sourceMgr, raw_ostream &output,
63 MLIRContext *context) {
64 OwningModuleRef module = function(sourceMgr, context);
65 if (!module || failed(verify(*module)))
66 return failure();
67 module->print(output);
68 return success();
69 };
70 registerTranslation(name, wrappedFn);
71 }
72
TranslateToMLIRRegistration(StringRef name,const TranslateSourceMgrToMLIRFunction & function)73 TranslateToMLIRRegistration::TranslateToMLIRRegistration(
74 StringRef name, const TranslateSourceMgrToMLIRFunction &function) {
75 registerTranslateToMLIRFunction(name, function);
76 }
77
78 /// Wraps `function` with a lambda that extracts a StringRef from a source
79 /// manager and registers the wrapper lambda as a to-MLIR conversion.
TranslateToMLIRRegistration(StringRef name,const TranslateStringRefToMLIRFunction & function)80 TranslateToMLIRRegistration::TranslateToMLIRRegistration(
81 StringRef name, const TranslateStringRefToMLIRFunction &function) {
82 registerTranslateToMLIRFunction(
83 name, [function](llvm::SourceMgr &sourceMgr, MLIRContext *ctx) {
84 const llvm::MemoryBuffer *buffer =
85 sourceMgr.getMemoryBuffer(sourceMgr.getMainFileID());
86 return function(buffer->getBuffer(), ctx);
87 });
88 }
89
90 //===----------------------------------------------------------------------===//
91 // Translation from MLIR
92 //===----------------------------------------------------------------------===//
93
TranslateFromMLIRRegistration(StringRef name,const TranslateFromMLIRFunction & function,std::function<void (DialectRegistry &)> dialectRegistration)94 TranslateFromMLIRRegistration::TranslateFromMLIRRegistration(
95 StringRef name, const TranslateFromMLIRFunction &function,
96 std::function<void(DialectRegistry &)> dialectRegistration) {
97 registerTranslation(name, [function, dialectRegistration](
98 llvm::SourceMgr &sourceMgr, raw_ostream &output,
99 MLIRContext *context) {
100 dialectRegistration(context->getDialectRegistry());
101 auto module = OwningModuleRef(parseSourceFile(sourceMgr, context));
102 if (!module)
103 return failure();
104 return function(module.get(), output);
105 });
106 }
107
108 //===----------------------------------------------------------------------===//
109 // Translation Parser
110 //===----------------------------------------------------------------------===//
111
TranslationParser(llvm::cl::Option & opt)112 TranslationParser::TranslationParser(llvm::cl::Option &opt)
113 : llvm::cl::parser<const TranslateFunction *>(opt) {
114 for (const auto &kv : getTranslationRegistry())
115 addLiteralOption(kv.first(), &kv.second, kv.first());
116 }
117
printOptionInfo(const llvm::cl::Option & o,size_t globalWidth) const118 void TranslationParser::printOptionInfo(const llvm::cl::Option &o,
119 size_t globalWidth) const {
120 TranslationParser *tp = const_cast<TranslationParser *>(this);
121 llvm::array_pod_sort(tp->Values.begin(), tp->Values.end(),
122 [](const TranslationParser::OptionInfo *lhs,
123 const TranslationParser::OptionInfo *rhs) {
124 return lhs->Name.compare(rhs->Name);
125 });
126 llvm::cl::parser<const TranslateFunction *>::printOptionInfo(o, globalWidth);
127 }
128
mlirTranslateMain(int argc,char ** argv,llvm::StringRef toolName)129 LogicalResult mlir::mlirTranslateMain(int argc, char **argv,
130 llvm::StringRef toolName) {
131
132 static llvm::cl::opt<std::string> inputFilename(
133 llvm::cl::Positional, llvm::cl::desc("<input file>"),
134 llvm::cl::init("-"));
135
136 static llvm::cl::opt<std::string> outputFilename(
137 "o", llvm::cl::desc("Output filename"), llvm::cl::value_desc("filename"),
138 llvm::cl::init("-"));
139
140 static llvm::cl::opt<bool> splitInputFile(
141 "split-input-file",
142 llvm::cl::desc("Split the input file into pieces and "
143 "process each chunk independently"),
144 llvm::cl::init(false));
145
146 static llvm::cl::opt<bool> verifyDiagnostics(
147 "verify-diagnostics",
148 llvm::cl::desc("Check that emitted diagnostics match "
149 "expected-* lines on the corresponding line"),
150 llvm::cl::init(false));
151
152 llvm::InitLLVM y(argc, argv);
153
154 // Add flags for all the registered translations.
155 llvm::cl::opt<const TranslateFunction *, false, TranslationParser>
156 translationRequested("", llvm::cl::desc("Translation to perform"),
157 llvm::cl::Required);
158 registerAsmPrinterCLOptions();
159 registerMLIRContextCLOptions();
160 llvm::cl::ParseCommandLineOptions(argc, argv, "MLIR translation driver\n");
161
162 std::string errorMessage;
163 auto input = openInputFile(inputFilename, &errorMessage);
164 if (!input) {
165 llvm::errs() << errorMessage << "\n";
166 return failure();
167 }
168
169 auto output = openOutputFile(outputFilename, &errorMessage);
170 if (!output) {
171 llvm::errs() << errorMessage << "\n";
172 return failure();
173 }
174
175 // Processes the memory buffer with a new MLIRContext.
176 auto processBuffer = [&](std::unique_ptr<llvm::MemoryBuffer> ownedBuffer,
177 raw_ostream &os) {
178 MLIRContext context;
179 context.printOpOnDiagnostic(!verifyDiagnostics);
180 llvm::SourceMgr sourceMgr;
181 sourceMgr.AddNewSourceBuffer(std::move(ownedBuffer), llvm::SMLoc());
182
183 if (!verifyDiagnostics) {
184 SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, &context);
185 return (*translationRequested)(sourceMgr, os, &context);
186 }
187
188 // In the diagnostic verification flow, we ignore whether the translation
189 // failed (in most cases, it is expected to fail). Instead, we check if the
190 // diagnostics were produced as expected.
191 SourceMgrDiagnosticVerifierHandler sourceMgrHandler(sourceMgr, &context);
192 (*translationRequested)(sourceMgr, os, &context);
193 return sourceMgrHandler.verify();
194 };
195
196 if (splitInputFile) {
197 if (failed(splitAndProcessBuffer(std::move(input), processBuffer,
198 output->os())))
199 return failure();
200 } else if (failed(processBuffer(std::move(input), output->os()))) {
201 return failure();
202 }
203
204 output->keep();
205 return success();
206 }
207