• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_MLIR_TFRT_RUNTIME_FALLBACK_RUNTIME_FALLBACK_EXECUTOR_H_
17 #define TENSORFLOW_COMPILER_MLIR_TFRT_RUNTIME_FALLBACK_RUNTIME_FALLBACK_EXECUTOR_H_
18 
19 #include <memory>
20 
21 #include "llvm/ADT/ArrayRef.h"
22 #include "llvm/ADT/SmallVector.h"
23 #include "llvm/ADT/StringRef.h"
24 #include "tensorflow/core/framework/tensor.h"
25 #include "tensorflow/core/platform/threadpool_interface.h"
26 #include "tfrt/bef/bef_buffer.h"  // from @tf_runtime
27 #include "tfrt/bef_executor/bef_file.h"  // from @tf_runtime
28 #include "tfrt/host_context/execution_context.h"  // from @tf_runtime
29 #include "tfrt/host_context/host_context.h"  // from @tf_runtime
30 #include "tfrt/host_context/resource_context.h"  // from @tf_runtime
31 #include "tfrt/support/ref_count.h"  // from @tf_runtime
32 
33 namespace tensorflow {
34 
35 class RuntimeFallbackExecutor {
36  public:
37   explicit RuntimeFallbackExecutor(int64_t num_threads);
38 
39   // Prepare() needs to be called once before calling Execute(). It sets up all
40   // things necessary to execute the given 'mlir_input' with the fallback to
41   // tensorflow.
42   void Prepare(llvm::StringRef mlir_input);
43 
44   // Execute() can be called several times after the call to Prepare() (e.g. for
45   // benchmarking).
46   llvm::SmallVector<Tensor> Execute(llvm::StringRef function_name,
47                                     llvm::ArrayRef<Tensor> arguments);
48 
49  private:
50   void RunTfrtInitializer();
51 
52   std::unique_ptr<thread::ThreadPoolInterface> intra_op_;
53   std::unique_ptr<tfrt::HostContext> host_context_;
54   tfrt::ResourceContext resource_context_;
55   std::unique_ptr<tfrt::ExecutionContext> exec_ctx_;
56   tfrt::BefBuffer bef_buffer_;
57   tfrt::RCReference<tfrt::BEFFile> bef_file_;
58 };
59 
60 }  // namespace tensorflow
61 
62 #endif  // THIRD_PARTY_TENSORFLOW_COMPILER_MLIR_TFRT_RUNTIME_FALLBACK_RUNTIME_FALLBACK_EXECUTOR_H_
63