• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This file implements logic for lowering MHLO dialect to Standard dialect.
17 
18 #include "llvm/ADT/StringSwitch.h"
19 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
20 #include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
21 #include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
22 #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
23 #include "mlir/Dialect/StandardOps/IR/Ops.h"
24 #include "mlir/IR/BuiltinOps.h"
25 #include "mlir/Pass/Pass.h"
26 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
27 
28 namespace mlir {
29 namespace {
30 #include "generated_legalize_to_standard.inc"
31 }  // end anonymous namespace
32 namespace mhlo {
33 namespace {
34 
35 class CompareIConvert : public OpRewritePattern<mhlo::CompareOp> {
36  public:
37   using OpRewritePattern::OpRewritePattern;
38 
matchAndRewrite(mhlo::CompareOp op,PatternRewriter & rewriter) const39   LogicalResult matchAndRewrite(mhlo::CompareOp op,
40                                 PatternRewriter &rewriter) const override {
41     auto lhs = op.lhs();
42     auto rhs = op.rhs();
43     auto lhs_type = lhs.getType().cast<TensorType>();
44     auto rhs_type = rhs.getType().cast<TensorType>();
45 
46     // Broadcasting not supported by this rewrite.
47     if (lhs_type.getShape() != rhs_type.getShape()) return failure();
48 
49     if (!lhs_type.getElementType().isSignlessInteger() ||
50         !rhs_type.getElementType().isSignlessInteger())
51       return failure();
52 
53     auto comparison_direction = op.comparison_direction();
54     auto compare_predicate =
55         llvm::StringSwitch<Optional<CmpIPredicate>>(comparison_direction)
56             .Case("EQ", CmpIPredicate::eq)
57             .Case("NE", CmpIPredicate::ne)
58             .Case("LT", CmpIPredicate::slt)
59             .Case("LE", CmpIPredicate::sle)
60             .Case("GT", CmpIPredicate::sgt)
61             .Case("GE", CmpIPredicate::sge)
62             .Default(llvm::None);
63 
64     if (!compare_predicate.hasValue()) return failure();
65 
66     rewriter.replaceOpWithNewOp<CmpIOp>(op, compare_predicate.getValue(), lhs,
67                                         rhs);
68     return success();
69   }
70 };
71 
72 class CompareFConvert : public OpRewritePattern<mhlo::CompareOp> {
73  public:
74   using OpRewritePattern::OpRewritePattern;
75 
matchAndRewrite(mhlo::CompareOp op,PatternRewriter & rewriter) const76   LogicalResult matchAndRewrite(mhlo::CompareOp op,
77                                 PatternRewriter &rewriter) const override {
78     auto lhs = op.lhs();
79     auto rhs = op.rhs();
80     auto lhs_type = lhs.getType().cast<TensorType>();
81     auto rhs_type = rhs.getType().cast<TensorType>();
82 
83     // Broadcasting not supported by this rewrite.
84     if (lhs_type.getShape() != rhs_type.getShape()) return failure();
85 
86     if (!lhs_type.getElementType().isa<FloatType>() ||
87         !rhs_type.getElementType().isa<FloatType>())
88       return failure();
89 
90     auto comparison_direction = op.comparison_direction();
91     auto compare_predicate =
92         llvm::StringSwitch<Optional<CmpFPredicate>>(comparison_direction)
93             .Case("EQ", CmpFPredicate::OEQ)
94             .Case("NE", CmpFPredicate::UNE)
95             .Case("LT", CmpFPredicate::OLT)
96             .Case("LE", CmpFPredicate::OLE)
97             .Case("GT", CmpFPredicate::OGT)
98             .Case("GE", CmpFPredicate::OGE)
99             .Default(llvm::None);
100 
101     if (!compare_predicate.hasValue()) return failure();
102 
103     rewriter.replaceOpWithNewOp<CmpFOp>(op, compare_predicate.getValue(), lhs,
104                                         rhs);
105     return success();
106   }
107 };
108 
109 // Replace IotaOp with an integer constant. A ConvertOp is added to
110 // convert the integer constant to iota result type. For complex types, the real
111 // part is replaced with the generated constant and the imaginary part is
112 // replaced with zero tensor.
113 class ConvertIotaOp : public OpRewritePattern<mhlo::IotaOp> {
114  public:
115   using OpRewritePattern::OpRewritePattern;
116 
matchAndRewrite(mhlo::IotaOp op,PatternRewriter & rewriter) const117   LogicalResult matchAndRewrite(mhlo::IotaOp op,
118                                 PatternRewriter &rewriter) const override {
119     auto output_type = op.getType().cast<ShapedType>();
120     auto output_size = output_type.getNumElements();
121     auto dimension = op.iota_dimension();
122     auto max_dim_size = output_type.getDimSize(dimension);
123 
124     auto element_type = output_type.getElementType();
125     int bitwidth;
126 
127     auto complex_ty = element_type.dyn_cast<ComplexType>();
128     Type int_or_float_ty = element_type;
129     if (complex_ty) int_or_float_ty = complex_ty.getElementType();
130 
131     bitwidth = int_or_float_ty.getIntOrFloatBitWidth();
132     llvm::SmallVector<APInt, 10> values;
133     values.reserve(output_size);
134 
135     int64_t increase_stride = output_size;
136     for (int i = 0; i <= dimension; i++) {
137       increase_stride /= output_type.getDimSize(i);
138     }
139 
140     int64_t current_value = 0;
141     for (int i = 0; i < output_size; i++) {
142       int64_t value = (current_value / increase_stride) % max_dim_size;
143       values.push_back(APInt(bitwidth, value));
144       ++current_value;
145     }
146 
147     auto int_shape_type = RankedTensorType::get(
148         output_type.getShape(),
149         IntegerType::get(rewriter.getContext(), bitwidth));
150     auto loc = op.getLoc();
151     auto integer_const = rewriter.create<mlir::ConstantOp>(
152         loc, DenseIntElementsAttr::get(int_shape_type, values));
153 
154     auto int_or_float_shape_ty =
155         RankedTensorType::get(output_type.getShape(), int_or_float_ty);
156 
157     auto iota_const =
158         rewriter.create<ConvertOp>(loc, int_or_float_shape_ty, integer_const);
159 
160     // For int/float types we are done, replace op and return.
161     if (!complex_ty) {
162       rewriter.replaceOp(op, iota_const.getResult());
163       return success();
164     }
165 
166     // For complex types, generate a constant tensor of zeroes for the imaginary
167     // part and use iota_const for real part.
168     auto zeroes = rewriter.create<mlir::ConstantOp>(
169         loc, DenseIntElementsAttr::get(int_shape_type, APInt(bitwidth, 0)));
170     auto imag_zeroes =
171         rewriter.create<ConvertOp>(loc, int_or_float_shape_ty, zeroes);
172     rewriter.replaceOpWithNewOp<mhlo::ComplexOp>(op, iota_const, imag_zeroes);
173     return success();
174   }
175 };
176 
177 }  // end anonymous namespace
178 
179 namespace {
180 struct LegalizeToStandardPass
181     : public LegalizeToStandardPassBase<LegalizeToStandardPass> {
getDependentDialectsmlir::mhlo::__anon6b7b0a7b0311::LegalizeToStandardPass182   void getDependentDialects(DialectRegistry &registry) const override {
183     registry.insert<StandardOpsDialect>();
184   }
185 
186   /// Perform the lowering to Standard dialect.
187   void runOnFunction() override;
188 };
189 }  // end anonymous namespace
190 
createLegalizeToStdPass()191 std::unique_ptr<mlir::OperationPass<mlir::FuncOp>> createLegalizeToStdPass() {
192   return std::make_unique<LegalizeToStandardPass>();
193 }
194 
PopulateMhloToStdPatterns(OwningRewritePatternList * patterns,mlir::MLIRContext * ctx)195 void PopulateMhloToStdPatterns(OwningRewritePatternList *patterns,
196                                mlir::MLIRContext *ctx) {
197   mlir::populateWithGenerated(*patterns);
198   patterns->insert<CompareFConvert, CompareIConvert, ConvertIotaOp>(ctx);
199 }
200 
201 /// Perform the lowering to standard dialect.
runOnFunction()202 void LegalizeToStandardPass::runOnFunction() {
203   OwningRewritePatternList patterns(&getContext());
204   mlir::mhlo::PopulateMhloToStdPatterns(&patterns, &getContext());
205   (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
206 }
207 
208 }  // end namespace mhlo
209 }  // end namespace mlir
210