• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_FALLBACK_CONVERTER_H_
16 #define TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_FALLBACK_CONVERTER_H_
17 
18 #include "mlir/IR/Operation.h"  // from @llvm-project
19 #include "mlir/IR/Value.h"  // from @llvm-project
20 #include "mlir/Transforms/DialectConversion.h"  // from @llvm-project
21 
22 namespace tensorflow {
23 namespace tfrt_compiler {
24 
25 class FallbackConverter : public mlir::TypeConverter {
26  public:
27   explicit FallbackConverter(mlir::MLIRContext *context);
28 
29   // Return the next dense key for fallback ops. The key is simply an array
30   // index so that in runtime, the fallback ops can be efficiently retrieved.
GetNextFallbackKey()31   int64_t GetNextFallbackKey() const { return fallback_ops_.size(); }
32 
RegisterFallbackOp(mlir::Operation * op)33   void RegisterFallbackOp(mlir::Operation *op) { fallback_ops_.push_back(op); }
34 
GetFallbackOps()35   llvm::ArrayRef<mlir::Operation *> GetFallbackOps() const {
36     return fallback_ops_;
37   }
38 
39  private:
40   mlir::Builder builder_;
41   // Using a vector to keep fallback ops in order, and the key for a fallback op
42   // is its corresponding index here.
43   llvm::SmallVector<mlir::Operation *, 8> fallback_ops_;
44 };
45 
46 // Convert the `value` that is a !corert.tensorhandle to
47 // !tfrt_fallback.tf_tensor. If needed, tensor conversion kernels will be added.
48 // On error it returns nullptr.
49 mlir::Value ConvertCoreRTTensorHandleToFallbackTensor(
50     mlir::Location loc, llvm::StringRef device, mlir::Value value,
51     mlir::ConversionPatternRewriter &rewriter);
52 
53 // Convert the `value` that is a !tfrt_fallback.tf_tensor to
54 // !corert.tensorhandle. If needed, tensor conversion kernels will be added. On
55 // error it returns nullptr.
56 mlir::Value ConvertFallbackTensorToCoreRTTensorHandle(
57     mlir::Location loc, mlir::Value value,
58     mlir::ConversionPatternRewriter &rewriter);
59 
60 // Convert operands that might be !tfrt_fallback.tf_tensor for corert operations
61 // that take only !corert.tensorhandle.
62 mlir::LogicalResult ConvertCoreRTOperands(
63     mlir::Operation *op, llvm::ArrayRef<mlir::Value> operands,
64     llvm::SmallVectorImpl<mlir::Value> *new_operands,
65     mlir::ConversionPatternRewriter &rewriter);
66 
67 // Convert operands that might be !corert.tensorhandle for fallback operations
68 // that take only !tfrt_fallback.tf_tensor.
69 mlir::LogicalResult ConvertFallbackOperands(
70     mlir::Operation *op, llvm::StringRef device,
71     llvm::ArrayRef<mlir::Value> operands,
72     llvm::SmallVectorImpl<mlir::Value> *new_operands,
73     mlir::ConversionPatternRewriter &rewriter);
74 
75 }  // namespace tfrt_compiler
76 }  // namespace tensorflow
77 
78 #endif  // TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_FALLBACK_CONVERTER_H_
79