• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_TFRT_EAGER_TFRT_CONTEXT_H_
16 #define TENSORFLOW_CORE_TFRT_EAGER_TFRT_CONTEXT_H_
17 
18 #include <functional>
19 #include <utility>
20 
21 #include "tensorflow/c/eager/immediate_execution_context.h"
22 #include "tensorflow/core/platform/threadpool_interface.h"
23 #include "tensorflow/core/public/session_options.h"
24 #include "tensorflow/core/tfrt/runtime/tf_threadpool_concurrent_work_queue.h"
25 #include "tfrt/host_context/resource_context.h"  // from @tf_runtime
26 
27 namespace tensorflow {
28 class EagerContext;
29 class DynamicDeviceMgr;
30 }
31 namespace tfrt {
32 class HostContext;
33 class CoreRuntime;
34 class OpHandler;
35 
36 namespace tf {
37 
38 // Wraps an `Eigen::ThreadPoolInterface` as a
39 // `tensorflow::thread::ThreadPoolInterface`.
40 //
41 // Copied from internal directory: http://shortn/_jsmzLpQu7q
42 class ThreadPoolInterfaceWrapper
43     : public tensorflow::thread::ThreadPoolInterface {
44  public:
ThreadPoolInterfaceWrapper(Eigen::ThreadPoolInterface * thread_pool)45   explicit ThreadPoolInterfaceWrapper(Eigen::ThreadPoolInterface* thread_pool)
46       : thread_pool_{thread_pool} {
47     DCHECK(thread_pool);
48   }
49 
Schedule(std::function<void ()> fn)50   void Schedule(std::function<void()> fn) override {
51     return thread_pool().Schedule(std::move(fn));
52   }
53 
ScheduleWithHint(std::function<void ()> fn,int start,int end)54   void ScheduleWithHint(std::function<void()> fn, int start, int end) override {
55     return thread_pool().ScheduleWithHint(std::move(fn), start, end);
56   }
57 
Cancel()58   void Cancel() override { thread_pool().Cancel(); }
59 
NumThreads()60   int NumThreads() const override { return thread_pool().NumThreads(); }
61 
CurrentThreadId()62   int CurrentThreadId() const override {
63     return thread_pool().CurrentThreadId();
64   }
65 
66  private:
thread_pool()67   Eigen::ThreadPoolInterface& thread_pool() const {
68     DCHECK(thread_pool_);
69     return *thread_pool_;
70   }
71 
72   // Not owning pointer to the thread pool.
73   Eigen::ThreadPoolInterface* thread_pool_ = nullptr;
74 };
75 
76 // This class defines a list of objects needed to support execution with TFRT.
77 class TfrtContext {
78  public:
79   TfrtContext(
80       const tensorflow::SessionOptions& opts,
81       tensorflow::ContextDevicePlacementPolicy default_device_placement_policy,
82       bool is_async);
83   ~TfrtContext();
84 
GetHostContext()85   HostContext* GetHostContext() { return host_context_; }
GetCoreRuntime()86   CoreRuntime* GetCoreRuntime() { return corert_.get(); }
GetEagerContext()87   tensorflow::EagerContext* GetEagerContext() { return eager_context_; }
GetEagerContext()88   const tensorflow::EagerContext* GetEagerContext() const {
89     return eager_context_;
90   }
GetFallbackOpHandler()91   OpHandler* GetFallbackOpHandler() { return fallback_op_handler_; }
92 
GetResourceContext()93   ResourceContext* GetResourceContext() { return &resource_context_; }
94 
GetTfThreadPoolWorkQueue()95   tensorflow::tfrt_stub::TfThreadPoolWorkQueue* GetTfThreadPoolWorkQueue() {
96     return tf_thread_pool_work_queue_.get();
97   }
98 
99   const tensorflow::DeviceNameUtils::ParsedName& HostCPUParsedName() const;
100 
101   bool IsAsync() const;
102 
103  private:
104   std::unique_ptr<CoreRuntime> corert_;
105   ::tfrt::HostContext* host_context_;
106   OpHandler* fallback_op_handler_;
107   ResourceContext resource_context_;
108   tensorflow::EagerContext* eager_context_;
109   std::unique_ptr<ThreadPoolInterfaceWrapper> eager_ctx_thread_pool_;
110 
111   // Manage the local thread pool's lifetime because the wrapper does not own
112   // the thread pool.
113   std::unique_ptr<tensorflow::thread::ThreadPool> local_thread_pool_;
114   std::unique_ptr<ThreadPoolInterfaceWrapper> local_thread_pool_wrapper_;
115   std::unique_ptr<tensorflow::tfrt_stub::TfThreadPoolWorkQueue>
116       tf_thread_pool_work_queue_;
117 };
118 
119 }  // namespace tf
120 }  // namespace tfrt
121 
122 #endif  // TENSORFLOW_CORE_TFRT_EAGER_TFRT_CONTEXT_H_
123