1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_FALLBACK_CONVERTER_H_ 16 #define TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_FALLBACK_CONVERTER_H_ 17 18 #include "mlir/IR/Operation.h" // from @llvm-project 19 #include "mlir/IR/Value.h" // from @llvm-project 20 #include "mlir/Transforms/DialectConversion.h" // from @llvm-project 21 22 namespace tensorflow { 23 namespace tfrt_compiler { 24 25 class FallbackConverter : public mlir::TypeConverter { 26 public: 27 explicit FallbackConverter(mlir::MLIRContext *context); 28 29 // Return the next dense key for fallback ops. The key is simply an array 30 // index so that in runtime, the fallback ops can be efficiently retrieved. GetNextFallbackKey()31 int64_t GetNextFallbackKey() const { return fallback_ops_.size(); } 32 RegisterFallbackOp(mlir::Operation * op)33 void RegisterFallbackOp(mlir::Operation *op) { fallback_ops_.push_back(op); } 34 GetFallbackOps()35 llvm::ArrayRef<mlir::Operation *> GetFallbackOps() const { 36 return fallback_ops_; 37 } 38 39 private: 40 mlir::Builder builder_; 41 // Using a vector to keep fallback ops in order, and the key for a fallback op 42 // is its corresponding index here. 43 llvm::SmallVector<mlir::Operation *, 8> fallback_ops_; 44 }; 45 46 // Convert the `value` that is a !corert.tensorhandle to 47 // !tfrt_fallback.tf_tensor. If needed, tensor conversion kernels will be added. 48 // On error it returns nullptr. 49 mlir::Value ConvertCoreRTTensorHandleToFallbackTensor( 50 mlir::Location loc, llvm::StringRef device, mlir::Value value, 51 mlir::ConversionPatternRewriter &rewriter); 52 53 // Convert the `value` that is a !tfrt_fallback.tf_tensor to 54 // !corert.tensorhandle. If needed, tensor conversion kernels will be added. On 55 // error it returns nullptr. 56 mlir::Value ConvertFallbackTensorToCoreRTTensorHandle( 57 mlir::Location loc, mlir::Value value, 58 mlir::ConversionPatternRewriter &rewriter); 59 60 // Convert operands that might be !tfrt_fallback.tf_tensor for corert operations 61 // that take only !corert.tensorhandle. 62 mlir::LogicalResult ConvertCoreRTOperands( 63 mlir::Operation *op, llvm::ArrayRef<mlir::Value> operands, 64 llvm::SmallVectorImpl<mlir::Value> *new_operands, 65 mlir::ConversionPatternRewriter &rewriter); 66 67 // Convert operands that might be !corert.tensorhandle for fallback operations 68 // that take only !tfrt_fallback.tf_tensor. 69 mlir::LogicalResult ConvertFallbackOperands( 70 mlir::Operation *op, llvm::StringRef device, 71 llvm::ArrayRef<mlir::Value> operands, 72 llvm::SmallVectorImpl<mlir::Value> *new_operands, 73 mlir::ConversionPatternRewriter &rewriter); 74 75 } // namespace tfrt_compiler 76 } // namespace tensorflow 77 78 #endif // TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_FALLBACK_CONVERTER_H_ 79