• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This file implements logic for lowering HLO dialect to LHLO dialect.
17 
18 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
19 #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
20 #include "mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h"
21 #include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
22 #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
23 #include "mlir/Dialect/Shape/IR/Shape.h"
24 #include "mlir/Dialect/Shape/Transforms/Passes.h"
25 #include "mlir/Dialect/StandardOps/IR/Ops.h"
26 #include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h"
27 #include "mlir/Dialect/Tensor/IR/Tensor.h"
28 #include "mlir/IR/AffineMap.h"
29 #include "mlir/IR/Attributes.h"
30 #include "mlir/IR/BlockAndValueMapping.h"
31 #include "mlir/IR/Builders.h"
32 #include "mlir/IR/BuiltinOps.h"
33 #include "mlir/IR/BuiltinTypes.h"
34 #include "mlir/IR/Location.h"
35 #include "mlir/IR/MLIRContext.h"
36 #include "mlir/IR/Operation.h"
37 #include "mlir/IR/PatternMatch.h"
38 #include "mlir/Pass/Pass.h"
39 #include "mlir/Transforms/Bufferize.h"
40 #include "mlir/Transforms/DialectConversion.h"
41 
42 namespace mlir {
43 namespace mhlo {
44 namespace {
45 
46 template <typename T>
47 using BaseOpConversion = OpConversionPattern<T>;
48 
InsertDynamicAllocAndDealloc(Location loc,Value result,Value shape_operand,ConversionPatternRewriter * rewriter)49 Value InsertDynamicAllocAndDealloc(Location loc, Value result,
50                                    Value shape_operand,
51                                    ConversionPatternRewriter* rewriter) {
52   auto result_type = result.getType().dyn_cast<RankedTensorType>();
53   if (!result_type) {
54     result.getDefiningOp()->emitOpError()
55         << "tensor to buffer conversion expects ranked results";
56   }
57   auto memref_type =
58       MemRefType::get(result_type.getShape(), result_type.getElementType());
59 
60   // Extract the required element out of the vector.
61   SmallVector<Value, 4> dynamic_operands;
62   for (auto shape_element : llvm::enumerate(result_type.getShape())) {
63     if (shape_element.value() != ShapedType::kDynamicSize) continue;
64     Value index = rewriter->create<ConstantIndexOp>(loc, shape_element.index());
65     Value alloc_operand =
66         rewriter->create<tensor::ExtractOp>(loc, shape_operand, index);
67     if (!alloc_operand.getType().isIndex()) {
68       alloc_operand = rewriter->create<IndexCastOp>(loc, alloc_operand,
69                                                     rewriter->getIndexType());
70     }
71     dynamic_operands.push_back(alloc_operand);
72   }
73 
74   return rewriter->create<AllocOp>(loc, memref_type, dynamic_operands);
75 }
76 
InsertAlloc(Location loc,OpResult result,ConversionPatternRewriter * rewriter)77 Value InsertAlloc(Location loc, OpResult result,
78                   ConversionPatternRewriter* rewriter) {
79   auto result_type = result.getType().dyn_cast<RankedTensorType>();
80   if (!result_type || !result_type.hasStaticShape()) {
81     result.getDefiningOp()->emitOpError()
82         << "tensor to buffer conversion expects statically shaped results";
83   }
84   auto memref_type =
85       MemRefType::get(result_type.getShape(), result_type.getElementType());
86   OpBuilder::InsertionGuard guard(*rewriter);
87   rewriter->setInsertionPoint(result.getDefiningOp());
88   auto alloc = rewriter->create<AllocOp>(loc, memref_type);
89   return alloc;
90 }
91 
92 /// Converts the results of the operation `op` to memref types and append them
93 /// to the `results` vector.
ConvertResults(Operation * op,SmallVectorImpl<Value> & results,ConversionPatternRewriter & rewriter)94 LogicalResult ConvertResults(Operation* op, SmallVectorImpl<Value>& results,
95                              ConversionPatternRewriter& rewriter) {
96   for (auto result : llvm::enumerate(op->getResults())) {
97     RankedTensorType resultType =
98         result.value().getType().dyn_cast<RankedTensorType>();
99     if (!resultType) return failure();
100 
101     if (resultType.hasStaticShape()) {
102       results.push_back(InsertAlloc(op->getLoc(), result.value(), &rewriter));
103       continue;
104     }
105     auto shape_type_op = dyn_cast<InferShapedTypeOpInterface>(op);
106     if (!shape_type_op) return failure();
107 
108     SmallVector<Value, 1> results_shape;
109     auto status = shape_type_op.reifyReturnTypeShapes(rewriter, results_shape);
110     if (failed(status)) return failure();
111     results.push_back(
112         InsertDynamicAllocAndDealloc(op->getLoc(), result.value(),
113                                      results_shape[result.index()], &rewriter));
114   }
115   return success();
116 }
117 
118 template <typename HloOpTy>
119 class HloToLhloOpConverter : public BaseOpConversion<HloOpTy> {
120  public:
121   using BaseOpConversion<HloOpTy>::BaseOpConversion;
matchAndRewrite(HloOpTy hloOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const122   LogicalResult matchAndRewrite(
123       HloOpTy hloOp, ArrayRef<Value> operands,
124       ConversionPatternRewriter& rewriter) const final {
125     Operation* op = hloOp.getOperation();
126     SmallVector<Value, 4> buffer_args(operands.begin(), operands.end());
127     if (failed(ConvertResults(op, buffer_args, rewriter))) return failure();
128     rewriter.create<mhlo::HloToLhloOp<HloOpTy>>(op->getLoc(), llvm::None,
129                                                 buffer_args, op->getAttrs());
130     rewriter.replaceOp(
131         op, llvm::makeArrayRef(buffer_args).drop_front(operands.size()));
132     return success();
133   }
134 };
135 
136 // This specialization exists so that LMHLO's Dot can be given a specific set of
137 // dimension numbers, when lowering from MHLO's Dot, which does not have
138 // dimension numbers (it uses DotGeneral for this generalized notion of dot
139 // products). When these two dialects are in sync with respect to the
140 // Dot/DotGeneral issue, this specialization should be deleted.
141 template <>
142 class HloToLhloOpConverter<mhlo::DotOp> : public BaseOpConversion<mhlo::DotOp> {
143  public:
144   using BaseOpConversion<mhlo::DotOp>::BaseOpConversion;
matchAndRewrite(mhlo::DotOp hloOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const145   LogicalResult matchAndRewrite(
146       mhlo::DotOp hloOp, ArrayRef<Value> operands,
147       ConversionPatternRewriter& rewriter) const final {
148     Operation* op = hloOp.getOperation();
149     SmallVector<Value, 2> buffer_args(operands.begin(), operands.end());
150     if (failed(ConvertResults(op, buffer_args, rewriter))) return failure();
151 
152     // TODO(silvasean): Move this helper to MLIR core.
153     auto make_elements_attr = [&rewriter](ArrayRef<int64_t> integers) {
154       auto type = RankedTensorType::get({static_cast<int64_t>(integers.size())},
155                                         rewriter.getIntegerType(64));
156       return DenseIntElementsAttr::get(type, integers);
157     };
158     auto dotOp = rewriter.create<lmhlo::DotOp>(op->getLoc(), llvm::None,
159                                                buffer_args, op->getAttrs());
160     // MHLO's Dot uses rank-2 operands, of the form ([N, M], [M, O]) -> [N, O].
161     auto dimension_numbers = mhlo::DotDimensionNumbers::get(
162         make_elements_attr({}), make_elements_attr({}), make_elements_attr({1}),
163         make_elements_attr({0}), rewriter.getContext());
164     dotOp.dot_dimension_numbersAttr(dimension_numbers);
165     rewriter.replaceOp(op, ArrayRef<Value>(buffer_args).slice(operands.size()));
166     return success();
167   }
168 };
169 
170 struct HloToLhloCustomCallOpConverter
171     : public BaseOpConversion<mhlo::CustomCallOp> {
172  public:
173   using BaseOpConversion<mhlo::CustomCallOp>::BaseOpConversion;
174 
matchAndRewritemlir::mhlo::__anon03bb23560111::HloToLhloCustomCallOpConverter175   LogicalResult matchAndRewrite(
176       mhlo::CustomCallOp hloOp, ArrayRef<Value> operands,
177       ConversionPatternRewriter& rewriter) const final {
178     Operation* op = hloOp.getOperation();
179     SmallVector<Value, 2> buffer_args(operands.begin(), operands.end());
180     if (failed(ConvertResults(op, buffer_args, rewriter))) return failure();
181 
182     auto lhloOp = rewriter.create<lmhlo::CustomCallOp>(
183         op->getLoc(), llvm::None, buffer_args, op->getAttrs());
184     // Setup AttrSizedOperandSegments attribute to indicate number of operands
185     // for args and outputs.
186     const int32_t segments[2] = {static_cast<int32_t>(operands.size()),
187                                  static_cast<int32_t>(op->getNumResults())};
188     lhloOp->setAttr(lhloOp.getOperandSegmentSizeAttr(),
189                     rewriter.getI32VectorAttr(segments));
190 
191     rewriter.replaceOp(op, ArrayRef<Value>(buffer_args).slice(operands.size()));
192     return success();
193   }
194 };
195 
196 class HloToLhloReshapeUnrankedConverter
197     : public BaseOpConversion<mhlo::ReshapeOp> {
198  public:
199   using BaseOpConversion<mhlo::ReshapeOp>::BaseOpConversion;
200 
matchAndRewrite(mhlo::ReshapeOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const201   LogicalResult matchAndRewrite(
202       mhlo::ReshapeOp op, ArrayRef<Value> operands,
203       ConversionPatternRewriter& rewriter) const final {
204     mhlo::ReshapeOp::Adaptor adaptor(operands);
205     auto unranked_operand_type =
206         adaptor.operand().getType().dyn_cast<UnrankedMemRefType>();
207     if (unranked_operand_type == nullptr) return failure();
208 
209     auto result_type = op.getType().cast<RankedTensorType>();
210     rewriter.replaceOpWithNewOp<MemRefCastOp>(
211         op, adaptor.operand(),
212         MemRefType::get(result_type.getShape(), result_type.getElementType()));
213     return success();
214   }
215 };
216 
217 // TODO(pifon): Consider inserting lhlo.copy as in
218 // HloToLhloDynamicBroadcastInDimOpConverter.
219 class HloToLhloDynamicReshapeConverter
220     : public BaseOpConversion<mhlo::DynamicReshapeOp> {
221  public:
222   using BaseOpConversion<mhlo::DynamicReshapeOp>::BaseOpConversion;
223 
matchAndRewrite(mhlo::DynamicReshapeOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const224   LogicalResult matchAndRewrite(
225       mhlo::DynamicReshapeOp op, ArrayRef<Value> operands,
226       ConversionPatternRewriter& rewriter) const final {
227     Type result_type;
228     if (auto ranked_type = op.getType().dyn_cast<RankedTensorType>()) {
229       result_type =
230           MemRefType::get(ranked_type.getShape(), ranked_type.getElementType());
231     } else if (auto unranked_type =
232                    op.getType().dyn_cast<UnrankedTensorType>()) {
233       result_type = UnrankedMemRefType::get(unranked_type.getElementType(), 0);
234     } else {
235       return failure();
236     }
237     mhlo::DynamicReshapeOp::Adaptor adaptor(operands);
238     rewriter.replaceOpWithNewOp<MemRefReshapeOp>(
239         op, result_type, adaptor.operand(), adaptor.output_shape());
240     return success();
241   }
242 };
243 
244 // TODO(b/175670649) Fix this to no longer access original tensor operands.
245 class HloToLhloDynamicBroadcastInDimOpConverter
246     : public BaseOpConversion<mhlo::DynamicBroadcastInDimOp> {
247  public:
HloToLhloDynamicBroadcastInDimOpConverter(TypeConverter & converter,MLIRContext * ctx,bool insert_copy=true)248   HloToLhloDynamicBroadcastInDimOpConverter(TypeConverter& converter,
249                                             MLIRContext* ctx,
250                                             bool insert_copy = true)
251       : BaseOpConversion<mhlo::DynamicBroadcastInDimOp>(converter, ctx),
252         insert_copy_(insert_copy) {}
253 
matchAndRewrite(mhlo::DynamicBroadcastInDimOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const254   LogicalResult matchAndRewrite(
255       mhlo::DynamicBroadcastInDimOp op, ArrayRef<Value> operands,
256       ConversionPatternRewriter& rewriter) const final {
257     if (!op.getType().isa<RankedTensorType>()) return failure();
258     Value result = InsertDynamicMemrefCastOp(op, operands.front(), &rewriter);
259 
260     if (insert_copy_) {
261       auto loc = op.getLoc();
262       Value result_buffer = InsertDynamicAllocAndDealloc(
263           loc, op.getResult(), op.output_dimensions(), &rewriter);
264 
265       rewriter.create<lmhlo::CopyOp>(loc, result, result_buffer);
266       result = result_buffer;
267     }
268     rewriter.replaceOp(op, {result});
269     return success();
270   }
271 
272  private:
273   // Inserts dynamic memref to change the layout of the memref to put 0-stride
274   // and size of the target dimension if size-1 dimension expansion is
275   // necessary.
InsertDynamicMemrefCastOp(mhlo::DynamicBroadcastInDimOp op,Value operand,OpBuilder * b) const276   MemRefReinterpretCastOp InsertDynamicMemrefCastOp(
277       mhlo::DynamicBroadcastInDimOp op, Value operand, OpBuilder* b) const {
278     auto loc = op.getLoc();
279     auto operand_type = operand.getType().cast<MemRefType>();
280     auto operand_shape = operand_type.getShape();
281     auto operand_rank = operand_type.getRank();
282 
283     auto result_type = op.getType().cast<RankedTensorType>();
284     auto result_rank = result_type.getRank();
285 
286     Value zero = b->create<ConstantIndexOp>(loc, 0);
287     Value one = b->create<ConstantIndexOp>(loc, 1);
288 
289     // Compute a reversed scan product. Compute the stride for the dimensions so
290     // far, working from minor to major dimensions. Additionally, save the
291     // operand shape Values to use in the next loop.
292     SmallVector<Value, 2> operand_strides(operand_rank, one);
293     SmallVector<Value, 2> operand_sizes(operand_rank, one);
294     Value stride_so_far = one;
295     for (int i = operand_rank - 1; i >= 0; --i) {
296       Value operand_dim_size =
297           ShapedType::isDynamic(operand_shape[i])
298               ? b->create<DimOp>(loc, operand, i).getResult()
299               : b->create<ConstantIndexOp>(loc, operand_shape[i]).getResult();
300       operand_sizes[i] = operand_dim_size;
301 
302       operand_strides[i] = stride_so_far;
303       if (i > 0) {
304         stride_so_far = b->create<MulIOp>(loc, stride_so_far, operand_dim_size);
305       }
306     }
307 
308     SmallVector<OpFoldResult, 2> sizes, strides;
309     sizes.reserve(result_rank);
310     strides.reserve(result_rank);
311 
312     DenseMap<int, int> output_to_input_dim;
313     for (auto dim : llvm::enumerate(op.broadcast_dimensions())) {
314       output_to_input_dim[dim.value().getSExtValue()] = dim.index();
315     }
316     for (int i = 0; i < result_rank; ++i) {
317       Value i_val = b->create<ConstantIndexOp>(loc, i);
318       Value result_dim_size =
319           b->create<tensor::ExtractOp>(loc, op.output_dimensions(), i_val);
320       if (!result_dim_size.getType().isIndex()) {
321         result_dim_size =
322             b->create<IndexCastOp>(loc, result_dim_size, b->getIndexType());
323       }
324       sizes.push_back(result_dim_size);
325 
326       auto it = output_to_input_dim.find(i);
327       // If the rank of the output is greater than the rank of the input, i.e.
328       // there was no output dimension in the inverse broadcast_dimensions map
329       // we also set stride to 0 to emulate padding of the shape with 1s and the
330       // corresponding expansion.
331       if (it == output_to_input_dim.end()) {
332         strides.push_back(zero);
333         continue;
334       }
335 
336       // There can be two cases:
337       // 1) Operand dim == result dim => expansion is not needed
338       //    => stride flattened buffer stride
339       // 2) Operand dim < result dim => expansion is needed => stride := 0.
340       int dim = it->second;
341       Value is_expansion = b->create<CmpIOp>(
342           loc, CmpIPredicate::slt, operand_sizes[dim], result_dim_size);
343       Value select = b->create<mlir::SelectOp>(loc, is_expansion, zero,
344                                                operand_strides[dim]);
345       strides.push_back(select);
346     }
347 
348     // Type-erased memref type with static rank, dynamic sizes and strides.
349     SmallVector<int64_t, 2> dynamic_layout(result_rank,
350                                            MemRefType::kDynamicStrideOrOffset);
351     SmallVector<int64_t, 2> dynamic_shape(result_rank,
352                                           MemRefType::kDynamicSize);
353     auto type_erased_memref_type = MemRefType::get(
354         dynamic_shape, operand_type.getElementType(),
355         makeStridedLinearLayoutMap(dynamic_layout,
356                                    /*offset=*/0, b->getContext()));
357 
358     auto transformed_operand = b->create<MemRefReinterpretCastOp>(
359         loc, type_erased_memref_type, operand,
360         /*offset=*/b->getI64IntegerAttr(0), sizes, strides);
361     return transformed_operand;
362   }
363 
364   // Keep the copy semantics and allocate a buffer for the result of the memref
365   // cast.
366   bool insert_copy_;
367 };
368 
369 struct HloToLhloDotGeneralOpConverter
370     : public BaseOpConversion<mhlo::DotGeneralOp> {
371   using BaseOpConversion<mhlo::DotGeneralOp>::BaseOpConversion;
matchAndRewritemlir::mhlo::__anon03bb23560111::HloToLhloDotGeneralOpConverter372   LogicalResult matchAndRewrite(
373       mhlo::DotGeneralOp dotGeneralOp, ArrayRef<Value> operands,
374       ConversionPatternRewriter& rewriter) const final {
375     Operation* op = dotGeneralOp.getOperation();
376 
377     if (op->getResults().empty()) return failure();
378     OpResult result = op->getResults()[0];
379     RankedTensorType resultType = result.getType().dyn_cast<RankedTensorType>();
380     if (!resultType) return failure();
381 
382     // The third buffer argument will be filled with what used to be the return
383     // type of the DotGeneral.
384     if (operands.size() != 2) return failure();
385     std::array<Value, 3> bufferArgs = {operands[0], operands[1], {}};
386 
387     if (resultType.hasStaticShape()) {
388       bufferArgs[2] = InsertAlloc(op->getLoc(), result, &rewriter);
389     } else {
390       SmallVector<Value, 1> results_shape;
391       auto shape_type_op = dyn_cast<InferShapedTypeOpInterface>(op);
392       if (failed(shape_type_op.reifyReturnTypeShapes(rewriter, results_shape)))
393         return failure();
394 
395       bufferArgs[2] = InsertDynamicAllocAndDealloc(
396           op->getLoc(), result, results_shape.front(), &rewriter);
397     }
398 
399     rewriter.create<lmhlo::DotOp>(op->getLoc(), llvm::None, bufferArgs,
400                                   op->getAttrs());
401     rewriter.replaceOp(op, bufferArgs[2]);
402     return success();
403   }
404 };
405 
406 struct HloToLhloReduceOpConverter : public BaseOpConversion<mhlo::ReduceOp> {
407  public:
408   using BaseOpConversion<mhlo::ReduceOp>::BaseOpConversion;
409 
matchAndRewritemlir::mhlo::__anon03bb23560111::HloToLhloReduceOpConverter410   LogicalResult matchAndRewrite(
411       mhlo::ReduceOp op, ArrayRef<Value> operands,
412       ConversionPatternRewriter& rewriter) const final {
413     auto loc = op.getLoc();
414     // TODO(b/137624192) Implement variadic reduce.
415     if (op.getNumResults() != 1) return failure();
416     if (!llvm::hasSingleElement(op.body())) {
417       return op.emitOpError()
418              << "tensor to buffer conversion expects a single block "
419                 "in the region containing the operation";
420     }
421     const auto& original_results = op.getResults();
422     SmallVector<Value, 4> buffer_args(operands.begin(), operands.end());
423     for (auto result : original_results) {
424       buffer_args.push_back(InsertAlloc(loc, result, &rewriter));
425     }
426     auto new_op = rewriter.create<lmhlo::ReduceOp>(loc, llvm::None, buffer_args,
427                                                    op.getAttrs());
428 
429     // Copy over the operations inside the region.
430     rewriter.inlineRegionBefore(op.body(), new_op.body(), new_op.body().end());
431 
432     // Convert the region signature to memref and add extra result.
433     auto& entry_block = new_op.body().front();
434     TypeConverter::SignatureConversion sig_conversion(
435         entry_block.getNumArguments() + 1);
436     for (auto arg : entry_block.getArguments()) {
437       auto old_type = arg.getType().cast<TensorType>();
438       auto new_type =
439           MemRefType::get(old_type.getShape(), old_type.getElementType());
440       sig_conversion.addInputs(arg.getArgNumber(), new_type);
441     }
442     auto return_op = cast<mhlo::ReturnOp>(entry_block.getTerminator());
443     auto result_type = return_op.results().front().getType().cast<TensorType>();
444     sig_conversion.addInputs({MemRefType::get(result_type.getShape(),
445                                               result_type.getElementType())});
446     rewriter.applySignatureConversion(&new_op.body(), sig_conversion);
447 
448     rewriter.replaceOp(op, ArrayRef<Value>(buffer_args).slice(operands.size()));
449 
450     return success();
451   }
452 };
453 
454 // Legalize mhlo.return to a lmhlo.copy and lmhlo.terminator.
455 struct HloToLhloReturnOpConverter : public BaseOpConversion<mhlo::ReturnOp> {
456  public:
457   using BaseOpConversion<mhlo::ReturnOp>::BaseOpConversion;
458 
matchAndRewritemlir::mhlo::__anon03bb23560111::HloToLhloReturnOpConverter459   LogicalResult matchAndRewrite(
460       mhlo::ReturnOp op, ArrayRef<Value> operands,
461       ConversionPatternRewriter& rewriter) const final {
462     auto loc = op.getLoc();
463     auto& entry_block = op->getParentRegion()->front();
464     auto num_arguments = entry_block.getNumArguments();
465     if (operands.size() > num_arguments) {
466       return op.emitError(
467           "The number of operands that need Copy operations is more "
468           "than the number of target function arguments.");
469     }
470 
471     // The index of the first output block argument.
472     auto dest_arg_idx = num_arguments - operands.size();
473 
474     // Create a lmhlo.copy for each operand of mhlo.return.
475     for (Value operand : operands) {
476       rewriter.create<lmhlo::CopyOp>(loc, operand,
477                                      entry_block.getArgument(dest_arg_idx));
478       ++dest_arg_idx;
479     }
480     rewriter.replaceOpWithNewOp<lmhlo::TerminatorOp>(op);
481     return success();
482   }
483 };
484 
485 // TODO(b/175789537) Remove this pattern.
486 class HloToLhloTensorStoreOpLegacyConverter
487     : public BaseOpConversion<mlir::TensorStoreOp> {
488  public:
489   using BaseOpConversion<mlir::TensorStoreOp>::BaseOpConversion;
490 
matchAndRewrite(mlir::TensorStoreOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const491   LogicalResult matchAndRewrite(
492       mlir::TensorStoreOp op, ArrayRef<Value> operands,
493       ConversionPatternRewriter& rewriter) const final {
494     rewriter.replaceOpWithNewOp<lmhlo::CopyOp>(op, llvm::None, operands.front(),
495                                                operands.back());
496     return success();
497   }
498 };
499 
500 // Lowers from HLO dialect to LHLO dialect allocating/deallocating temporary
501 // buffers if necessary.
502 //
503 // Example fusion with HLO ops.
504 //
505 // func @fusion(%arg0: memref<2x2xf32>,
506 //              %arg1: memref<2x2xf32>,
507 //              %arg2: memref<2x2xf32>,
508 //              %arg3: memref<2x2xf32>) {
509 //   "lmhlo.fusion"() ({
510 //     %0 = tensor_load %arg1 : memref<2x2xf32>
511 //     %1 = tensor_load %arg2 : memref<2x2xf32>
512 //     %2 = "mhlo.add"(%0, %1) :
513 //         (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
514 //     %3 = tensor_load %arg0 : memref<2x2xf32>
515 //     %4 = "mhlo.multiply"(%2, %3) :
516 //         (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
517 //     tensor_store %4, %arg3 : memref<2x2xf32>
518 //     "lmhlo.terminator"() : () -> ()
519 //   }) : () -> ()
520 //   return
521 // }
522 //
523 // Transformed fusion with LHLO ops.
524 // func @fusion(%arg0: memref<2x2xf32>,
525 //              %arg1: memref<2x2xf32>,
526 //              %arg2: memref<2x2xf32>,
527 //              %arg3: memref<2x2xf32>) {
528 //   "lmhlo.fusion"() ( {
529 //     %0 = alloc() : memref<2x2xf32>
530 //     "lmhlo.add"(%arg1, %arg2, %0) :
531 //         (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> ()
532 //     "lmhlo.multiply"(%0, %arg0, %arg3) :
533 //         (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> ()
534 //     "lmhlo.terminator"() : () -> ()
535 //   }) : () -> ()
536 //   return
537 // }
538 //
539 // FuncOp signature conversion example:
540 //
541 // func @func_op(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
542 //   %0 = "mhlo.maximum"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) ->
543 //   tensor<4xf32> %1 = "mhlo.add"(%arg0, %0)  : (tensor<4xf32>,
544 //   tensor<4xf32>) -> tensor<4xf32> return %1 : tensor<4xf32>
545 // }
546 //
547 // Transformed function with an extra argument for the result. The types have
548 // been converted from tensor to memref.
549 //
550 // func @func_op(%arg0: memref<4xf32>,
551 //               %arg1: memref<4xf32>,
552 //               %arg2: memref<4xf32>) {
553 //   %0 = alloc() : memref<4xf32>
554 
555 //   "lmhlo.maximum"(%arg0, %arg1, %0) :
556 //         (memref<4xf32>, memref<4xf32>, memref<4xf32>) -> ()
557 //   %1 = alloc() : memref<4xf32>
558 //   "lmhlo.add"(%arg0, %0, %1) :
559 //         (memref<4xf32>, memref<4xf32>, memref<4xf32>) -> ()
560 //   "lmhlo.copy"(%1, %arg2) : (memref<4xf32>, memref<4xf32>) -> ()
561 //   "lmhlo.terminator"() : () -> ()
562 // }
563 
564 struct HloLegalizeToLhlo
565     : public PassWrapper<HloLegalizeToLhlo, OperationPass<ModuleOp>> {
getDependentDialectsmlir::mhlo::__anon03bb23560111::HloLegalizeToLhlo566   void getDependentDialects(DialectRegistry& registry) const override {
567     registry.insert<lmhlo::LmhloDialect>();
568   }
569 
570  public:
571   HloLegalizeToLhlo() = default;
HloLegalizeToLhlomlir::mhlo::__anon03bb23560111::HloLegalizeToLhlo572   HloLegalizeToLhlo(const HloLegalizeToLhlo& o) {}
573 
runOnOperationmlir::mhlo::__anon03bb23560111::HloLegalizeToLhlo574   void runOnOperation() override {
575     OwningRewritePatternList patterns;
576     auto& context = getContext();
577     ConversionTarget target(context);
578     target.addLegalDialect<lmhlo::LmhloDialect>();
579     target.addLegalDialect<StandardOpsDialect>();
580     target.addLegalDialect<tensor::TensorDialect>();
581     target.addIllegalDialect<mhlo::MhloDialect>();
582     // Declare tensor_load and tensor_store illegal.
583     target.addIllegalOp<mlir::TensorLoadOp, mlir::TensorStoreOp>();
584     // tensor_to_memref is illegal if it has uses.
585     // TODO(b/175670649) Make tensor_to_memref illegal.
586     target.addDynamicallyLegalOp<mlir::TensorToMemrefOp>(
587         [](auto op) { return op->use_empty(); });
588 
589     BufferizeTypeConverter converter;
590     auto isMemRefType = [](Type type) { return type.isa<BaseMemRefType>(); };
591     target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
592       auto inputs = op.getType().getInputs();
593       return llvm::all_of(inputs, isMemRefType) &&
594              converter.isLegal(&op.getBody());
595     });
596     target.addDynamicallyLegalOp<CallOp>([&](CallOp op) {
597       return std::all_of(op.operand_type_begin(), op.operand_type_end(),
598                          isMemRefType) &&
599              std::all_of(op.result_type_begin(), op.result_type_end(),
600                          isMemRefType);
601     });
602     target.addDynamicallyLegalOp<mlir::ReturnOp>([&](mlir::ReturnOp op) {
603       return std::all_of(op.operand_type_begin(), op.operand_type_end(),
604                          isMemRefType);
605     });
606 
607     populateHLOToLHLOConversionPattern(&context, &converter, &patterns);
608     populateFuncOpTypeConversionPattern(patterns, &context, converter);
609     populateCallOpTypeConversionPattern(patterns, &context, converter);
610     populateBranchOpInterfaceAndReturnOpTypeConversionPattern(
611         patterns, &context, converter);
612     populateEliminateBufferizeMaterializationsPatterns(&context, converter,
613                                                        patterns);
614 
615     populateShapeStructuralTypeConversionsAndLegality(&context, converter,
616                                                       patterns, target);
617 
618     // TODO(b/175789537) Remove this pattern.
619     patterns.insert<HloToLhloTensorStoreOpLegacyConverter>(&context);
620 
621     if (failed(applyPartialConversion(getOperation(), target,
622                                       std::move(patterns))))
623       signalPassFailure();
624   }
625 };
626 }  // namespace
627 
populateDynamicHLOToLHLOConversionPattern(MLIRContext * context,BufferizeTypeConverter * converter,OwningRewritePatternList * patterns,bool insert_copy)628 void populateDynamicHLOToLHLOConversionPattern(
629     MLIRContext* context, BufferizeTypeConverter* converter,
630     OwningRewritePatternList* patterns, bool insert_copy) {
631   patterns->insert<HloToLhloDynamicBroadcastInDimOpConverter>(
632       *converter, context, insert_copy);
633   patterns->insert<HloToLhloDynamicReshapeConverter,
634                    HloToLhloReshapeUnrankedConverter>(*converter, context);
635 }
636 
populateHLOToLHLOConversionPattern(MLIRContext * context,BufferizeTypeConverter * converter,OwningRewritePatternList * patterns)637 void populateHLOToLHLOConversionPattern(MLIRContext* context,
638                                         BufferizeTypeConverter* converter,
639                                         OwningRewritePatternList* patterns) {
640   populateDynamicHLOToLHLOConversionPattern(context, converter, patterns);
641   // clang-format off
642   patterns->insert<
643       HloToLhloCustomCallOpConverter,
644       HloToLhloDotGeneralOpConverter,
645       HloToLhloOpConverter<mhlo::AbsOp>,
646       HloToLhloOpConverter<mhlo::AddOp>,
647       HloToLhloOpConverter<mhlo::AndOp>,
648       HloToLhloOpConverter<mhlo::Atan2Op>,
649       HloToLhloOpConverter<mhlo::BroadcastInDimOp>,
650       HloToLhloOpConverter<mhlo::CeilOp>,
651       HloToLhloOpConverter<mhlo::CompareOp>,
652       HloToLhloOpConverter<mhlo::ComplexOp>,
653       HloToLhloOpConverter<mhlo::ConstOp>,
654       HloToLhloOpConverter<mhlo::ConvOp>,
655       HloToLhloOpConverter<mhlo::ConvertOp>,
656       HloToLhloOpConverter<mhlo::CopyOp>,
657       HloToLhloOpConverter<mhlo::CosOp>,
658       HloToLhloOpConverter<mhlo::DivOp>,
659       HloToLhloOpConverter<mhlo::DotOp>,
660       HloToLhloOpConverter<mhlo::ExpOp>,
661       HloToLhloOpConverter<mhlo::Expm1Op>,
662       HloToLhloOpConverter<mhlo::FloorOp>,
663       HloToLhloOpConverter<mhlo::GatherOp>,
664       HloToLhloOpConverter<mhlo::ImagOp>,
665       HloToLhloOpConverter<mhlo::IotaOp>,
666       HloToLhloOpConverter<mhlo::IsFiniteOp>,
667       HloToLhloOpConverter<mhlo::LogOp>,
668       HloToLhloOpConverter<mhlo::MaxOp>,
669       HloToLhloOpConverter<mhlo::MinOp>,
670       HloToLhloOpConverter<mhlo::MulOp>,
671       HloToLhloOpConverter<mhlo::NegOp>,
672       HloToLhloOpConverter<mhlo::NotOp>,
673       HloToLhloOpConverter<mhlo::OrOp>,
674       HloToLhloOpConverter<mhlo::PowOp>,
675       HloToLhloOpConverter<mhlo::RealOp>,
676       HloToLhloOpConverter<mhlo::RemOp>,
677       HloToLhloOpConverter<mhlo::RsqrtOp>,
678       HloToLhloOpConverter<mhlo::ReshapeOp>,
679       HloToLhloOpConverter<mhlo::SelectOp>,
680       HloToLhloOpConverter<mhlo::ShiftLeftOp>,
681       HloToLhloOpConverter<mhlo::ShiftRightArithmeticOp>,
682       HloToLhloOpConverter<mhlo::ShiftRightLogicalOp>,
683       HloToLhloOpConverter<mhlo::SignOp>,
684       HloToLhloOpConverter<mhlo::SinOp>,
685       HloToLhloOpConverter<mhlo::SliceOp>,
686       HloToLhloOpConverter<mhlo::SqrtOp>,
687       HloToLhloOpConverter<mhlo::SubOp>,
688       HloToLhloOpConverter<mhlo::TanhOp>,
689       HloToLhloOpConverter<mhlo::TransposeOp>,
690       HloToLhloOpConverter<mhlo::XorOp>,
691       HloToLhloReduceOpConverter,
692       HloToLhloReturnOpConverter
693   >(*converter, context);
694   // clang-format on
695 }
696 
createLegalizeToLhloPass()697 std::unique_ptr<OperationPass<ModuleOp>> createLegalizeToLhloPass() {
698   return std::make_unique<HloLegalizeToLhlo>();
699 }
700 
701 }  // namespace mhlo
702 }  // namespace mlir
703