1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
17 #include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
18 #include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
19 #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
20 #include "mlir/IR/BuiltinOps.h"
21 #include "mlir/Pass/Pass.h"
22 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
23
24 namespace mlir {
25 namespace mhlo {
26 namespace {
27
Make1DElementsAttr(OpBuilder & b,ArrayRef<int64_t> integers)28 DenseIntElementsAttr Make1DElementsAttr(OpBuilder &b,
29 ArrayRef<int64_t> integers) {
30 auto type = RankedTensorType::get({static_cast<int64_t>(integers.size())},
31 b.getI64Type());
32 return DenseIntElementsAttr::get(type, integers);
33 }
34
35 struct EinsumToDotGeneralPattern : public OpRewritePattern<EinsumOp> {
36 using OpRewritePattern<EinsumOp>::OpRewritePattern;
37
matchAndRewritemlir::mhlo::__anon938bf37e0111::EinsumToDotGeneralPattern38 LogicalResult matchAndRewrite(EinsumOp einsum,
39 PatternRewriter &rewriter) const override {
40 StringRef equation = einsum.einsum_config();
41 SmallVector<char> lhs_tokens, rhs_tokens;
42 llvm::SmallDenseSet<char> result_tokens;
43 size_t index = 0;
44 enum EquationVariable { kIsLhs, kIsRhs, kIsResult };
45 EquationVariable current_variable = kIsLhs;
46 while (index < equation.size()) {
47 if (std::isalpha(equation[index])) {
48 if (current_variable == kIsLhs) {
49 lhs_tokens.push_back(equation[index]);
50 } else if (current_variable == kIsRhs) {
51 rhs_tokens.push_back(equation[index]);
52 } else {
53 result_tokens.insert(equation[index]);
54 }
55 } else if (equation.substr(index, 1).contains(",")) {
56 current_variable = kIsRhs;
57 } else if ((index < (equation.size() - 1)) &&
58 (equation.substr(index, 2).contains("->"))) {
59 current_variable = kIsResult;
60 index++;
61 } else {
62 return einsum.emitError("unexpected character ")
63 << equation.substr(index, 1) << " encountered";
64 }
65 index++;
66 }
67 assert(lhs_tokens.size() ==
68 einsum.lhs().getType().cast<RankedTensorType>().getRank());
69 assert(rhs_tokens.size() ==
70 einsum.rhs().getType().cast<RankedTensorType>().getRank());
71
72 auto collect_contracting_batching_dims =
73 [&](SmallVector<char> tokens, SmallVector<char> others,
74 SmallVectorImpl<int64_t> &contracting_dims,
75 SmallVectorImpl<int64_t> &batching_dims) {
76 llvm::SmallDenseSet<char> others_set(others.begin(), others.end());
77 for (auto en : llvm::enumerate(tokens)) {
78 if (!result_tokens.contains(en.value())) {
79 contracting_dims.emplace_back(en.index());
80 }
81 if (others_set.contains(en.value()) &&
82 result_tokens.contains(en.value())) {
83 batching_dims.emplace_back(en.index());
84 }
85 }
86 };
87 SmallVector<int64_t> lhs_contracting_dims, lhs_batching_dims,
88 rhs_contracting_dims, rhs_batching_dims;
89 collect_contracting_batching_dims(lhs_tokens, rhs_tokens,
90 lhs_contracting_dims, lhs_batching_dims);
91 collect_contracting_batching_dims(rhs_tokens, lhs_tokens,
92 rhs_contracting_dims, rhs_batching_dims);
93
94 auto dim_numbers = mhlo::DotDimensionNumbers::get(
95 Make1DElementsAttr(rewriter, lhs_batching_dims),
96 Make1DElementsAttr(rewriter, rhs_batching_dims),
97 Make1DElementsAttr(rewriter, lhs_contracting_dims),
98 Make1DElementsAttr(rewriter, rhs_contracting_dims),
99 rewriter.getContext());
100 rewriter.replaceOpWithNewOp<DotGeneralOp>(
101 einsum, einsum.getType(), einsum.lhs(), einsum.rhs(), dim_numbers,
102 /*precision_config=*/ArrayAttr{});
103 return success();
104 }
105 };
106
107 struct LegalizeEinsumToDotGeneralPass
108 : public LegalizeEinsumToDotGeneralPassBase<
109 LegalizeEinsumToDotGeneralPass> {
runOnFunctionmlir::mhlo::__anon938bf37e0111::LegalizeEinsumToDotGeneralPass110 void runOnFunction() override {
111 OwningRewritePatternList patterns(&getContext());
112 PopulateEinsumToDotGeneralPatterns(&getContext(), &patterns);
113 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
114 }
115 };
116 } // namespace
117
PopulateEinsumToDotGeneralPatterns(mlir::MLIRContext * context,OwningRewritePatternList * patterns)118 void PopulateEinsumToDotGeneralPatterns(mlir::MLIRContext *context,
119 OwningRewritePatternList *patterns) {
120 patterns->insert<EinsumToDotGeneralPattern>(context);
121 }
122
createLegalizeEinsumToDotGeneralPass()123 std::unique_ptr<FunctionPass> createLegalizeEinsumToDotGeneralPass() {
124 return std::make_unique<LegalizeEinsumToDotGeneralPass>();
125 }
126
127 } // namespace mhlo
128 } // namespace mlir
129