• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "mlir/Dialect/MemRef/IR/MemRef.h"  // from @llvm-project
17 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
18 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
19 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
20 #include "mlir/IR/TypeRange.h"  // from @llvm-project
21 #include "mlir/Transforms/DialectConversion.h"  // from @llvm-project
22 #include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h"
23 #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/rewriters.h"
24 
25 namespace mlir {
26 namespace kernel_gen {
27 namespace tf_framework {
28 namespace {
29 
30 // Prepends argument type list of the function with an OpKernelContextType arg.
31 class FuncOpConverter : public OpConversionPattern<FuncOp> {
32  public:
33   using OpConversionPattern<FuncOp>::OpConversionPattern;
34 
matchAndRewrite(FuncOp func,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const35   LogicalResult matchAndRewrite(
36       FuncOp func, ArrayRef<Value> operands,
37       ConversionPatternRewriter &rewriter) const override {
38     // Convert function arguments using the provided TypeConverter.
39     auto func_type = func.getType();
40     TypeConverter::SignatureConversion conversion(func_type.getNumInputs());
41 
42     conversion.addInputs(OpKernelContextType::get(rewriter.getContext()));
43     for (auto arg_type : llvm::enumerate(func_type.getInputs())) {
44       conversion.addInputs(arg_type.index(), arg_type.value());
45     }
46 
47     rewriter.applySignatureConversion(&func.getBody(), conversion);
48 
49     // Update the signature of the function.
50     rewriter.updateRootInPlace(func, [&] {
51       func.setType(rewriter.getFunctionType(conversion.getConvertedTypes(),
52                                             func_type.getResults()));
53     });
54     return success();
55   }
56 };
57 
FindOpKernelContext(Operation * op)58 llvm::Optional<Value> FindOpKernelContext(Operation *op) {
59   auto func = op->getParentOfType<FuncOp>();
60   if (func.getNumArguments() == 0) {
61     return llvm::None;
62   }
63   Value ctx = func.getArgument(0);
64   if (!ctx.getType().isa<OpKernelContextType>()) {
65     return llvm::None;
66   }
67   return ctx;
68 }
69 
70 // Converts std.alloc to tf_framework.alloc_raw using OpKernelContextType arg of
71 // the parent function.
72 struct AllocOpConverter : public OpConversionPattern<memref::AllocOp> {
73   using OpConversionPattern<memref::AllocOp>::OpConversionPattern;
74 
matchAndRewritemlir::kernel_gen::tf_framework::__anon6a1e402c0111::AllocOpConverter75   LogicalResult matchAndRewrite(
76       memref::AllocOp alloc, ArrayRef<Value> operands,
77       ConversionPatternRewriter &rewriter) const override {
78     llvm::Optional<Value> ctx = FindOpKernelContext(alloc);
79     if (!ctx) return failure();
80 
81     // Symbolic operands that bind to the symbols of the memref's layout map are
82     // not supported by TFAllocOp.
83     if (!alloc.symbolOperands().empty()) {
84       return failure();
85     }
86     auto reuse_input_candidates = alloc->getAttrOfType<ArrayAttr>(
87         TFAllocOp::kReuseInputCandidatesAttrName);
88     auto reuse_output_index =
89         alloc->getAttrOfType<IntegerAttr>(TFAllocOp::kReuseOutputAttrName);
90     Value buffer = rewriter.replaceOpWithNewOp<TFAllocOp>(
91         alloc, alloc.getType(), *ctx, operands, reuse_input_candidates,
92         reuse_output_index);
93     Location loc = buffer.getLoc();
94     Value cond = rewriter.create<IsValidMemRefOp>(
95         loc, rewriter.getIntegerType(1), buffer);
96     rewriter.create<TFAssertOp>(loc, *ctx, cond, ErrorCode::RESOURCE_EXHAUSTED,
97                                 "failed to allocate memory");
98     return success();
99   }
100 };
101 
102 // Converts std.dealloc to tf_framework.dealloc_raw using OpKernelContextType
103 // arg of the parent function.
104 struct DeallocOpConverter : public OpConversionPattern<memref::DeallocOp> {
105   using OpConversionPattern<memref::DeallocOp>::OpConversionPattern;
106 
matchAndRewritemlir::kernel_gen::tf_framework::__anon6a1e402c0111::DeallocOpConverter107   LogicalResult matchAndRewrite(
108       memref::DeallocOp dealloc, ArrayRef<Value> operands,
109       ConversionPatternRewriter &rewriter) const override {
110     llvm::Optional<Value> ctx = FindOpKernelContext(dealloc);
111     if (!ctx) return failure();
112 
113     // Operand with no layout is expected.
114     auto operand_memref_type = dealloc.memref().getType().cast<MemRefType>();
115     if (!operand_memref_type.getAffineMaps().empty()) {
116       return failure();
117     }
118     memref::DeallocOp::Adaptor transformed(operands);
119     rewriter.replaceOpWithNewOp<TFDeallocOp>(dealloc, *ctx,
120                                              transformed.memref());
121     return success();
122   }
123 };
124 
125 // Converts std.assert to tf_framework.assert with using OpKernelContextType
126 // arg of the parent function.
127 struct AssertOpConverter : public OpConversionPattern<AssertOp> {
128  public:
129   using OpConversionPattern<AssertOp>::OpConversionPattern;
130 
matchAndRewritemlir::kernel_gen::tf_framework::__anon6a1e402c0111::AssertOpConverter131   LogicalResult matchAndRewrite(
132       AssertOp op, ArrayRef<Value> operands,
133       ConversionPatternRewriter &rewriter) const override {
134     llvm::Optional<Value> ctx = FindOpKernelContext(op);
135     if (!ctx) return failure();
136     AssertOp::Adaptor transformed(operands, op->getAttrDictionary());
137     rewriter.replaceOpWithNewOp<TFAssertOp>(op, *ctx, transformed.arg(),
138                                             ErrorCode::INVALID_ARGUMENT,
139                                             transformed.msg().getValue());
140     return success();
141   }
142 };
143 
144 // Amends `tf_framework.jit_execute` with the newly introduced OpKernelContext.
145 struct JITExecuteOpConverter : public OpConversionPattern<JITExecuteOp> {
146   using OpConversionPattern<JITExecuteOp>::OpConversionPattern;
147 
matchAndRewritemlir::kernel_gen::tf_framework::__anon6a1e402c0111::JITExecuteOpConverter148   LogicalResult matchAndRewrite(
149       JITExecuteOp op, ArrayRef<Value> operands,
150       ConversionPatternRewriter &rewriter) const override {
151     llvm::Optional<Value> ctx = FindOpKernelContext(op);
152     if (!ctx) return failure();
153     rewriter.replaceOpWithNewOp<JITExecuteOp>(op, op.getResultTypes(), *ctx,
154                                               op.callable(), op.operands());
155     return success();
156   }
157 };
158 
159 // Amends `tf_framework.jit_compile_from_str` with the newly introduced
160 // OpKernelContext.
161 struct JITCompileFromStrOpConverter
162     : public OpConversionPattern<JITCompileFromStrOp> {
163   using OpConversionPattern<JITCompileFromStrOp>::OpConversionPattern;
164 
matchAndRewritemlir::kernel_gen::tf_framework::__anon6a1e402c0111::JITCompileFromStrOpConverter165   LogicalResult matchAndRewrite(
166       JITCompileFromStrOp op, ArrayRef<Value> operands,
167       ConversionPatternRewriter &rewriter) const override {
168     llvm::Optional<Value> ctx = FindOpKernelContext(op);
169     if (!ctx) return failure();
170     rewriter.replaceOpWithNewOp<JITCompileFromStrOp>(
171         op, rewriter.getType<JITCallableType>(), *ctx, op->getAttrs());
172     return success();
173   }
174 };
175 
176 }  // namespace
177 
PopulateEmbedTFFrameworkAssertPattern(RewritePatternSet * patterns)178 void PopulateEmbedTFFrameworkAssertPattern(RewritePatternSet *patterns) {
179   patterns->insert<AssertOpConverter>(patterns->getContext());
180 }
181 
PopulateEmbedTFFrameworkPatterns(RewritePatternSet * patterns)182 void PopulateEmbedTFFrameworkPatterns(RewritePatternSet *patterns) {
183   // clang-format off
184   patterns->insert<
185       AllocOpConverter,
186       AssertOpConverter,
187       DeallocOpConverter,
188       FuncOpConverter,
189       JITCompileFromStrOpConverter,
190       JITExecuteOpConverter>(patterns->getContext());
191   // clang-format on
192 }
193 
194 }  // namespace tf_framework
195 }  // namespace kernel_gen
196 }  // namespace mlir
197