• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
17 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
18 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
19 #include "mlir/IR/TypeRange.h"  // from @llvm-project
20 #include "mlir/Transforms/DialectConversion.h"  // from @llvm-project
21 #include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h"
22 #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/rewriters.h"
23 
24 namespace mlir {
25 namespace kernel_gen {
26 namespace tf_framework {
27 namespace {
28 
29 // Prepends argument type list of the function with an OpKernelContextType arg.
30 class FuncOpConverter : public OpConversionPattern<FuncOp> {
31  public:
32   using OpConversionPattern<FuncOp>::OpConversionPattern;
33 
matchAndRewrite(FuncOp func,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const34   LogicalResult matchAndRewrite(
35       FuncOp func, ArrayRef<Value> operands,
36       ConversionPatternRewriter &rewriter) const override {
37     // Convert function arguments using the provided TypeConverter.
38     auto func_type = func.getType();
39     TypeConverter::SignatureConversion conversion(func_type.getNumInputs());
40 
41     conversion.addInputs(OpKernelContextType::get(rewriter.getContext()));
42     for (auto arg_type : llvm::enumerate(func_type.getInputs())) {
43       conversion.addInputs(arg_type.index(), arg_type.value());
44     }
45 
46     rewriter.applySignatureConversion(&func.getBody(), conversion);
47 
48     // Update the signature of the function.
49     rewriter.updateRootInPlace(func, [&] {
50       func.setType(rewriter.getFunctionType(conversion.getConvertedTypes(),
51                                             func_type.getResults()));
52     });
53     return success();
54   }
55 };
56 
57 // Converts std.alloc to tf_framework.alloc_raw using OpKernelContextType arg of
58 // the parent function.
59 class TFAllocOpConverter : public OpConversionPattern<AllocOp> {
60  public:
61   using OpConversionPattern<AllocOp>::OpConversionPattern;
62 
matchAndRewrite(AllocOp alloc,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const63   LogicalResult matchAndRewrite(
64       AllocOp alloc, ArrayRef<Value> operands,
65       ConversionPatternRewriter &rewriter) const override {
66     auto func = alloc->getParentOfType<FuncOp>();
67     if (func.getNumArguments() == 0) {
68       return failure();
69     }
70     Value ctx = func.getArgument(0);
71     if (!ctx.getType().isa<OpKernelContextType>()) {
72       return failure();
73     }
74     // Symbolic operands that bind to the symbols of the memref's layout map are
75     // not supported by TFAllocOp.
76     if (!alloc.symbolOperands().empty()) {
77       return failure();
78     }
79     auto reuse_input_candidates = alloc->getAttrOfType<ArrayAttr>(
80         TFAllocOp::kReuseInputCandidatesAttrName);
81     auto reuse_output_index =
82         alloc->getAttrOfType<IntegerAttr>(TFAllocOp::kReuseOutputAttrName);
83     rewriter.replaceOpWithNewOp<TFAllocOp>(alloc, alloc.getType(), ctx,
84                                            operands, reuse_input_candidates,
85                                            reuse_output_index);
86     return success();
87   }
88 };
89 
90 // Converts std.dealloc to tf_framework.dealloc_raw using OpKernelContextType
91 // arg of the parent function.
92 class TFDeallocOpConverter : public OpConversionPattern<DeallocOp> {
93  public:
94   using OpConversionPattern<DeallocOp>::OpConversionPattern;
95 
matchAndRewrite(DeallocOp dealloc,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const96   LogicalResult matchAndRewrite(
97       DeallocOp dealloc, ArrayRef<Value> operands,
98       ConversionPatternRewriter &rewriter) const override {
99     auto func = dealloc->getParentOfType<FuncOp>();
100     if (func.getNumArguments() == 0) {
101       return failure();
102     }
103     Value ctx = func.getArgument(0);
104     if (!ctx.getType().isa<OpKernelContextType>()) {
105       return failure();
106     }
107     // Operand with no layout is expected.
108     auto operand_memref_type = dealloc.memref().getType().cast<MemRefType>();
109     if (!operand_memref_type.getAffineMaps().empty()) {
110       return failure();
111     }
112     DeallocOp::Adaptor transformed(operands);
113     rewriter.replaceOpWithNewOp<TFDeallocOp>(dealloc, ctx,
114                                              transformed.memref());
115     return success();
116   }
117 };
118 
119 // Converts std.assert to tf_framework.assert with using OpKernelContextType
120 // arg of the parent function.
121 class TFAssertOpConverter : public OpConversionPattern<AssertOp> {
122  public:
123   using OpConversionPattern<AssertOp>::OpConversionPattern;
124 
matchAndRewrite(AssertOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const125   LogicalResult matchAndRewrite(
126       AssertOp op, ArrayRef<Value> operands,
127       ConversionPatternRewriter &rewriter) const override {
128     auto func = op->getParentOfType<FuncOp>();
129     if (func.getNumArguments() == 0) {
130       return failure();
131     }
132     Value ctx = func.getArgument(0);
133     if (!ctx.getType().isa<OpKernelContextType>()) {
134       return failure();
135     }
136     Location loc = op.getLoc();
137     AssertOp::Adaptor transformed(operands, op->getAttrDictionary());
138 
139     // Split the block to insert CondBr.
140     OpBuilder::InsertPoint ip = rewriter.saveInsertionPoint();
141     Block *split_block = rewriter.splitBlock(
142         rewriter.getInsertionBlock(), std::next(rewriter.getInsertionPoint()));
143 
144     Block *error_reporting_block =
145         rewriter.createBlock(&func.getRegion(), {}, {});
146     rewriter.create<ReportErrorOp>(loc, ctx, ErrorCode::INVALID_ARGUMENT,
147                                    transformed.msg().getValue());
148 
149     SmallVector<Value, 2> null_memrefs;
150     for (auto type : func.getType().getResults()) {
151       // This can be extended to support various result types if necessary.
152       if (!type.isa<UnrankedMemRefType>()) {
153         op.emitError("only UnrankedMemRefType results are supported");
154         return failure();
155       }
156       null_memrefs.push_back(rewriter.create<NullMemRefOp>(loc, type));
157     }
158     rewriter.create<ReturnOp>(loc, null_memrefs);
159 
160     rewriter.restoreInsertionPoint(ip);
161     rewriter.replaceOpWithNewOp<CondBranchOp>(
162         op, transformed.arg(), split_block, llvm::None, error_reporting_block,
163         llvm::None);
164     return success();
165   }
166 };
167 
168 }  // namespace
169 
PopulateEmbedTFFrameworkFunctionAndAllocConversionPatterns(MLIRContext * context,OwningRewritePatternList * patterns)170 void PopulateEmbedTFFrameworkFunctionAndAllocConversionPatterns(
171     MLIRContext *context, OwningRewritePatternList *patterns) {
172   patterns->insert<TFAllocOpConverter, TFDeallocOpConverter, FuncOpConverter>(
173       context);
174 }
175 
PopulateEmbedTFFrameworkAssertConversionPatterns(MLIRContext * context,OwningRewritePatternList * patterns)176 void PopulateEmbedTFFrameworkAssertConversionPatterns(
177     MLIRContext *context, OwningRewritePatternList *patterns) {
178   patterns->insert<TFAssertOpConverter, FuncOpConverter>(context);
179 }
180 
181 }  // namespace tf_framework
182 }  // namespace kernel_gen
183 }  // namespace mlir
184