• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_FALLBACK_CONVERTER_H_
16 #define TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_FALLBACK_CONVERTER_H_
17 
18 #include "mlir/IR/Operation.h"  // from @llvm-project
19 #include "mlir/IR/Value.h"  // from @llvm-project
20 #include "mlir/Transforms/DialectConversion.h"  // from @llvm-project
21 
22 namespace tensorflow {
23 namespace tfrt_compiler {
24 
GetDefaultCpuDeviceName()25 inline llvm::StringRef GetDefaultCpuDeviceName() {
26   static constexpr char kCpuDeviceName[] =
27       "/job:localhost/replica:0/task:0/device:CPU:0";
28   return kCpuDeviceName;
29 }
30 
31 class FallbackConverter : public mlir::TypeConverter {
32  public:
33   explicit FallbackConverter(mlir::MLIRContext *context);
34 
35   // Return the next dense key for fallback ops. The key is simply an array
36   // index so that in runtime, the fallback ops can be efficiently retrieved.
GetNextFallbackKey()37   int64_t GetNextFallbackKey() const { return fallback_ops_.size(); }
38 
RegisterFallbackOp(mlir::Operation * op)39   void RegisterFallbackOp(mlir::Operation *op) { fallback_ops_.push_back(op); }
40 
ReplaceFallbackOp(int64_t key,mlir::Operation * op)41   void ReplaceFallbackOp(int64_t key, mlir::Operation *op) {
42     fallback_ops_[key] = op;
43   }
44 
GetFallbackOps()45   llvm::ArrayRef<mlir::Operation *> GetFallbackOps() const {
46     return fallback_ops_;
47   }
48 
49  private:
50   mlir::Builder builder_;
51   // Using a vector to keep fallback ops in order, and the key for a fallback op
52   // is its corresponding index here.
53   llvm::SmallVector<mlir::Operation *, 8> fallback_ops_;
54 };
55 
56 // Convert the `value` that is a !corert.tensorhandle to
57 // !tfrt_fallback.tf_tensor. If needed, tensor conversion kernels will be added.
58 // On error it returns nullptr.
59 mlir::Value ConvertCoreRTTensorHandleToFallbackTensor(
60     mlir::Location loc, llvm::StringRef device, mlir::Value value,
61     mlir::ConversionPatternRewriter &rewriter);
62 
63 // Convert the `value` that is a !tfrt_fallback.tf_tensor to
64 // !corert.tensorhandle. If needed, tensor conversion kernels will be added. On
65 // error it returns nullptr.
66 mlir::Value ConvertFallbackTensorToCoreRTTensorHandle(
67     mlir::Location loc, mlir::Value value,
68     mlir::ConversionPatternRewriter &rewriter);
69 
70 // Convert operands that might be !tfrt_fallback.tf_tensor for corert operations
71 // that take only !corert.tensorhandle.
72 mlir::LogicalResult ConvertCoreRTOperands(
73     mlir::Operation *op, mlir::ValueRange operands,
74     llvm::SmallVectorImpl<mlir::Value> *new_operands,
75     mlir::ConversionPatternRewriter &rewriter);
76 
77 // Convert operands that might be !corert.tensorhandle for fallback operations
78 // that take only !tfrt_fallback.tf_tensor.
79 mlir::LogicalResult ConvertFallbackOperands(
80     mlir::Operation *op, llvm::StringRef device, mlir::ValueRange operands,
81     llvm::SmallVectorImpl<mlir::Value> *new_operands,
82     mlir::ConversionPatternRewriter &rewriter);
83 
84 }  // namespace tfrt_compiler
85 }  // namespace tensorflow
86 
87 #endif  // TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_FALLBACK_CONVERTER_H_
88