1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <sstream>
17
18 #include "llvm/ADT/STLExtras.h"
19 #include "llvm/ADT/Sequence.h"
20 #include "llvm/ADT/StringExtras.h"
21 #include "llvm/ADT/StringMap.h"
22 #include "llvm/Support/CommandLine.h"
23 #include "llvm/Support/FormatVariadic.h"
24 #include "llvm/Support/InitLLVM.h"
25 #include "llvm/Support/Signals.h"
26 #include "llvm/TableGen/Error.h"
27 #include "llvm/TableGen/Main.h"
28 #include "llvm/TableGen/Record.h"
29 #include "llvm/TableGen/TableGenBackend.h"
30 #include "mlir/TableGen/Operator.h" // from @llvm-project
31
32 using llvm::interleaveComma;
33 using llvm::raw_ostream;
34 using llvm::RecordKeeper;
35 using llvm::StringRef;
36 using mlir::tblgen::Attribute;
37 using mlir::tblgen::NamedAttribute;
38 using mlir::tblgen::NamedTypeConstraint;
39 using mlir::tblgen::Operator;
40
GetDefaultAttrExport(const mlir::tblgen::NamedAttribute & named_attr)41 static std::string GetDefaultAttrExport(
42 const mlir::tblgen::NamedAttribute& named_attr) {
43 Attribute attr = named_attr.attr;
44 StringRef storage_type = attr.getStorageType();
45 // For some attribute types we have a general conversion, so use that.
46 if (!attr.isEnumAttr() && (storage_type.endswith("BoolAttr") ||
47 storage_type.endswith("FloatAttr") ||
48 storage_type.endswith("IntegerAttr") ||
49 storage_type.endswith("StringAttr"))) {
50 // The return type may contains qualified namespaces. Split to remove them.
51 std::pair<StringRef, StringRef> splits = attr.getReturnType().rsplit("::");
52 StringRef symbol = splits.second;
53 if (symbol.empty()) symbol = splits.first;
54 return "Convert" + symbol.str();
55 }
56 return "Convert_" + named_attr.name.str();
57 }
58
GetClientBuilder(const Operator & op)59 static StringRef GetClientBuilder(const Operator& op) {
60 static const auto* kOpToXLABuilderMap =
61 new llvm::StringMap<StringRef>{{"ReverseOp", "Rev"},
62 {"ConcatenateOp", "ConcatInDim"},
63 {"ConvOp", "ConvGeneralDilated"}};
64
65 StringRef op_name = op.getCppClassName();
66
67 // Default case where the client builder method names closely follow the op
68 // names in the dialect. For e.g., AddOp -> xla::Add method.
69 if (!kOpToXLABuilderMap->count(op_name)) return op_name.drop_back(2);
70
71 // Otherwise, if the op to client builder method mapping is provided.
72 return kOpToXLABuilderMap->lookup(op_name);
73 }
74
BuildOperator(const Operator & op,raw_ostream & os)75 static void BuildOperator(const Operator& op, raw_ostream& os) {
76 os << "mlir::LogicalResult ExportXlaOp(mlir::mhlo::" << op.getCppClassName()
77 << " op, OpLoweringContext ctx) {\n"
78 << " auto& value_map = *ctx.values;\n"
79 << " auto result = op.getResult();\n";
80
81 // Build a conversion for each of the arguments.
82 int operand_number = 0;
83 for (int index : llvm::seq<int>(0, op.getNumArgs())) {
84 auto arg = op.getArg(index);
85
86 // Emit an argument for an operand.
87 if (auto* operand_cst = arg.dyn_cast<NamedTypeConstraint*>()) {
88 std::string xla_arg = "xla_arg_" + std::to_string(index);
89 // Handle a non-variadic operand.
90 if (!operand_cst->isVariableLength()) {
91 os << " xla::XlaOp " << xla_arg << ";\n";
92 os << " if (failed(GetXlaOp(*op.getODSOperands(" << operand_number++
93 << ").begin(), value_map, &" << xla_arg << ", op)))\n";
94 os << " return mlir::failure();\n";
95 continue;
96 }
97
98 // Otherwise, this is a varidiac operand list.
99 os << " std::vector<xla::XlaOp> " << xla_arg << ";\n"
100 << " for (auto operand : op.getODSOperands(" << operand_number++
101 << ")) {\n";
102 os << " xla::XlaOp result;\n";
103 os << " if (failed(GetXlaOp(operand, value_map, &result, op)))\n";
104 os << " return mlir::failure();\n";
105 os << " " << xla_arg << ".push_back(result);\n";
106 os << " }\n";
107 continue;
108 }
109
110 // Otherwise, this is an attribute.
111 auto named_attr = arg.get<NamedAttribute*>();
112 os << " auto xla_arg_" << index << " = "
113 << GetDefaultAttrExport(*named_attr) << "(op." << op.getArgName(index)
114 << "());\n";
115 }
116
117 // Emit call to client API
118 os << " auto xla_result = xla::" << GetClientBuilder(op) << "(";
119
120 // If all operands are variadic, then pass the builder explicitly to xla
121 // client API call
122 if (op.getNumOperands() == op.getNumVariableLengthOperands()) {
123 os << "ctx.builder";
124 if (op.getNumArgs() != 0) os << ", ";
125 }
126
127 // Emit each of the arguments.
128 interleaveComma(llvm::seq<int>(0, op.getNumArgs()), os,
129 [&](int i) { os << "Unwrap(xla_arg_" << i << ')'; });
130 os << ");\n";
131
132 os << " value_map[result] = xla_result;\n";
133 os << " return mlir::success();\n";
134 os << "}\n";
135 }
136
137 // The function below has a non-constant reference as that is required by LLVM's
138 // TableGenMain.
139 // NOLINTNEXTLINE
OperatorWritersMain(raw_ostream & os,RecordKeeper & records)140 static bool OperatorWritersMain(raw_ostream& os, RecordKeeper& records) {
141 emitSourceFileHeader("MLIR XLA Builders", os);
142
143 // Emit all the helper functions.
144 for (const auto* def : records.getAllDerivedDefinitions("HLO_Op")) {
145 Operator op(def);
146
147 // Skip operations that have a custom exporter.
148 if (!def->getValueAsBit("hasCustomHLOConverter")) BuildOperator(op, os);
149 }
150
151 // Emit a function to generate an XLA operation for the operations with
152 // auto-generated builders.
153 os << "mlir::LogicalResult ExportXlaOperator(\n"
154 "mlir::Operation* op, OpLoweringContext lowering_context) {\n\n";
155
156 // Create a scoped object to assign sharding to generated XLA ops. Any HLO
157 // can have an attribute of "sharding".
158 os << " xla::XlaScopedShardingAssignment sharding(lowering_context.builder, "
159 "CreateOpShardingFromAttribute(op));\n\n";
160
161 // Create a scoped object to assign frontend attributes to generated XLA ops.
162 // Any HLO can have an attribute of "frontend_attributes", which are used to
163 // pass hints / configuration options.
164 os << " xla::XlaScopedFrontendAttributesAssignment "
165 "frontend_attributes(lowering_context.builder, "
166 "CreateOpFrontendAttributesFromAttribute(op));\n\n";
167
168 // Create a scoped object to assign op metadata to generated XLA ops.
169 os << " xla::XlaScopedOpMetadataAssignment "
170 "op_metadata(lowering_context.builder, "
171 "CreateOpMetadataFromLocation(op));\n\n";
172
173 // Retrieve all the definitions derived from HLO_Op and sort by record name.
174 for (const auto* def : records.getAllDerivedDefinitions("HLO_Op")) {
175 // Skip operations that have a custom exporter.
176 Operator op(def);
177
178 // Cast to the current operation and build the exporter.
179 os << " if (auto xla_op = llvm::dyn_cast<mlir::mhlo::"
180 << op.getCppClassName() << ">(op)) {\n";
181 os << " return ";
182 // The autogenerated converters aren't in the same namespace.
183 // TODO(jpienaar): Reconsider this.
184 if (def->getValueAsBit("hasCustomHLOConverter")) os << "mlir::mhlo::";
185 os << "ExportXlaOp(xla_op, lowering_context);\n";
186 os << " }\n";
187 }
188
189 os << " return mlir::failure();\n"
190 "}\n";
191 return false;
192 }
193
main(int argc,char ** argv)194 int main(int argc, char** argv) {
195 llvm::InitLLVM y(argc, argv);
196 llvm::cl::ParseCommandLineOptions(argc, argv);
197 return TableGenMain(argv[0], &OperatorWritersMain);
198 }
199