1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This file implements logic for lowering HLO dialect to LHLO dialect.
17
18 #include <utility>
19
20 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
21 #include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
22 #include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
23 #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
24 #include "mlir-hlo/Dialect/mhlo/transforms/type_conversion.h"
25 #include "mlir/Dialect/MemRef/IR/MemRef.h"
26 #include "mlir/Dialect/StandardOps/IR/Ops.h"
27 #include "mlir/Dialect/Tensor/IR/Tensor.h"
28 #include "mlir/IR/BuiltinDialect.h"
29 #include "mlir/IR/BuiltinOps.h"
30 #include "mlir/IR/BuiltinTypes.h"
31 #include "mlir/Pass/Pass.h"
32 #include "mlir/Transforms/DialectConversion.h"
33
34 namespace mlir {
35 namespace mhlo {
36 namespace {
37
38 template <typename T>
39 class SignlessOpConversion : public OpConversionPattern<T> {
40 public:
SignlessOpConversion(TypeConverter & type_converter,RemoveSignTypeConverter * remove_sign_converter,MLIRContext * ctx)41 SignlessOpConversion(TypeConverter& type_converter,
42 RemoveSignTypeConverter* remove_sign_converter,
43 MLIRContext* ctx)
44 : OpConversionPattern<T>(type_converter, ctx),
45 remove_sign_converter_(remove_sign_converter) {}
46
matchAndRewrite(T op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const47 LogicalResult matchAndRewrite(
48 T op, ArrayRef<Value> operands,
49 ConversionPatternRewriter& rewriter) const final {
50 auto loc = op.getLoc();
51 // Sign-convert operands and result type.
52 SmallVector<Value> converted_operands;
53 for (auto operand : operands) {
54 Type original = operand.getType();
55 Type converted = remove_sign_converter_->convertType(original);
56 if (converted == original) {
57 converted_operands.push_back(operand);
58 } else {
59 converted_operands.push_back(
60 rewriter
61 .create<UnrealizedConversionCastOp>(loc, converted, operand)
62 ->getResult(0));
63 }
64 }
65 Type op_result_type = remove_sign_converter_->convertType(op.getType());
66 // Perform actual rewrite.
67 Value result =
68 signlessRewrite(op, converted_operands, op_result_type, rewriter);
69 if (!result) return failure();
70
71 // If the element type of the original op and the returned value differ,
72 // do a conversion cast to fix it up.
73 auto expected_element_type =
74 op.getType().template cast<ShapedType>().getElementType();
75 auto result_type = result.getType().cast<BaseMemRefType>();
76 auto actual_element_type = result_type.getElementType();
77 if (expected_element_type != actual_element_type) {
78 assert(remove_sign_converter_->convertType(expected_element_type) ==
79 actual_element_type);
80 Type new_type;
81 if (auto ranked = result_type.dyn_cast<MemRefType>()) {
82 new_type =
83 MemRefType::get(ranked.getShape(), expected_element_type,
84 ranked.getAffineMaps(), ranked.getMemorySpace());
85 } else {
86 new_type = UnrankedMemRefType::get(expected_element_type,
87 result_type.getMemorySpace());
88 }
89 result =
90 rewriter.create<UnrealizedConversionCastOp>(loc, new_type, result)
91 .getResult(0);
92 }
93 rewriter.replaceOp(op, result);
94 return success();
95 }
96
97 protected:
98 virtual Value signlessRewrite(T op, ArrayRef<Value> operands,
99 Type result_type,
100 ConversionPatternRewriter& rewriter) const = 0;
101
102 private:
103 RemoveSignTypeConverter* remove_sign_converter_;
104 };
105
106 template <typename T>
107 using BaseOpConversion = SignlessOpConversion<T>;
108
109 class HloToMemrefReshapeUnrankedConverter
110 : public BaseOpConversion<mhlo::ReshapeOp> {
111 public:
112 using BaseOpConversion<mhlo::ReshapeOp>::BaseOpConversion;
113
signlessRewrite(mhlo::ReshapeOp op,ArrayRef<Value> operands,Type op_result_type,ConversionPatternRewriter & rewriter) const114 Value signlessRewrite(mhlo::ReshapeOp op, ArrayRef<Value> operands,
115 Type op_result_type,
116 ConversionPatternRewriter& rewriter) const final {
117 mhlo::ReshapeOp::Adaptor adaptor(operands);
118 auto unranked_operand_type =
119 adaptor.operand().getType().dyn_cast<UnrankedMemRefType>();
120 if (unranked_operand_type == nullptr) return {};
121 auto loc = op->getLoc();
122 auto result_type = op_result_type.cast<RankedTensorType>();
123 auto cast = rewriter.create<memref::CastOp>(
124 loc, adaptor.operand(),
125 MemRefType::get(result_type.getShape(), result_type.getElementType()));
126
127 return cast;
128 }
129 };
130
131 class HloToMemrefDynamicReshapeConverter
132 : public BaseOpConversion<mhlo::DynamicReshapeOp> {
133 public:
134 using BaseOpConversion<mhlo::DynamicReshapeOp>::BaseOpConversion;
135
signlessRewrite(mhlo::DynamicReshapeOp op,ArrayRef<Value> operands,Type op_result_type,ConversionPatternRewriter & rewriter) const136 Value signlessRewrite(mhlo::DynamicReshapeOp op, ArrayRef<Value> operands,
137 Type op_result_type,
138 ConversionPatternRewriter& rewriter) const final {
139 ShapedType result_type;
140 if (auto ranked_type = op_result_type.dyn_cast<RankedTensorType>()) {
141 result_type =
142 MemRefType::get(ranked_type.getShape(), ranked_type.getElementType());
143 } else if (auto unranked_type =
144 op_result_type.dyn_cast<UnrankedTensorType>()) {
145 result_type = UnrankedMemRefType::get(unranked_type.getElementType(), 0);
146 } else {
147 return {};
148 }
149 mhlo::DynamicReshapeOp::Adaptor adaptor(operands);
150 auto reshape = rewriter.create<memref::ReshapeOp>(
151 op.getLoc(), result_type, adaptor.operand(), adaptor.output_shape());
152 return reshape;
153 }
154 };
155
156 // TODO(b/175670649) Fix this to no longer access original tensor operands.
157 class HloToMemrefDynamicBroadcastInDimOpConverter
158 : public BaseOpConversion<mhlo::DynamicBroadcastInDimOp> {
159 public:
HloToMemrefDynamicBroadcastInDimOpConverter(TypeConverter & converter,RemoveSignTypeConverter * sign_converter,MLIRContext * ctx,bool enforce_identity_maps)160 HloToMemrefDynamicBroadcastInDimOpConverter(
161 TypeConverter& converter, RemoveSignTypeConverter* sign_converter,
162 MLIRContext* ctx, bool enforce_identity_maps)
163 : BaseOpConversion<mhlo::DynamicBroadcastInDimOp>(converter,
164 sign_converter, ctx),
165 enforce_identity_maps_(enforce_identity_maps) {}
166
signlessRewrite(mhlo::DynamicBroadcastInDimOp op,ArrayRef<Value> operands,Type op_result_type,ConversionPatternRewriter & rewriter) const167 Value signlessRewrite(mhlo::DynamicBroadcastInDimOp op,
168 ArrayRef<Value> operands, Type op_result_type,
169 ConversionPatternRewriter& rewriter) const final {
170 auto result_type = op_result_type.dyn_cast<RankedTensorType>();
171 if (!result_type) return {};
172 Value result = InsertDynamicMemrefCastOp(op, operands.front(), &rewriter);
173
174 if (enforce_identity_maps_) {
175 result = CreateCopy(op, result, &rewriter);
176 }
177
178 return result;
179 }
180
181 private:
182 // Inserts dynamic memref to change the layout of the memref to put 0-stride
183 // and size of the target dimension if size-1 dimension expansion is
184 // necessary.
InsertDynamicMemrefCastOp(mhlo::DynamicBroadcastInDimOp op,Value operand,OpBuilder * b) const185 memref::ReinterpretCastOp InsertDynamicMemrefCastOp(
186 mhlo::DynamicBroadcastInDimOp op, Value operand, OpBuilder* b) const {
187 auto loc = op.getLoc();
188 auto operand_type = operand.getType().cast<MemRefType>();
189 auto operand_shape = operand_type.getShape();
190 auto operand_rank = operand_type.getRank();
191
192 auto result_type = op.getType().cast<RankedTensorType>();
193 auto result_rank = result_type.getRank();
194
195 Value zero = b->create<ConstantIndexOp>(loc, 0);
196 Value one = b->create<ConstantIndexOp>(loc, 1);
197
198 // Compute a reversed scan product. Compute the stride for the dimensions so
199 // far, working from minor to major dimensions. Additionally, save the
200 // operand shape Values to use in the next loop.
201 SmallVector<Value, 2> operand_strides(operand_rank, one);
202 SmallVector<Value, 2> operand_sizes(operand_rank, one);
203 Value stride_so_far = one;
204 for (int i = operand_rank - 1; i >= 0; --i) {
205 Value operand_dim_size =
206 ShapedType::isDynamic(operand_shape[i])
207 ? b->create<memref::DimOp>(loc, operand, i).getResult()
208 : b->create<ConstantIndexOp>(loc, operand_shape[i]).getResult();
209 operand_sizes[i] = operand_dim_size;
210
211 operand_strides[i] = stride_so_far;
212 if (i > 0) {
213 stride_so_far = b->create<MulIOp>(loc, stride_so_far, operand_dim_size);
214 }
215 }
216
217 SmallVector<OpFoldResult, 2> sizes, strides;
218 sizes.reserve(result_rank);
219 strides.reserve(result_rank);
220
221 DenseMap<int, int> output_to_input_dim;
222 for (auto dim : llvm::enumerate(op.broadcast_dimensions())) {
223 output_to_input_dim[dim.value().getSExtValue()] = dim.index();
224 }
225 for (int i = 0; i < result_rank; ++i) {
226 Value i_val = b->create<ConstantIndexOp>(loc, i);
227 Value result_dim_size =
228 b->create<tensor::ExtractOp>(loc, op.output_dimensions(), i_val);
229 if (!result_dim_size.getType().isIndex()) {
230 result_dim_size =
231 b->create<IndexCastOp>(loc, result_dim_size, b->getIndexType());
232 }
233 sizes.push_back(result_dim_size);
234
235 auto it = output_to_input_dim.find(i);
236 // If the rank of the output is greater than the rank of the input, i.e.
237 // there was no output dimension in the inverse broadcast_dimensions map
238 // we also set stride to 0 to emulate padding of the shape with 1s and the
239 // corresponding expansion.
240 if (it == output_to_input_dim.end()) {
241 strides.push_back(zero);
242 continue;
243 }
244
245 // There can be two cases:
246 // 1) Operand dim == result dim => expansion is not needed
247 // => stride flattened buffer stride
248 // 2) Operand dim < result dim => expansion is needed => stride := 0.
249 int dim = it->second;
250 Value is_expansion = b->create<CmpIOp>(
251 loc, CmpIPredicate::slt, operand_sizes[dim], result_dim_size);
252 Value select = b->create<mlir::SelectOp>(loc, is_expansion, zero,
253 operand_strides[dim]);
254 strides.push_back(select);
255 }
256
257 // Type-erased memref type with static rank, dynamic sizes and strides.
258 SmallVector<int64_t, 2> dynamic_layout(result_rank,
259 MemRefType::kDynamicStrideOrOffset);
260 SmallVector<int64_t, 2> dynamic_shape(result_rank,
261 MemRefType::kDynamicSize);
262 auto type_erased_memref_type = MemRefType::get(
263 dynamic_shape, operand_type.getElementType(),
264 makeStridedLinearLayoutMap(dynamic_layout,
265 /*offset=*/0, b->getContext()));
266
267 auto transformed_operand = b->create<memref::ReinterpretCastOp>(
268 loc, type_erased_memref_type, operand,
269 /*offset=*/b->getI64IntegerAttr(0), sizes, strides);
270 return transformed_operand;
271 }
272
CreateCopy(mhlo::DynamicBroadcastInDimOp op,Value broadcasted,OpBuilder * b) const273 Value CreateCopy(mhlo::DynamicBroadcastInDimOp op, Value broadcasted,
274 OpBuilder* b) const {
275 MemRefType result_type = broadcasted.getType().cast<MemRefType>();
276 auto loc = op.getLoc();
277 SmallVector<Value, 4> dynamic_operands;
278 for (int i = 0; i < result_type.getRank(); ++i) {
279 auto index = b->createOrFold<ConstantIndexOp>(loc, i);
280 Value size =
281 b->create<tensor::ExtractOp>(loc, op.output_dimensions(), index);
282 if (!size.getType().isIndex()) {
283 size = b->create<IndexCastOp>(loc, size, b->getIndexType());
284 }
285 dynamic_operands.push_back(size);
286 }
287 auto identity_map_memref =
288 MemRefType::get(result_type.getShape(), result_type.getElementType());
289 auto copy = b->create<memref::AllocOp>(op.getLoc(), identity_map_memref,
290 dynamic_operands);
291 b->create<memref::CopyOp>(loc, broadcasted, copy);
292
293 return copy;
294 }
295
296 bool enforce_identity_maps_;
297 };
298
299 struct HloLegalizeToMemrefPass
300 : public HloLegalizeToMemrefPassBase<HloLegalizeToMemrefPass> {
getDependentDialectsmlir::mhlo::__anon200410680111::HloLegalizeToMemrefPass301 void getDependentDialects(DialectRegistry& registry) const override {
302 registry.insert<memref::MemRefDialect, tensor::TensorDialect>();
303 }
304
305 public:
runOnFunctionmlir::mhlo::__anon200410680111::HloLegalizeToMemrefPass306 void runOnFunction() override {
307 auto& context = getContext();
308 OwningRewritePatternList patterns(&context);
309 ConversionTarget target(context);
310
311 BufferizeTypeConverter converter;
312 RemoveSignTypeConverter sign_converter;
313
314 populateHLOToMemrefConversionPattern(&converter, &sign_converter,
315 &patterns);
316
317 target.addIllegalOp<DynamicReshapeOp, DynamicBroadcastInDimOp>();
318 target.addLegalDialect<BuiltinDialect, memref::MemRefDialect,
319 StandardOpsDialect, tensor::TensorDialect>();
320
321 auto func = getFunction();
322 if (failed(applyPartialConversion(func, target, std::move(patterns))))
323 signalPassFailure();
324 }
325 };
326
327 } // namespace
328
populateHLOToMemrefConversionPattern(BufferizeTypeConverter * converter,RemoveSignTypeConverter * sign_converter,OwningRewritePatternList * patterns,bool enforce_identity_maps)329 void populateHLOToMemrefConversionPattern(
330 BufferizeTypeConverter* converter, RemoveSignTypeConverter* sign_converter,
331 OwningRewritePatternList* patterns, bool enforce_identity_maps) {
332 MLIRContext* context = patterns->getContext();
333 patterns->insert<HloToMemrefDynamicBroadcastInDimOpConverter>(
334 *converter, sign_converter, context, enforce_identity_maps);
335 patterns->insert<HloToMemrefDynamicReshapeConverter,
336 HloToMemrefReshapeUnrankedConverter>(
337 *converter, sign_converter, context);
338 }
339
createLegalizeToMemrefPass()340 std::unique_ptr<FunctionPass> createLegalizeToMemrefPass() {
341 return std::make_unique<HloLegalizeToMemrefPass>();
342 }
343
344 } // namespace mhlo
345 } // namespace mlir
346