1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This file implements logic for lowering HLO dialect to LHLO dialect.
17
18 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
19 #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
20 #include "mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h"
21 #include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
22 #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
23 #include "mlir/Dialect/Shape/IR/Shape.h"
24 #include "mlir/Dialect/Shape/Transforms/Passes.h"
25 #include "mlir/Dialect/StandardOps/IR/Ops.h"
26 #include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h"
27 #include "mlir/Dialect/Tensor/IR/Tensor.h"
28 #include "mlir/IR/AffineMap.h"
29 #include "mlir/IR/Attributes.h"
30 #include "mlir/IR/BlockAndValueMapping.h"
31 #include "mlir/IR/Builders.h"
32 #include "mlir/IR/BuiltinOps.h"
33 #include "mlir/IR/BuiltinTypes.h"
34 #include "mlir/IR/Location.h"
35 #include "mlir/IR/MLIRContext.h"
36 #include "mlir/IR/Operation.h"
37 #include "mlir/IR/PatternMatch.h"
38 #include "mlir/Pass/Pass.h"
39 #include "mlir/Transforms/Bufferize.h"
40 #include "mlir/Transforms/DialectConversion.h"
41
42 namespace mlir {
43 namespace mhlo {
44 namespace {
45
46 template <typename T>
47 using BaseOpConversion = OpConversionPattern<T>;
48
InsertDynamicAllocAndDealloc(Location loc,Value result,Value shape_operand,ConversionPatternRewriter * rewriter)49 Value InsertDynamicAllocAndDealloc(Location loc, Value result,
50 Value shape_operand,
51 ConversionPatternRewriter* rewriter) {
52 auto result_type = result.getType().dyn_cast<RankedTensorType>();
53 if (!result_type) {
54 result.getDefiningOp()->emitOpError()
55 << "tensor to buffer conversion expects ranked results";
56 }
57 auto memref_type =
58 MemRefType::get(result_type.getShape(), result_type.getElementType());
59
60 // Extract the required element out of the vector.
61 SmallVector<Value, 4> dynamic_operands;
62 for (auto shape_element : llvm::enumerate(result_type.getShape())) {
63 if (shape_element.value() != ShapedType::kDynamicSize) continue;
64 Value index = rewriter->create<ConstantIndexOp>(loc, shape_element.index());
65 Value alloc_operand =
66 rewriter->create<tensor::ExtractOp>(loc, shape_operand, index);
67 if (!alloc_operand.getType().isIndex()) {
68 alloc_operand = rewriter->create<IndexCastOp>(loc, alloc_operand,
69 rewriter->getIndexType());
70 }
71 dynamic_operands.push_back(alloc_operand);
72 }
73
74 return rewriter->create<AllocOp>(loc, memref_type, dynamic_operands);
75 }
76
InsertAlloc(Location loc,OpResult result,ConversionPatternRewriter * rewriter)77 Value InsertAlloc(Location loc, OpResult result,
78 ConversionPatternRewriter* rewriter) {
79 auto result_type = result.getType().dyn_cast<RankedTensorType>();
80 if (!result_type || !result_type.hasStaticShape()) {
81 result.getDefiningOp()->emitOpError()
82 << "tensor to buffer conversion expects statically shaped results";
83 }
84 auto memref_type =
85 MemRefType::get(result_type.getShape(), result_type.getElementType());
86 OpBuilder::InsertionGuard guard(*rewriter);
87 rewriter->setInsertionPoint(result.getDefiningOp());
88 auto alloc = rewriter->create<AllocOp>(loc, memref_type);
89 return alloc;
90 }
91
92 /// Converts the results of the operation `op` to memref types and append them
93 /// to the `results` vector.
ConvertResults(Operation * op,SmallVectorImpl<Value> & results,ConversionPatternRewriter & rewriter)94 LogicalResult ConvertResults(Operation* op, SmallVectorImpl<Value>& results,
95 ConversionPatternRewriter& rewriter) {
96 for (auto result : llvm::enumerate(op->getResults())) {
97 RankedTensorType resultType =
98 result.value().getType().dyn_cast<RankedTensorType>();
99 if (!resultType) return failure();
100
101 if (resultType.hasStaticShape()) {
102 results.push_back(InsertAlloc(op->getLoc(), result.value(), &rewriter));
103 continue;
104 }
105 auto shape_type_op = dyn_cast<InferShapedTypeOpInterface>(op);
106 if (!shape_type_op) return failure();
107
108 SmallVector<Value, 1> results_shape;
109 auto status = shape_type_op.reifyReturnTypeShapes(rewriter, results_shape);
110 if (failed(status)) return failure();
111 results.push_back(
112 InsertDynamicAllocAndDealloc(op->getLoc(), result.value(),
113 results_shape[result.index()], &rewriter));
114 }
115 return success();
116 }
117
118 template <typename HloOpTy>
119 class HloToLhloOpConverter : public BaseOpConversion<HloOpTy> {
120 public:
121 using BaseOpConversion<HloOpTy>::BaseOpConversion;
matchAndRewrite(HloOpTy hloOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const122 LogicalResult matchAndRewrite(
123 HloOpTy hloOp, ArrayRef<Value> operands,
124 ConversionPatternRewriter& rewriter) const final {
125 Operation* op = hloOp.getOperation();
126 SmallVector<Value, 4> buffer_args(operands.begin(), operands.end());
127 if (failed(ConvertResults(op, buffer_args, rewriter))) return failure();
128 rewriter.create<mhlo::HloToLhloOp<HloOpTy>>(op->getLoc(), llvm::None,
129 buffer_args, op->getAttrs());
130 rewriter.replaceOp(
131 op, llvm::makeArrayRef(buffer_args).drop_front(operands.size()));
132 return success();
133 }
134 };
135
136 // This specialization exists so that LMHLO's Dot can be given a specific set of
137 // dimension numbers, when lowering from MHLO's Dot, which does not have
138 // dimension numbers (it uses DotGeneral for this generalized notion of dot
139 // products). When these two dialects are in sync with respect to the
140 // Dot/DotGeneral issue, this specialization should be deleted.
141 template <>
142 class HloToLhloOpConverter<mhlo::DotOp> : public BaseOpConversion<mhlo::DotOp> {
143 public:
144 using BaseOpConversion<mhlo::DotOp>::BaseOpConversion;
matchAndRewrite(mhlo::DotOp hloOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const145 LogicalResult matchAndRewrite(
146 mhlo::DotOp hloOp, ArrayRef<Value> operands,
147 ConversionPatternRewriter& rewriter) const final {
148 Operation* op = hloOp.getOperation();
149 SmallVector<Value, 2> buffer_args(operands.begin(), operands.end());
150 if (failed(ConvertResults(op, buffer_args, rewriter))) return failure();
151
152 // TODO(silvasean): Move this helper to MLIR core.
153 auto make_elements_attr = [&rewriter](ArrayRef<int64_t> integers) {
154 auto type = RankedTensorType::get({static_cast<int64_t>(integers.size())},
155 rewriter.getIntegerType(64));
156 return DenseIntElementsAttr::get(type, integers);
157 };
158 auto dotOp = rewriter.create<lmhlo::DotOp>(op->getLoc(), llvm::None,
159 buffer_args, op->getAttrs());
160 // MHLO's Dot uses rank-2 operands, of the form ([N, M], [M, O]) -> [N, O].
161 auto dimension_numbers = mhlo::DotDimensionNumbers::get(
162 make_elements_attr({}), make_elements_attr({}), make_elements_attr({1}),
163 make_elements_attr({0}), rewriter.getContext());
164 dotOp.dot_dimension_numbersAttr(dimension_numbers);
165 rewriter.replaceOp(op, ArrayRef<Value>(buffer_args).slice(operands.size()));
166 return success();
167 }
168 };
169
170 struct HloToLhloCustomCallOpConverter
171 : public BaseOpConversion<mhlo::CustomCallOp> {
172 public:
173 using BaseOpConversion<mhlo::CustomCallOp>::BaseOpConversion;
174
matchAndRewritemlir::mhlo::__anon03bb23560111::HloToLhloCustomCallOpConverter175 LogicalResult matchAndRewrite(
176 mhlo::CustomCallOp hloOp, ArrayRef<Value> operands,
177 ConversionPatternRewriter& rewriter) const final {
178 Operation* op = hloOp.getOperation();
179 SmallVector<Value, 2> buffer_args(operands.begin(), operands.end());
180 if (failed(ConvertResults(op, buffer_args, rewriter))) return failure();
181
182 auto lhloOp = rewriter.create<lmhlo::CustomCallOp>(
183 op->getLoc(), llvm::None, buffer_args, op->getAttrs());
184 // Setup AttrSizedOperandSegments attribute to indicate number of operands
185 // for args and outputs.
186 const int32_t segments[2] = {static_cast<int32_t>(operands.size()),
187 static_cast<int32_t>(op->getNumResults())};
188 lhloOp->setAttr(lhloOp.getOperandSegmentSizeAttr(),
189 rewriter.getI32VectorAttr(segments));
190
191 rewriter.replaceOp(op, ArrayRef<Value>(buffer_args).slice(operands.size()));
192 return success();
193 }
194 };
195
196 class HloToLhloReshapeUnrankedConverter
197 : public BaseOpConversion<mhlo::ReshapeOp> {
198 public:
199 using BaseOpConversion<mhlo::ReshapeOp>::BaseOpConversion;
200
matchAndRewrite(mhlo::ReshapeOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const201 LogicalResult matchAndRewrite(
202 mhlo::ReshapeOp op, ArrayRef<Value> operands,
203 ConversionPatternRewriter& rewriter) const final {
204 mhlo::ReshapeOp::Adaptor adaptor(operands);
205 auto unranked_operand_type =
206 adaptor.operand().getType().dyn_cast<UnrankedMemRefType>();
207 if (unranked_operand_type == nullptr) return failure();
208
209 auto result_type = op.getType().cast<RankedTensorType>();
210 rewriter.replaceOpWithNewOp<MemRefCastOp>(
211 op, adaptor.operand(),
212 MemRefType::get(result_type.getShape(), result_type.getElementType()));
213 return success();
214 }
215 };
216
217 // TODO(pifon): Consider inserting lhlo.copy as in
218 // HloToLhloDynamicBroadcastInDimOpConverter.
219 class HloToLhloDynamicReshapeConverter
220 : public BaseOpConversion<mhlo::DynamicReshapeOp> {
221 public:
222 using BaseOpConversion<mhlo::DynamicReshapeOp>::BaseOpConversion;
223
matchAndRewrite(mhlo::DynamicReshapeOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const224 LogicalResult matchAndRewrite(
225 mhlo::DynamicReshapeOp op, ArrayRef<Value> operands,
226 ConversionPatternRewriter& rewriter) const final {
227 Type result_type;
228 if (auto ranked_type = op.getType().dyn_cast<RankedTensorType>()) {
229 result_type =
230 MemRefType::get(ranked_type.getShape(), ranked_type.getElementType());
231 } else if (auto unranked_type =
232 op.getType().dyn_cast<UnrankedTensorType>()) {
233 result_type = UnrankedMemRefType::get(unranked_type.getElementType(), 0);
234 } else {
235 return failure();
236 }
237 mhlo::DynamicReshapeOp::Adaptor adaptor(operands);
238 rewriter.replaceOpWithNewOp<MemRefReshapeOp>(
239 op, result_type, adaptor.operand(), adaptor.output_shape());
240 return success();
241 }
242 };
243
244 // TODO(b/175670649) Fix this to no longer access original tensor operands.
245 class HloToLhloDynamicBroadcastInDimOpConverter
246 : public BaseOpConversion<mhlo::DynamicBroadcastInDimOp> {
247 public:
HloToLhloDynamicBroadcastInDimOpConverter(TypeConverter & converter,MLIRContext * ctx,bool insert_copy=true)248 HloToLhloDynamicBroadcastInDimOpConverter(TypeConverter& converter,
249 MLIRContext* ctx,
250 bool insert_copy = true)
251 : BaseOpConversion<mhlo::DynamicBroadcastInDimOp>(converter, ctx),
252 insert_copy_(insert_copy) {}
253
matchAndRewrite(mhlo::DynamicBroadcastInDimOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const254 LogicalResult matchAndRewrite(
255 mhlo::DynamicBroadcastInDimOp op, ArrayRef<Value> operands,
256 ConversionPatternRewriter& rewriter) const final {
257 if (!op.getType().isa<RankedTensorType>()) return failure();
258 Value result = InsertDynamicMemrefCastOp(op, operands.front(), &rewriter);
259
260 if (insert_copy_) {
261 auto loc = op.getLoc();
262 Value result_buffer = InsertDynamicAllocAndDealloc(
263 loc, op.getResult(), op.output_dimensions(), &rewriter);
264
265 rewriter.create<lmhlo::CopyOp>(loc, result, result_buffer);
266 result = result_buffer;
267 }
268 rewriter.replaceOp(op, {result});
269 return success();
270 }
271
272 private:
273 // Inserts dynamic memref to change the layout of the memref to put 0-stride
274 // and size of the target dimension if size-1 dimension expansion is
275 // necessary.
InsertDynamicMemrefCastOp(mhlo::DynamicBroadcastInDimOp op,Value operand,OpBuilder * b) const276 MemRefReinterpretCastOp InsertDynamicMemrefCastOp(
277 mhlo::DynamicBroadcastInDimOp op, Value operand, OpBuilder* b) const {
278 auto loc = op.getLoc();
279 auto operand_type = operand.getType().cast<MemRefType>();
280 auto operand_shape = operand_type.getShape();
281 auto operand_rank = operand_type.getRank();
282
283 auto result_type = op.getType().cast<RankedTensorType>();
284 auto result_rank = result_type.getRank();
285
286 Value zero = b->create<ConstantIndexOp>(loc, 0);
287 Value one = b->create<ConstantIndexOp>(loc, 1);
288
289 // Compute a reversed scan product. Compute the stride for the dimensions so
290 // far, working from minor to major dimensions. Additionally, save the
291 // operand shape Values to use in the next loop.
292 SmallVector<Value, 2> operand_strides(operand_rank, one);
293 SmallVector<Value, 2> operand_sizes(operand_rank, one);
294 Value stride_so_far = one;
295 for (int i = operand_rank - 1; i >= 0; --i) {
296 Value operand_dim_size =
297 ShapedType::isDynamic(operand_shape[i])
298 ? b->create<DimOp>(loc, operand, i).getResult()
299 : b->create<ConstantIndexOp>(loc, operand_shape[i]).getResult();
300 operand_sizes[i] = operand_dim_size;
301
302 operand_strides[i] = stride_so_far;
303 if (i > 0) {
304 stride_so_far = b->create<MulIOp>(loc, stride_so_far, operand_dim_size);
305 }
306 }
307
308 SmallVector<OpFoldResult, 2> sizes, strides;
309 sizes.reserve(result_rank);
310 strides.reserve(result_rank);
311
312 DenseMap<int, int> output_to_input_dim;
313 for (auto dim : llvm::enumerate(op.broadcast_dimensions())) {
314 output_to_input_dim[dim.value().getSExtValue()] = dim.index();
315 }
316 for (int i = 0; i < result_rank; ++i) {
317 Value i_val = b->create<ConstantIndexOp>(loc, i);
318 Value result_dim_size =
319 b->create<tensor::ExtractOp>(loc, op.output_dimensions(), i_val);
320 if (!result_dim_size.getType().isIndex()) {
321 result_dim_size =
322 b->create<IndexCastOp>(loc, result_dim_size, b->getIndexType());
323 }
324 sizes.push_back(result_dim_size);
325
326 auto it = output_to_input_dim.find(i);
327 // If the rank of the output is greater than the rank of the input, i.e.
328 // there was no output dimension in the inverse broadcast_dimensions map
329 // we also set stride to 0 to emulate padding of the shape with 1s and the
330 // corresponding expansion.
331 if (it == output_to_input_dim.end()) {
332 strides.push_back(zero);
333 continue;
334 }
335
336 // There can be two cases:
337 // 1) Operand dim == result dim => expansion is not needed
338 // => stride flattened buffer stride
339 // 2) Operand dim < result dim => expansion is needed => stride := 0.
340 int dim = it->second;
341 Value is_expansion = b->create<CmpIOp>(
342 loc, CmpIPredicate::slt, operand_sizes[dim], result_dim_size);
343 Value select = b->create<mlir::SelectOp>(loc, is_expansion, zero,
344 operand_strides[dim]);
345 strides.push_back(select);
346 }
347
348 // Type-erased memref type with static rank, dynamic sizes and strides.
349 SmallVector<int64_t, 2> dynamic_layout(result_rank,
350 MemRefType::kDynamicStrideOrOffset);
351 SmallVector<int64_t, 2> dynamic_shape(result_rank,
352 MemRefType::kDynamicSize);
353 auto type_erased_memref_type = MemRefType::get(
354 dynamic_shape, operand_type.getElementType(),
355 makeStridedLinearLayoutMap(dynamic_layout,
356 /*offset=*/0, b->getContext()));
357
358 auto transformed_operand = b->create<MemRefReinterpretCastOp>(
359 loc, type_erased_memref_type, operand,
360 /*offset=*/b->getI64IntegerAttr(0), sizes, strides);
361 return transformed_operand;
362 }
363
364 // Keep the copy semantics and allocate a buffer for the result of the memref
365 // cast.
366 bool insert_copy_;
367 };
368
369 struct HloToLhloDotGeneralOpConverter
370 : public BaseOpConversion<mhlo::DotGeneralOp> {
371 using BaseOpConversion<mhlo::DotGeneralOp>::BaseOpConversion;
matchAndRewritemlir::mhlo::__anon03bb23560111::HloToLhloDotGeneralOpConverter372 LogicalResult matchAndRewrite(
373 mhlo::DotGeneralOp dotGeneralOp, ArrayRef<Value> operands,
374 ConversionPatternRewriter& rewriter) const final {
375 Operation* op = dotGeneralOp.getOperation();
376
377 if (op->getResults().empty()) return failure();
378 OpResult result = op->getResults()[0];
379 RankedTensorType resultType = result.getType().dyn_cast<RankedTensorType>();
380 if (!resultType) return failure();
381
382 // The third buffer argument will be filled with what used to be the return
383 // type of the DotGeneral.
384 if (operands.size() != 2) return failure();
385 std::array<Value, 3> bufferArgs = {operands[0], operands[1], {}};
386
387 if (resultType.hasStaticShape()) {
388 bufferArgs[2] = InsertAlloc(op->getLoc(), result, &rewriter);
389 } else {
390 SmallVector<Value, 1> results_shape;
391 auto shape_type_op = dyn_cast<InferShapedTypeOpInterface>(op);
392 if (failed(shape_type_op.reifyReturnTypeShapes(rewriter, results_shape)))
393 return failure();
394
395 bufferArgs[2] = InsertDynamicAllocAndDealloc(
396 op->getLoc(), result, results_shape.front(), &rewriter);
397 }
398
399 rewriter.create<lmhlo::DotOp>(op->getLoc(), llvm::None, bufferArgs,
400 op->getAttrs());
401 rewriter.replaceOp(op, bufferArgs[2]);
402 return success();
403 }
404 };
405
406 struct HloToLhloReduceOpConverter : public BaseOpConversion<mhlo::ReduceOp> {
407 public:
408 using BaseOpConversion<mhlo::ReduceOp>::BaseOpConversion;
409
matchAndRewritemlir::mhlo::__anon03bb23560111::HloToLhloReduceOpConverter410 LogicalResult matchAndRewrite(
411 mhlo::ReduceOp op, ArrayRef<Value> operands,
412 ConversionPatternRewriter& rewriter) const final {
413 auto loc = op.getLoc();
414 // TODO(b/137624192) Implement variadic reduce.
415 if (op.getNumResults() != 1) return failure();
416 if (!llvm::hasSingleElement(op.body())) {
417 return op.emitOpError()
418 << "tensor to buffer conversion expects a single block "
419 "in the region containing the operation";
420 }
421 const auto& original_results = op.getResults();
422 SmallVector<Value, 4> buffer_args(operands.begin(), operands.end());
423 for (auto result : original_results) {
424 buffer_args.push_back(InsertAlloc(loc, result, &rewriter));
425 }
426 auto new_op = rewriter.create<lmhlo::ReduceOp>(loc, llvm::None, buffer_args,
427 op.getAttrs());
428
429 // Copy over the operations inside the region.
430 rewriter.inlineRegionBefore(op.body(), new_op.body(), new_op.body().end());
431
432 // Convert the region signature to memref and add extra result.
433 auto& entry_block = new_op.body().front();
434 TypeConverter::SignatureConversion sig_conversion(
435 entry_block.getNumArguments() + 1);
436 for (auto arg : entry_block.getArguments()) {
437 auto old_type = arg.getType().cast<TensorType>();
438 auto new_type =
439 MemRefType::get(old_type.getShape(), old_type.getElementType());
440 sig_conversion.addInputs(arg.getArgNumber(), new_type);
441 }
442 auto return_op = cast<mhlo::ReturnOp>(entry_block.getTerminator());
443 auto result_type = return_op.results().front().getType().cast<TensorType>();
444 sig_conversion.addInputs({MemRefType::get(result_type.getShape(),
445 result_type.getElementType())});
446 rewriter.applySignatureConversion(&new_op.body(), sig_conversion);
447
448 rewriter.replaceOp(op, ArrayRef<Value>(buffer_args).slice(operands.size()));
449
450 return success();
451 }
452 };
453
454 // Legalize mhlo.return to a lmhlo.copy and lmhlo.terminator.
455 struct HloToLhloReturnOpConverter : public BaseOpConversion<mhlo::ReturnOp> {
456 public:
457 using BaseOpConversion<mhlo::ReturnOp>::BaseOpConversion;
458
matchAndRewritemlir::mhlo::__anon03bb23560111::HloToLhloReturnOpConverter459 LogicalResult matchAndRewrite(
460 mhlo::ReturnOp op, ArrayRef<Value> operands,
461 ConversionPatternRewriter& rewriter) const final {
462 auto loc = op.getLoc();
463 auto& entry_block = op->getParentRegion()->front();
464 auto num_arguments = entry_block.getNumArguments();
465 if (operands.size() > num_arguments) {
466 return op.emitError(
467 "The number of operands that need Copy operations is more "
468 "than the number of target function arguments.");
469 }
470
471 // The index of the first output block argument.
472 auto dest_arg_idx = num_arguments - operands.size();
473
474 // Create a lmhlo.copy for each operand of mhlo.return.
475 for (Value operand : operands) {
476 rewriter.create<lmhlo::CopyOp>(loc, operand,
477 entry_block.getArgument(dest_arg_idx));
478 ++dest_arg_idx;
479 }
480 rewriter.replaceOpWithNewOp<lmhlo::TerminatorOp>(op);
481 return success();
482 }
483 };
484
485 // TODO(b/175789537) Remove this pattern.
486 class HloToLhloTensorStoreOpLegacyConverter
487 : public BaseOpConversion<mlir::TensorStoreOp> {
488 public:
489 using BaseOpConversion<mlir::TensorStoreOp>::BaseOpConversion;
490
matchAndRewrite(mlir::TensorStoreOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const491 LogicalResult matchAndRewrite(
492 mlir::TensorStoreOp op, ArrayRef<Value> operands,
493 ConversionPatternRewriter& rewriter) const final {
494 rewriter.replaceOpWithNewOp<lmhlo::CopyOp>(op, llvm::None, operands.front(),
495 operands.back());
496 return success();
497 }
498 };
499
500 // Lowers from HLO dialect to LHLO dialect allocating/deallocating temporary
501 // buffers if necessary.
502 //
503 // Example fusion with HLO ops.
504 //
505 // func @fusion(%arg0: memref<2x2xf32>,
506 // %arg1: memref<2x2xf32>,
507 // %arg2: memref<2x2xf32>,
508 // %arg3: memref<2x2xf32>) {
509 // "lmhlo.fusion"() ({
510 // %0 = tensor_load %arg1 : memref<2x2xf32>
511 // %1 = tensor_load %arg2 : memref<2x2xf32>
512 // %2 = "mhlo.add"(%0, %1) :
513 // (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
514 // %3 = tensor_load %arg0 : memref<2x2xf32>
515 // %4 = "mhlo.multiply"(%2, %3) :
516 // (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
517 // tensor_store %4, %arg3 : memref<2x2xf32>
518 // "lmhlo.terminator"() : () -> ()
519 // }) : () -> ()
520 // return
521 // }
522 //
523 // Transformed fusion with LHLO ops.
524 // func @fusion(%arg0: memref<2x2xf32>,
525 // %arg1: memref<2x2xf32>,
526 // %arg2: memref<2x2xf32>,
527 // %arg3: memref<2x2xf32>) {
528 // "lmhlo.fusion"() ( {
529 // %0 = alloc() : memref<2x2xf32>
530 // "lmhlo.add"(%arg1, %arg2, %0) :
531 // (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> ()
532 // "lmhlo.multiply"(%0, %arg0, %arg3) :
533 // (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> ()
534 // "lmhlo.terminator"() : () -> ()
535 // }) : () -> ()
536 // return
537 // }
538 //
539 // FuncOp signature conversion example:
540 //
541 // func @func_op(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
542 // %0 = "mhlo.maximum"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) ->
543 // tensor<4xf32> %1 = "mhlo.add"(%arg0, %0) : (tensor<4xf32>,
544 // tensor<4xf32>) -> tensor<4xf32> return %1 : tensor<4xf32>
545 // }
546 //
547 // Transformed function with an extra argument for the result. The types have
548 // been converted from tensor to memref.
549 //
550 // func @func_op(%arg0: memref<4xf32>,
551 // %arg1: memref<4xf32>,
552 // %arg2: memref<4xf32>) {
553 // %0 = alloc() : memref<4xf32>
554
555 // "lmhlo.maximum"(%arg0, %arg1, %0) :
556 // (memref<4xf32>, memref<4xf32>, memref<4xf32>) -> ()
557 // %1 = alloc() : memref<4xf32>
558 // "lmhlo.add"(%arg0, %0, %1) :
559 // (memref<4xf32>, memref<4xf32>, memref<4xf32>) -> ()
560 // "lmhlo.copy"(%1, %arg2) : (memref<4xf32>, memref<4xf32>) -> ()
561 // "lmhlo.terminator"() : () -> ()
562 // }
563
564 struct HloLegalizeToLhlo
565 : public PassWrapper<HloLegalizeToLhlo, OperationPass<ModuleOp>> {
getDependentDialectsmlir::mhlo::__anon03bb23560111::HloLegalizeToLhlo566 void getDependentDialects(DialectRegistry& registry) const override {
567 registry.insert<lmhlo::LmhloDialect>();
568 }
569
570 public:
571 HloLegalizeToLhlo() = default;
HloLegalizeToLhlomlir::mhlo::__anon03bb23560111::HloLegalizeToLhlo572 HloLegalizeToLhlo(const HloLegalizeToLhlo& o) {}
573
runOnOperationmlir::mhlo::__anon03bb23560111::HloLegalizeToLhlo574 void runOnOperation() override {
575 OwningRewritePatternList patterns;
576 auto& context = getContext();
577 ConversionTarget target(context);
578 target.addLegalDialect<lmhlo::LmhloDialect>();
579 target.addLegalDialect<StandardOpsDialect>();
580 target.addLegalDialect<tensor::TensorDialect>();
581 target.addIllegalDialect<mhlo::MhloDialect>();
582 // Declare tensor_load and tensor_store illegal.
583 target.addIllegalOp<mlir::TensorLoadOp, mlir::TensorStoreOp>();
584 // tensor_to_memref is illegal if it has uses.
585 // TODO(b/175670649) Make tensor_to_memref illegal.
586 target.addDynamicallyLegalOp<mlir::TensorToMemrefOp>(
587 [](auto op) { return op->use_empty(); });
588
589 BufferizeTypeConverter converter;
590 auto isMemRefType = [](Type type) { return type.isa<BaseMemRefType>(); };
591 target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
592 auto inputs = op.getType().getInputs();
593 return llvm::all_of(inputs, isMemRefType) &&
594 converter.isLegal(&op.getBody());
595 });
596 target.addDynamicallyLegalOp<CallOp>([&](CallOp op) {
597 return std::all_of(op.operand_type_begin(), op.operand_type_end(),
598 isMemRefType) &&
599 std::all_of(op.result_type_begin(), op.result_type_end(),
600 isMemRefType);
601 });
602 target.addDynamicallyLegalOp<mlir::ReturnOp>([&](mlir::ReturnOp op) {
603 return std::all_of(op.operand_type_begin(), op.operand_type_end(),
604 isMemRefType);
605 });
606
607 populateHLOToLHLOConversionPattern(&context, &converter, &patterns);
608 populateFuncOpTypeConversionPattern(patterns, &context, converter);
609 populateCallOpTypeConversionPattern(patterns, &context, converter);
610 populateBranchOpInterfaceAndReturnOpTypeConversionPattern(
611 patterns, &context, converter);
612 populateEliminateBufferizeMaterializationsPatterns(&context, converter,
613 patterns);
614
615 populateShapeStructuralTypeConversionsAndLegality(&context, converter,
616 patterns, target);
617
618 // TODO(b/175789537) Remove this pattern.
619 patterns.insert<HloToLhloTensorStoreOpLegacyConverter>(&context);
620
621 if (failed(applyPartialConversion(getOperation(), target,
622 std::move(patterns))))
623 signalPassFailure();
624 }
625 };
626 } // namespace
627
populateDynamicHLOToLHLOConversionPattern(MLIRContext * context,BufferizeTypeConverter * converter,OwningRewritePatternList * patterns,bool insert_copy)628 void populateDynamicHLOToLHLOConversionPattern(
629 MLIRContext* context, BufferizeTypeConverter* converter,
630 OwningRewritePatternList* patterns, bool insert_copy) {
631 patterns->insert<HloToLhloDynamicBroadcastInDimOpConverter>(
632 *converter, context, insert_copy);
633 patterns->insert<HloToLhloDynamicReshapeConverter,
634 HloToLhloReshapeUnrankedConverter>(*converter, context);
635 }
636
populateHLOToLHLOConversionPattern(MLIRContext * context,BufferizeTypeConverter * converter,OwningRewritePatternList * patterns)637 void populateHLOToLHLOConversionPattern(MLIRContext* context,
638 BufferizeTypeConverter* converter,
639 OwningRewritePatternList* patterns) {
640 populateDynamicHLOToLHLOConversionPattern(context, converter, patterns);
641 // clang-format off
642 patterns->insert<
643 HloToLhloCustomCallOpConverter,
644 HloToLhloDotGeneralOpConverter,
645 HloToLhloOpConverter<mhlo::AbsOp>,
646 HloToLhloOpConverter<mhlo::AddOp>,
647 HloToLhloOpConverter<mhlo::AndOp>,
648 HloToLhloOpConverter<mhlo::Atan2Op>,
649 HloToLhloOpConverter<mhlo::BroadcastInDimOp>,
650 HloToLhloOpConverter<mhlo::CeilOp>,
651 HloToLhloOpConverter<mhlo::CompareOp>,
652 HloToLhloOpConverter<mhlo::ComplexOp>,
653 HloToLhloOpConverter<mhlo::ConstOp>,
654 HloToLhloOpConverter<mhlo::ConvOp>,
655 HloToLhloOpConverter<mhlo::ConvertOp>,
656 HloToLhloOpConverter<mhlo::CopyOp>,
657 HloToLhloOpConverter<mhlo::CosOp>,
658 HloToLhloOpConverter<mhlo::DivOp>,
659 HloToLhloOpConverter<mhlo::DotOp>,
660 HloToLhloOpConverter<mhlo::ExpOp>,
661 HloToLhloOpConverter<mhlo::Expm1Op>,
662 HloToLhloOpConverter<mhlo::FloorOp>,
663 HloToLhloOpConverter<mhlo::GatherOp>,
664 HloToLhloOpConverter<mhlo::ImagOp>,
665 HloToLhloOpConverter<mhlo::IotaOp>,
666 HloToLhloOpConverter<mhlo::IsFiniteOp>,
667 HloToLhloOpConverter<mhlo::LogOp>,
668 HloToLhloOpConverter<mhlo::MaxOp>,
669 HloToLhloOpConverter<mhlo::MinOp>,
670 HloToLhloOpConverter<mhlo::MulOp>,
671 HloToLhloOpConverter<mhlo::NegOp>,
672 HloToLhloOpConverter<mhlo::NotOp>,
673 HloToLhloOpConverter<mhlo::OrOp>,
674 HloToLhloOpConverter<mhlo::PowOp>,
675 HloToLhloOpConverter<mhlo::RealOp>,
676 HloToLhloOpConverter<mhlo::RemOp>,
677 HloToLhloOpConverter<mhlo::RsqrtOp>,
678 HloToLhloOpConverter<mhlo::ReshapeOp>,
679 HloToLhloOpConverter<mhlo::SelectOp>,
680 HloToLhloOpConverter<mhlo::ShiftLeftOp>,
681 HloToLhloOpConverter<mhlo::ShiftRightArithmeticOp>,
682 HloToLhloOpConverter<mhlo::ShiftRightLogicalOp>,
683 HloToLhloOpConverter<mhlo::SignOp>,
684 HloToLhloOpConverter<mhlo::SinOp>,
685 HloToLhloOpConverter<mhlo::SliceOp>,
686 HloToLhloOpConverter<mhlo::SqrtOp>,
687 HloToLhloOpConverter<mhlo::SubOp>,
688 HloToLhloOpConverter<mhlo::TanhOp>,
689 HloToLhloOpConverter<mhlo::TransposeOp>,
690 HloToLhloOpConverter<mhlo::XorOp>,
691 HloToLhloReduceOpConverter,
692 HloToLhloReturnOpConverter
693 >(*converter, context);
694 // clang-format on
695 }
696
createLegalizeToLhloPass()697 std::unique_ptr<OperationPass<ModuleOp>> createLegalizeToLhloPass() {
698 return std::make_unique<HloLegalizeToLhlo>();
699 }
700
701 } // namespace mhlo
702 } // namespace mlir
703