1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/runtime_fallback/opdefs/tfrt_fallback_async.h"
16
17 #include "llvm/ADT/STLExtras.h"
18 #include "mlir/IR/Attributes.h" // from @llvm-project
19 #include "mlir/IR/Builders.h" // from @llvm-project
20 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
21 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
22 #include "mlir/IR/DialectImplementation.h" // from @llvm-project
23 #include "mlir/IR/Matchers.h" // from @llvm-project
24 #include "mlir/IR/OpDefinition.h" // from @llvm-project
25 #include "mlir/IR/OpImplementation.h" // from @llvm-project
26 #include "mlir/IR/OperationSupport.h" // from @llvm-project
27 #include "mlir/IR/PatternMatch.h" // from @llvm-project
28 #include "mlir/IR/TypeUtilities.h" // from @llvm-project
29 #include "mlir/Support/LogicalResult.h" // from @llvm-project
30 #include "mlir/Transforms/InliningUtils.h" // from @llvm-project
31 #include "tensorflow/core/runtime_fallback/opdefs/tfrt_fallback.h"
32 #include "tensorflow/core/runtime_fallback/opdefs/tfrt_fallback_common.h"
33 #include "tfrt/basic_kernels/opdefs/tfrt_base.h" // from @tf_runtime
34 #include "tfrt/basic_kernels/opdefs/types.h" // from @tf_runtime
35 #include "tfrt/core_runtime/opdefs/attributes.h" // from @tf_runtime
36 #include "tfrt/core_runtime/opdefs/core_runtime.h" // from @tf_runtime
37 #include "tfrt/core_runtime/opdefs/sync/core_runtime.h" // from @tf_runtime
38 #include "tfrt/core_runtime/opdefs/types.h" // from @tf_runtime
39
40 namespace tfrt {
41 namespace fallback_async {
42
43 namespace {
44
45 struct FallbackInlinerInterface : public mlir::DialectInlinerInterface {
46 using DialectInlinerInterface::DialectInlinerInterface;
47
isLegalToInlinetfrt::fallback_async::__anon300ef72b0111::FallbackInlinerInterface48 bool isLegalToInline(Operation *op, Region *dest, bool would_be_cloned,
49 BlockAndValueMapping &) const final {
50 return true;
51 }
52 };
53
54 } // namespace
55
FallbackAsyncDialect(MLIRContext * context)56 FallbackAsyncDialect::FallbackAsyncDialect(MLIRContext *context)
57 : Dialect(/*name=*/"tfrt_fallback_async", context,
58 TypeID::get<FallbackAsyncDialect>()) {
59 context->getOrLoadDialect<tfrt::fallback::FallbackDialect>();
60 context->getOrLoadDialect<compiler::TFRTDialect>();
61 context->getOrLoadDialect<corert::CoreRTDialect>();
62
63 allowUnknownTypes();
64
65 allowUnknownOperations();
66
67 addInterfaces<FallbackInlinerInterface>();
68
69 addOperations<
70 #define GET_OP_LIST
71 #include "tensorflow/core/runtime_fallback/opdefs/tfrt_fallback_async.cpp.inc"
72 >();
73 }
74
GetChainType(Builder * builder)75 static Type GetChainType(Builder *builder) {
76 return builder->getType<compiler::ChainType>();
77 }
78
verify(CreateOp op)79 static LogicalResult verify(CreateOp op) {
80 return fallback_common::VerifyFallbackExecuteOp(op);
81 }
verify(ExecuteOp op)82 static LogicalResult verify(ExecuteOp op) {
83 return fallback_common::VerifyFallbackExecuteOp(op);
84 }
verify(ExecuteOpSeq op)85 static LogicalResult verify(ExecuteOpSeq op) {
86 return fallback_common::VerifyFallbackExecuteOp(op);
87 }
verify(BatchFunctionOp op)88 static LogicalResult verify(BatchFunctionOp op) {
89 return fallback_common::VerifyExecuteOpCommon(op);
90 }
91
parseCreateOp(OpAsmParser & parser,OperationState & result)92 static ParseResult parseCreateOp(OpAsmParser &parser, OperationState &result) {
93 fallback_common::ParseExecuteOpOptions parse_options;
94 parse_options.has_chain = true;
95 parse_options.has_key = true;
96 parse_options.has_device = true;
97 parse_options.has_func_attr = true;
98 parse_options.has_cost = false;
99
100 auto &builder = parser.getBuilder();
101 if (mlir::failed(fallback_common::ParseExecuteOpCommon(
102 parser, builder, result, builder.getType<fallback::TFTensorType>(),
103 parse_options)))
104 return mlir::failure();
105
106 mlir::IntegerAttr num_args;
107 if (parser.parseKeyword("num_args") || parser.parseLParen() ||
108 parser.parseAttribute(num_args, "num_args", result.attributes) ||
109 parser.parseRParen())
110 return mlir::failure();
111
112 return mlir::success();
113 }
parseExecuteOp(OpAsmParser & parser,OperationState & result)114 static ParseResult parseExecuteOp(OpAsmParser &parser, OperationState &result) {
115 fallback_common::ParseExecuteOpOptions parse_options;
116 parse_options.has_chain = false;
117 parse_options.has_key = true;
118 parse_options.has_device = true;
119 parse_options.has_func_attr = true;
120 parse_options.has_cost = true;
121
122 auto &builder = parser.getBuilder();
123 return fallback_common::ParseExecuteOpCommon(
124 parser, builder, result, builder.getType<fallback::TFTensorType>(),
125 parse_options);
126 }
parseExecuteOpSeq(OpAsmParser & parser,OperationState & result)127 static ParseResult parseExecuteOpSeq(OpAsmParser &parser,
128 OperationState &result) {
129 fallback_common::ParseExecuteOpOptions parse_options;
130 parse_options.has_chain = true;
131 parse_options.has_key = true;
132 parse_options.has_device = true;
133 parse_options.has_func_attr = true;
134 parse_options.has_cost = true;
135
136 auto &builder = parser.getBuilder();
137 return fallback_common::ParseExecuteOpCommon(
138 parser, builder, result, builder.getType<fallback::TFTensorType>(),
139 parse_options);
140 }
141
parseBatchFunctionOp(OpAsmParser & parser,OperationState & result)142 static ParseResult parseBatchFunctionOp(OpAsmParser &parser,
143 OperationState &result) {
144 auto &builder = parser.getBuilder();
145 auto chain_type = GetChainType(&builder);
146 auto tensorhandle_type = builder.getType<corert::TensorHandleType>();
147
148 FlatSymbolRefAttr f;
149 SmallVector<OpAsmParser::OperandType, 4> in_chains;
150 SmallVector<OpAsmParser::OperandType, 4> operands;
151 NamedAttrList op_attrs;
152 auto loc = parser.getNameLoc();
153
154 if (parser.parseOperandList(in_chains,
155 /*requiredOperandCount=*/1,
156 OpAsmParser::Delimiter::Paren))
157 return failure();
158
159 if (parser.parseAttribute(f, "f", result.attributes) ||
160 parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) ||
161 parser.parseOptionalAttrDict(op_attrs))
162 return failure();
163
164 int64_t num_results = 0;
165 if (succeeded(parser.parseOptionalColon())) {
166 IntegerAttr attr;
167 mlir::NamedAttrList attrs;
168 if (failed(parser.parseAttribute(attr, "num_results", attrs)))
169 return failure();
170 num_results = attr.getValue().getSExtValue();
171 }
172
173 SmallVector<Type, 4> operand_types;
174 operand_types.push_back(chain_type);
175 if (parser.resolveOperands(in_chains, operand_types, loc, result.operands) ||
176 parser.resolveOperands(operands, tensorhandle_type, result.operands))
177 return failure();
178
179 result.types.push_back(chain_type);
180 result.types.append(num_results, tensorhandle_type);
181
182 SmallVector<Attribute, 4> op_attr_array;
183 for (const auto &key_value : op_attrs) {
184 auto key = builder.getStringAttr(key_value.first.strref());
185 auto value = key_value.second;
186 op_attr_array.push_back(builder.getArrayAttr({key, value}));
187 }
188
189 result.attributes.push_back(
190 builder.getNamedAttr("op_attrs", builder.getArrayAttr(op_attr_array)));
191
192 return success();
193 }
194
print(OpAsmPrinter & p,CreateOp op)195 static void print(OpAsmPrinter &p, CreateOp op) {
196 p << "tfrt_fallback_async.createop(" << op.in_ch() << ") key("
197 << op->getAttrOfType<mlir::IntegerAttr>("op_key").getInt() << ") device("
198 << op->getAttr("device") << ") " << op->getAttr("op_name") << "()";
199
200 fallback_common::PrintExecuteOpCommon(p, op);
201 fallback_common::PrintExecuteOpFuncAttribute(p, op);
202
203 p << " num_args(" << op->getAttrOfType<mlir::IntegerAttr>("num_args").getInt()
204 << ')';
205 }
206
print(OpAsmPrinter & p,ExecuteOp op)207 static void print(OpAsmPrinter &p, ExecuteOp op) {
208 p << "tfrt_fallback_async.executeop key("
209 << op->getAttrOfType<mlir::IntegerAttr>("op_key").getInt() << ") cost("
210 << op->getAttrOfType<mlir::IntegerAttr>("_tfrt_cost").getInt()
211 << ") device(" << op->getAttr("device") << ") " << op->getAttr("op_name")
212 << '(' << op.operands() << ')';
213
214 fallback_common::PrintExecuteOpCommon(p, op);
215 fallback_common::PrintExecuteOpFuncAttribute(p, op);
216 if (!op.results().empty()) p << " : " << op.results().size();
217 }
218
print(OpAsmPrinter & p,ExecuteOpSeq op)219 static void print(OpAsmPrinter &p, ExecuteOpSeq op) {
220 p << "tfrt_fallback_async.executeop.seq(" << op.in_op_chain() << ") key("
221 << op->getAttrOfType<mlir::IntegerAttr>("op_key").getInt() << ") cost("
222 << op->getAttrOfType<mlir::IntegerAttr>("_tfrt_cost").getInt()
223 << ") device(" << op->getAttr("device") << ") " << op->getAttr("op_name")
224 << '(' << op.operands() << ')';
225
226 fallback_common::PrintExecuteOpCommon(p, op);
227 fallback_common::PrintExecuteOpFuncAttribute(p, op);
228 if (!op.results().empty()) p << " : " << op.results().size();
229 }
230
print(OpAsmPrinter & p,BatchFunctionOp op)231 static void print(OpAsmPrinter &p, BatchFunctionOp op) {
232 p << "tfrt_fallback_async.batch_function(" << op.in_op_chain() << ") "
233 << op->getAttr("f") << " (" << op.operands() << ") ";
234
235 fallback_common::PrintExecuteOpCommon(p, op);
236 if (!op.results().empty()) p << " : " << op.results().size();
237 }
238
getOpAttrs(SmallVectorImpl<std::pair<StringRef,Attribute>> * op_attrs)239 void ExecuteOp::getOpAttrs(
240 SmallVectorImpl<std::pair<StringRef, Attribute>> *op_attrs) {
241 fallback_common::GetExecuteOpAttrsCommon(
242 this->getContext(), this->op_attrs().getValue(), op_attrs);
243 }
244
245 //===----------------------------------------------------------------------===//
246 // ConstDenseTensorOp
247 //===----------------------------------------------------------------------===//
248
fold(ArrayRef<Attribute> operands)249 OpFoldResult ConstDenseTensorOp::fold(ArrayRef<Attribute> operands) {
250 return value();
251 }
252
253 //===----------------------------------------------------------------------===//
254 // CoreRTTensorHandleToFallbackTensorOp
255 //===----------------------------------------------------------------------===//
256
257 namespace {
258
259 // Simplifies pattern containing a corert const tensor op followed by a
260 // `tfrt_fallback_async.corert_tensorhandle_to_fallback_tensor` op to a single
261 // tfrt_fallback_async const tensor.
262 struct ConstCoreRTTensorHandleToFallbackTensorCanonicalization
263 : public OpRewritePattern<CoreRTTensorHandleToFallbackTensorOp> {
264 using OpRewritePattern<
265 CoreRTTensorHandleToFallbackTensorOp>::OpRewritePattern;
266
matchAndRewritetfrt::fallback_async::__anon300ef72b0211::ConstCoreRTTensorHandleToFallbackTensorCanonicalization267 LogicalResult matchAndRewrite(CoreRTTensorHandleToFallbackTensorOp op,
268 PatternRewriter &rewriter) const override {
269 SmallVector<Value, 1> new_values;
270 bool should_rewrite = false;
271 for (auto operand : op.operands()) {
272 if (auto corert_const_dense_tensor_op =
273 operand.getDefiningOp<corert::ConstDenseTensorOp>()) {
274 new_values.push_back(
275 rewriter.create<fallback_async::ConstDenseTensorOp>(
276 op.getLoc(), rewriter.getType<fallback::TFTensorType>(),
277 corert_const_dense_tensor_op.value()));
278 should_rewrite = true;
279 continue;
280 }
281 if (auto corert_const_string_tensor_op =
282 operand.getDefiningOp<corert::ConstStringTensorOp>()) {
283 new_values.push_back(
284 rewriter.create<fallback_async::ConstStringTensorOp>(
285 op.getLoc(), rewriter.getType<fallback::TFTensorType>(),
286 corert_const_string_tensor_op.shape(),
287 corert_const_string_tensor_op.value()));
288 should_rewrite = true;
289 continue;
290 }
291 // To guarantee that the new values are in the same order as the old
292 // ones, we create individual ops for the non-canonicalizable operands.
293 // For simplicity, we don't consolidate these ops when all the
294 // non-canonicalizable operands are adjacent.
295 new_values.push_back(
296 rewriter
297 .create<fallback_async::CoreRTTensorHandleToFallbackTensorOp>(
298 op.getLoc(), rewriter.getType<fallback::TFTensorType>(),
299 operand, op->getAttrOfType<mlir::StringAttr>("device"))
300 .getResult(0));
301 }
302
303 if (!should_rewrite) return failure();
304 rewriter.replaceOp(op, new_values);
305 return success();
306 }
307 };
308
309 // Removes the following double tensor conversion:
310 // %1 = tfrt_fallback_async.fallback_tensor_to_corert_tensorhandle %0
311 // %2 = tfrt_fallback_async.corert_tensorhandle_to_fallback_tensor %1
312 struct RemoveDoubleTensorConversion
313 : mlir::OpRewritePattern<CoreRTTensorHandleToFallbackTensorOp> {
314 using OpRewritePattern<
315 CoreRTTensorHandleToFallbackTensorOp>::OpRewritePattern;
316
matchAndRewritetfrt::fallback_async::__anon300ef72b0211::RemoveDoubleTensorConversion317 mlir::LogicalResult matchAndRewrite(
318 CoreRTTensorHandleToFallbackTensorOp op,
319 mlir::PatternRewriter &rewriter) const override {
320 // Currently only handles the case where there is only one value in the
321 // conversion op. This should be enough for most of the cases.
322 if (op.getNumOperands() > 1) return mlir::failure();
323
324 auto def =
325 op.getOperand(0).getDefiningOp<FallbackTensorToCoreRTTensorHandleOp>();
326 if (!def) return mlir::failure();
327
328 if (def.getNumResults() > 1) return mlir::failure();
329
330 rewriter.replaceOp(op, def.getOperand(0));
331
332 return mlir::success();
333 }
334 };
335
336 } // namespace
337
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)338 void CoreRTTensorHandleToFallbackTensorOp::getCanonicalizationPatterns(
339 OwningRewritePatternList &results, MLIRContext *context) {
340 results.insert<ConstCoreRTTensorHandleToFallbackTensorCanonicalization,
341 RemoveDoubleTensorConversion>(context);
342 }
343
344 } // namespace fallback_async
345 } // namespace tfrt
346
347 //===----------------------------------------------------------------------===//
348 // TableGen'd op method definitions
349 //===----------------------------------------------------------------------===//
350
351 #define GET_OP_CLASSES
352 #include "tensorflow/core/runtime_fallback/opdefs/tfrt_fallback_async.cpp.inc"
353