1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
17 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
18 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
19 #include "mlir/IR/TypeRange.h" // from @llvm-project
20 #include "mlir/Transforms/DialectConversion.h" // from @llvm-project
21 #include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h"
22 #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/rewriters.h"
23
24 namespace mlir {
25 namespace kernel_gen {
26 namespace tf_framework {
27 namespace {
28
29 // Prepends argument type list of the function with an OpKernelContextType arg.
30 class FuncOpConverter : public OpConversionPattern<FuncOp> {
31 public:
32 using OpConversionPattern<FuncOp>::OpConversionPattern;
33
matchAndRewrite(FuncOp func,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const34 LogicalResult matchAndRewrite(
35 FuncOp func, ArrayRef<Value> operands,
36 ConversionPatternRewriter &rewriter) const override {
37 // Convert function arguments using the provided TypeConverter.
38 auto func_type = func.getType();
39 TypeConverter::SignatureConversion conversion(func_type.getNumInputs());
40
41 conversion.addInputs(OpKernelContextType::get(rewriter.getContext()));
42 for (auto arg_type : llvm::enumerate(func_type.getInputs())) {
43 conversion.addInputs(arg_type.index(), arg_type.value());
44 }
45
46 rewriter.applySignatureConversion(&func.getBody(), conversion);
47
48 // Update the signature of the function.
49 rewriter.updateRootInPlace(func, [&] {
50 func.setType(rewriter.getFunctionType(conversion.getConvertedTypes(),
51 func_type.getResults()));
52 });
53 return success();
54 }
55 };
56
57 // Converts std.alloc to tf_framework.alloc_raw using OpKernelContextType arg of
58 // the parent function.
59 class TFAllocOpConverter : public OpConversionPattern<AllocOp> {
60 public:
61 using OpConversionPattern<AllocOp>::OpConversionPattern;
62
matchAndRewrite(AllocOp alloc,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const63 LogicalResult matchAndRewrite(
64 AllocOp alloc, ArrayRef<Value> operands,
65 ConversionPatternRewriter &rewriter) const override {
66 auto func = alloc->getParentOfType<FuncOp>();
67 if (func.getNumArguments() == 0) {
68 return failure();
69 }
70 Value ctx = func.getArgument(0);
71 if (!ctx.getType().isa<OpKernelContextType>()) {
72 return failure();
73 }
74 // Symbolic operands that bind to the symbols of the memref's layout map are
75 // not supported by TFAllocOp.
76 if (!alloc.symbolOperands().empty()) {
77 return failure();
78 }
79 auto reuse_input_candidates = alloc->getAttrOfType<ArrayAttr>(
80 TFAllocOp::kReuseInputCandidatesAttrName);
81 auto reuse_output_index =
82 alloc->getAttrOfType<IntegerAttr>(TFAllocOp::kReuseOutputAttrName);
83 rewriter.replaceOpWithNewOp<TFAllocOp>(alloc, alloc.getType(), ctx,
84 operands, reuse_input_candidates,
85 reuse_output_index);
86 return success();
87 }
88 };
89
90 // Converts std.dealloc to tf_framework.dealloc_raw using OpKernelContextType
91 // arg of the parent function.
92 class TFDeallocOpConverter : public OpConversionPattern<DeallocOp> {
93 public:
94 using OpConversionPattern<DeallocOp>::OpConversionPattern;
95
matchAndRewrite(DeallocOp dealloc,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const96 LogicalResult matchAndRewrite(
97 DeallocOp dealloc, ArrayRef<Value> operands,
98 ConversionPatternRewriter &rewriter) const override {
99 auto func = dealloc->getParentOfType<FuncOp>();
100 if (func.getNumArguments() == 0) {
101 return failure();
102 }
103 Value ctx = func.getArgument(0);
104 if (!ctx.getType().isa<OpKernelContextType>()) {
105 return failure();
106 }
107 // Operand with no layout is expected.
108 auto operand_memref_type = dealloc.memref().getType().cast<MemRefType>();
109 if (!operand_memref_type.getAffineMaps().empty()) {
110 return failure();
111 }
112 DeallocOp::Adaptor transformed(operands);
113 rewriter.replaceOpWithNewOp<TFDeallocOp>(dealloc, ctx,
114 transformed.memref());
115 return success();
116 }
117 };
118
119 // Converts std.assert to tf_framework.assert with using OpKernelContextType
120 // arg of the parent function.
121 class TFAssertOpConverter : public OpConversionPattern<AssertOp> {
122 public:
123 using OpConversionPattern<AssertOp>::OpConversionPattern;
124
matchAndRewrite(AssertOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const125 LogicalResult matchAndRewrite(
126 AssertOp op, ArrayRef<Value> operands,
127 ConversionPatternRewriter &rewriter) const override {
128 auto func = op->getParentOfType<FuncOp>();
129 if (func.getNumArguments() == 0) {
130 return failure();
131 }
132 Value ctx = func.getArgument(0);
133 if (!ctx.getType().isa<OpKernelContextType>()) {
134 return failure();
135 }
136 Location loc = op.getLoc();
137 AssertOp::Adaptor transformed(operands, op->getAttrDictionary());
138
139 // Split the block to insert CondBr.
140 OpBuilder::InsertPoint ip = rewriter.saveInsertionPoint();
141 Block *split_block = rewriter.splitBlock(
142 rewriter.getInsertionBlock(), std::next(rewriter.getInsertionPoint()));
143
144 Block *error_reporting_block =
145 rewriter.createBlock(&func.getRegion(), {}, {});
146 rewriter.create<ReportErrorOp>(loc, ctx, ErrorCode::INVALID_ARGUMENT,
147 transformed.msg().getValue());
148
149 SmallVector<Value, 2> null_memrefs;
150 for (auto type : func.getType().getResults()) {
151 // This can be extended to support various result types if necessary.
152 if (!type.isa<UnrankedMemRefType>()) {
153 op.emitError("only UnrankedMemRefType results are supported");
154 return failure();
155 }
156 null_memrefs.push_back(rewriter.create<NullMemRefOp>(loc, type));
157 }
158 rewriter.create<ReturnOp>(loc, null_memrefs);
159
160 rewriter.restoreInsertionPoint(ip);
161 rewriter.replaceOpWithNewOp<CondBranchOp>(
162 op, transformed.arg(), split_block, llvm::None, error_reporting_block,
163 llvm::None);
164 return success();
165 }
166 };
167
168 } // namespace
169
PopulateEmbedTFFrameworkFunctionAndAllocConversionPatterns(MLIRContext * context,OwningRewritePatternList * patterns)170 void PopulateEmbedTFFrameworkFunctionAndAllocConversionPatterns(
171 MLIRContext *context, OwningRewritePatternList *patterns) {
172 patterns->insert<TFAllocOpConverter, TFDeallocOpConverter, FuncOpConverter>(
173 context);
174 }
175
PopulateEmbedTFFrameworkAssertConversionPatterns(MLIRContext * context,OwningRewritePatternList * patterns)176 void PopulateEmbedTFFrameworkAssertConversionPatterns(
177 MLIRContext *context, OwningRewritePatternList *patterns) {
178 patterns->insert<TFAssertOpConverter, FuncOpConverter>(context);
179 }
180
181 } // namespace tf_framework
182 } // namespace kernel_gen
183 } // namespace mlir
184