1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "mlir/Dialect/MemRef/IR/MemRef.h" // from @llvm-project
17 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
18 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
19 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
20 #include "mlir/IR/TypeRange.h" // from @llvm-project
21 #include "mlir/Transforms/DialectConversion.h" // from @llvm-project
22 #include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h"
23 #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/rewriters.h"
24
25 namespace mlir {
26 namespace kernel_gen {
27 namespace tf_framework {
28 namespace {
29
30 // Prepends argument type list of the function with an OpKernelContextType arg.
31 class FuncOpConverter : public OpConversionPattern<FuncOp> {
32 public:
33 using OpConversionPattern<FuncOp>::OpConversionPattern;
34
matchAndRewrite(FuncOp func,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const35 LogicalResult matchAndRewrite(
36 FuncOp func, ArrayRef<Value> operands,
37 ConversionPatternRewriter &rewriter) const override {
38 // Convert function arguments using the provided TypeConverter.
39 auto func_type = func.getType();
40 TypeConverter::SignatureConversion conversion(func_type.getNumInputs());
41
42 conversion.addInputs(OpKernelContextType::get(rewriter.getContext()));
43 for (auto arg_type : llvm::enumerate(func_type.getInputs())) {
44 conversion.addInputs(arg_type.index(), arg_type.value());
45 }
46
47 rewriter.applySignatureConversion(&func.getBody(), conversion);
48
49 // Update the signature of the function.
50 rewriter.updateRootInPlace(func, [&] {
51 func.setType(rewriter.getFunctionType(conversion.getConvertedTypes(),
52 func_type.getResults()));
53 });
54 return success();
55 }
56 };
57
FindOpKernelContext(Operation * op)58 llvm::Optional<Value> FindOpKernelContext(Operation *op) {
59 auto func = op->getParentOfType<FuncOp>();
60 if (func.getNumArguments() == 0) {
61 return llvm::None;
62 }
63 Value ctx = func.getArgument(0);
64 if (!ctx.getType().isa<OpKernelContextType>()) {
65 return llvm::None;
66 }
67 return ctx;
68 }
69
70 // Converts std.alloc to tf_framework.alloc_raw using OpKernelContextType arg of
71 // the parent function.
72 struct AllocOpConverter : public OpConversionPattern<memref::AllocOp> {
73 using OpConversionPattern<memref::AllocOp>::OpConversionPattern;
74
matchAndRewritemlir::kernel_gen::tf_framework::__anon6a1e402c0111::AllocOpConverter75 LogicalResult matchAndRewrite(
76 memref::AllocOp alloc, ArrayRef<Value> operands,
77 ConversionPatternRewriter &rewriter) const override {
78 llvm::Optional<Value> ctx = FindOpKernelContext(alloc);
79 if (!ctx) return failure();
80
81 // Symbolic operands that bind to the symbols of the memref's layout map are
82 // not supported by TFAllocOp.
83 if (!alloc.symbolOperands().empty()) {
84 return failure();
85 }
86 auto reuse_input_candidates = alloc->getAttrOfType<ArrayAttr>(
87 TFAllocOp::kReuseInputCandidatesAttrName);
88 auto reuse_output_index =
89 alloc->getAttrOfType<IntegerAttr>(TFAllocOp::kReuseOutputAttrName);
90 Value buffer = rewriter.replaceOpWithNewOp<TFAllocOp>(
91 alloc, alloc.getType(), *ctx, operands, reuse_input_candidates,
92 reuse_output_index);
93 Location loc = buffer.getLoc();
94 Value cond = rewriter.create<IsValidMemRefOp>(
95 loc, rewriter.getIntegerType(1), buffer);
96 rewriter.create<TFAssertOp>(loc, *ctx, cond, ErrorCode::RESOURCE_EXHAUSTED,
97 "failed to allocate memory");
98 return success();
99 }
100 };
101
102 // Converts std.dealloc to tf_framework.dealloc_raw using OpKernelContextType
103 // arg of the parent function.
104 struct DeallocOpConverter : public OpConversionPattern<memref::DeallocOp> {
105 using OpConversionPattern<memref::DeallocOp>::OpConversionPattern;
106
matchAndRewritemlir::kernel_gen::tf_framework::__anon6a1e402c0111::DeallocOpConverter107 LogicalResult matchAndRewrite(
108 memref::DeallocOp dealloc, ArrayRef<Value> operands,
109 ConversionPatternRewriter &rewriter) const override {
110 llvm::Optional<Value> ctx = FindOpKernelContext(dealloc);
111 if (!ctx) return failure();
112
113 // Operand with no layout is expected.
114 auto operand_memref_type = dealloc.memref().getType().cast<MemRefType>();
115 if (!operand_memref_type.getAffineMaps().empty()) {
116 return failure();
117 }
118 memref::DeallocOp::Adaptor transformed(operands);
119 rewriter.replaceOpWithNewOp<TFDeallocOp>(dealloc, *ctx,
120 transformed.memref());
121 return success();
122 }
123 };
124
125 // Converts std.assert to tf_framework.assert with using OpKernelContextType
126 // arg of the parent function.
127 struct AssertOpConverter : public OpConversionPattern<AssertOp> {
128 public:
129 using OpConversionPattern<AssertOp>::OpConversionPattern;
130
matchAndRewritemlir::kernel_gen::tf_framework::__anon6a1e402c0111::AssertOpConverter131 LogicalResult matchAndRewrite(
132 AssertOp op, ArrayRef<Value> operands,
133 ConversionPatternRewriter &rewriter) const override {
134 llvm::Optional<Value> ctx = FindOpKernelContext(op);
135 if (!ctx) return failure();
136 AssertOp::Adaptor transformed(operands, op->getAttrDictionary());
137 rewriter.replaceOpWithNewOp<TFAssertOp>(op, *ctx, transformed.arg(),
138 ErrorCode::INVALID_ARGUMENT,
139 transformed.msg().getValue());
140 return success();
141 }
142 };
143
144 // Amends `tf_framework.jit_execute` with the newly introduced OpKernelContext.
145 struct JITExecuteOpConverter : public OpConversionPattern<JITExecuteOp> {
146 using OpConversionPattern<JITExecuteOp>::OpConversionPattern;
147
matchAndRewritemlir::kernel_gen::tf_framework::__anon6a1e402c0111::JITExecuteOpConverter148 LogicalResult matchAndRewrite(
149 JITExecuteOp op, ArrayRef<Value> operands,
150 ConversionPatternRewriter &rewriter) const override {
151 llvm::Optional<Value> ctx = FindOpKernelContext(op);
152 if (!ctx) return failure();
153 rewriter.replaceOpWithNewOp<JITExecuteOp>(op, op.getResultTypes(), *ctx,
154 op.callable(), op.operands());
155 return success();
156 }
157 };
158
159 // Amends `tf_framework.jit_compile_from_str` with the newly introduced
160 // OpKernelContext.
161 struct JITCompileFromStrOpConverter
162 : public OpConversionPattern<JITCompileFromStrOp> {
163 using OpConversionPattern<JITCompileFromStrOp>::OpConversionPattern;
164
matchAndRewritemlir::kernel_gen::tf_framework::__anon6a1e402c0111::JITCompileFromStrOpConverter165 LogicalResult matchAndRewrite(
166 JITCompileFromStrOp op, ArrayRef<Value> operands,
167 ConversionPatternRewriter &rewriter) const override {
168 llvm::Optional<Value> ctx = FindOpKernelContext(op);
169 if (!ctx) return failure();
170 rewriter.replaceOpWithNewOp<JITCompileFromStrOp>(
171 op, rewriter.getType<JITCallableType>(), *ctx, op->getAttrs());
172 return success();
173 }
174 };
175
176 } // namespace
177
PopulateEmbedTFFrameworkAssertPattern(RewritePatternSet * patterns)178 void PopulateEmbedTFFrameworkAssertPattern(RewritePatternSet *patterns) {
179 patterns->insert<AssertOpConverter>(patterns->getContext());
180 }
181
PopulateEmbedTFFrameworkPatterns(RewritePatternSet * patterns)182 void PopulateEmbedTFFrameworkPatterns(RewritePatternSet *patterns) {
183 // clang-format off
184 patterns->insert<
185 AllocOpConverter,
186 AssertOpConverter,
187 DeallocOpConverter,
188 FuncOpConverter,
189 JITCompileFromStrOpConverter,
190 JITExecuteOpConverter>(patterns->getContext());
191 // clang-format on
192 }
193
194 } // namespace tf_framework
195 } // namespace kernel_gen
196 } // namespace mlir
197