1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_MLIR_TFRT_RUNTIME_FALLBACK_RUNTIME_FALLBACK_EXECUTOR_H_ 17 #define TENSORFLOW_COMPILER_MLIR_TFRT_RUNTIME_FALLBACK_RUNTIME_FALLBACK_EXECUTOR_H_ 18 19 #include <memory> 20 21 #include "llvm/ADT/ArrayRef.h" 22 #include "llvm/ADT/SmallVector.h" 23 #include "llvm/ADT/StringRef.h" 24 #include "tensorflow/core/framework/tensor.h" 25 #include "tensorflow/core/platform/threadpool_interface.h" 26 #include "tfrt/bef/bef_buffer.h" // from @tf_runtime 27 #include "tfrt/bef_executor/bef_file.h" // from @tf_runtime 28 #include "tfrt/host_context/execution_context.h" // from @tf_runtime 29 #include "tfrt/host_context/host_context.h" // from @tf_runtime 30 #include "tfrt/host_context/resource_context.h" // from @tf_runtime 31 #include "tfrt/support/ref_count.h" // from @tf_runtime 32 33 namespace tensorflow { 34 35 class RuntimeFallbackExecutor { 36 public: 37 explicit RuntimeFallbackExecutor(int64_t num_threads); 38 39 // Prepare() needs to be called once before calling Execute(). It sets up all 40 // things necessary to execute the given 'mlir_input' with the fallback to 41 // tensorflow. 42 void Prepare(llvm::StringRef mlir_input); 43 44 // Execute() can be called several times after the call to Prepare() (e.g. for 45 // benchmarking). 46 llvm::SmallVector<Tensor> Execute(llvm::StringRef function_name, 47 llvm::ArrayRef<Tensor> arguments); 48 49 private: 50 void RunTfrtInitializer(); 51 52 std::unique_ptr<thread::ThreadPoolInterface> intra_op_; 53 std::unique_ptr<tfrt::HostContext> host_context_; 54 tfrt::ResourceContext resource_context_; 55 std::unique_ptr<tfrt::ExecutionContext> exec_ctx_; 56 tfrt::BefBuffer bef_buffer_; 57 tfrt::RCReference<tfrt::BEFFile> bef_file_; 58 }; 59 60 } // namespace tensorflow 61 62 #endif // THIRD_PARTY_TENSORFLOW_COMPILER_MLIR_TFRT_RUNTIME_FALLBACK_RUNTIME_FALLBACK_EXECUTOR_H_ 63