• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This file implements logic for lowering HLO dialect to LHLO dialect.
17 
18 #include <utility>
19 
20 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
21 #include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
22 #include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
23 #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
24 #include "mlir-hlo/Dialect/mhlo/transforms/type_conversion.h"
25 #include "mlir/Dialect/MemRef/IR/MemRef.h"
26 #include "mlir/Dialect/StandardOps/IR/Ops.h"
27 #include "mlir/Dialect/Tensor/IR/Tensor.h"
28 #include "mlir/IR/BuiltinDialect.h"
29 #include "mlir/IR/BuiltinOps.h"
30 #include "mlir/IR/BuiltinTypes.h"
31 #include "mlir/Pass/Pass.h"
32 #include "mlir/Transforms/DialectConversion.h"
33 
34 namespace mlir {
35 namespace mhlo {
36 namespace {
37 
38 template <typename T>
39 class SignlessOpConversion : public OpConversionPattern<T> {
40  public:
SignlessOpConversion(TypeConverter & type_converter,RemoveSignTypeConverter * remove_sign_converter,MLIRContext * ctx)41   SignlessOpConversion(TypeConverter& type_converter,
42                        RemoveSignTypeConverter* remove_sign_converter,
43                        MLIRContext* ctx)
44       : OpConversionPattern<T>(type_converter, ctx),
45         remove_sign_converter_(remove_sign_converter) {}
46 
matchAndRewrite(T op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const47   LogicalResult matchAndRewrite(
48       T op, ArrayRef<Value> operands,
49       ConversionPatternRewriter& rewriter) const final {
50     auto loc = op.getLoc();
51     // Sign-convert operands and result type.
52     SmallVector<Value> converted_operands;
53     for (auto operand : operands) {
54       Type original = operand.getType();
55       Type converted = remove_sign_converter_->convertType(original);
56       if (converted == original) {
57         converted_operands.push_back(operand);
58       } else {
59         converted_operands.push_back(
60             rewriter
61                 .create<UnrealizedConversionCastOp>(loc, converted, operand)
62                 ->getResult(0));
63       }
64     }
65     Type op_result_type = remove_sign_converter_->convertType(op.getType());
66     // Perform actual rewrite.
67     Value result =
68         signlessRewrite(op, converted_operands, op_result_type, rewriter);
69     if (!result) return failure();
70 
71     // If the element type of the original op and the returned value differ,
72     // do a conversion cast to fix it up.
73     auto expected_element_type =
74         op.getType().template cast<ShapedType>().getElementType();
75     auto result_type = result.getType().cast<BaseMemRefType>();
76     auto actual_element_type = result_type.getElementType();
77     if (expected_element_type != actual_element_type) {
78       assert(remove_sign_converter_->convertType(expected_element_type) ==
79              actual_element_type);
80       Type new_type;
81       if (auto ranked = result_type.dyn_cast<MemRefType>()) {
82         new_type =
83             MemRefType::get(ranked.getShape(), expected_element_type,
84                             ranked.getAffineMaps(), ranked.getMemorySpace());
85       } else {
86         new_type = UnrankedMemRefType::get(expected_element_type,
87                                            result_type.getMemorySpace());
88       }
89       result =
90           rewriter.create<UnrealizedConversionCastOp>(loc, new_type, result)
91               .getResult(0);
92     }
93     rewriter.replaceOp(op, result);
94     return success();
95   }
96 
97  protected:
98   virtual Value signlessRewrite(T op, ArrayRef<Value> operands,
99                                 Type result_type,
100                                 ConversionPatternRewriter& rewriter) const = 0;
101 
102  private:
103   RemoveSignTypeConverter* remove_sign_converter_;
104 };
105 
106 template <typename T>
107 using BaseOpConversion = SignlessOpConversion<T>;
108 
109 class HloToMemrefReshapeUnrankedConverter
110     : public BaseOpConversion<mhlo::ReshapeOp> {
111  public:
112   using BaseOpConversion<mhlo::ReshapeOp>::BaseOpConversion;
113 
signlessRewrite(mhlo::ReshapeOp op,ArrayRef<Value> operands,Type op_result_type,ConversionPatternRewriter & rewriter) const114   Value signlessRewrite(mhlo::ReshapeOp op, ArrayRef<Value> operands,
115                         Type op_result_type,
116                         ConversionPatternRewriter& rewriter) const final {
117     mhlo::ReshapeOp::Adaptor adaptor(operands);
118     auto unranked_operand_type =
119         adaptor.operand().getType().dyn_cast<UnrankedMemRefType>();
120     if (unranked_operand_type == nullptr) return {};
121     auto loc = op->getLoc();
122     auto result_type = op_result_type.cast<RankedTensorType>();
123     auto cast = rewriter.create<memref::CastOp>(
124         loc, adaptor.operand(),
125         MemRefType::get(result_type.getShape(), result_type.getElementType()));
126 
127     return cast;
128   }
129 };
130 
131 class HloToMemrefDynamicReshapeConverter
132     : public BaseOpConversion<mhlo::DynamicReshapeOp> {
133  public:
134   using BaseOpConversion<mhlo::DynamicReshapeOp>::BaseOpConversion;
135 
signlessRewrite(mhlo::DynamicReshapeOp op,ArrayRef<Value> operands,Type op_result_type,ConversionPatternRewriter & rewriter) const136   Value signlessRewrite(mhlo::DynamicReshapeOp op, ArrayRef<Value> operands,
137                         Type op_result_type,
138                         ConversionPatternRewriter& rewriter) const final {
139     ShapedType result_type;
140     if (auto ranked_type = op_result_type.dyn_cast<RankedTensorType>()) {
141       result_type =
142           MemRefType::get(ranked_type.getShape(), ranked_type.getElementType());
143     } else if (auto unranked_type =
144                    op_result_type.dyn_cast<UnrankedTensorType>()) {
145       result_type = UnrankedMemRefType::get(unranked_type.getElementType(), 0);
146     } else {
147       return {};
148     }
149     mhlo::DynamicReshapeOp::Adaptor adaptor(operands);
150     auto reshape = rewriter.create<memref::ReshapeOp>(
151         op.getLoc(), result_type, adaptor.operand(), adaptor.output_shape());
152     return reshape;
153   }
154 };
155 
156 // TODO(b/175670649) Fix this to no longer access original tensor operands.
157 class HloToMemrefDynamicBroadcastInDimOpConverter
158     : public BaseOpConversion<mhlo::DynamicBroadcastInDimOp> {
159  public:
HloToMemrefDynamicBroadcastInDimOpConverter(TypeConverter & converter,RemoveSignTypeConverter * sign_converter,MLIRContext * ctx,bool enforce_identity_maps)160   HloToMemrefDynamicBroadcastInDimOpConverter(
161       TypeConverter& converter, RemoveSignTypeConverter* sign_converter,
162       MLIRContext* ctx, bool enforce_identity_maps)
163       : BaseOpConversion<mhlo::DynamicBroadcastInDimOp>(converter,
164                                                         sign_converter, ctx),
165         enforce_identity_maps_(enforce_identity_maps) {}
166 
signlessRewrite(mhlo::DynamicBroadcastInDimOp op,ArrayRef<Value> operands,Type op_result_type,ConversionPatternRewriter & rewriter) const167   Value signlessRewrite(mhlo::DynamicBroadcastInDimOp op,
168                         ArrayRef<Value> operands, Type op_result_type,
169                         ConversionPatternRewriter& rewriter) const final {
170     auto result_type = op_result_type.dyn_cast<RankedTensorType>();
171     if (!result_type) return {};
172     Value result = InsertDynamicMemrefCastOp(op, operands.front(), &rewriter);
173 
174     if (enforce_identity_maps_) {
175       result = CreateCopy(op, result, &rewriter);
176     }
177 
178     return result;
179   }
180 
181  private:
182   // Inserts dynamic memref to change the layout of the memref to put 0-stride
183   // and size of the target dimension if size-1 dimension expansion is
184   // necessary.
InsertDynamicMemrefCastOp(mhlo::DynamicBroadcastInDimOp op,Value operand,OpBuilder * b) const185   memref::ReinterpretCastOp InsertDynamicMemrefCastOp(
186       mhlo::DynamicBroadcastInDimOp op, Value operand, OpBuilder* b) const {
187     auto loc = op.getLoc();
188     auto operand_type = operand.getType().cast<MemRefType>();
189     auto operand_shape = operand_type.getShape();
190     auto operand_rank = operand_type.getRank();
191 
192     auto result_type = op.getType().cast<RankedTensorType>();
193     auto result_rank = result_type.getRank();
194 
195     Value zero = b->create<ConstantIndexOp>(loc, 0);
196     Value one = b->create<ConstantIndexOp>(loc, 1);
197 
198     // Compute a reversed scan product. Compute the stride for the dimensions so
199     // far, working from minor to major dimensions. Additionally, save the
200     // operand shape Values to use in the next loop.
201     SmallVector<Value, 2> operand_strides(operand_rank, one);
202     SmallVector<Value, 2> operand_sizes(operand_rank, one);
203     Value stride_so_far = one;
204     for (int i = operand_rank - 1; i >= 0; --i) {
205       Value operand_dim_size =
206           ShapedType::isDynamic(operand_shape[i])
207               ? b->create<memref::DimOp>(loc, operand, i).getResult()
208               : b->create<ConstantIndexOp>(loc, operand_shape[i]).getResult();
209       operand_sizes[i] = operand_dim_size;
210 
211       operand_strides[i] = stride_so_far;
212       if (i > 0) {
213         stride_so_far = b->create<MulIOp>(loc, stride_so_far, operand_dim_size);
214       }
215     }
216 
217     SmallVector<OpFoldResult, 2> sizes, strides;
218     sizes.reserve(result_rank);
219     strides.reserve(result_rank);
220 
221     DenseMap<int, int> output_to_input_dim;
222     for (auto dim : llvm::enumerate(op.broadcast_dimensions())) {
223       output_to_input_dim[dim.value().getSExtValue()] = dim.index();
224     }
225     for (int i = 0; i < result_rank; ++i) {
226       Value i_val = b->create<ConstantIndexOp>(loc, i);
227       Value result_dim_size =
228           b->create<tensor::ExtractOp>(loc, op.output_dimensions(), i_val);
229       if (!result_dim_size.getType().isIndex()) {
230         result_dim_size =
231             b->create<IndexCastOp>(loc, result_dim_size, b->getIndexType());
232       }
233       sizes.push_back(result_dim_size);
234 
235       auto it = output_to_input_dim.find(i);
236       // If the rank of the output is greater than the rank of the input, i.e.
237       // there was no output dimension in the inverse broadcast_dimensions map
238       // we also set stride to 0 to emulate padding of the shape with 1s and the
239       // corresponding expansion.
240       if (it == output_to_input_dim.end()) {
241         strides.push_back(zero);
242         continue;
243       }
244 
245       // There can be two cases:
246       // 1) Operand dim == result dim => expansion is not needed
247       //    => stride flattened buffer stride
248       // 2) Operand dim < result dim => expansion is needed => stride := 0.
249       int dim = it->second;
250       Value is_expansion = b->create<CmpIOp>(
251           loc, CmpIPredicate::slt, operand_sizes[dim], result_dim_size);
252       Value select = b->create<mlir::SelectOp>(loc, is_expansion, zero,
253                                                operand_strides[dim]);
254       strides.push_back(select);
255     }
256 
257     // Type-erased memref type with static rank, dynamic sizes and strides.
258     SmallVector<int64_t, 2> dynamic_layout(result_rank,
259                                            MemRefType::kDynamicStrideOrOffset);
260     SmallVector<int64_t, 2> dynamic_shape(result_rank,
261                                           MemRefType::kDynamicSize);
262     auto type_erased_memref_type = MemRefType::get(
263         dynamic_shape, operand_type.getElementType(),
264         makeStridedLinearLayoutMap(dynamic_layout,
265                                    /*offset=*/0, b->getContext()));
266 
267     auto transformed_operand = b->create<memref::ReinterpretCastOp>(
268         loc, type_erased_memref_type, operand,
269         /*offset=*/b->getI64IntegerAttr(0), sizes, strides);
270     return transformed_operand;
271   }
272 
CreateCopy(mhlo::DynamicBroadcastInDimOp op,Value broadcasted,OpBuilder * b) const273   Value CreateCopy(mhlo::DynamicBroadcastInDimOp op, Value broadcasted,
274                    OpBuilder* b) const {
275     MemRefType result_type = broadcasted.getType().cast<MemRefType>();
276     auto loc = op.getLoc();
277     SmallVector<Value, 4> dynamic_operands;
278     for (int i = 0; i < result_type.getRank(); ++i) {
279       auto index = b->createOrFold<ConstantIndexOp>(loc, i);
280       Value size =
281           b->create<tensor::ExtractOp>(loc, op.output_dimensions(), index);
282       if (!size.getType().isIndex()) {
283         size = b->create<IndexCastOp>(loc, size, b->getIndexType());
284       }
285       dynamic_operands.push_back(size);
286     }
287     auto identity_map_memref =
288         MemRefType::get(result_type.getShape(), result_type.getElementType());
289     auto copy = b->create<memref::AllocOp>(op.getLoc(), identity_map_memref,
290                                            dynamic_operands);
291     b->create<memref::CopyOp>(loc, broadcasted, copy);
292 
293     return copy;
294   }
295 
296   bool enforce_identity_maps_;
297 };
298 
299 struct HloLegalizeToMemrefPass
300     : public HloLegalizeToMemrefPassBase<HloLegalizeToMemrefPass> {
getDependentDialectsmlir::mhlo::__anon200410680111::HloLegalizeToMemrefPass301   void getDependentDialects(DialectRegistry& registry) const override {
302     registry.insert<memref::MemRefDialect, tensor::TensorDialect>();
303   }
304 
305  public:
runOnFunctionmlir::mhlo::__anon200410680111::HloLegalizeToMemrefPass306   void runOnFunction() override {
307     auto& context = getContext();
308     OwningRewritePatternList patterns(&context);
309     ConversionTarget target(context);
310 
311     BufferizeTypeConverter converter;
312     RemoveSignTypeConverter sign_converter;
313 
314     populateHLOToMemrefConversionPattern(&converter, &sign_converter,
315                                          &patterns);
316 
317     target.addIllegalOp<DynamicReshapeOp, DynamicBroadcastInDimOp>();
318     target.addLegalDialect<BuiltinDialect, memref::MemRefDialect,
319                            StandardOpsDialect, tensor::TensorDialect>();
320 
321     auto func = getFunction();
322     if (failed(applyPartialConversion(func, target, std::move(patterns))))
323       signalPassFailure();
324   }
325 };
326 
327 }  // namespace
328 
populateHLOToMemrefConversionPattern(BufferizeTypeConverter * converter,RemoveSignTypeConverter * sign_converter,OwningRewritePatternList * patterns,bool enforce_identity_maps)329 void populateHLOToMemrefConversionPattern(
330     BufferizeTypeConverter* converter, RemoveSignTypeConverter* sign_converter,
331     OwningRewritePatternList* patterns, bool enforce_identity_maps) {
332   MLIRContext* context = patterns->getContext();
333   patterns->insert<HloToMemrefDynamicBroadcastInDimOpConverter>(
334       *converter, sign_converter, context, enforce_identity_maps);
335   patterns->insert<HloToMemrefDynamicReshapeConverter,
336                    HloToMemrefReshapeUnrankedConverter>(
337       *converter, sign_converter, context);
338 }
339 
createLegalizeToMemrefPass()340 std::unique_ptr<FunctionPass> createLegalizeToMemrefPass() {
341   return std::make_unique<HloLegalizeToMemrefPass>();
342 }
343 
344 }  // namespace mhlo
345 }  // namespace mlir
346