1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This file implements logic for lowering HLO dialect to LHLO dialect.
17 
18 #include <algorithm>
19 #include <utility>
20 
21 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
22 #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
23 #include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
24 #include "mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h"
25 #include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
26 #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
27 #include "mlir/Dialect/MemRef/IR/MemRef.h"
28 #include "mlir/Dialect/Shape/IR/Shape.h"
29 #include "mlir/Dialect/Shape/Transforms/Passes.h"
30 #include "mlir/Dialect/StandardOps/IR/Ops.h"
31 #include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h"
32 #include "mlir/Dialect/Tensor/IR/Tensor.h"
33 #include "mlir/IR/AffineMap.h"
34 #include "mlir/IR/Attributes.h"
35 #include "mlir/IR/BlockAndValueMapping.h"
36 #include "mlir/IR/Builders.h"
37 #include "mlir/IR/BuiltinOps.h"
38 #include "mlir/IR/BuiltinTypes.h"
39 #include "mlir/IR/Location.h"
40 #include "mlir/IR/MLIRContext.h"
41 #include "mlir/IR/Operation.h"
42 #include "mlir/IR/PatternMatch.h"
43 #include "mlir/Pass/Pass.h"
44 #include "mlir/Transforms/Bufferize.h"
45 #include "mlir/Transforms/DialectConversion.h"
46 
47 namespace mlir {
48 namespace mhlo {
49 namespace {
50 
51 template <typename T>
52 using BaseOpConversion = OpConversionPattern<T>;
53 
InsertDynamicAlloc(Location loc,Value result,Value shape_operand,ConversionPatternRewriter * rewriter)54 Value InsertDynamicAlloc(Location loc, Value result, Value shape_operand,
55                          ConversionPatternRewriter* rewriter) {
56   auto result_type = result.getType().dyn_cast<RankedTensorType>();
57   if (!result_type) {
58     result.getDefiningOp()->emitOpError()
59         << "tensor to buffer conversion expects ranked results";
60   }
61   auto memref_type =
62       MemRefType::get(result_type.getShape(), result_type.getElementType());
63 
64   // Extract the required element out of the vector.
65   SmallVector<Value, 4> dynamic_operands;
66   for (auto shape_element : llvm::enumerate(result_type.getShape())) {
67     if (shape_element.value() != ShapedType::kDynamicSize) continue;
68     Value index = rewriter->create<ConstantIndexOp>(loc, shape_element.index());
69     Value alloc_operand =
70         rewriter->create<tensor::ExtractOp>(loc, shape_operand, index);
71     if (!alloc_operand.getType().isIndex()) {
72       alloc_operand = rewriter->create<IndexCastOp>(loc, alloc_operand,
73                                                     rewriter->getIndexType());
74     }
75     dynamic_operands.push_back(alloc_operand);
76   }
77 
78   return rewriter->create<memref::AllocOp>(loc, memref_type, dynamic_operands);
79 }
80 
InsertAlloc(Location loc,OpResult result,ConversionPatternRewriter * rewriter)81 Value InsertAlloc(Location loc, OpResult result,
82                   ConversionPatternRewriter* rewriter) {
83   auto result_type = result.getType().dyn_cast<RankedTensorType>();
84   if (!result_type || !result_type.hasStaticShape()) {
85     result.getDefiningOp()->emitOpError()
86         << "tensor to buffer conversion expects statically shaped results";
87   }
88   auto memref_type =
89       MemRefType::get(result_type.getShape(), result_type.getElementType());
90   OpBuilder::InsertionGuard guard(*rewriter);
91   rewriter->setInsertionPoint(result.getDefiningOp());
92   auto alloc = rewriter->create<memref::AllocOp>(loc, memref_type);
93   return alloc;
94 }
95 
96 /// Converts the results of the operation `op` to memref types and append them
97 /// to the `results` vector.
ConvertResults(Operation * op,SmallVectorImpl<Value> & results,ConversionPatternRewriter & rewriter)98 LogicalResult ConvertResults(Operation* op, SmallVectorImpl<Value>& results,
99                              ConversionPatternRewriter& rewriter) {
100   size_t num_operands = results.size();
101   SmallVector<Value, 2> tensor_operands;
102   for (auto result : llvm::enumerate(op->getResults())) {
103     RankedTensorType resultType =
104         result.value().getType().dyn_cast<RankedTensorType>();
105     if (!resultType) return failure();
106 
107     if (resultType.hasStaticShape()) {
108       results.push_back(InsertAlloc(op->getLoc(), result.value(), &rewriter));
109       continue;
110     }
111     auto shape_type_op = dyn_cast<InferShapedTypeOpInterface>(op);
112     if (!shape_type_op) return failure();
113 
114     if (tensor_operands.empty()) {
115       for (auto operand : ArrayRef<Value>(results).take_front(num_operands)) {
116         auto operand_type = operand.getType().dyn_cast<MemRefType>();
117         if (!operand_type) return failure();
118         tensor_operands.push_back(rewriter.create<memref::TensorLoadOp>(
119             op->getLoc(),
120             RankedTensorType::get(operand_type.getShape(),
121                                   operand_type.getElementType()),
122             operand));
123       }
124     }
125 
126     SmallVector<Value, 1> results_shape;
127     auto status = shape_type_op.reifyReturnTypeShapes(rewriter, tensor_operands,
128                                                       results_shape);
129     if (failed(status)) return failure();
130     results.push_back(InsertDynamicAlloc(op->getLoc(), result.value(),
131                                          results_shape[result.index()],
132                                          &rewriter));
133   }
134   return success();
135 }
136 
137 template <typename HloOpTy>
138 class HloToLhloOpConverter : public BaseOpConversion<HloOpTy> {
139  public:
140   using BaseOpConversion<HloOpTy>::BaseOpConversion;
matchAndRewrite(HloOpTy hloOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const141   LogicalResult matchAndRewrite(
142       HloOpTy hloOp, ArrayRef<Value> operands,
143       ConversionPatternRewriter& rewriter) const final {
144     Operation* op = hloOp.getOperation();
145     SmallVector<Value, 4> buffer_args(operands.begin(), operands.end());
146     if (failed(ConvertResults(op, buffer_args, rewriter))) return failure();
147     rewriter.create<mhlo::HloToLhloOp<HloOpTy>>(op->getLoc(), llvm::None,
148                                                 buffer_args, op->getAttrs());
149     rewriter.replaceOp(
150         op, llvm::makeArrayRef(buffer_args).drop_front(operands.size()));
151     return success();
152   }
153 };
154 
155 // This specialization exists so that LMHLO's Dot can be given a specific set of
156 // dimension numbers, when lowering from MHLO's Dot, which does not have
157 // dimension numbers (it uses DotGeneral for this generalized notion of dot
158 // products). When these two dialects are in sync with respect to the
159 // Dot/DotGeneral issue, this specialization should be deleted.
160 template <>
161 class HloToLhloOpConverter<mhlo::DotOp> : public BaseOpConversion<mhlo::DotOp> {
162  public:
163   using BaseOpConversion<mhlo::DotOp>::BaseOpConversion;
matchAndRewrite(mhlo::DotOp hloOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const164   LogicalResult matchAndRewrite(
165       mhlo::DotOp hloOp, ArrayRef<Value> operands,
166       ConversionPatternRewriter& rewriter) const final {
167     Operation* op = hloOp.getOperation();
168     SmallVector<Value, 2> buffer_args(operands.begin(), operands.end());
169     if (failed(ConvertResults(op, buffer_args, rewriter))) return failure();
170 
171     // TODO(silvasean): Move this helper to MLIR core.
172     auto make_elements_attr = [&rewriter](ArrayRef<int64_t> integers) {
173       auto type = RankedTensorType::get({static_cast<int64_t>(integers.size())},
174                                         rewriter.getIntegerType(64));
175       return DenseIntElementsAttr::get(type, integers);
176     };
177     auto dotOp = rewriter.create<lmhlo::DotOp>(op->getLoc(), llvm::None,
178                                                buffer_args, op->getAttrs());
179     // MHLO's Dot uses rank-2 operands, of the form ([N, M], [M, O]) -> [N, O].
180     auto dimension_numbers = mhlo::DotDimensionNumbers::get(
181         make_elements_attr({}), make_elements_attr({}), make_elements_attr({1}),
182         make_elements_attr({0}), rewriter.getContext());
183     dotOp.dot_dimension_numbersAttr(dimension_numbers);
184     rewriter.replaceOp(op, ArrayRef<Value>(buffer_args).slice(operands.size()));
185     return success();
186   }
187 };
188 
189 struct HloToLhloCustomCallOpConverter
190     : public BaseOpConversion<mhlo::CustomCallOp> {
191  public:
192   using BaseOpConversion<mhlo::CustomCallOp>::BaseOpConversion;
193 
matchAndRewritemlir::mhlo::__anon6f36e01b0111::HloToLhloCustomCallOpConverter194   LogicalResult matchAndRewrite(
195       mhlo::CustomCallOp hloOp, ArrayRef<Value> operands,
196       ConversionPatternRewriter& rewriter) const final {
197     Operation* op = hloOp.getOperation();
198     SmallVector<Value, 2> buffer_args(operands.begin(), operands.end());
199     if (failed(ConvertResults(op, buffer_args, rewriter))) return failure();
200 
201     auto lhloOp = rewriter.create<lmhlo::CustomCallOp>(
202         op->getLoc(), llvm::None, buffer_args, op->getAttrs());
203     // Setup AttrSizedOperandSegments attribute to indicate number of operands
204     // for args and outputs.
205     const int32_t segments[2] = {static_cast<int32_t>(operands.size()),
206                                  static_cast<int32_t>(op->getNumResults())};
207     lhloOp->setAttr(lhloOp.getOperandSegmentSizeAttr(),
208                     rewriter.getI32VectorAttr(segments));
209 
210     rewriter.replaceOp(op, ArrayRef<Value>(buffer_args).slice(operands.size()));
211     return success();
212   }
213 };
214 
215 struct HloToLhloDotGeneralOpConverter
216     : public BaseOpConversion<mhlo::DotGeneralOp> {
217   using BaseOpConversion<mhlo::DotGeneralOp>::BaseOpConversion;
matchAndRewritemlir::mhlo::__anon6f36e01b0111::HloToLhloDotGeneralOpConverter218   LogicalResult matchAndRewrite(
219       mhlo::DotGeneralOp dotGeneralOp, ArrayRef<Value> operands,
220       ConversionPatternRewriter& rewriter) const final {
221     Operation* op = dotGeneralOp.getOperation();
222 
223     if (op->getResults().empty()) return failure();
224     OpResult result = op->getResults()[0];
225     RankedTensorType resultType = result.getType().dyn_cast<RankedTensorType>();
226     if (!resultType) return failure();
227 
228     // The third buffer argument will be filled with what used to be the return
229     // type of the DotGeneral.
230     if (operands.size() != 2) return failure();
231     std::array<Value, 3> bufferArgs = {operands[0], operands[1], {}};
232 
233     if (resultType.hasStaticShape()) {
234       bufferArgs[2] = InsertAlloc(op->getLoc(), result, &rewriter);
235     } else {
236       SmallVector<Value, 1> results_shape;
237       auto shape_type_op = dyn_cast<InferShapedTypeOpInterface>(op);
238       if (failed(shape_type_op.reifyReturnTypeShapes(rewriter, operands,
239                                                      results_shape)))
240         return failure();
241 
242       bufferArgs[2] = InsertDynamicAlloc(op->getLoc(), result,
243                                          results_shape.front(), &rewriter);
244     }
245 
246     rewriter.create<lmhlo::DotOp>(op->getLoc(), llvm::None, bufferArgs,
247                                   op->getAttrs());
248     rewriter.replaceOp(op, bufferArgs[2]);
249     return success();
250   }
251 };
252 
253 struct HloToLhloReduceOpConverter : public BaseOpConversion<mhlo::ReduceOp> {
254  public:
255   using BaseOpConversion<mhlo::ReduceOp>::BaseOpConversion;
256 
matchAndRewritemlir::mhlo::__anon6f36e01b0111::HloToLhloReduceOpConverter257   LogicalResult matchAndRewrite(
258       mhlo::ReduceOp op, ArrayRef<Value> operands,
259       ConversionPatternRewriter& rewriter) const final {
260     auto loc = op.getLoc();
261     // TODO(b/137624192) Implement variadic reduce.
262     if (op.getNumResults() != 1) return failure();
263     if (!llvm::hasSingleElement(op.body())) {
264       return op.emitOpError()
265              << "tensor to buffer conversion expects a single block "
266                 "in the region containing the operation";
267     }
268     SmallVector<Value, 4> buffer_args(operands.begin(), operands.end());
269     if (failed(ConvertResults(op, buffer_args, rewriter))) return failure();
270     auto new_op = rewriter.create<lmhlo::ReduceOp>(loc, llvm::None, buffer_args,
271                                                    op->getAttrs());
272 
273     // Copy over the operations inside the region.
274     rewriter.inlineRegionBefore(op.body(), new_op.body(), new_op.body().end());
275 
276     // Convert the region signature to memref and add extra result.
277     auto& entry_block = new_op.body().front();
278     TypeConverter::SignatureConversion sig_conversion(
279         entry_block.getNumArguments() + 1);
280     for (auto arg : entry_block.getArguments()) {
281       auto old_type = arg.getType().cast<TensorType>();
282       auto new_type =
283           MemRefType::get(old_type.getShape(), old_type.getElementType());
284       sig_conversion.addInputs(arg.getArgNumber(), new_type);
285     }
286     auto return_op = cast<mhlo::ReturnOp>(entry_block.getTerminator());
287     auto result_type = return_op.results().front().getType().cast<TensorType>();
288     sig_conversion.addInputs({MemRefType::get(result_type.getShape(),
289                                               result_type.getElementType())});
290     rewriter.applySignatureConversion(&new_op.body(), sig_conversion);
291 
292     rewriter.replaceOp(op, ArrayRef<Value>(buffer_args).slice(operands.size()));
293 
294     return success();
295   }
296 };
297 
298 // Legalize mhlo.return to a lmhlo.copy and lmhlo.terminator.
299 struct HloToLhloReturnOpConverter : public BaseOpConversion<mhlo::ReturnOp> {
300  public:
301   using BaseOpConversion<mhlo::ReturnOp>::BaseOpConversion;
302 
matchAndRewritemlir::mhlo::__anon6f36e01b0111::HloToLhloReturnOpConverter303   LogicalResult matchAndRewrite(
304       mhlo::ReturnOp op, ArrayRef<Value> operands,
305       ConversionPatternRewriter& rewriter) const final {
306     auto loc = op.getLoc();
307     auto& entry_block = op->getParentRegion()->front();
308     auto num_arguments = entry_block.getNumArguments();
309     if (operands.size() > num_arguments) {
310       return op.emitError(
311           "The number of operands that need Copy operations is more "
312           "than the number of target function arguments.");
313     }
314 
315     // The index of the first output block argument.
316     auto dest_arg_idx = num_arguments - operands.size();
317 
318     // Create a lmhlo.copy for each operand of mhlo.return.
319     for (Value operand : operands) {
320       rewriter.create<lmhlo::CopyOp>(loc, operand,
321                                      entry_block.getArgument(dest_arg_idx));
322       ++dest_arg_idx;
323     }
324     rewriter.replaceOpWithNewOp<lmhlo::TerminatorOp>(op);
325     return success();
326   }
327 };
328 
329 // TODO(b/175789537) Remove this pattern.
330 class HloToLhloTensorStoreOpLegacyConverter
331     : public BaseOpConversion<mlir::memref::TensorStoreOp> {
332  public:
333   using BaseOpConversion<mlir::memref::TensorStoreOp>::BaseOpConversion;
334 
matchAndRewrite(mlir::memref::TensorStoreOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const335   LogicalResult matchAndRewrite(
336       mlir::memref::TensorStoreOp op, ArrayRef<Value> operands,
337       ConversionPatternRewriter& rewriter) const final {
338     rewriter.replaceOpWithNewOp<lmhlo::CopyOp>(op, llvm::None, operands.front(),
339                                                operands.back());
340     return success();
341   }
342 };
343 
344 // Lowers from HLO dialect to LHLO dialect allocating/deallocating temporary
345 // buffers if necessary.
346 //
347 // Example fusion with HLO ops.
348 //
349 // func @fusion(%arg0: memref<2x2xf32>,
350 //              %arg1: memref<2x2xf32>,
351 //              %arg2: memref<2x2xf32>,
352 //              %arg3: memref<2x2xf32>) {
353 //   "lmhlo.fusion"() ({
354 //     %0 = tensor_load %arg1 : memref<2x2xf32>
355 //     %1 = tensor_load %arg2 : memref<2x2xf32>
356 //     %2 = "mhlo.add"(%0, %1) :
357 //         (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
358 //     %3 = tensor_load %arg0 : memref<2x2xf32>
359 //     %4 = "mhlo.multiply"(%2, %3) :
360 //         (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
361 //     tensor_store %4, %arg3 : memref<2x2xf32>
362 //     "lmhlo.terminator"() : () -> ()
363 //   }) : () -> ()
364 //   return
365 // }
366 //
367 // Transformed fusion with LHLO ops.
368 // func @fusion(%arg0: memref<2x2xf32>,
369 //              %arg1: memref<2x2xf32>,
370 //              %arg2: memref<2x2xf32>,
371 //              %arg3: memref<2x2xf32>) {
372 //   "lmhlo.fusion"() ( {
373 //     %0 = alloc() : memref<2x2xf32>
374 //     "lmhlo.add"(%arg1, %arg2, %0) :
375 //         (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> ()
376 //     "lmhlo.multiply"(%0, %arg0, %arg3) :
377 //         (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> ()
378 //     "lmhlo.terminator"() : () -> ()
379 //   }) : () -> ()
380 //   return
381 // }
382 //
383 // FuncOp signature conversion example:
384 //
385 // func @func_op(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
386 //   %0 = "mhlo.maximum"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) ->
387 //   tensor<4xf32> %1 = "mhlo.add"(%arg0, %0)  : (tensor<4xf32>,
388 //   tensor<4xf32>) -> tensor<4xf32> return %1 : tensor<4xf32>
389 // }
390 //
391 // Transformed function with an extra argument for the result. The types have
392 // been converted from tensor to memref.
393 //
394 // func @func_op(%arg0: memref<4xf32>,
395 //               %arg1: memref<4xf32>,
396 //               %arg2: memref<4xf32>) {
397 //   %0 = alloc() : memref<4xf32>
398 
399 //   "lmhlo.maximum"(%arg0, %arg1, %0) :
400 //         (memref<4xf32>, memref<4xf32>, memref<4xf32>) -> ()
401 //   %1 = alloc() : memref<4xf32>
402 //   "lmhlo.add"(%arg0, %0, %1) :
403 //         (memref<4xf32>, memref<4xf32>, memref<4xf32>) -> ()
404 //   "lmhlo.copy"(%1, %arg2) : (memref<4xf32>, memref<4xf32>) -> ()
405 //   "lmhlo.terminator"() : () -> ()
406 // }
407 
408 struct HloLegalizeToLhlo : public HloLegalizeToLhloPassBase<HloLegalizeToLhlo> {
409   using HloLegalizeToLhloPassBase<HloLegalizeToLhlo>::HloLegalizeToLhloPassBase;
410 
getDependentDialectsmlir::mhlo::__anon6f36e01b0111::HloLegalizeToLhlo411   void getDependentDialects(DialectRegistry& registry) const override {
412     registry.insert<lmhlo::LmhloDialect, memref::MemRefDialect,
413                     shape::ShapeDialect>();
414   }
415 
416  public:
417   HloLegalizeToLhlo() = default;
HloLegalizeToLhlomlir::mhlo::__anon6f36e01b0111::HloLegalizeToLhlo418   HloLegalizeToLhlo(const HloLegalizeToLhlo& o) {}
419 
runOnOperationmlir::mhlo::__anon6f36e01b0111::HloLegalizeToLhlo420   void runOnOperation() override {
421     auto& context = getContext();
422     OwningRewritePatternList patterns(&context);
423     ConversionTarget target(context);
424     target.addLegalDialect<lmhlo::LmhloDialect>();
425     target.addLegalDialect<StandardOpsDialect>();
426     target.addLegalDialect<memref::MemRefDialect>();
427     target.addLegalDialect<shape::ShapeDialect>();
428     target.addLegalDialect<tensor::TensorDialect>();
429     target.addIllegalDialect<mhlo::MhloDialect>();
430     // Declare tensor_store illegal. tensor_load may be used to reify output
431     // shape computation during dialect conversion and will be handled later.
432     target.addIllegalOp<mlir::memref::TensorStoreOp>();
433     // buffer_cast is illegal if it has uses.
434     // TODO(b/175670649) Make buffer_cast illegal.
435     target.addDynamicallyLegalOp<mlir::memref::BufferCastOp>(
436         [](auto op) { return op->use_empty(); });
437 
438     BufferizeTypeConverter converter;
439     auto isMemRefType = [](Type type) { return type.isa<BaseMemRefType>(); };
440     target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
441       return converter.isSignatureLegal(op.getType()) &&
442              converter.isLegal(&op.getBody());
443     });
444     target.addDynamicallyLegalOp<CallOp>([&](CallOp op) {
445       return std::all_of(op.operand_type_begin(), op.operand_type_end(),
446                          isMemRefType) &&
447              std::all_of(op.result_type_begin(), op.result_type_end(),
448                          isMemRefType);
449     });
450     target.addDynamicallyLegalOp<mlir::ReturnOp>([&](mlir::ReturnOp op) {
451       return std::all_of(op.operand_type_begin(), op.operand_type_end(),
452                          isMemRefType);
453     });
454 
455     populateHLOToLHLOConversionPattern(&context, &converter, &patterns);
456     populateFuncOpTypeConversionPattern(patterns, converter);
457     populateCallOpTypeConversionPattern(patterns, converter);
458     populateBranchOpInterfaceTypeConversionPattern(patterns, converter);
459     populateReturnOpTypeConversionPattern(patterns, converter);
460     populateEliminateBufferizeMaterializationsPatterns(converter, patterns);
461 
462     populateShapeStructuralTypeConversionsAndLegality(converter, patterns,
463                                                       target);
464 
465     // TODO(b/175789537) Remove this pattern.
466     patterns.insert<HloToLhloTensorStoreOpLegacyConverter>(&context);
467 
468     if (failed(applyPartialConversion(getOperation(), target,
469                                       std::move(patterns))))
470       signalPassFailure();
471   }
472 };
473 }  // namespace
474 
475 // Simply lowers all mhlo ops to their lmhlo counterparts.
populateDynamicHLOToLHLOConversionPattern(MLIRContext * context,BufferizeTypeConverter * converter,OwningRewritePatternList * patterns)476 void populateDynamicHLOToLHLOConversionPattern(
477     MLIRContext* context, BufferizeTypeConverter* converter,
478     OwningRewritePatternList* patterns) {
479   // clang-format off
480   patterns->insert<HloToLhloOpConverter<mhlo::DynamicBroadcastInDimOp>,
481                    HloToLhloOpConverter<mhlo::DynamicGatherOp>,
482                    HloToLhloOpConverter<mhlo::DynamicIotaOp>,
483                    HloToLhloOpConverter<mhlo::DynamicPadOp>,
484                    HloToLhloOpConverter<mhlo::DynamicReshapeOp>,
485                    HloToLhloOpConverter<mhlo::RealDynamicSliceOp>
486   >(*converter, context);
487   // clang-format on
488 }
489 
populateHLOToLHLOConversionPattern(MLIRContext * context,BufferizeTypeConverter * converter,OwningRewritePatternList * patterns)490 void populateHLOToLHLOConversionPattern(MLIRContext* context,
491                                         BufferizeTypeConverter* converter,
492                                         OwningRewritePatternList* patterns) {
493   populateDynamicHLOToLHLOConversionPattern(context, converter, patterns);
494 
495   // clang-format off
496   patterns->insert<
497       HloToLhloCustomCallOpConverter,
498       HloToLhloDotGeneralOpConverter,
499       HloToLhloOpConverter<mhlo::AbsOp>,
500       HloToLhloOpConverter<mhlo::AddOp>,
501       HloToLhloOpConverter<mhlo::AndOp>,
502       HloToLhloOpConverter<mhlo::Atan2Op>,
503       HloToLhloOpConverter<mhlo::BroadcastInDimOp>,
504       HloToLhloOpConverter<mhlo::CeilOp>,
505       HloToLhloOpConverter<mhlo::CompareOp>,
506       HloToLhloOpConverter<mhlo::ComplexOp>,
507       HloToLhloOpConverter<mhlo::ConcatenateOp>,
508       HloToLhloOpConverter<mhlo::ConstOp>,
509       HloToLhloOpConverter<mhlo::ConvOp>,
510       HloToLhloOpConverter<mhlo::ConvertOp>,
511       HloToLhloOpConverter<mhlo::CopyOp>,
512       HloToLhloOpConverter<mhlo::CosOp>,
513       HloToLhloOpConverter<mhlo::DivOp>,
514       HloToLhloOpConverter<mhlo::DotOp>,
515       HloToLhloOpConverter<mhlo::ExpOp>,
516       HloToLhloOpConverter<mhlo::Expm1Op>,
517       HloToLhloOpConverter<mhlo::FloorOp>,
518       HloToLhloOpConverter<mhlo::GatherOp>,
519       HloToLhloOpConverter<mhlo::ImagOp>,
520       HloToLhloOpConverter<mhlo::IotaOp>,
521       HloToLhloOpConverter<mhlo::IsFiniteOp>,
522       HloToLhloOpConverter<mhlo::LogOp>,
523       HloToLhloOpConverter<mhlo::LogisticOp>,
524       HloToLhloOpConverter<mhlo::MaxOp>,
525       HloToLhloOpConverter<mhlo::MinOp>,
526       HloToLhloOpConverter<mhlo::MulOp>,
527       HloToLhloOpConverter<mhlo::NegOp>,
528       HloToLhloOpConverter<mhlo::NotOp>,
529       HloToLhloOpConverter<mhlo::OrOp>,
530       HloToLhloOpConverter<mhlo::PowOp>,
531       HloToLhloOpConverter<mhlo::RealOp>,
532       HloToLhloOpConverter<mhlo::RemOp>,
533       HloToLhloOpConverter<mhlo::RsqrtOp>,
534       HloToLhloOpConverter<mhlo::ReshapeOp>,
535       HloToLhloOpConverter<mhlo::SelectOp>,
536       HloToLhloOpConverter<mhlo::ShiftLeftOp>,
537       HloToLhloOpConverter<mhlo::ShiftRightArithmeticOp>,
538       HloToLhloOpConverter<mhlo::ShiftRightLogicalOp>,
539       HloToLhloOpConverter<mhlo::SignOp>,
540       HloToLhloOpConverter<mhlo::SinOp>,
541       HloToLhloOpConverter<mhlo::SliceOp>,
542       HloToLhloOpConverter<mhlo::SqrtOp>,
543       HloToLhloOpConverter<mhlo::SubOp>,
544       HloToLhloOpConverter<mhlo::TanhOp>,
545       HloToLhloOpConverter<mhlo::TransposeOp>,
546       HloToLhloOpConverter<mhlo::XorOp>,
547       HloToLhloReduceOpConverter,
548       HloToLhloReturnOpConverter
549   >(*converter, context);
550   // clang-format on
551 }
552 
createLegalizeToLhloPass()553 std::unique_ptr<OperationPass<ModuleOp>> createLegalizeToLhloPass() {
554   return std::make_unique<HloLegalizeToLhlo>();
555 }
556 
557 }  // namespace mhlo
558 }  // namespace mlir
559