1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This file implements logic for lowering HLO dialect to LHLO dialect.
17
18 #include <algorithm>
19 #include <utility>
20
21 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
22 #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
23 #include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
24 #include "mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h"
25 #include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
26 #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
27 #include "mlir/Dialect/MemRef/IR/MemRef.h"
28 #include "mlir/Dialect/Shape/IR/Shape.h"
29 #include "mlir/Dialect/Shape/Transforms/Passes.h"
30 #include "mlir/Dialect/StandardOps/IR/Ops.h"
31 #include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h"
32 #include "mlir/Dialect/Tensor/IR/Tensor.h"
33 #include "mlir/IR/AffineMap.h"
34 #include "mlir/IR/Attributes.h"
35 #include "mlir/IR/BlockAndValueMapping.h"
36 #include "mlir/IR/Builders.h"
37 #include "mlir/IR/BuiltinOps.h"
38 #include "mlir/IR/BuiltinTypes.h"
39 #include "mlir/IR/Location.h"
40 #include "mlir/IR/MLIRContext.h"
41 #include "mlir/IR/Operation.h"
42 #include "mlir/IR/PatternMatch.h"
43 #include "mlir/Pass/Pass.h"
44 #include "mlir/Transforms/Bufferize.h"
45 #include "mlir/Transforms/DialectConversion.h"
46
47 namespace mlir {
48 namespace mhlo {
49 namespace {
50
51 template <typename T>
52 using BaseOpConversion = OpConversionPattern<T>;
53
InsertDynamicAlloc(Location loc,Value result,Value shape_operand,ConversionPatternRewriter * rewriter)54 Value InsertDynamicAlloc(Location loc, Value result, Value shape_operand,
55 ConversionPatternRewriter* rewriter) {
56 auto result_type = result.getType().dyn_cast<RankedTensorType>();
57 if (!result_type) {
58 result.getDefiningOp()->emitOpError()
59 << "tensor to buffer conversion expects ranked results";
60 }
61 auto memref_type =
62 MemRefType::get(result_type.getShape(), result_type.getElementType());
63
64 // Extract the required element out of the vector.
65 SmallVector<Value, 4> dynamic_operands;
66 for (auto shape_element : llvm::enumerate(result_type.getShape())) {
67 if (shape_element.value() != ShapedType::kDynamicSize) continue;
68 Value index = rewriter->create<ConstantIndexOp>(loc, shape_element.index());
69 Value alloc_operand =
70 rewriter->create<tensor::ExtractOp>(loc, shape_operand, index);
71 if (!alloc_operand.getType().isIndex()) {
72 alloc_operand = rewriter->create<IndexCastOp>(loc, alloc_operand,
73 rewriter->getIndexType());
74 }
75 dynamic_operands.push_back(alloc_operand);
76 }
77
78 return rewriter->create<memref::AllocOp>(loc, memref_type, dynamic_operands);
79 }
80
InsertAlloc(Location loc,OpResult result,ConversionPatternRewriter * rewriter)81 Value InsertAlloc(Location loc, OpResult result,
82 ConversionPatternRewriter* rewriter) {
83 auto result_type = result.getType().dyn_cast<RankedTensorType>();
84 if (!result_type || !result_type.hasStaticShape()) {
85 result.getDefiningOp()->emitOpError()
86 << "tensor to buffer conversion expects statically shaped results";
87 }
88 auto memref_type =
89 MemRefType::get(result_type.getShape(), result_type.getElementType());
90 OpBuilder::InsertionGuard guard(*rewriter);
91 rewriter->setInsertionPoint(result.getDefiningOp());
92 auto alloc = rewriter->create<memref::AllocOp>(loc, memref_type);
93 return alloc;
94 }
95
96 /// Converts the results of the operation `op` to memref types and append them
97 /// to the `results` vector.
ConvertResults(Operation * op,SmallVectorImpl<Value> & results,ConversionPatternRewriter & rewriter)98 LogicalResult ConvertResults(Operation* op, SmallVectorImpl<Value>& results,
99 ConversionPatternRewriter& rewriter) {
100 size_t num_operands = results.size();
101 SmallVector<Value, 2> tensor_operands;
102 for (auto result : llvm::enumerate(op->getResults())) {
103 RankedTensorType resultType =
104 result.value().getType().dyn_cast<RankedTensorType>();
105 if (!resultType) return failure();
106
107 if (resultType.hasStaticShape()) {
108 results.push_back(InsertAlloc(op->getLoc(), result.value(), &rewriter));
109 continue;
110 }
111 auto shape_type_op = dyn_cast<InferShapedTypeOpInterface>(op);
112 if (!shape_type_op) return failure();
113
114 if (tensor_operands.empty()) {
115 for (auto operand : ArrayRef<Value>(results).take_front(num_operands)) {
116 auto operand_type = operand.getType().dyn_cast<MemRefType>();
117 if (!operand_type) return failure();
118 tensor_operands.push_back(rewriter.create<memref::TensorLoadOp>(
119 op->getLoc(),
120 RankedTensorType::get(operand_type.getShape(),
121 operand_type.getElementType()),
122 operand));
123 }
124 }
125
126 SmallVector<Value, 1> results_shape;
127 auto status = shape_type_op.reifyReturnTypeShapes(rewriter, tensor_operands,
128 results_shape);
129 if (failed(status)) return failure();
130 results.push_back(InsertDynamicAlloc(op->getLoc(), result.value(),
131 results_shape[result.index()],
132 &rewriter));
133 }
134 return success();
135 }
136
137 template <typename HloOpTy>
138 class HloToLhloOpConverter : public BaseOpConversion<HloOpTy> {
139 public:
140 using BaseOpConversion<HloOpTy>::BaseOpConversion;
matchAndRewrite(HloOpTy hloOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const141 LogicalResult matchAndRewrite(
142 HloOpTy hloOp, ArrayRef<Value> operands,
143 ConversionPatternRewriter& rewriter) const final {
144 Operation* op = hloOp.getOperation();
145 SmallVector<Value, 4> buffer_args(operands.begin(), operands.end());
146 if (failed(ConvertResults(op, buffer_args, rewriter))) return failure();
147 rewriter.create<mhlo::HloToLhloOp<HloOpTy>>(op->getLoc(), llvm::None,
148 buffer_args, op->getAttrs());
149 rewriter.replaceOp(
150 op, llvm::makeArrayRef(buffer_args).drop_front(operands.size()));
151 return success();
152 }
153 };
154
155 // This specialization exists so that LMHLO's Dot can be given a specific set of
156 // dimension numbers, when lowering from MHLO's Dot, which does not have
157 // dimension numbers (it uses DotGeneral for this generalized notion of dot
158 // products). When these two dialects are in sync with respect to the
159 // Dot/DotGeneral issue, this specialization should be deleted.
160 template <>
161 class HloToLhloOpConverter<mhlo::DotOp> : public BaseOpConversion<mhlo::DotOp> {
162 public:
163 using BaseOpConversion<mhlo::DotOp>::BaseOpConversion;
matchAndRewrite(mhlo::DotOp hloOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const164 LogicalResult matchAndRewrite(
165 mhlo::DotOp hloOp, ArrayRef<Value> operands,
166 ConversionPatternRewriter& rewriter) const final {
167 Operation* op = hloOp.getOperation();
168 SmallVector<Value, 2> buffer_args(operands.begin(), operands.end());
169 if (failed(ConvertResults(op, buffer_args, rewriter))) return failure();
170
171 // TODO(silvasean): Move this helper to MLIR core.
172 auto make_elements_attr = [&rewriter](ArrayRef<int64_t> integers) {
173 auto type = RankedTensorType::get({static_cast<int64_t>(integers.size())},
174 rewriter.getIntegerType(64));
175 return DenseIntElementsAttr::get(type, integers);
176 };
177 auto dotOp = rewriter.create<lmhlo::DotOp>(op->getLoc(), llvm::None,
178 buffer_args, op->getAttrs());
179 // MHLO's Dot uses rank-2 operands, of the form ([N, M], [M, O]) -> [N, O].
180 auto dimension_numbers = mhlo::DotDimensionNumbers::get(
181 make_elements_attr({}), make_elements_attr({}), make_elements_attr({1}),
182 make_elements_attr({0}), rewriter.getContext());
183 dotOp.dot_dimension_numbersAttr(dimension_numbers);
184 rewriter.replaceOp(op, ArrayRef<Value>(buffer_args).slice(operands.size()));
185 return success();
186 }
187 };
188
189 struct HloToLhloCustomCallOpConverter
190 : public BaseOpConversion<mhlo::CustomCallOp> {
191 public:
192 using BaseOpConversion<mhlo::CustomCallOp>::BaseOpConversion;
193
matchAndRewritemlir::mhlo::__anon6f36e01b0111::HloToLhloCustomCallOpConverter194 LogicalResult matchAndRewrite(
195 mhlo::CustomCallOp hloOp, ArrayRef<Value> operands,
196 ConversionPatternRewriter& rewriter) const final {
197 Operation* op = hloOp.getOperation();
198 SmallVector<Value, 2> buffer_args(operands.begin(), operands.end());
199 if (failed(ConvertResults(op, buffer_args, rewriter))) return failure();
200
201 auto lhloOp = rewriter.create<lmhlo::CustomCallOp>(
202 op->getLoc(), llvm::None, buffer_args, op->getAttrs());
203 // Setup AttrSizedOperandSegments attribute to indicate number of operands
204 // for args and outputs.
205 const int32_t segments[2] = {static_cast<int32_t>(operands.size()),
206 static_cast<int32_t>(op->getNumResults())};
207 lhloOp->setAttr(lhloOp.getOperandSegmentSizeAttr(),
208 rewriter.getI32VectorAttr(segments));
209
210 rewriter.replaceOp(op, ArrayRef<Value>(buffer_args).slice(operands.size()));
211 return success();
212 }
213 };
214
215 struct HloToLhloDotGeneralOpConverter
216 : public BaseOpConversion<mhlo::DotGeneralOp> {
217 using BaseOpConversion<mhlo::DotGeneralOp>::BaseOpConversion;
matchAndRewritemlir::mhlo::__anon6f36e01b0111::HloToLhloDotGeneralOpConverter218 LogicalResult matchAndRewrite(
219 mhlo::DotGeneralOp dotGeneralOp, ArrayRef<Value> operands,
220 ConversionPatternRewriter& rewriter) const final {
221 Operation* op = dotGeneralOp.getOperation();
222
223 if (op->getResults().empty()) return failure();
224 OpResult result = op->getResults()[0];
225 RankedTensorType resultType = result.getType().dyn_cast<RankedTensorType>();
226 if (!resultType) return failure();
227
228 // The third buffer argument will be filled with what used to be the return
229 // type of the DotGeneral.
230 if (operands.size() != 2) return failure();
231 std::array<Value, 3> bufferArgs = {operands[0], operands[1], {}};
232
233 if (resultType.hasStaticShape()) {
234 bufferArgs[2] = InsertAlloc(op->getLoc(), result, &rewriter);
235 } else {
236 SmallVector<Value, 1> results_shape;
237 auto shape_type_op = dyn_cast<InferShapedTypeOpInterface>(op);
238 if (failed(shape_type_op.reifyReturnTypeShapes(rewriter, operands,
239 results_shape)))
240 return failure();
241
242 bufferArgs[2] = InsertDynamicAlloc(op->getLoc(), result,
243 results_shape.front(), &rewriter);
244 }
245
246 rewriter.create<lmhlo::DotOp>(op->getLoc(), llvm::None, bufferArgs,
247 op->getAttrs());
248 rewriter.replaceOp(op, bufferArgs[2]);
249 return success();
250 }
251 };
252
253 struct HloToLhloReduceOpConverter : public BaseOpConversion<mhlo::ReduceOp> {
254 public:
255 using BaseOpConversion<mhlo::ReduceOp>::BaseOpConversion;
256
matchAndRewritemlir::mhlo::__anon6f36e01b0111::HloToLhloReduceOpConverter257 LogicalResult matchAndRewrite(
258 mhlo::ReduceOp op, ArrayRef<Value> operands,
259 ConversionPatternRewriter& rewriter) const final {
260 auto loc = op.getLoc();
261 // TODO(b/137624192) Implement variadic reduce.
262 if (op.getNumResults() != 1) return failure();
263 if (!llvm::hasSingleElement(op.body())) {
264 return op.emitOpError()
265 << "tensor to buffer conversion expects a single block "
266 "in the region containing the operation";
267 }
268 SmallVector<Value, 4> buffer_args(operands.begin(), operands.end());
269 if (failed(ConvertResults(op, buffer_args, rewriter))) return failure();
270 auto new_op = rewriter.create<lmhlo::ReduceOp>(loc, llvm::None, buffer_args,
271 op->getAttrs());
272
273 // Copy over the operations inside the region.
274 rewriter.inlineRegionBefore(op.body(), new_op.body(), new_op.body().end());
275
276 // Convert the region signature to memref and add extra result.
277 auto& entry_block = new_op.body().front();
278 TypeConverter::SignatureConversion sig_conversion(
279 entry_block.getNumArguments() + 1);
280 for (auto arg : entry_block.getArguments()) {
281 auto old_type = arg.getType().cast<TensorType>();
282 auto new_type =
283 MemRefType::get(old_type.getShape(), old_type.getElementType());
284 sig_conversion.addInputs(arg.getArgNumber(), new_type);
285 }
286 auto return_op = cast<mhlo::ReturnOp>(entry_block.getTerminator());
287 auto result_type = return_op.results().front().getType().cast<TensorType>();
288 sig_conversion.addInputs({MemRefType::get(result_type.getShape(),
289 result_type.getElementType())});
290 rewriter.applySignatureConversion(&new_op.body(), sig_conversion);
291
292 rewriter.replaceOp(op, ArrayRef<Value>(buffer_args).slice(operands.size()));
293
294 return success();
295 }
296 };
297
298 // Legalize mhlo.return to a lmhlo.copy and lmhlo.terminator.
299 struct HloToLhloReturnOpConverter : public BaseOpConversion<mhlo::ReturnOp> {
300 public:
301 using BaseOpConversion<mhlo::ReturnOp>::BaseOpConversion;
302
matchAndRewritemlir::mhlo::__anon6f36e01b0111::HloToLhloReturnOpConverter303 LogicalResult matchAndRewrite(
304 mhlo::ReturnOp op, ArrayRef<Value> operands,
305 ConversionPatternRewriter& rewriter) const final {
306 auto loc = op.getLoc();
307 auto& entry_block = op->getParentRegion()->front();
308 auto num_arguments = entry_block.getNumArguments();
309 if (operands.size() > num_arguments) {
310 return op.emitError(
311 "The number of operands that need Copy operations is more "
312 "than the number of target function arguments.");
313 }
314
315 // The index of the first output block argument.
316 auto dest_arg_idx = num_arguments - operands.size();
317
318 // Create a lmhlo.copy for each operand of mhlo.return.
319 for (Value operand : operands) {
320 rewriter.create<lmhlo::CopyOp>(loc, operand,
321 entry_block.getArgument(dest_arg_idx));
322 ++dest_arg_idx;
323 }
324 rewriter.replaceOpWithNewOp<lmhlo::TerminatorOp>(op);
325 return success();
326 }
327 };
328
329 // TODO(b/175789537) Remove this pattern.
330 class HloToLhloTensorStoreOpLegacyConverter
331 : public BaseOpConversion<mlir::memref::TensorStoreOp> {
332 public:
333 using BaseOpConversion<mlir::memref::TensorStoreOp>::BaseOpConversion;
334
matchAndRewrite(mlir::memref::TensorStoreOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const335 LogicalResult matchAndRewrite(
336 mlir::memref::TensorStoreOp op, ArrayRef<Value> operands,
337 ConversionPatternRewriter& rewriter) const final {
338 rewriter.replaceOpWithNewOp<lmhlo::CopyOp>(op, llvm::None, operands.front(),
339 operands.back());
340 return success();
341 }
342 };
343
344 // Lowers from HLO dialect to LHLO dialect allocating/deallocating temporary
345 // buffers if necessary.
346 //
347 // Example fusion with HLO ops.
348 //
349 // func @fusion(%arg0: memref<2x2xf32>,
350 // %arg1: memref<2x2xf32>,
351 // %arg2: memref<2x2xf32>,
352 // %arg3: memref<2x2xf32>) {
353 // "lmhlo.fusion"() ({
354 // %0 = tensor_load %arg1 : memref<2x2xf32>
355 // %1 = tensor_load %arg2 : memref<2x2xf32>
356 // %2 = "mhlo.add"(%0, %1) :
357 // (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
358 // %3 = tensor_load %arg0 : memref<2x2xf32>
359 // %4 = "mhlo.multiply"(%2, %3) :
360 // (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
361 // tensor_store %4, %arg3 : memref<2x2xf32>
362 // "lmhlo.terminator"() : () -> ()
363 // }) : () -> ()
364 // return
365 // }
366 //
367 // Transformed fusion with LHLO ops.
368 // func @fusion(%arg0: memref<2x2xf32>,
369 // %arg1: memref<2x2xf32>,
370 // %arg2: memref<2x2xf32>,
371 // %arg3: memref<2x2xf32>) {
372 // "lmhlo.fusion"() ( {
373 // %0 = alloc() : memref<2x2xf32>
374 // "lmhlo.add"(%arg1, %arg2, %0) :
375 // (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> ()
376 // "lmhlo.multiply"(%0, %arg0, %arg3) :
377 // (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> ()
378 // "lmhlo.terminator"() : () -> ()
379 // }) : () -> ()
380 // return
381 // }
382 //
383 // FuncOp signature conversion example:
384 //
385 // func @func_op(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
386 // %0 = "mhlo.maximum"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) ->
387 // tensor<4xf32> %1 = "mhlo.add"(%arg0, %0) : (tensor<4xf32>,
388 // tensor<4xf32>) -> tensor<4xf32> return %1 : tensor<4xf32>
389 // }
390 //
391 // Transformed function with an extra argument for the result. The types have
392 // been converted from tensor to memref.
393 //
394 // func @func_op(%arg0: memref<4xf32>,
395 // %arg1: memref<4xf32>,
396 // %arg2: memref<4xf32>) {
397 // %0 = alloc() : memref<4xf32>
398
399 // "lmhlo.maximum"(%arg0, %arg1, %0) :
400 // (memref<4xf32>, memref<4xf32>, memref<4xf32>) -> ()
401 // %1 = alloc() : memref<4xf32>
402 // "lmhlo.add"(%arg0, %0, %1) :
403 // (memref<4xf32>, memref<4xf32>, memref<4xf32>) -> ()
404 // "lmhlo.copy"(%1, %arg2) : (memref<4xf32>, memref<4xf32>) -> ()
405 // "lmhlo.terminator"() : () -> ()
406 // }
407
408 struct HloLegalizeToLhlo : public HloLegalizeToLhloPassBase<HloLegalizeToLhlo> {
409 using HloLegalizeToLhloPassBase<HloLegalizeToLhlo>::HloLegalizeToLhloPassBase;
410
getDependentDialectsmlir::mhlo::__anon6f36e01b0111::HloLegalizeToLhlo411 void getDependentDialects(DialectRegistry& registry) const override {
412 registry.insert<lmhlo::LmhloDialect, memref::MemRefDialect,
413 shape::ShapeDialect>();
414 }
415
416 public:
417 HloLegalizeToLhlo() = default;
HloLegalizeToLhlomlir::mhlo::__anon6f36e01b0111::HloLegalizeToLhlo418 HloLegalizeToLhlo(const HloLegalizeToLhlo& o) {}
419
runOnOperationmlir::mhlo::__anon6f36e01b0111::HloLegalizeToLhlo420 void runOnOperation() override {
421 auto& context = getContext();
422 OwningRewritePatternList patterns(&context);
423 ConversionTarget target(context);
424 target.addLegalDialect<lmhlo::LmhloDialect>();
425 target.addLegalDialect<StandardOpsDialect>();
426 target.addLegalDialect<memref::MemRefDialect>();
427 target.addLegalDialect<shape::ShapeDialect>();
428 target.addLegalDialect<tensor::TensorDialect>();
429 target.addIllegalDialect<mhlo::MhloDialect>();
430 // Declare tensor_store illegal. tensor_load may be used to reify output
431 // shape computation during dialect conversion and will be handled later.
432 target.addIllegalOp<mlir::memref::TensorStoreOp>();
433 // buffer_cast is illegal if it has uses.
434 // TODO(b/175670649) Make buffer_cast illegal.
435 target.addDynamicallyLegalOp<mlir::memref::BufferCastOp>(
436 [](auto op) { return op->use_empty(); });
437
438 BufferizeTypeConverter converter;
439 auto isMemRefType = [](Type type) { return type.isa<BaseMemRefType>(); };
440 target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
441 return converter.isSignatureLegal(op.getType()) &&
442 converter.isLegal(&op.getBody());
443 });
444 target.addDynamicallyLegalOp<CallOp>([&](CallOp op) {
445 return std::all_of(op.operand_type_begin(), op.operand_type_end(),
446 isMemRefType) &&
447 std::all_of(op.result_type_begin(), op.result_type_end(),
448 isMemRefType);
449 });
450 target.addDynamicallyLegalOp<mlir::ReturnOp>([&](mlir::ReturnOp op) {
451 return std::all_of(op.operand_type_begin(), op.operand_type_end(),
452 isMemRefType);
453 });
454
455 populateHLOToLHLOConversionPattern(&context, &converter, &patterns);
456 populateFuncOpTypeConversionPattern(patterns, converter);
457 populateCallOpTypeConversionPattern(patterns, converter);
458 populateBranchOpInterfaceTypeConversionPattern(patterns, converter);
459 populateReturnOpTypeConversionPattern(patterns, converter);
460 populateEliminateBufferizeMaterializationsPatterns(converter, patterns);
461
462 populateShapeStructuralTypeConversionsAndLegality(converter, patterns,
463 target);
464
465 // TODO(b/175789537) Remove this pattern.
466 patterns.insert<HloToLhloTensorStoreOpLegacyConverter>(&context);
467
468 if (failed(applyPartialConversion(getOperation(), target,
469 std::move(patterns))))
470 signalPassFailure();
471 }
472 };
473 } // namespace
474
475 // Simply lowers all mhlo ops to their lmhlo counterparts.
populateDynamicHLOToLHLOConversionPattern(MLIRContext * context,BufferizeTypeConverter * converter,OwningRewritePatternList * patterns)476 void populateDynamicHLOToLHLOConversionPattern(
477 MLIRContext* context, BufferizeTypeConverter* converter,
478 OwningRewritePatternList* patterns) {
479 // clang-format off
480 patterns->insert<HloToLhloOpConverter<mhlo::DynamicBroadcastInDimOp>,
481 HloToLhloOpConverter<mhlo::DynamicGatherOp>,
482 HloToLhloOpConverter<mhlo::DynamicIotaOp>,
483 HloToLhloOpConverter<mhlo::DynamicPadOp>,
484 HloToLhloOpConverter<mhlo::DynamicReshapeOp>,
485 HloToLhloOpConverter<mhlo::RealDynamicSliceOp>
486 >(*converter, context);
487 // clang-format on
488 }
489
populateHLOToLHLOConversionPattern(MLIRContext * context,BufferizeTypeConverter * converter,OwningRewritePatternList * patterns)490 void populateHLOToLHLOConversionPattern(MLIRContext* context,
491 BufferizeTypeConverter* converter,
492 OwningRewritePatternList* patterns) {
493 populateDynamicHLOToLHLOConversionPattern(context, converter, patterns);
494
495 // clang-format off
496 patterns->insert<
497 HloToLhloCustomCallOpConverter,
498 HloToLhloDotGeneralOpConverter,
499 HloToLhloOpConverter<mhlo::AbsOp>,
500 HloToLhloOpConverter<mhlo::AddOp>,
501 HloToLhloOpConverter<mhlo::AndOp>,
502 HloToLhloOpConverter<mhlo::Atan2Op>,
503 HloToLhloOpConverter<mhlo::BroadcastInDimOp>,
504 HloToLhloOpConverter<mhlo::CeilOp>,
505 HloToLhloOpConverter<mhlo::CompareOp>,
506 HloToLhloOpConverter<mhlo::ComplexOp>,
507 HloToLhloOpConverter<mhlo::ConcatenateOp>,
508 HloToLhloOpConverter<mhlo::ConstOp>,
509 HloToLhloOpConverter<mhlo::ConvOp>,
510 HloToLhloOpConverter<mhlo::ConvertOp>,
511 HloToLhloOpConverter<mhlo::CopyOp>,
512 HloToLhloOpConverter<mhlo::CosOp>,
513 HloToLhloOpConverter<mhlo::DivOp>,
514 HloToLhloOpConverter<mhlo::DotOp>,
515 HloToLhloOpConverter<mhlo::ExpOp>,
516 HloToLhloOpConverter<mhlo::Expm1Op>,
517 HloToLhloOpConverter<mhlo::FloorOp>,
518 HloToLhloOpConverter<mhlo::GatherOp>,
519 HloToLhloOpConverter<mhlo::ImagOp>,
520 HloToLhloOpConverter<mhlo::IotaOp>,
521 HloToLhloOpConverter<mhlo::IsFiniteOp>,
522 HloToLhloOpConverter<mhlo::LogOp>,
523 HloToLhloOpConverter<mhlo::LogisticOp>,
524 HloToLhloOpConverter<mhlo::MaxOp>,
525 HloToLhloOpConverter<mhlo::MinOp>,
526 HloToLhloOpConverter<mhlo::MulOp>,
527 HloToLhloOpConverter<mhlo::NegOp>,
528 HloToLhloOpConverter<mhlo::NotOp>,
529 HloToLhloOpConverter<mhlo::OrOp>,
530 HloToLhloOpConverter<mhlo::PowOp>,
531 HloToLhloOpConverter<mhlo::RealOp>,
532 HloToLhloOpConverter<mhlo::RemOp>,
533 HloToLhloOpConverter<mhlo::RsqrtOp>,
534 HloToLhloOpConverter<mhlo::ReshapeOp>,
535 HloToLhloOpConverter<mhlo::SelectOp>,
536 HloToLhloOpConverter<mhlo::ShiftLeftOp>,
537 HloToLhloOpConverter<mhlo::ShiftRightArithmeticOp>,
538 HloToLhloOpConverter<mhlo::ShiftRightLogicalOp>,
539 HloToLhloOpConverter<mhlo::SignOp>,
540 HloToLhloOpConverter<mhlo::SinOp>,
541 HloToLhloOpConverter<mhlo::SliceOp>,
542 HloToLhloOpConverter<mhlo::SqrtOp>,
543 HloToLhloOpConverter<mhlo::SubOp>,
544 HloToLhloOpConverter<mhlo::TanhOp>,
545 HloToLhloOpConverter<mhlo::TransposeOp>,
546 HloToLhloOpConverter<mhlo::XorOp>,
547 HloToLhloReduceOpConverter,
548 HloToLhloReturnOpConverter
549 >(*converter, context);
550 // clang-format on
551 }
552
createLegalizeToLhloPass()553 std::unique_ptr<OperationPass<ModuleOp>> createLegalizeToLhloPass() {
554 return std::make_unique<HloLegalizeToLhlo>();
555 }
556
557 } // namespace mhlo
558 } // namespace mlir
559