• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
17 #include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
18 #include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
19 #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
20 #include "mlir/IR/BuiltinOps.h"
21 #include "mlir/Pass/Pass.h"
22 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
23 
24 namespace mlir {
25 
26 namespace mhlo {
27 namespace {
28 
29 struct GatherIsTorchIndexSelect : public OpRewritePattern<GatherOp> {
30   using OpRewritePattern<GatherOp>::OpRewritePattern;
31 
matchAndRewritemlir::mhlo::__anon0d0b789a0111::GatherIsTorchIndexSelect32   LogicalResult matchAndRewrite(GatherOp gather,
33                                 PatternRewriter &rewriter) const override {
34     auto start_indices = gather.start_indices();
35     auto start_indices_ty = start_indices.getType().cast<ShapedType>();
36     if (!start_indices_ty.hasRank()) {
37       return failure();
38     }
39 
40     auto operand = gather.operand();
41     auto operand_ty = operand.getType().cast<ShapedType>();
42     if (!operand_ty.hasRank()) {
43       return failure();
44     }
45 
46     int64_t index_vector_dim =
47         std::max<int64_t>(0, start_indices_ty.getRank() - 1);
48 
49     // We can use torch_index_select if the last dimension represents the
50     // gather indices.
51     auto dimension_numbers = gather.dimension_numbers();
52     if (dimension_numbers.index_vector_dim().getValue().getSExtValue() !=
53         index_vector_dim) {
54       return failure();
55     }
56 
57     // Index select only works across a single dimension.
58     if (!start_indices_ty.getShape().empty() &&
59         start_indices_ty.getShape().back() != 1) {
60       return failure();
61     }
62 
63     // Only support the default case for start_index_map.
64     if (dimension_numbers.start_index_map().getType().getRank() != 1 ||
65         dimension_numbers.start_index_map()
66                 .getValue(0)
67                 .cast<IntegerAttr>()
68                 .getValue() != 0) {
69       return failure();
70     }
71 
72     auto result_ty = gather.getResult().getType().dyn_cast<RankedTensorType>();
73     if (!result_ty) {
74       return failure();
75     }
76 
77     // Offset dimensions should be the defaults.
78     if (dimension_numbers.offset_dims().getType().getNumElements() !=
79         result_ty.getRank() - index_vector_dim) {
80       return failure();
81     }
82 
83     for (auto it : llvm::enumerate(dimension_numbers.offset_dims())) {
84       if ((it.index() + index_vector_dim) != it.value()) {
85         return failure();
86       }
87     }
88 
89     for (auto it : llvm::enumerate(gather.slice_sizes().getIntValues())) {
90       // First shape value must be 1.
91       if (it.index() == 0) {
92         if (it.value().getSExtValue() != 1) {
93           return failure();
94         }
95         continue;
96       }
97 
98       // The gather needs to index the entire slice for each other dimension.
99       if (it.value().getSExtValue() != operand_ty.getDimSize(it.index())) {
100         return failure();
101       }
102     }
103 
104     llvm::SmallVector<int64_t, 4> index_select_shape =
105         llvm::to_vector<4>(start_indices_ty.getShape());
106 
107     for (auto dim : operand_ty.getShape().drop_front()) {
108       index_select_shape.push_back(dim);
109     }
110 
111     if (!dimension_numbers.collapsed_slice_dims().getType().hasRank() ||
112         dimension_numbers.collapsed_slice_dims().getType().getNumElements() !=
113             1 ||
114         dimension_numbers.collapsed_slice_dims().getValue<int64_t>({0}) != 0) {
115       return failure();
116     }
117 
118     auto torch_index_select = rewriter.create<TorchIndexSelectOp>(
119         gather.getLoc(),
120         RankedTensorType::get(index_select_shape, operand_ty.getElementType()),
121         operand, gather.start_indices(), rewriter.getI64IntegerAttr(0),
122         rewriter.getI64IntegerAttr(0));
123 
124     rewriter.replaceOpWithNewOp<ReshapeOp>(gather, gather.getType(),
125                                            torch_index_select);
126 
127     return success();
128   }
129 };
130 
131 struct LegalizeGatherToTorchIndexSelectPass
132     : public LegalizeGatherToTorchIndexSelectPassBase<
133           LegalizeGatherToTorchIndexSelectPass> {
134   /// Perform the lowering of standard dialect operations to approximations.
runOnFunctionmlir::mhlo::__anon0d0b789a0111::LegalizeGatherToTorchIndexSelectPass135   void runOnFunction() override {
136     OwningRewritePatternList patterns(&getContext());
137     PopulateGatherToTorchIndexSelectPatterns(&getContext(), &patterns);
138     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
139   }
140 };
141 }  // namespace
142 
PopulateGatherToTorchIndexSelectPatterns(mlir::MLIRContext * context,OwningRewritePatternList * patterns)143 void PopulateGatherToTorchIndexSelectPatterns(
144     mlir::MLIRContext *context, OwningRewritePatternList *patterns) {
145   patterns->insert<GatherIsTorchIndexSelect>(context);
146 }
147 
createLegalizeGatherToTorchIndexSelectPass()148 std::unique_ptr<FunctionPass> createLegalizeGatherToTorchIndexSelectPass() {
149   return std::make_unique<LegalizeGatherToTorchIndexSelectPass>();
150 }
151 
152 }  // namespace mhlo
153 }  // namespace mlir
154