1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/compiler/mlir/tfrt/transforms/fallback_converter.h"
16
17 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
18 #include "tensorflow/core/runtime_fallback/opdefs/tfrt_fallback.h"
19 #include "tensorflow/core/runtime_fallback/opdefs/tfrt_fallback_async.h"
20 #include "tfrt/basic_kernels/opdefs/types.h" // from @tf_runtime
21 #include "tfrt/core_runtime/opdefs/types.h" // from @tf_runtime
22
23 namespace tensorflow {
24 namespace tfrt_compiler {
25 namespace {
26 constexpr char kCpuDeviceName[] =
27 "/job:localhost/replica:0/task:0/device:CPU:0";
28 }
29
FallbackConverter(mlir::MLIRContext * context)30 FallbackConverter::FallbackConverter(mlir::MLIRContext *context)
31 : builder_(context) {
32 addConversion([](tfrt::compiler::ChainType type) { return type; });
33 addConversion([](tfrt::fallback::TFTensorType type) { return type; });
34 addConversion([=](mlir::TensorType type) -> llvm::Optional<mlir::Type> {
35 // Ref types are not supported in both compiler and runtime.
36 if (type.getElementType().isa<mlir::TF::TensorFlowRefType>()) {
37 return llvm::None;
38 }
39
40 return builder_.getType<tfrt::fallback::TFTensorType>();
41 });
42 addConversion([=](mlir::Type type) -> llvm::Optional<mlir::Type> {
43 if (type == builder_.getI1Type()) return type;
44 return llvm::None;
45 });
46 }
47
ConvertCoreRTTensorHandleToFallbackTensor(mlir::Location loc,llvm::StringRef device,mlir::Value value,mlir::ConversionPatternRewriter & rewriter)48 mlir::Value ConvertCoreRTTensorHandleToFallbackTensor(
49 mlir::Location loc, llvm::StringRef device, mlir::Value value,
50 mlir::ConversionPatternRewriter &rewriter) {
51 if (value.getType().isa<tfrt::fallback::TFTensorType>()) return value;
52
53 if (!value.getType().isa<tfrt::corert::TensorHandleType>()) return {};
54
55 mlir::OpBuilder::InsertionGuard guard(rewriter);
56
57 auto *def = value.getDefiningOp();
58 if (def) {
59 rewriter.setInsertionPointAfter(def);
60 } else {
61 rewriter.setInsertionPointToStart(value.getParentBlock());
62 }
63
64 return rewriter
65 .create<tfrt::fallback_async::CoreRTTensorHandleToFallbackTensorOp>(
66 loc, rewriter.getType<tfrt::fallback::TFTensorType>(), value, device)
67 .getResult(0);
68 }
69
ConvertFallbackTensorToCoreRTTensorHandle(mlir::Location loc,mlir::Value value,mlir::ConversionPatternRewriter & rewriter)70 mlir::Value ConvertFallbackTensorToCoreRTTensorHandle(
71 mlir::Location loc, mlir::Value value,
72 mlir::ConversionPatternRewriter &rewriter) {
73 if (value.getType().isa<tfrt::corert::TensorHandleType>()) return value;
74
75 if (!value.getType().isa<tfrt::fallback::TFTensorType>()) return {};
76
77 // Use CPU device by default if no device is specified.
78 std::string device = kCpuDeviceName;
79 if (auto *def = value.getDefiningOp()) {
80 if (auto device_attr = def->getAttrOfType<mlir::StringAttr>("device")) {
81 device = device_attr.getValue().str();
82 }
83 }
84
85 return rewriter
86 .create<tfrt::fallback_async::FallbackTensorToCoreRTTensorHandleOp>(
87 loc, rewriter.getType<tfrt::corert::TensorHandleType>(), value,
88 device)
89 .getResult(0);
90 }
91
ConvertCoreRTOperands(mlir::Operation * op,llvm::ArrayRef<mlir::Value> operands,llvm::SmallVectorImpl<mlir::Value> * new_operands,mlir::ConversionPatternRewriter & rewriter)92 mlir::LogicalResult ConvertCoreRTOperands(
93 mlir::Operation *op, llvm::ArrayRef<mlir::Value> operands,
94 llvm::SmallVectorImpl<mlir::Value> *new_operands,
95 mlir::ConversionPatternRewriter &rewriter) {
96 mlir::OpBuilder::InsertionGuard guard(rewriter);
97 // Insert before the current op.
98 rewriter.setInsertionPoint(op);
99
100 for (auto operand : operands) {
101 auto value = ConvertFallbackTensorToCoreRTTensorHandle(op->getLoc(),
102 operand, rewriter);
103 if (!value) {
104 return op->emitWarning("failed to convert to !corert.tensorhandle")
105 << operand.getType();
106 }
107
108 new_operands->push_back(value);
109 }
110 return success();
111 }
112
ConvertFallbackOperands(mlir::Operation * op,llvm::StringRef device,llvm::ArrayRef<mlir::Value> operands,llvm::SmallVectorImpl<mlir::Value> * new_operands,mlir::ConversionPatternRewriter & rewriter)113 mlir::LogicalResult ConvertFallbackOperands(
114 mlir::Operation *op, llvm::StringRef device,
115 llvm::ArrayRef<mlir::Value> operands,
116 llvm::SmallVectorImpl<mlir::Value> *new_operands,
117 mlir::ConversionPatternRewriter &rewriter) {
118 for (auto operand : operands) {
119 if (!operand.getType().isa<tfrt::fallback::TFTensorType>()) {
120 auto new_operand = ConvertCoreRTTensorHandleToFallbackTensor(
121 op->getLoc(), device, operand, rewriter);
122 if (!new_operand)
123 return op->emitWarning(
124 "failed to convert the operand to fallback tensor.");
125 new_operands->push_back(new_operand);
126 } else {
127 new_operands->push_back(operand);
128 }
129 }
130 return success();
131 }
132
133 } // namespace tfrt_compiler
134 } // namespace tensorflow
135