• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
17 #include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
18 #include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
19 #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
20 #include "mlir/IR/BuiltinOps.h"
21 #include "mlir/Pass/Pass.h"
22 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
23 
24 namespace mlir {
25 namespace mhlo {
26 namespace {
27 
Make1DElementsAttr(OpBuilder & b,ArrayRef<int64_t> integers)28 DenseIntElementsAttr Make1DElementsAttr(OpBuilder &b,
29                                         ArrayRef<int64_t> integers) {
30   auto type = RankedTensorType::get({static_cast<int64_t>(integers.size())},
31                                     b.getI64Type());
32   return DenseIntElementsAttr::get(type, integers);
33 }
34 
35 struct EinsumToDotGeneralPattern : public OpRewritePattern<EinsumOp> {
36   using OpRewritePattern<EinsumOp>::OpRewritePattern;
37 
matchAndRewritemlir::mhlo::__anon938bf37e0111::EinsumToDotGeneralPattern38   LogicalResult matchAndRewrite(EinsumOp einsum,
39                                 PatternRewriter &rewriter) const override {
40     StringRef equation = einsum.einsum_config();
41     SmallVector<char> lhs_tokens, rhs_tokens;
42     llvm::SmallDenseSet<char> result_tokens;
43     size_t index = 0;
44     enum EquationVariable { kIsLhs, kIsRhs, kIsResult };
45     EquationVariable current_variable = kIsLhs;
46     while (index < equation.size()) {
47       if (std::isalpha(equation[index])) {
48         if (current_variable == kIsLhs) {
49           lhs_tokens.push_back(equation[index]);
50         } else if (current_variable == kIsRhs) {
51           rhs_tokens.push_back(equation[index]);
52         } else {
53           result_tokens.insert(equation[index]);
54         }
55       } else if (equation.substr(index, 1).contains(",")) {
56         current_variable = kIsRhs;
57       } else if ((index < (equation.size() - 1)) &&
58                  (equation.substr(index, 2).contains("->"))) {
59         current_variable = kIsResult;
60         index++;
61       } else {
62         return einsum.emitError("unexpected character ")
63                << equation.substr(index, 1) << " encountered";
64       }
65       index++;
66     }
67     assert(lhs_tokens.size() ==
68            einsum.lhs().getType().cast<RankedTensorType>().getRank());
69     assert(rhs_tokens.size() ==
70            einsum.rhs().getType().cast<RankedTensorType>().getRank());
71 
72     auto collect_contracting_batching_dims =
73         [&](SmallVector<char> tokens, SmallVector<char> others,
74             SmallVectorImpl<int64_t> &contracting_dims,
75             SmallVectorImpl<int64_t> &batching_dims) {
76           llvm::SmallDenseSet<char> others_set(others.begin(), others.end());
77           for (auto en : llvm::enumerate(tokens)) {
78             if (!result_tokens.contains(en.value())) {
79               contracting_dims.emplace_back(en.index());
80             }
81             if (others_set.contains(en.value()) &&
82                 result_tokens.contains(en.value())) {
83               batching_dims.emplace_back(en.index());
84             }
85           }
86         };
87     SmallVector<int64_t> lhs_contracting_dims, lhs_batching_dims,
88         rhs_contracting_dims, rhs_batching_dims;
89     collect_contracting_batching_dims(lhs_tokens, rhs_tokens,
90                                       lhs_contracting_dims, lhs_batching_dims);
91     collect_contracting_batching_dims(rhs_tokens, lhs_tokens,
92                                       rhs_contracting_dims, rhs_batching_dims);
93 
94     auto dim_numbers = mhlo::DotDimensionNumbers::get(
95         Make1DElementsAttr(rewriter, lhs_batching_dims),
96         Make1DElementsAttr(rewriter, rhs_batching_dims),
97         Make1DElementsAttr(rewriter, lhs_contracting_dims),
98         Make1DElementsAttr(rewriter, rhs_contracting_dims),
99         rewriter.getContext());
100     rewriter.replaceOpWithNewOp<DotGeneralOp>(
101         einsum, einsum.getType(), einsum.lhs(), einsum.rhs(), dim_numbers,
102         /*precision_config=*/ArrayAttr{});
103     return success();
104   }
105 };
106 
107 struct LegalizeEinsumToDotGeneralPass
108     : public LegalizeEinsumToDotGeneralPassBase<
109           LegalizeEinsumToDotGeneralPass> {
runOnFunctionmlir::mhlo::__anon938bf37e0111::LegalizeEinsumToDotGeneralPass110   void runOnFunction() override {
111     OwningRewritePatternList patterns(&getContext());
112     PopulateEinsumToDotGeneralPatterns(&getContext(), &patterns);
113     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
114   }
115 };
116 }  // namespace
117 
PopulateEinsumToDotGeneralPatterns(mlir::MLIRContext * context,OwningRewritePatternList * patterns)118 void PopulateEinsumToDotGeneralPatterns(mlir::MLIRContext *context,
119                                         OwningRewritePatternList *patterns) {
120   patterns->insert<EinsumToDotGeneralPattern>(context);
121 }
122 
createLegalizeEinsumToDotGeneralPass()123 std::unique_ptr<FunctionPass> createLegalizeEinsumToDotGeneralPass() {
124   return std::make_unique<LegalizeEinsumToDotGeneralPass>();
125 }
126 
127 }  // namespace mhlo
128 }  // namespace mlir
129