• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This file implements logic for translating mixed IR to buffer form.
17 
18 #include "mlir/Transforms/Bufferize.h"  // from @llvm-project
19 
20 #include "mlir/Dialect/SCF/SCF.h"  // from @llvm-project
21 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
22 #include "mlir/IR/Attributes.h"  // from @llvm-project
23 #include "mlir/IR/BlockAndValueMapping.h"  // from @llvm-project
24 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
25 #include "mlir/Transforms/DialectConversion.h"  // from @llvm-project
26 #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/rewriters.h"
27 
28 namespace mlir {
29 namespace kernel_gen {
30 namespace transforms {
31 namespace {
32 
33 class BufferizeConstantOp : public OpConversionPattern<ConstantOp> {
34  public:
35   using OpConversionPattern<ConstantOp>::OpConversionPattern;
36 
matchAndRewrite(ConstantOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const37   LogicalResult matchAndRewrite(
38       ConstantOp op, ArrayRef<Value> operands,
39       ConversionPatternRewriter &rewriter) const final {
40     // We only need to bufferize tensor constants.
41     Location loc = op.getLoc();
42     auto result_type = op.getType().dyn_cast<RankedTensorType>();
43     int64_t result_rank = result_type.getRank();
44     if (!result_type || !result_type.hasStaticShape() || result_rank > 1)
45       return failure();
46 
47     auto memref_type =
48         MemRefType::get(result_type.getShape(), result_type.getElementType());
49     auto elements_attr = op.value().cast<DenseElementsAttr>();
50 
51     if (result_rank == 0) {
52       Value buffer = rewriter.create<AllocOp>(loc, memref_type);
53       Value constant =
54           rewriter.create<ConstantOp>(loc, elements_attr.getValue({}));
55       rewriter.create<StoreOp>(loc, constant, buffer);
56       rewriter.replaceOp(op, {buffer});
57       return success();
58     }
59 
60     Value buffer = rewriter.create<AllocaOp>(loc, memref_type);
61 
62     bool all_same_elems = elements_attr.isSplat();
63     Value value;
64     if (all_same_elems)
65       value = rewriter.create<ConstantOp>(loc, elements_attr.getSplatValue());
66     for (auto en : llvm::enumerate(elements_attr.getAttributeValues())) {
67       if (!all_same_elems) value = rewriter.create<ConstantOp>(loc, en.value());
68       Value index = rewriter.create<ConstantIndexOp>(loc, en.index());
69       rewriter.create<StoreOp>(loc, value, buffer, index);
70     }
71     rewriter.replaceOp(op, {buffer});
72     return success();
73   }
74 };
75 
76 class BufferizeDimOp : public OpConversionPattern<DimOp> {
77  public:
78   using OpConversionPattern::OpConversionPattern;
matchAndRewrite(DimOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const79   LogicalResult matchAndRewrite(
80       DimOp op, ArrayRef<Value> operands,
81       ConversionPatternRewriter &rewriter) const override {
82     DimOp::Adaptor adaptor(operands);
83     rewriter.replaceOpWithNewOp<DimOp>(op, adaptor.memrefOrTensor(),
84                                        adaptor.index());
85     return success();
86   }
87 };
88 
89 class BufferizeRankOp : public OpConversionPattern<RankOp> {
90  public:
91   using OpConversionPattern::OpConversionPattern;
matchAndRewrite(RankOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const92   LogicalResult matchAndRewrite(
93       RankOp op, ArrayRef<Value> operands,
94       ConversionPatternRewriter &rewriter) const override {
95     RankOp::Adaptor adaptor(operands);
96     rewriter.replaceOpWithNewOp<RankOp>(op, adaptor.memrefOrTensor());
97     return success();
98   }
99 };
100 }  // namespace
101 
populateExtraStdBufferizePattern(MLIRContext * context,BufferizeTypeConverter * converter,OwningRewritePatternList * patterns)102 void populateExtraStdBufferizePattern(MLIRContext *context,
103                                       BufferizeTypeConverter *converter,
104                                       OwningRewritePatternList *patterns) {
105   patterns->insert<BufferizeConstantOp, BufferizeDimOp, BufferizeRankOp>(
106       *converter, context);
107 }
108 
109 }  // namespace transforms
110 }  // namespace kernel_gen
111 }  // namespace mlir
112