• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <sstream>
17 
18 #include "llvm/ADT/STLExtras.h"
19 #include "llvm/ADT/Sequence.h"
20 #include "llvm/ADT/StringExtras.h"
21 #include "llvm/ADT/StringMap.h"
22 #include "llvm/Support/CommandLine.h"
23 #include "llvm/Support/FormatVariadic.h"
24 #include "llvm/Support/InitLLVM.h"
25 #include "llvm/Support/Signals.h"
26 #include "llvm/TableGen/Error.h"
27 #include "llvm/TableGen/Main.h"
28 #include "llvm/TableGen/Record.h"
29 #include "llvm/TableGen/TableGenBackend.h"
30 #include "mlir/TableGen/Operator.h"  // from @llvm-project
31 
32 using llvm::interleaveComma;
33 using llvm::raw_ostream;
34 using llvm::RecordKeeper;
35 using llvm::StringRef;
36 using mlir::tblgen::Attribute;
37 using mlir::tblgen::NamedAttribute;
38 using mlir::tblgen::NamedTypeConstraint;
39 using mlir::tblgen::Operator;
40 
GetDefaultAttrExport(const mlir::tblgen::NamedAttribute & named_attr)41 static std::string GetDefaultAttrExport(
42     const mlir::tblgen::NamedAttribute& named_attr) {
43   Attribute attr = named_attr.attr;
44   StringRef storage_type = attr.getStorageType();
45   // For some attribute types we have a general conversion, so use that.
46   if (!attr.isEnumAttr() && (storage_type.endswith("BoolAttr") ||
47                              storage_type.endswith("FloatAttr") ||
48                              storage_type.endswith("IntegerAttr") ||
49                              storage_type.endswith("StringAttr"))) {
50     // The return type may contains qualified namespaces. Split to remove them.
51     std::pair<StringRef, StringRef> splits = attr.getReturnType().rsplit("::");
52     StringRef symbol = splits.second;
53     if (symbol.empty()) symbol = splits.first;
54     return "Convert" + symbol.str();
55   }
56   return "Convert_" + named_attr.name.str();
57 }
58 
GetClientBuilder(const Operator & op)59 static StringRef GetClientBuilder(const Operator& op) {
60   static const auto* kOpToXLABuilderMap =
61       new llvm::StringMap<StringRef>{{"ReverseOp", "Rev"},
62                                      {"ConcatenateOp", "ConcatInDim"},
63                                      {"ConvOp", "ConvGeneralDilated"}};
64 
65   StringRef op_name = op.getCppClassName();
66 
67   // Default case where the client builder method names closely follow the op
68   // names in the dialect. For e.g., AddOp -> xla::Add method.
69   if (!kOpToXLABuilderMap->count(op_name)) return op_name.drop_back(2);
70 
71   // Otherwise, if the op to client builder method mapping is provided.
72   return kOpToXLABuilderMap->lookup(op_name);
73 }
74 
BuildOperator(const Operator & op,raw_ostream & os)75 static void BuildOperator(const Operator& op, raw_ostream& os) {
76   os << "mlir::LogicalResult ExportXlaOp(mlir::mhlo::" << op.getCppClassName()
77      << " op, OpLoweringContext ctx) {\n"
78      << "  auto& value_map = *ctx.values;\n"
79      << "  auto result = op.getResult();\n";
80 
81   // Build a conversion for each of the arguments.
82   int operand_number = 0;
83   for (int index : llvm::seq<int>(0, op.getNumArgs())) {
84     auto arg = op.getArg(index);
85 
86     // Emit an argument for an operand.
87     if (auto* operand_cst = arg.dyn_cast<NamedTypeConstraint*>()) {
88       std::string xla_arg = "xla_arg_" + std::to_string(index);
89       // Handle a non-variadic operand.
90       if (!operand_cst->isVariableLength()) {
91         os << "  xla::XlaOp " << xla_arg << ";\n";
92         os << "  if (failed(GetXlaOp(*op.getODSOperands(" << operand_number++
93            << ").begin(), value_map, &" << xla_arg << ", op)))\n";
94         os << "    return mlir::failure();\n";
95         continue;
96       }
97 
98       // Otherwise, this is a varidiac operand list.
99       os << "  std::vector<xla::XlaOp> " << xla_arg << ";\n"
100          << "  for (auto operand : op.getODSOperands(" << operand_number++
101          << ")) {\n";
102       os << "    xla::XlaOp result;\n";
103       os << "    if (failed(GetXlaOp(operand, value_map, &result, op)))\n";
104       os << "      return mlir::failure();\n";
105       os << "    " << xla_arg << ".push_back(result);\n";
106       os << "  }\n";
107       continue;
108     }
109 
110     // Otherwise, this is an attribute.
111     auto named_attr = arg.get<NamedAttribute*>();
112     os << "  auto xla_arg_" << index << " = "
113        << GetDefaultAttrExport(*named_attr) << "(op." << op.getArgName(index)
114        << "());\n";
115   }
116 
117   // Emit call to client API
118   os << "  auto xla_result = xla::" << GetClientBuilder(op) << "(";
119 
120   // If all operands are variadic, then pass the builder explicitly to xla
121   // client API call
122   if (op.getNumOperands() == op.getNumVariableLengthOperands()) {
123     os << "ctx.builder";
124     if (op.getNumArgs() != 0) os << ", ";
125   }
126 
127   // Emit each of the arguments.
128   interleaveComma(llvm::seq<int>(0, op.getNumArgs()), os,
129                   [&](int i) { os << "Unwrap(xla_arg_" << i << ')'; });
130   os << ");\n";
131 
132   os << "  value_map[result] = xla_result;\n";
133   os << "  return mlir::success();\n";
134   os << "}\n";
135 }
136 
137 // The function below has a non-constant reference as that is required by LLVM's
138 // TableGenMain.
139 // NOLINTNEXTLINE
OperatorWritersMain(raw_ostream & os,RecordKeeper & records)140 static bool OperatorWritersMain(raw_ostream& os, RecordKeeper& records) {
141   emitSourceFileHeader("MLIR XLA Builders", os);
142 
143   // Emit all the helper functions.
144   for (const auto* def : records.getAllDerivedDefinitions("HLO_Op")) {
145     Operator op(def);
146 
147     // Skip operations that have a custom exporter.
148     if (!def->getValueAsBit("hasCustomHLOConverter")) BuildOperator(op, os);
149   }
150 
151   // Emit a function to generate an XLA operation for the operations with
152   // auto-generated builders.
153   os << "mlir::LogicalResult ExportXlaOperator(\n"
154         "mlir::Operation* op, OpLoweringContext lowering_context) {\n\n";
155 
156   // Create a scoped object to assign sharding to generated XLA ops. Any HLO
157   // can have an attribute of "sharding".
158   os << "  xla::XlaScopedShardingAssignment sharding(lowering_context.builder, "
159         "CreateOpShardingFromAttribute(op));\n\n";
160 
161   // Create a scoped object to assign frontend attributes to generated XLA ops.
162   // Any HLO can have an attribute of "frontend_attributes", which are used to
163   // pass hints / configuration options.
164   os << "  xla::XlaScopedFrontendAttributesAssignment "
165         "frontend_attributes(lowering_context.builder, "
166         "CreateOpFrontendAttributesFromAttribute(op));\n\n";
167 
168   // Create a scoped object to assign op metadata to generated XLA ops.
169   os << "  xla::XlaScopedOpMetadataAssignment "
170         "op_metadata(lowering_context.builder, "
171         "CreateOpMetadataFromLocation(op));\n\n";
172 
173   // Retrieve all the definitions derived from HLO_Op and sort by record name.
174   for (const auto* def : records.getAllDerivedDefinitions("HLO_Op")) {
175     // Skip operations that have a custom exporter.
176     Operator op(def);
177 
178     // Cast to the current operation and build the exporter.
179     os << "  if (auto xla_op = llvm::dyn_cast<mlir::mhlo::"
180        << op.getCppClassName() << ">(op)) {\n";
181     os << "    return ";
182     // The autogenerated converters aren't in the same namespace.
183     // TODO(jpienaar): Reconsider this.
184     if (def->getValueAsBit("hasCustomHLOConverter")) os << "mlir::mhlo::";
185     os << "ExportXlaOp(xla_op, lowering_context);\n";
186     os << "  }\n";
187   }
188 
189   os << "  return mlir::failure();\n"
190         "}\n";
191   return false;
192 }
193 
main(int argc,char ** argv)194 int main(int argc, char** argv) {
195   llvm::InitLLVM y(argc, argv);
196   llvm::cl::ParseCommandLineOptions(argc, argv);
197   return TableGenMain(argv[0], &OperatorWritersMain);
198 }
199