1 //===- StructuralTypeConversions.cpp - Shape structural type conversions --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "PassDetail.h"
10 #include "mlir/Dialect/Shape/IR/Shape.h"
11 #include "mlir/Dialect/Shape/Transforms/Passes.h"
12 #include "mlir/Transforms/DialectConversion.h"
13
14 using namespace mlir;
15 using namespace mlir::shape;
16
17 namespace {
18 class ConvertAssumingOpTypes : public OpConversionPattern<AssumingOp> {
19 public:
20 using OpConversionPattern::OpConversionPattern;
21
22 LogicalResult
matchAndRewrite(AssumingOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const23 matchAndRewrite(AssumingOp op, ArrayRef<Value> operands,
24 ConversionPatternRewriter &rewriter) const final {
25 SmallVector<Type, 2> newResultTypes;
26 newResultTypes.reserve(op.getNumResults());
27 for (auto result : op.getResults()) {
28 auto originalType = result.getType();
29 Type convertedType = getTypeConverter()->convertType(originalType);
30 newResultTypes.push_back(convertedType);
31 }
32
33 auto newAssumingOp =
34 rewriter.create<AssumingOp>(op.getLoc(), newResultTypes, op.witness());
35 rewriter.inlineRegionBefore(op.doRegion(), newAssumingOp.doRegion(),
36 newAssumingOp.doRegion().end());
37 rewriter.replaceOp(op, newAssumingOp.getResults());
38
39 return success();
40 }
41 };
42 } // namespace
43
44 namespace {
45 class ConvertAssumingYieldOpTypes
46 : public OpConversionPattern<AssumingYieldOp> {
47 public:
48 using OpConversionPattern::OpConversionPattern;
49
50 LogicalResult
matchAndRewrite(AssumingYieldOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const51 matchAndRewrite(AssumingYieldOp op, ArrayRef<Value> operands,
52 ConversionPatternRewriter &rewriter) const final {
53 rewriter.replaceOpWithNewOp<AssumingYieldOp>(op, operands);
54 return success();
55 }
56 };
57 } // namespace
58
populateShapeStructuralTypeConversionsAndLegality(MLIRContext * context,TypeConverter & typeConverter,OwningRewritePatternList & patterns,ConversionTarget & target)59 void mlir::populateShapeStructuralTypeConversionsAndLegality(
60 MLIRContext *context, TypeConverter &typeConverter,
61 OwningRewritePatternList &patterns, ConversionTarget &target) {
62 patterns.insert<ConvertAssumingOpTypes, ConvertAssumingYieldOpTypes>(
63 typeConverter, context);
64 target.addDynamicallyLegalOp<AssumingOp>([&](AssumingOp op) {
65 return typeConverter.isLegal(op.getResultTypes());
66 });
67 target.addDynamicallyLegalOp<AssumingYieldOp>([&](AssumingYieldOp op) {
68 return typeConverter.isLegal(op.getOperandTypes());
69 });
70 }
71