• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_CONTEXT_H_
16 #define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_CONTEXT_H_
17 
18 #include <algorithm>
19 #include <cstddef>
20 #include <map>
21 #include <memory>
22 #include <queue>
23 #include <string>
24 #include <vector>
25 
26 #include "tensorflow/core/common_runtime/device_factory.h"
27 #include "tensorflow/core/common_runtime/device_mgr.h"
28 #include "tensorflow/core/common_runtime/eager/eager_executor.h"
29 #include "tensorflow/core/common_runtime/eager/kernel_and_device.h"
30 #include "tensorflow/core/common_runtime/function.h"
31 #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
32 #include "tensorflow/core/example/example.pb.h"
33 #include "tensorflow/core/platform/env.h"
34 #ifndef __ANDROID__
35 #include "tensorflow/core/distributed_runtime/eager/eager_client.h"
36 #include "tensorflow/core/distributed_runtime/server_lib.h"
37 #include "tensorflow/core/distributed_runtime/worker_cache.h"
38 #endif
39 #include "tensorflow/core/framework/collective.h"
40 #include "tensorflow/core/framework/log_memory.h"
41 #include "tensorflow/core/framework/rendezvous.h"
42 #include "tensorflow/core/lib/core/stringpiece.h"
43 #include "tensorflow/core/lib/core/threadpool.h"
44 #include "tensorflow/core/lib/gtl/flatmap.h"
45 #include "tensorflow/core/lib/gtl/flatset.h"
46 #include "tensorflow/core/lib/gtl/inlined_vector.h"
47 #include "tensorflow/core/lib/gtl/map_util.h"
48 #include "tensorflow/core/lib/gtl/stl_util.h"
49 #include "tensorflow/core/platform/fingerprint.h"
50 #include "tensorflow/core/platform/mutex.h"
51 #include "tensorflow/core/platform/thread_annotations.h"
52 #include "tensorflow/core/public/session_options.h"
53 #include "tensorflow/core/public/version.h"
54 
55 namespace tensorflow {
56 
57 // Note: there's a copy enum in eager/c_api.h. It should be kept in sync.
58 enum ContextDevicePlacementPolicy {
59   // Running operations with input tensors on the wrong device will fail.
60   DEVICE_PLACEMENT_EXPLICIT = 0,
61   // Copy the tensor to the right device but log a warning.
62   DEVICE_PLACEMENT_WARN = 1,
63   // Silently copy the tensor, which has a performance cost since the operation
64   // will be blocked till the copy completes. This is the default policy.
65   DEVICE_PLACEMENT_SILENT = 2,
66   // Placement policy which silently copies int32 tensors but not other dtypes.
67   DEVICE_PLACEMENT_SILENT_FOR_INT32 = 3,
68 };
69 
70 class RunMetadataListener {
71  public:
~RunMetadataListener()72   virtual ~RunMetadataListener() {}
73   virtual void BeforeClearRunMetadata() = 0;
74 };
75 
76 class EagerContext {
77  public:
78   // TODO: remove this constructor once we migrate all callers to the next one.
79   EagerContext(const SessionOptions& opts,
80                ContextDevicePlacementPolicy default_policy, bool async,
81                std::unique_ptr<const DeviceMgr> device_mgr,
82                Rendezvous* rendezvous);
83 
84   EagerContext(const SessionOptions& opts,
85                ContextDevicePlacementPolicy default_policy, bool async,
86                const DeviceMgr* device_mgr, bool device_mgr_owned,
87                Rendezvous* rendezvous);
88 
89   ~EagerContext();
90 
91   // Returns the function library runtime for the given device.
func_lib(Device * d)92   FunctionLibraryRuntime* func_lib(Device* d) const {
93     return pflr_->GetFLR(d->name());
94   }
95 
pflr()96   ProcessFunctionLibraryRuntime* pflr() const { return pflr_.get(); }
97 
98   // True if running in asynchronous mode.
99   bool Async() const;
100 
Executor()101   EagerExecutor* Executor() { return &executor_; }
102 
runner()103   std::function<void(std::function<void()>)>* runner() { return &runner_; }
104 
105   // Sets whether this thread should run in synchronous or asynchronous mode.
106   Status SetAsyncForThread(bool async);
107 
108   // TODO(apassos) make this return a constant reference
device_map()109   gtl::FlatMap<string, Device*, StringPieceHasher>* device_map() {
110     return &devices_map_;
111   }
112 
113   // TODO(apassos) make this return a constant reference
devices()114   std::vector<Device*>* devices() { return &devices_; }
prioritized_device_type_list()115   const std::vector<DeviceType>& prioritized_device_type_list() {
116     return prioritized_device_type_list_;
117   }
118 
119   // Clears the kernel caches.
120   Status ClearCaches();
121 
122   // Sets the device placement policy for the current thread.
123   void SetThreadLocalDevicePlacementPolicy(ContextDevicePlacementPolicy policy);
124 
125   // Returns the device placement policy for the current thread.
126   ContextDevicePlacementPolicy GetDevicePlacementPolicy();
127 
AsyncWait()128   Status AsyncWait() { return executor_.WaitForAllPendingNodes(); }
129 
GetStatus()130   Status GetStatus() { return executor_.status(); }
131 
ClearAsyncError()132   void ClearAsyncError() { executor_.ClearError(); }
133 
134   bool FindFunctionByName(const string& name);
135 
136   Status FindFunctionOpData(const string& name,
137                             const tensorflow::OpRegistrationData** op_data);
138 
139   const FunctionDef* FindFunctionDef(const string& name);
140 
141   Status FindDeviceByName(const string& name, Device** result);
142 
HostCPU()143   Device* HostCPU() const { return devices_[0]; }
144 
GetGraphCollector()145   GraphCollector* GetGraphCollector() { return &graph_collector_; }
146 
NextId()147   uint64 NextId() { return executor_.NextId(); }
148 
ExecutorAdd(EagerNode * node)149   void ExecutorAdd(EagerNode* node) { executor_.Add(node); }
150 
151   Status AddFunctionDef(const FunctionDef& fdef);
152 
153   KernelAndDevice* GetCachedKernel(Fprint128 cache_key);
154 
155   void AddKernelToCache(Fprint128 cache_key, KernelAndDevice* kernel);
156 
LogDevicePlacement()157   bool LogDevicePlacement() const { return log_device_placement_; }
LogMemory()158   bool LogMemory() const { return log_memory_; }
159 
GetRendezvous()160   Rendezvous* GetRendezvous() const { return rendezvous_; }
collective_executor_mgr()161   CollectiveExecutorMgrInterface* collective_executor_mgr() {
162     return (collective_executor_mgr_ != nullptr)
163                ? collective_executor_mgr_.get()
164                : unowned_collective_executor_mgr_;
165   }
GetCollectiveExecutorHandle()166   std::unique_ptr<CollectiveExecutor::Handle> GetCollectiveExecutorHandle() {
167     return std::unique_ptr<CollectiveExecutor::Handle>(
168         new CollectiveExecutor::Handle(
169             collective_executor_mgr()->FindOrCreate(0), true /*inherit_ref*/));
170   }
171 
local_device_mgr()172   const tensorflow::DeviceMgr* local_device_mgr() const {
173     return (local_device_manager_ != nullptr) ? local_device_manager_.get()
174                                               : local_unowned_device_manager_;
175   }
remote_device_mgr()176   const tensorflow::DeviceMgr* remote_device_mgr() const {
177     return remote_device_manager_.get();
178   }
179 
180   // TODO(apassos) remove the need for this
ReleaseDeviceMgr()181   void ReleaseDeviceMgr() { local_device_manager_.release(); }
182 
183   // TODO(apassos) clean up RunMetadata storage.
MetadataMu()184   mutex* MetadataMu() LOCK_RETURNED(metadata_mu_) { return &metadata_mu_; }
185   bool ShouldStoreStepStats() LOCKS_EXCLUDED(metadata_mu_);
186   void SetShouldStoreStepStats(bool value);
187   bool ShouldStoreGraphs() LOCKS_EXCLUDED(metadata_mu_);
188   void SetShouldStoreGraphs(bool value);
RunMetadataProto()189   RunMetadata* RunMetadataProto() { return &run_metadata_; }
190   void ClearRunMetadata() EXCLUSIVE_LOCKS_REQUIRED(metadata_mu_);
191 
192   Status RegisterRunMetadataListener(RunMetadataListener* listener)
193       LOCKS_EXCLUDED(metadata_mu_);
194   void ClearRunMetadataListener() LOCKS_EXCLUDED(metadata_mu_);
195 
196   void StartStep();
197   void EndStep();
198   ScopedStepContainer* StepContainer();
199 
FuncLibDef()200   FunctionLibraryDefinition* FuncLibDef() { return &func_lib_def_; }
201 
202 #ifndef __ANDROID__
203   Status GetClientAndContextID(Device* device, eager::EagerClient** client,
204                                uint64* context_id);
205 
206   // TODO(nareshmodi): Encapsulate remote state into a separate
207   // class/struct.
208   //
209   // Enables the eager context to communicate with remote devices.
210   //
211   // - server: A ServerInterface that exports the tensorflow.WorkerService.
212   // Note that this class expects the server to already have been started.
213   // - remote_eager_workers: A cache from which we can get "EagerClient"s to
214   // communicate with remote eager services.
215   // - remote_device_mgr: A DeviceMgr* which contains all remote devices
216   // (should contain no local devices).
217   // - remote_contexts: A map containing task name to remote context ID.
218   Status InitializeRemote(
219       std::unique_ptr<ServerInterface> server,
220       std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
221       std::unique_ptr<DeviceMgr> remote_device_manager,
222       const gtl::FlatMap<string, uint64>& remote_contexts, Rendezvous* r,
223       DeviceMgr* local_device_mgr, int keep_alive_secs);
224 
HasActiveRemoteContext(uint64 context_id)225   bool HasActiveRemoteContext(uint64 context_id) {
226     return active_remote_contexts_.find(context_id) !=
227            active_remote_contexts_.end();
228   }
229 
230   Status StoreCollectiveOpsServer(
231       std::unique_ptr<ServerInterface> server, DeviceMgr* device_mgr,
232       CollectiveExecutorMgrInterface* rpc_collective_executor_mgr);
233 #endif
234 
235   // If true, then tensors should be shipped across processes via the
236   // EagerService.SendTensor RPC. If false, _Send/_Recv ops should be used
237   // instead (which in-turn use WorkerService.RecvTensor RPCs).
UseSendTensorRPC()238   bool UseSendTensorRPC() { return use_send_tensor_rpc_; }
PinSmallOpsToCPU()239   bool PinSmallOpsToCPU() { return pin_small_ops_to_cpu_; }
240 
TFEnv()241   tensorflow::Env* TFEnv() const { return env_; }
242 
243   // All child threads will be reset() when destructing EagerContext.
244   void AddChildThread(std::unique_ptr<Thread> thread);
245 
246  private:
247   void InitDeviceMapAndAsync();
248   Status MaybeRegisterFunctionRemotely(const FunctionDef& fdef);
249 
250   const ContextDevicePlacementPolicy policy_;
251 
252   // Note: we cannot use C++11 thread_local here as there is no concept of a
253   // thread-local-object-local variable in C++11.
254   mutex policy_map_mu_;
255   std::unordered_map<std::thread::id, ContextDevicePlacementPolicy>
256       thread_local_policies_ GUARDED_BY(policy_map_mu_);
257 
258   // Only one of the below is set.
259   std::unique_ptr<const DeviceMgr> local_device_manager_;
260   const DeviceMgr* local_unowned_device_manager_;
261   std::unique_ptr<DeviceMgr> remote_device_manager_;
262 
263   // Devices owned by device_manager
264   std::vector<Device*> devices_;
265   std::vector<DeviceType> prioritized_device_type_list_;
266   // All devices are not owned.
267   gtl::FlatMap<string, Device*, StringPieceHasher> devices_map_;
268   Rendezvous* rendezvous_;
269 
270   mutex functions_mu_;
GUARDED_BY(functions_mu_)271   FunctionLibraryDefinition func_lib_def_ GUARDED_BY(functions_mu_){
272       OpRegistry::Global(), {}};
273 
274   std::unique_ptr<thread::ThreadPool> thread_pool_;
275 
276   // One FunctionLibraryRuntime per device.
277   // func_libs[i] is the FunctionLibraryRuntime corresponding to
278   // session->devices[i].
279   std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_;
280 
281   std::function<void(std::function<void()>)> runner_;
282 
283   mutex cache_mu_;
284   std::unordered_map<Fprint128, KernelAndDevice*, Fprint128Hasher> kernel_cache_
285       GUARDED_BY(cache_mu_);
286 
287   // Whether we should compute RunMetadata.
288   std::atomic<bool> should_store_step_stats_{false};
289   std::atomic<bool> should_store_graphs_{false};
290   mutex metadata_mu_;
291   RunMetadata run_metadata_ GUARDED_BY(metadata_mu_);
292   RunMetadataListener* metadata_listener_ GUARDED_BY(metadata_mu_) = nullptr;
293   GraphCollector graph_collector_;
294   const bool log_device_placement_;
295   // EagerExecutor for async execution.
296   EagerExecutor executor_;
297 
298   // Information related to step containers.
299   std::atomic<int> num_active_steps_;
300   std::unique_ptr<ScopedStepContainer> step_container_ GUARDED_BY(metadata_mu_);
301 
302   // True if the default value for execution mode is async. Note that this value
303   // can be overridden per thread based on `thread_local_async` overrides.
304   const bool async_default_;
305   mutable mutex async_map_mu_;
306   std::unordered_map<std::thread::id, bool> thread_local_async_
307       GUARDED_BY(async_map_mu_);
308 
309   const bool log_memory_;
310 
311   Env* const env_;
312 
313   std::unique_ptr<CollectiveExecutorMgrInterface> collective_executor_mgr_;
314   CollectiveExecutorMgrInterface* unowned_collective_executor_mgr_ = nullptr;
315 
316 #ifndef __ANDROID__
317   void CloseRemoteContexts();
318 
319   // The server_ is not const since we release it when the context is destroyed.
320   // Therefore the server_ object is not marked as const (even though it should
321   // be).
322   std::unique_ptr<ServerInterface> server_;
323   std::unique_ptr<eager::EagerClientCache> remote_eager_workers_;
324 
325   mutex remote_state_mu_;
326 
327   gtl::FlatMap<string, uint64> remote_contexts_;
328   gtl::FlatSet<uint64> active_remote_contexts_;
329   gtl::FlatMap<Device*, std::pair<eager::EagerClient*, uint64>>
330       device_to_client_cache_;
331 
332   int keep_alive_secs_ GUARDED_BY(remote_state_mu_);
333   std::atomic<int> sleep_for_secs_;
334 
335   std::unique_ptr<Thread> keep_alive_thread_;
336   mutex keep_alive_thread_shutdown_mu_;
337   condition_variable keep_alive_thread_cv_;
338   bool shutting_down_ GUARDED_BY(keep_alive_thread_shutdown_mu_) = false;
339 #endif
340 
341   bool use_send_tensor_rpc_;
342   const bool pin_small_ops_to_cpu_;
343   std::vector<std::unique_ptr<tensorflow::Thread>> child_threads_;
344 };
345 
346 }  // namespace tensorflow
347 
348 #endif  // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_CONTEXT_H_
349