• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/runtime_fallback/opdefs/tfrt_fallback_async.h"
16 
17 #include "llvm/ADT/STLExtras.h"
18 #include "mlir/IR/Attributes.h"  // from @llvm-project
19 #include "mlir/IR/Builders.h"  // from @llvm-project
20 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
21 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
22 #include "mlir/IR/DialectImplementation.h"  // from @llvm-project
23 #include "mlir/IR/Matchers.h"  // from @llvm-project
24 #include "mlir/IR/OpDefinition.h"  // from @llvm-project
25 #include "mlir/IR/OpImplementation.h"  // from @llvm-project
26 #include "mlir/IR/OperationSupport.h"  // from @llvm-project
27 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
28 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
29 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
30 #include "mlir/Transforms/InliningUtils.h"  // from @llvm-project
31 #include "tensorflow/core/runtime_fallback/opdefs/tfrt_fallback.h"
32 #include "tensorflow/core/runtime_fallback/opdefs/tfrt_fallback_common.h"
33 #include "tfrt/basic_kernels/opdefs/tfrt_base.h"  // from @tf_runtime
34 #include "tfrt/basic_kernels/opdefs/types.h"  // from @tf_runtime
35 #include "tfrt/core_runtime/opdefs/attributes.h"  // from @tf_runtime
36 #include "tfrt/core_runtime/opdefs/core_runtime.h"  // from @tf_runtime
37 #include "tfrt/core_runtime/opdefs/sync/core_runtime.h"  // from @tf_runtime
38 #include "tfrt/core_runtime/opdefs/types.h"  // from @tf_runtime
39 
40 namespace tfrt {
41 namespace fallback_async {
42 
43 namespace {
44 
45 struct FallbackInlinerInterface : public mlir::DialectInlinerInterface {
46   using DialectInlinerInterface::DialectInlinerInterface;
47 
isLegalToInlinetfrt::fallback_async::__anon300ef72b0111::FallbackInlinerInterface48   bool isLegalToInline(Operation *op, Region *dest, bool would_be_cloned,
49                        BlockAndValueMapping &) const final {
50     return true;
51   }
52 };
53 
54 }  // namespace
55 
FallbackAsyncDialect(MLIRContext * context)56 FallbackAsyncDialect::FallbackAsyncDialect(MLIRContext *context)
57     : Dialect(/*name=*/"tfrt_fallback_async", context,
58               TypeID::get<FallbackAsyncDialect>()) {
59   context->getOrLoadDialect<tfrt::fallback::FallbackDialect>();
60   context->getOrLoadDialect<compiler::TFRTDialect>();
61   context->getOrLoadDialect<corert::CoreRTDialect>();
62 
63   allowUnknownTypes();
64 
65   allowUnknownOperations();
66 
67   addInterfaces<FallbackInlinerInterface>();
68 
69   addOperations<
70 #define GET_OP_LIST
71 #include "tensorflow/core/runtime_fallback/opdefs/tfrt_fallback_async.cpp.inc"
72       >();
73 }
74 
GetChainType(Builder * builder)75 static Type GetChainType(Builder *builder) {
76   return builder->getType<compiler::ChainType>();
77 }
78 
verify(CreateOp op)79 static LogicalResult verify(CreateOp op) {
80   return fallback_common::VerifyFallbackExecuteOp(op);
81 }
verify(ExecuteOp op)82 static LogicalResult verify(ExecuteOp op) {
83   return fallback_common::VerifyFallbackExecuteOp(op);
84 }
verify(ExecuteOpSeq op)85 static LogicalResult verify(ExecuteOpSeq op) {
86   return fallback_common::VerifyFallbackExecuteOp(op);
87 }
verify(BatchFunctionOp op)88 static LogicalResult verify(BatchFunctionOp op) {
89   return fallback_common::VerifyExecuteOpCommon(op);
90 }
91 
parseCreateOp(OpAsmParser & parser,OperationState & result)92 static ParseResult parseCreateOp(OpAsmParser &parser, OperationState &result) {
93   fallback_common::ParseExecuteOpOptions parse_options;
94   parse_options.has_chain = true;
95   parse_options.has_key = true;
96   parse_options.has_device = true;
97   parse_options.has_func_attr = true;
98   parse_options.has_cost = false;
99 
100   auto &builder = parser.getBuilder();
101   if (mlir::failed(fallback_common::ParseExecuteOpCommon(
102           parser, builder, result, builder.getType<fallback::TFTensorType>(),
103           parse_options)))
104     return mlir::failure();
105 
106   mlir::IntegerAttr num_args;
107   if (parser.parseKeyword("num_args") || parser.parseLParen() ||
108       parser.parseAttribute(num_args, "num_args", result.attributes) ||
109       parser.parseRParen())
110     return mlir::failure();
111 
112   return mlir::success();
113 }
parseExecuteOp(OpAsmParser & parser,OperationState & result)114 static ParseResult parseExecuteOp(OpAsmParser &parser, OperationState &result) {
115   fallback_common::ParseExecuteOpOptions parse_options;
116   parse_options.has_chain = false;
117   parse_options.has_key = true;
118   parse_options.has_device = true;
119   parse_options.has_func_attr = true;
120   parse_options.has_cost = true;
121 
122   auto &builder = parser.getBuilder();
123   return fallback_common::ParseExecuteOpCommon(
124       parser, builder, result, builder.getType<fallback::TFTensorType>(),
125       parse_options);
126 }
parseExecuteOpSeq(OpAsmParser & parser,OperationState & result)127 static ParseResult parseExecuteOpSeq(OpAsmParser &parser,
128                                      OperationState &result) {
129   fallback_common::ParseExecuteOpOptions parse_options;
130   parse_options.has_chain = true;
131   parse_options.has_key = true;
132   parse_options.has_device = true;
133   parse_options.has_func_attr = true;
134   parse_options.has_cost = true;
135 
136   auto &builder = parser.getBuilder();
137   return fallback_common::ParseExecuteOpCommon(
138       parser, builder, result, builder.getType<fallback::TFTensorType>(),
139       parse_options);
140 }
141 
parseBatchFunctionOp(OpAsmParser & parser,OperationState & result)142 static ParseResult parseBatchFunctionOp(OpAsmParser &parser,
143                                         OperationState &result) {
144   auto &builder = parser.getBuilder();
145   auto chain_type = GetChainType(&builder);
146   auto tensorhandle_type = builder.getType<corert::TensorHandleType>();
147 
148   FlatSymbolRefAttr f;
149   SmallVector<OpAsmParser::OperandType, 4> in_chains;
150   SmallVector<OpAsmParser::OperandType, 4> operands;
151   NamedAttrList op_attrs;
152   auto loc = parser.getNameLoc();
153 
154   if (parser.parseOperandList(in_chains,
155                               /*requiredOperandCount=*/1,
156                               OpAsmParser::Delimiter::Paren))
157     return failure();
158 
159   if (parser.parseAttribute(f, "f", result.attributes) ||
160       parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) ||
161       parser.parseOptionalAttrDict(op_attrs))
162     return failure();
163 
164   int64_t num_results = 0;
165   if (succeeded(parser.parseOptionalColon())) {
166     IntegerAttr attr;
167     mlir::NamedAttrList attrs;
168     if (failed(parser.parseAttribute(attr, "num_results", attrs)))
169       return failure();
170     num_results = attr.getValue().getSExtValue();
171   }
172 
173   SmallVector<Type, 4> operand_types;
174   operand_types.push_back(chain_type);
175   if (parser.resolveOperands(in_chains, operand_types, loc, result.operands) ||
176       parser.resolveOperands(operands, tensorhandle_type, result.operands))
177     return failure();
178 
179   result.types.push_back(chain_type);
180   result.types.append(num_results, tensorhandle_type);
181 
182   SmallVector<Attribute, 4> op_attr_array;
183   for (const auto &key_value : op_attrs) {
184     auto key = builder.getStringAttr(key_value.first.strref());
185     auto value = key_value.second;
186     op_attr_array.push_back(builder.getArrayAttr({key, value}));
187   }
188 
189   result.attributes.push_back(
190       builder.getNamedAttr("op_attrs", builder.getArrayAttr(op_attr_array)));
191 
192   return success();
193 }
194 
print(OpAsmPrinter & p,CreateOp op)195 static void print(OpAsmPrinter &p, CreateOp op) {
196   p << "tfrt_fallback_async.createop(" << op.in_ch() << ") key("
197     << op->getAttrOfType<mlir::IntegerAttr>("op_key").getInt() << ") device("
198     << op->getAttr("device") << ") " << op->getAttr("op_name") << "()";
199 
200   fallback_common::PrintExecuteOpCommon(p, op);
201   fallback_common::PrintExecuteOpFuncAttribute(p, op);
202 
203   p << " num_args(" << op->getAttrOfType<mlir::IntegerAttr>("num_args").getInt()
204     << ')';
205 }
206 
print(OpAsmPrinter & p,ExecuteOp op)207 static void print(OpAsmPrinter &p, ExecuteOp op) {
208   p << "tfrt_fallback_async.executeop key("
209     << op->getAttrOfType<mlir::IntegerAttr>("op_key").getInt() << ") cost("
210     << op->getAttrOfType<mlir::IntegerAttr>("_tfrt_cost").getInt()
211     << ") device(" << op->getAttr("device") << ") " << op->getAttr("op_name")
212     << '(' << op.operands() << ')';
213 
214   fallback_common::PrintExecuteOpCommon(p, op);
215   fallback_common::PrintExecuteOpFuncAttribute(p, op);
216   if (!op.results().empty()) p << " : " << op.results().size();
217 }
218 
print(OpAsmPrinter & p,ExecuteOpSeq op)219 static void print(OpAsmPrinter &p, ExecuteOpSeq op) {
220   p << "tfrt_fallback_async.executeop.seq(" << op.in_op_chain() << ") key("
221     << op->getAttrOfType<mlir::IntegerAttr>("op_key").getInt() << ") cost("
222     << op->getAttrOfType<mlir::IntegerAttr>("_tfrt_cost").getInt()
223     << ") device(" << op->getAttr("device") << ") " << op->getAttr("op_name")
224     << '(' << op.operands() << ')';
225 
226   fallback_common::PrintExecuteOpCommon(p, op);
227   fallback_common::PrintExecuteOpFuncAttribute(p, op);
228   if (!op.results().empty()) p << " : " << op.results().size();
229 }
230 
print(OpAsmPrinter & p,BatchFunctionOp op)231 static void print(OpAsmPrinter &p, BatchFunctionOp op) {
232   p << "tfrt_fallback_async.batch_function(" << op.in_op_chain() << ") "
233     << op->getAttr("f") << " (" << op.operands() << ") ";
234 
235   fallback_common::PrintExecuteOpCommon(p, op);
236   if (!op.results().empty()) p << " : " << op.results().size();
237 }
238 
getOpAttrs(SmallVectorImpl<std::pair<StringRef,Attribute>> * op_attrs)239 void ExecuteOp::getOpAttrs(
240     SmallVectorImpl<std::pair<StringRef, Attribute>> *op_attrs) {
241   fallback_common::GetExecuteOpAttrsCommon(
242       this->getContext(), this->op_attrs().getValue(), op_attrs);
243 }
244 
245 //===----------------------------------------------------------------------===//
246 // ConstDenseTensorOp
247 //===----------------------------------------------------------------------===//
248 
fold(ArrayRef<Attribute> operands)249 OpFoldResult ConstDenseTensorOp::fold(ArrayRef<Attribute> operands) {
250   return value();
251 }
252 
253 //===----------------------------------------------------------------------===//
254 // CoreRTTensorHandleToFallbackTensorOp
255 //===----------------------------------------------------------------------===//
256 
257 namespace {
258 
259 // Simplifies pattern containing a corert const tensor op followed by a
260 // `tfrt_fallback_async.corert_tensorhandle_to_fallback_tensor` op to a single
261 // tfrt_fallback_async const tensor.
262 struct ConstCoreRTTensorHandleToFallbackTensorCanonicalization
263     : public OpRewritePattern<CoreRTTensorHandleToFallbackTensorOp> {
264   using OpRewritePattern<
265       CoreRTTensorHandleToFallbackTensorOp>::OpRewritePattern;
266 
matchAndRewritetfrt::fallback_async::__anon300ef72b0211::ConstCoreRTTensorHandleToFallbackTensorCanonicalization267   LogicalResult matchAndRewrite(CoreRTTensorHandleToFallbackTensorOp op,
268                                 PatternRewriter &rewriter) const override {
269     SmallVector<Value, 1> new_values;
270     bool should_rewrite = false;
271     for (auto operand : op.operands()) {
272       if (auto corert_const_dense_tensor_op =
273               operand.getDefiningOp<corert::ConstDenseTensorOp>()) {
274         new_values.push_back(
275             rewriter.create<fallback_async::ConstDenseTensorOp>(
276                 op.getLoc(), rewriter.getType<fallback::TFTensorType>(),
277                 corert_const_dense_tensor_op.value()));
278         should_rewrite = true;
279         continue;
280       }
281       if (auto corert_const_string_tensor_op =
282               operand.getDefiningOp<corert::ConstStringTensorOp>()) {
283         new_values.push_back(
284             rewriter.create<fallback_async::ConstStringTensorOp>(
285                 op.getLoc(), rewriter.getType<fallback::TFTensorType>(),
286                 corert_const_string_tensor_op.shape(),
287                 corert_const_string_tensor_op.value()));
288         should_rewrite = true;
289         continue;
290       }
291       // To guarantee that the new values are in the same order as the old
292       // ones, we create individual ops for the non-canonicalizable operands.
293       // For simplicity, we don't consolidate these ops when all the
294       // non-canonicalizable operands are adjacent.
295       new_values.push_back(
296           rewriter
297               .create<fallback_async::CoreRTTensorHandleToFallbackTensorOp>(
298                   op.getLoc(), rewriter.getType<fallback::TFTensorType>(),
299                   operand, op->getAttrOfType<mlir::StringAttr>("device"))
300               .getResult(0));
301     }
302 
303     if (!should_rewrite) return failure();
304     rewriter.replaceOp(op, new_values);
305     return success();
306   }
307 };
308 
309 // Removes the following double tensor conversion:
310 //  %1 = tfrt_fallback_async.fallback_tensor_to_corert_tensorhandle %0
311 //  %2 = tfrt_fallback_async.corert_tensorhandle_to_fallback_tensor %1
312 struct RemoveDoubleTensorConversion
313     : mlir::OpRewritePattern<CoreRTTensorHandleToFallbackTensorOp> {
314   using OpRewritePattern<
315       CoreRTTensorHandleToFallbackTensorOp>::OpRewritePattern;
316 
matchAndRewritetfrt::fallback_async::__anon300ef72b0211::RemoveDoubleTensorConversion317   mlir::LogicalResult matchAndRewrite(
318       CoreRTTensorHandleToFallbackTensorOp op,
319       mlir::PatternRewriter &rewriter) const override {
320     // Currently only handles the case where there is only one value in the
321     // conversion op. This should be enough for most of the cases.
322     if (op.getNumOperands() > 1) return mlir::failure();
323 
324     auto def =
325         op.getOperand(0).getDefiningOp<FallbackTensorToCoreRTTensorHandleOp>();
326     if (!def) return mlir::failure();
327 
328     if (def.getNumResults() > 1) return mlir::failure();
329 
330     rewriter.replaceOp(op, def.getOperand(0));
331 
332     return mlir::success();
333   }
334 };
335 
336 }  // namespace
337 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)338 void CoreRTTensorHandleToFallbackTensorOp::getCanonicalizationPatterns(
339     OwningRewritePatternList &results, MLIRContext *context) {
340   results.insert<ConstCoreRTTensorHandleToFallbackTensorCanonicalization,
341                  RemoveDoubleTensorConversion>(context);
342 }
343 
344 }  // namespace fallback_async
345 }  // namespace tfrt
346 
347 //===----------------------------------------------------------------------===//
348 // TableGen'd op method definitions
349 //===----------------------------------------------------------------------===//
350 
351 #define GET_OP_CLASSES
352 #include "tensorflow/core/runtime_fallback/opdefs/tfrt_fallback_async.cpp.inc"
353