1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
17 #include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
18 #include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
19 #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
20 #include "mlir/IR/BuiltinOps.h"
21 #include "mlir/Pass/Pass.h"
22 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
23
24 namespace mlir {
25
26 namespace mhlo {
27 namespace {
28
29 struct GatherIsTorchIndexSelect : public OpRewritePattern<GatherOp> {
30 using OpRewritePattern<GatherOp>::OpRewritePattern;
31
matchAndRewritemlir::mhlo::__anon0d0b789a0111::GatherIsTorchIndexSelect32 LogicalResult matchAndRewrite(GatherOp gather,
33 PatternRewriter &rewriter) const override {
34 auto start_indices = gather.start_indices();
35 auto start_indices_ty = start_indices.getType().cast<ShapedType>();
36 if (!start_indices_ty.hasRank()) {
37 return failure();
38 }
39
40 auto operand = gather.operand();
41 auto operand_ty = operand.getType().cast<ShapedType>();
42 if (!operand_ty.hasRank()) {
43 return failure();
44 }
45
46 int64_t index_vector_dim =
47 std::max<int64_t>(0, start_indices_ty.getRank() - 1);
48
49 // We can use torch_index_select if the last dimension represents the
50 // gather indices.
51 auto dimension_numbers = gather.dimension_numbers();
52 if (dimension_numbers.index_vector_dim().getValue().getSExtValue() !=
53 index_vector_dim) {
54 return failure();
55 }
56
57 // Index select only works across a single dimension.
58 if (!start_indices_ty.getShape().empty() &&
59 start_indices_ty.getShape().back() != 1) {
60 return failure();
61 }
62
63 // Only support the default case for start_index_map.
64 if (dimension_numbers.start_index_map().getType().getRank() != 1 ||
65 dimension_numbers.start_index_map()
66 .getValue(0)
67 .cast<IntegerAttr>()
68 .getValue() != 0) {
69 return failure();
70 }
71
72 auto result_ty = gather.getResult().getType().dyn_cast<RankedTensorType>();
73 if (!result_ty) {
74 return failure();
75 }
76
77 // Offset dimensions should be the defaults.
78 if (dimension_numbers.offset_dims().getType().getNumElements() !=
79 result_ty.getRank() - index_vector_dim) {
80 return failure();
81 }
82
83 for (auto it : llvm::enumerate(dimension_numbers.offset_dims())) {
84 if ((it.index() + index_vector_dim) != it.value()) {
85 return failure();
86 }
87 }
88
89 for (auto it : llvm::enumerate(gather.slice_sizes().getIntValues())) {
90 // First shape value must be 1.
91 if (it.index() == 0) {
92 if (it.value().getSExtValue() != 1) {
93 return failure();
94 }
95 continue;
96 }
97
98 // The gather needs to index the entire slice for each other dimension.
99 if (it.value().getSExtValue() != operand_ty.getDimSize(it.index())) {
100 return failure();
101 }
102 }
103
104 llvm::SmallVector<int64_t, 4> index_select_shape =
105 llvm::to_vector<4>(start_indices_ty.getShape());
106
107 for (auto dim : operand_ty.getShape().drop_front()) {
108 index_select_shape.push_back(dim);
109 }
110
111 if (!dimension_numbers.collapsed_slice_dims().getType().hasRank() ||
112 dimension_numbers.collapsed_slice_dims().getType().getNumElements() !=
113 1 ||
114 dimension_numbers.collapsed_slice_dims().getValue<int64_t>({0}) != 0) {
115 return failure();
116 }
117
118 auto torch_index_select = rewriter.create<TorchIndexSelectOp>(
119 gather.getLoc(),
120 RankedTensorType::get(index_select_shape, operand_ty.getElementType()),
121 operand, gather.start_indices(), rewriter.getI64IntegerAttr(0),
122 rewriter.getI64IntegerAttr(0));
123
124 rewriter.replaceOpWithNewOp<ReshapeOp>(gather, gather.getType(),
125 torch_index_select);
126
127 return success();
128 }
129 };
130
131 struct LegalizeGatherToTorchIndexSelectPass
132 : public LegalizeGatherToTorchIndexSelectPassBase<
133 LegalizeGatherToTorchIndexSelectPass> {
134 /// Perform the lowering of standard dialect operations to approximations.
runOnFunctionmlir::mhlo::__anon0d0b789a0111::LegalizeGatherToTorchIndexSelectPass135 void runOnFunction() override {
136 OwningRewritePatternList patterns(&getContext());
137 PopulateGatherToTorchIndexSelectPatterns(&getContext(), &patterns);
138 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
139 }
140 };
141 } // namespace
142
PopulateGatherToTorchIndexSelectPatterns(mlir::MLIRContext * context,OwningRewritePatternList * patterns)143 void PopulateGatherToTorchIndexSelectPatterns(
144 mlir::MLIRContext *context, OwningRewritePatternList *patterns) {
145 patterns->insert<GatherIsTorchIndexSelect>(context);
146 }
147
createLegalizeGatherToTorchIndexSelectPass()148 std::unique_ptr<FunctionPass> createLegalizeGatherToTorchIndexSelectPass() {
149 return std::make_unique<LegalizeGatherToTorchIndexSelectPass>();
150 }
151
152 } // namespace mhlo
153 } // namespace mlir
154