1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This file implements logic for lowering MHLO dialect to Standard dialect.
17
18 #include "llvm/ADT/StringSwitch.h"
19 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
20 #include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
21 #include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
22 #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
23 #include "mlir/Dialect/StandardOps/IR/Ops.h"
24 #include "mlir/IR/BuiltinOps.h"
25 #include "mlir/Pass/Pass.h"
26 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
27
28 namespace mlir {
29 namespace {
30 #include "generated_legalize_to_standard.inc"
31 } // end anonymous namespace
32 namespace mhlo {
33 namespace {
34
35 class CompareIConvert : public OpRewritePattern<mhlo::CompareOp> {
36 public:
37 using OpRewritePattern::OpRewritePattern;
38
matchAndRewrite(mhlo::CompareOp op,PatternRewriter & rewriter) const39 LogicalResult matchAndRewrite(mhlo::CompareOp op,
40 PatternRewriter &rewriter) const override {
41 auto lhs = op.lhs();
42 auto rhs = op.rhs();
43 auto lhs_type = lhs.getType().cast<TensorType>();
44 auto rhs_type = rhs.getType().cast<TensorType>();
45
46 // Broadcasting not supported by this rewrite.
47 if (lhs_type.getShape() != rhs_type.getShape()) return failure();
48
49 if (!lhs_type.getElementType().isSignlessInteger() ||
50 !rhs_type.getElementType().isSignlessInteger())
51 return failure();
52
53 auto comparison_direction = op.comparison_direction();
54 auto compare_predicate =
55 llvm::StringSwitch<Optional<CmpIPredicate>>(comparison_direction)
56 .Case("EQ", CmpIPredicate::eq)
57 .Case("NE", CmpIPredicate::ne)
58 .Case("LT", CmpIPredicate::slt)
59 .Case("LE", CmpIPredicate::sle)
60 .Case("GT", CmpIPredicate::sgt)
61 .Case("GE", CmpIPredicate::sge)
62 .Default(llvm::None);
63
64 if (!compare_predicate.hasValue()) return failure();
65
66 rewriter.replaceOpWithNewOp<CmpIOp>(op, compare_predicate.getValue(), lhs,
67 rhs);
68 return success();
69 }
70 };
71
72 class CompareFConvert : public OpRewritePattern<mhlo::CompareOp> {
73 public:
74 using OpRewritePattern::OpRewritePattern;
75
matchAndRewrite(mhlo::CompareOp op,PatternRewriter & rewriter) const76 LogicalResult matchAndRewrite(mhlo::CompareOp op,
77 PatternRewriter &rewriter) const override {
78 auto lhs = op.lhs();
79 auto rhs = op.rhs();
80 auto lhs_type = lhs.getType().cast<TensorType>();
81 auto rhs_type = rhs.getType().cast<TensorType>();
82
83 // Broadcasting not supported by this rewrite.
84 if (lhs_type.getShape() != rhs_type.getShape()) return failure();
85
86 if (!lhs_type.getElementType().isa<FloatType>() ||
87 !rhs_type.getElementType().isa<FloatType>())
88 return failure();
89
90 auto comparison_direction = op.comparison_direction();
91 auto compare_predicate =
92 llvm::StringSwitch<Optional<CmpFPredicate>>(comparison_direction)
93 .Case("EQ", CmpFPredicate::OEQ)
94 .Case("NE", CmpFPredicate::UNE)
95 .Case("LT", CmpFPredicate::OLT)
96 .Case("LE", CmpFPredicate::OLE)
97 .Case("GT", CmpFPredicate::OGT)
98 .Case("GE", CmpFPredicate::OGE)
99 .Default(llvm::None);
100
101 if (!compare_predicate.hasValue()) return failure();
102
103 rewriter.replaceOpWithNewOp<CmpFOp>(op, compare_predicate.getValue(), lhs,
104 rhs);
105 return success();
106 }
107 };
108
109 // Replace IotaOp with an integer constant. A ConvertOp is added to
110 // convert the integer constant to iota result type. For complex types, the real
111 // part is replaced with the generated constant and the imaginary part is
112 // replaced with zero tensor.
113 class ConvertIotaOp : public OpRewritePattern<mhlo::IotaOp> {
114 public:
115 using OpRewritePattern::OpRewritePattern;
116
matchAndRewrite(mhlo::IotaOp op,PatternRewriter & rewriter) const117 LogicalResult matchAndRewrite(mhlo::IotaOp op,
118 PatternRewriter &rewriter) const override {
119 auto output_type = op.getType().cast<ShapedType>();
120 auto output_size = output_type.getNumElements();
121 auto dimension = op.iota_dimension();
122 auto max_dim_size = output_type.getDimSize(dimension);
123
124 auto element_type = output_type.getElementType();
125 int bitwidth;
126
127 auto complex_ty = element_type.dyn_cast<ComplexType>();
128 Type int_or_float_ty = element_type;
129 if (complex_ty) int_or_float_ty = complex_ty.getElementType();
130
131 bitwidth = int_or_float_ty.getIntOrFloatBitWidth();
132 llvm::SmallVector<APInt, 10> values;
133 values.reserve(output_size);
134
135 int64_t increase_stride = output_size;
136 for (int i = 0; i <= dimension; i++) {
137 increase_stride /= output_type.getDimSize(i);
138 }
139
140 int64_t current_value = 0;
141 for (int i = 0; i < output_size; i++) {
142 int64_t value = (current_value / increase_stride) % max_dim_size;
143 values.push_back(APInt(bitwidth, value));
144 ++current_value;
145 }
146
147 auto int_shape_type = RankedTensorType::get(
148 output_type.getShape(),
149 IntegerType::get(rewriter.getContext(), bitwidth));
150 auto loc = op.getLoc();
151 auto integer_const = rewriter.create<mlir::ConstantOp>(
152 loc, DenseIntElementsAttr::get(int_shape_type, values));
153
154 auto int_or_float_shape_ty =
155 RankedTensorType::get(output_type.getShape(), int_or_float_ty);
156
157 auto iota_const =
158 rewriter.create<ConvertOp>(loc, int_or_float_shape_ty, integer_const);
159
160 // For int/float types we are done, replace op and return.
161 if (!complex_ty) {
162 rewriter.replaceOp(op, iota_const.getResult());
163 return success();
164 }
165
166 // For complex types, generate a constant tensor of zeroes for the imaginary
167 // part and use iota_const for real part.
168 auto zeroes = rewriter.create<mlir::ConstantOp>(
169 loc, DenseIntElementsAttr::get(int_shape_type, APInt(bitwidth, 0)));
170 auto imag_zeroes =
171 rewriter.create<ConvertOp>(loc, int_or_float_shape_ty, zeroes);
172 rewriter.replaceOpWithNewOp<mhlo::ComplexOp>(op, iota_const, imag_zeroes);
173 return success();
174 }
175 };
176
177 } // end anonymous namespace
178
179 namespace {
180 struct LegalizeToStandardPass
181 : public LegalizeToStandardPassBase<LegalizeToStandardPass> {
getDependentDialectsmlir::mhlo::__anon6b7b0a7b0311::LegalizeToStandardPass182 void getDependentDialects(DialectRegistry ®istry) const override {
183 registry.insert<StandardOpsDialect>();
184 }
185
186 /// Perform the lowering to Standard dialect.
187 void runOnFunction() override;
188 };
189 } // end anonymous namespace
190
createLegalizeToStdPass()191 std::unique_ptr<mlir::OperationPass<mlir::FuncOp>> createLegalizeToStdPass() {
192 return std::make_unique<LegalizeToStandardPass>();
193 }
194
PopulateMhloToStdPatterns(OwningRewritePatternList * patterns,mlir::MLIRContext * ctx)195 void PopulateMhloToStdPatterns(OwningRewritePatternList *patterns,
196 mlir::MLIRContext *ctx) {
197 mlir::populateWithGenerated(*patterns);
198 patterns->insert<CompareFConvert, CompareIConvert, ConvertIotaOp>(ctx);
199 }
200
201 /// Perform the lowering to standard dialect.
runOnFunction()202 void LegalizeToStandardPass::runOnFunction() {
203 OwningRewritePatternList patterns(&getContext());
204 mlir::mhlo::PopulateMhloToStdPatterns(&patterns, &getContext());
205 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
206 }
207
208 } // end namespace mhlo
209 } // end namespace mlir
210