• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/compiler/mlir/tfrt/transforms/fallback_converter.h"
16 
17 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
18 #include "tensorflow/core/runtime_fallback/opdefs/tfrt_fallback.h"
19 #include "tensorflow/core/runtime_fallback/opdefs/tfrt_fallback_async.h"
20 #include "tfrt/basic_kernels/opdefs/types.h"  // from @tf_runtime
21 #include "tfrt/core_runtime/opdefs/types.h"  // from @tf_runtime
22 
23 namespace tensorflow {
24 namespace tfrt_compiler {
25 namespace {
26 constexpr char kCpuDeviceName[] =
27     "/job:localhost/replica:0/task:0/device:CPU:0";
28 }
29 
FallbackConverter(mlir::MLIRContext * context)30 FallbackConverter::FallbackConverter(mlir::MLIRContext *context)
31     : builder_(context) {
32   addConversion([](tfrt::compiler::ChainType type) { return type; });
33   addConversion([](tfrt::fallback::TFTensorType type) { return type; });
34   addConversion([=](mlir::TensorType type) -> llvm::Optional<mlir::Type> {
35     // Ref types are not supported in both compiler and runtime.
36     if (type.getElementType().isa<mlir::TF::TensorFlowRefType>()) {
37       return llvm::None;
38     }
39 
40     return builder_.getType<tfrt::fallback::TFTensorType>();
41   });
42   addConversion([=](mlir::Type type) -> llvm::Optional<mlir::Type> {
43     if (type == builder_.getI1Type()) return type;
44     return llvm::None;
45   });
46 }
47 
ConvertCoreRTTensorHandleToFallbackTensor(mlir::Location loc,llvm::StringRef device,mlir::Value value,mlir::ConversionPatternRewriter & rewriter)48 mlir::Value ConvertCoreRTTensorHandleToFallbackTensor(
49     mlir::Location loc, llvm::StringRef device, mlir::Value value,
50     mlir::ConversionPatternRewriter &rewriter) {
51   if (value.getType().isa<tfrt::fallback::TFTensorType>()) return value;
52 
53   if (!value.getType().isa<tfrt::corert::TensorHandleType>()) return {};
54 
55   mlir::OpBuilder::InsertionGuard guard(rewriter);
56 
57   auto *def = value.getDefiningOp();
58   if (def) {
59     rewriter.setInsertionPointAfter(def);
60   } else {
61     rewriter.setInsertionPointToStart(value.getParentBlock());
62   }
63 
64   return rewriter
65       .create<tfrt::fallback_async::CoreRTTensorHandleToFallbackTensorOp>(
66           loc, rewriter.getType<tfrt::fallback::TFTensorType>(), value, device)
67       .getResult(0);
68 }
69 
ConvertFallbackTensorToCoreRTTensorHandle(mlir::Location loc,mlir::Value value,mlir::ConversionPatternRewriter & rewriter)70 mlir::Value ConvertFallbackTensorToCoreRTTensorHandle(
71     mlir::Location loc, mlir::Value value,
72     mlir::ConversionPatternRewriter &rewriter) {
73   if (value.getType().isa<tfrt::corert::TensorHandleType>()) return value;
74 
75   if (!value.getType().isa<tfrt::fallback::TFTensorType>()) return {};
76 
77   // Use CPU device by default if no device is specified.
78   std::string device = kCpuDeviceName;
79   if (auto *def = value.getDefiningOp()) {
80     if (auto device_attr = def->getAttrOfType<mlir::StringAttr>("device")) {
81       device = device_attr.getValue().str();
82     }
83   }
84 
85   return rewriter
86       .create<tfrt::fallback_async::FallbackTensorToCoreRTTensorHandleOp>(
87           loc, rewriter.getType<tfrt::corert::TensorHandleType>(), value,
88           device)
89       .getResult(0);
90 }
91 
ConvertCoreRTOperands(mlir::Operation * op,llvm::ArrayRef<mlir::Value> operands,llvm::SmallVectorImpl<mlir::Value> * new_operands,mlir::ConversionPatternRewriter & rewriter)92 mlir::LogicalResult ConvertCoreRTOperands(
93     mlir::Operation *op, llvm::ArrayRef<mlir::Value> operands,
94     llvm::SmallVectorImpl<mlir::Value> *new_operands,
95     mlir::ConversionPatternRewriter &rewriter) {
96   mlir::OpBuilder::InsertionGuard guard(rewriter);
97   // Insert before the current op.
98   rewriter.setInsertionPoint(op);
99 
100   for (auto operand : operands) {
101     auto value = ConvertFallbackTensorToCoreRTTensorHandle(op->getLoc(),
102                                                            operand, rewriter);
103     if (!value) {
104       return op->emitWarning("failed to convert to !corert.tensorhandle")
105              << operand.getType();
106     }
107 
108     new_operands->push_back(value);
109   }
110   return success();
111 }
112 
ConvertFallbackOperands(mlir::Operation * op,llvm::StringRef device,llvm::ArrayRef<mlir::Value> operands,llvm::SmallVectorImpl<mlir::Value> * new_operands,mlir::ConversionPatternRewriter & rewriter)113 mlir::LogicalResult ConvertFallbackOperands(
114     mlir::Operation *op, llvm::StringRef device,
115     llvm::ArrayRef<mlir::Value> operands,
116     llvm::SmallVectorImpl<mlir::Value> *new_operands,
117     mlir::ConversionPatternRewriter &rewriter) {
118   for (auto operand : operands) {
119     if (!operand.getType().isa<tfrt::fallback::TFTensorType>()) {
120       auto new_operand = ConvertCoreRTTensorHandleToFallbackTensor(
121           op->getLoc(), device, operand, rewriter);
122       if (!new_operand)
123         return op->emitWarning(
124             "failed to convert the operand to fallback tensor.");
125       new_operands->push_back(new_operand);
126     } else {
127       new_operands->push_back(operand);
128     }
129   }
130   return success();
131 }
132 
133 }  // namespace tfrt_compiler
134 }  // namespace tensorflow
135