1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This file implements logic for translating mixed IR to buffer form.
17
18 #include "mlir/Transforms/Bufferize.h" // from @llvm-project
19
20 #include "mlir/Dialect/SCF/SCF.h" // from @llvm-project
21 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
22 #include "mlir/IR/Attributes.h" // from @llvm-project
23 #include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project
24 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
25 #include "mlir/Transforms/DialectConversion.h" // from @llvm-project
26 #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/rewriters.h"
27
28 namespace mlir {
29 namespace kernel_gen {
30 namespace transforms {
31 namespace {
32
33 class BufferizeConstantOp : public OpConversionPattern<ConstantOp> {
34 public:
35 using OpConversionPattern<ConstantOp>::OpConversionPattern;
36
matchAndRewrite(ConstantOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const37 LogicalResult matchAndRewrite(
38 ConstantOp op, ArrayRef<Value> operands,
39 ConversionPatternRewriter &rewriter) const final {
40 // We only need to bufferize tensor constants.
41 Location loc = op.getLoc();
42 auto result_type = op.getType().dyn_cast<RankedTensorType>();
43 int64_t result_rank = result_type.getRank();
44 if (!result_type || !result_type.hasStaticShape() || result_rank > 1)
45 return failure();
46
47 auto memref_type =
48 MemRefType::get(result_type.getShape(), result_type.getElementType());
49 auto elements_attr = op.value().cast<DenseElementsAttr>();
50
51 if (result_rank == 0) {
52 Value buffer = rewriter.create<AllocOp>(loc, memref_type);
53 Value constant =
54 rewriter.create<ConstantOp>(loc, elements_attr.getValue({}));
55 rewriter.create<StoreOp>(loc, constant, buffer);
56 rewriter.replaceOp(op, {buffer});
57 return success();
58 }
59
60 Value buffer = rewriter.create<AllocaOp>(loc, memref_type);
61
62 bool all_same_elems = elements_attr.isSplat();
63 Value value;
64 if (all_same_elems)
65 value = rewriter.create<ConstantOp>(loc, elements_attr.getSplatValue());
66 for (auto en : llvm::enumerate(elements_attr.getAttributeValues())) {
67 if (!all_same_elems) value = rewriter.create<ConstantOp>(loc, en.value());
68 Value index = rewriter.create<ConstantIndexOp>(loc, en.index());
69 rewriter.create<StoreOp>(loc, value, buffer, index);
70 }
71 rewriter.replaceOp(op, {buffer});
72 return success();
73 }
74 };
75
76 class BufferizeDimOp : public OpConversionPattern<DimOp> {
77 public:
78 using OpConversionPattern::OpConversionPattern;
matchAndRewrite(DimOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const79 LogicalResult matchAndRewrite(
80 DimOp op, ArrayRef<Value> operands,
81 ConversionPatternRewriter &rewriter) const override {
82 DimOp::Adaptor adaptor(operands);
83 rewriter.replaceOpWithNewOp<DimOp>(op, adaptor.memrefOrTensor(),
84 adaptor.index());
85 return success();
86 }
87 };
88
89 class BufferizeRankOp : public OpConversionPattern<RankOp> {
90 public:
91 using OpConversionPattern::OpConversionPattern;
matchAndRewrite(RankOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const92 LogicalResult matchAndRewrite(
93 RankOp op, ArrayRef<Value> operands,
94 ConversionPatternRewriter &rewriter) const override {
95 RankOp::Adaptor adaptor(operands);
96 rewriter.replaceOpWithNewOp<RankOp>(op, adaptor.memrefOrTensor());
97 return success();
98 }
99 };
100 } // namespace
101
populateExtraStdBufferizePattern(MLIRContext * context,BufferizeTypeConverter * converter,OwningRewritePatternList * patterns)102 void populateExtraStdBufferizePattern(MLIRContext *context,
103 BufferizeTypeConverter *converter,
104 OwningRewritePatternList *patterns) {
105 patterns->insert<BufferizeConstantOp, BufferizeDimOp, BufferizeRankOp>(
106 *converter, context);
107 }
108
109 } // namespace transforms
110 } // namespace kernel_gen
111 } // namespace mlir
112