• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/distributed_runtime/worker_session.h"
16 
17 #include "tensorflow/core/lib/monitoring/collection_registry.h"
18 #include "tensorflow/core/lib/monitoring/gauge.h"
19 
20 namespace tensorflow {
21 
22 namespace {
23 
24 auto* worker_session_created =
25     monitoring::Gauge<bool, 0>::New("/tensorflow/core/worker_session_created",
26                                     "True if a worker session was created.");
27 
28 // A private cache that wraps worker_cache and allows reuse of
29 // WorkerInterface objects.
30 class WorkerFreeListCache : public WorkerCacheInterface {
31  public:
WorkerFreeListCache(std::unique_ptr<WorkerCacheInterface> w)32   explicit WorkerFreeListCache(std::unique_ptr<WorkerCacheInterface> w)
33       : wrapped_(std::move(w)) {}
34 
~WorkerFreeListCache()35   ~WorkerFreeListCache() final {
36     for (auto& p : workers_) {
37       wrapped_->ReleaseWorker(p.first, p.second.worker);
38     }
39   }
40 
ListWorkers(std::vector<string> * workers) const41   void ListWorkers(std::vector<string>* workers) const override {
42     wrapped_->ListWorkers(workers);
43   }
44 
ListWorkersInJob(const string & job_name,std::vector<string> * workers) const45   void ListWorkersInJob(const string& job_name,
46                         std::vector<string>* workers) const override {
47     wrapped_->ListWorkersInJob(job_name, workers);
48   }
49 
GetOrCreateWorker(const string & target)50   WorkerInterface* GetOrCreateWorker(const string& target) override {
51     mutex_lock l(mu_);
52     auto p = workers_.find(target);
53     if (p != workers_.end()) {
54       return p->second.worker;
55     }
56     WorkerState state;
57     state.worker = wrapped_->GetOrCreateWorker(target);
58     if (state.worker != nullptr) {
59       workers_.insert(std::make_pair(target, state));
60     }
61     return state.worker;
62   }
63 
GetEagerClientCache(std::unique_ptr<eager::EagerClientCache> * eager_client_cache)64   Status GetEagerClientCache(
65       std::unique_ptr<eager::EagerClientCache>* eager_client_cache) override {
66     return wrapped_->GetEagerClientCache(eager_client_cache);
67   }
68 
ReleaseWorker(const string & target,WorkerInterface * worker)69   void ReleaseWorker(const string& target, WorkerInterface* worker) override {
70     // TODO(jeff,sanjay): Should decrement ref-count when we implement eviction.
71   }
72 
GetDeviceLocalityNonBlocking(const string & device,DeviceLocality * locality)73   bool GetDeviceLocalityNonBlocking(const string& device,
74                                     DeviceLocality* locality) override {
75     return wrapped_->GetDeviceLocalityNonBlocking(device, locality);
76   }
77 
GetDeviceLocalityAsync(const string & device,DeviceLocality * locality,StatusCallback done)78   void GetDeviceLocalityAsync(const string& device, DeviceLocality* locality,
79                               StatusCallback done) override {
80     wrapped_->GetDeviceLocalityAsync(device, locality, done);
81   }
82 
SetLogging(bool active)83   void SetLogging(bool active) override { wrapped_->SetLogging(active); }
84 
ClearLogs()85   void ClearLogs() override { wrapped_->ClearLogs(); }
86 
RetrieveLogs(int64 step_id,StepStats * ss)87   bool RetrieveLogs(int64 step_id, StepStats* ss) override {
88     return wrapped_->RetrieveLogs(step_id, ss);
89   }
90 
91  private:
92   std::unique_ptr<WorkerCacheInterface> wrapped_;
93 
94   // Information kept per created WorkerInterface.
95   struct WorkerState {
96     WorkerInterface* worker;
97     // TODO(jeff,sanjay): Add reference count if we support eviction.
98   };
99 
100   // TODO(jeff,sanjay): Eviction when the map becomes too big.
101   mutex mu_;
102   std::unordered_map<string, WorkerState> workers_ TF_GUARDED_BY(mu_);
103 };
104 
105 }  // namespace
106 
WorkerSession(const string & session_name,const string & worker_name,std::unique_ptr<WorkerCacheInterface> worker_cache,std::unique_ptr<DeviceMgr> device_mgr,std::unique_ptr<GraphMgr> graph_mgr,std::unique_ptr<DynamicDeviceMgr> remote_device_mgr)107 WorkerSession::WorkerSession(
108     const string& session_name, const string& worker_name,
109     std::unique_ptr<WorkerCacheInterface> worker_cache,
110     std::unique_ptr<DeviceMgr> device_mgr, std::unique_ptr<GraphMgr> graph_mgr,
111     std::unique_ptr<DynamicDeviceMgr> remote_device_mgr)
112     : session_name_(session_name),
113       worker_name_(worker_name),
114       worker_cache_(new WorkerFreeListCache(std::move(worker_cache))),
115       graph_mgr_(std::move(graph_mgr)),
116       cluster_flr_(new ClusterFunctionLibraryRuntime(
117           this, !session_name.empty(),
118           remote_device_mgr ? remote_device_mgr.get() : nullptr)),
119       device_mgr_(std::move(device_mgr)),
120       borrowed_device_mgr_(nullptr),
121       remote_device_mgr_(std::move(remote_device_mgr)) {
122   // Starts exporting metrics through a platform-specific monitoring API (if
123   // provided). For builds using "tensorflow/core/platform/default", this is
124   // currently a no-op.
125   worker_session_created->GetCell()->Set(true);
126 }
127 
UpdateWorkerCacheAndDevices(std::unique_ptr<WorkerCacheInterface> new_worker_cache,std::vector<std::unique_ptr<Device>> added_remote_devices,const std::vector<Device * > & removed_remote_devices)128 Status WorkerSession::UpdateWorkerCacheAndDevices(
129     std::unique_ptr<WorkerCacheInterface> new_worker_cache,
130     std::vector<std::unique_ptr<Device>> added_remote_devices,
131     const std::vector<Device*>& removed_remote_devices) {
132   {
133     mutex_lock l(worker_session_state_mu_);
134     worker_cache_ = std::shared_ptr<WorkerCacheInterface>(
135         new WorkerFreeListCache(std::move(new_worker_cache)));
136   }
137   TF_RETURN_IF_ERROR(remote_device_mgr_->RemoveDevices(removed_remote_devices));
138   TF_RETURN_IF_ERROR(
139       remote_device_mgr_->AddDevices(std::move(added_remote_devices)));
140   return Status::OK();
141 }
142 
143 /* static */
CreateWithBorrowedDeviceMgr(const string & session_name,const string & worker_name,std::unique_ptr<WorkerCacheInterface> worker_cache,const DeviceMgr * borrowed_device_mgr,std::unique_ptr<GraphMgr> graph_mgr,std::unique_ptr<DynamicDeviceMgr> remote_device_mgr)144 std::shared_ptr<WorkerSession> WorkerSession::CreateWithBorrowedDeviceMgr(
145     const string& session_name, const string& worker_name,
146     std::unique_ptr<WorkerCacheInterface> worker_cache,
147     const DeviceMgr* borrowed_device_mgr, std::unique_ptr<GraphMgr> graph_mgr,
148     std::unique_ptr<DynamicDeviceMgr> remote_device_mgr) {
149   return std::shared_ptr<WorkerSession>(new WorkerSession(
150       session_name, worker_name, std::move(worker_cache), borrowed_device_mgr,
151       std::move(graph_mgr), std::move(remote_device_mgr)));
152 }
153 
WorkerSession(const string & session_name,const string & worker_name,std::unique_ptr<WorkerCacheInterface> worker_cache,const DeviceMgr * borrowed_device_mgr,std::unique_ptr<GraphMgr> graph_mgr,std::unique_ptr<DynamicDeviceMgr> remote_device_mgr)154 WorkerSession::WorkerSession(
155     const string& session_name, const string& worker_name,
156     std::unique_ptr<WorkerCacheInterface> worker_cache,
157     const DeviceMgr* borrowed_device_mgr, std::unique_ptr<GraphMgr> graph_mgr,
158     std::unique_ptr<DynamicDeviceMgr> remote_device_mgr)
159     : session_name_(session_name),
160       worker_name_(worker_name),
161       worker_cache_(new WorkerFreeListCache(std::move(worker_cache))),
162       graph_mgr_(std::move(graph_mgr)),
163       cluster_flr_(new ClusterFunctionLibraryRuntime(
164           this, !session_name.empty(), remote_device_mgr.get())),
165       device_mgr_(nullptr),
166       borrowed_device_mgr_(borrowed_device_mgr),
167       remote_device_mgr_(std::move(remote_device_mgr)) {
168   // Starts exporting metrics through a platform-specific monitoring API (if
169   // provided). For builds using "tensorflow/core/platform/default", this is
170   // currently a no-op.
171   worker_session_created->GetCell()->Set(true);
172 }
173 
~WorkerSession()174 WorkerSession::~WorkerSession() {
175   if (graph_mgr_) {
176     Status s = graph_mgr_->DeregisterAll();
177     if (!s.ok()) {
178       LOG(WARNING) << "Error during worker session deletion: " << s;
179     }
180   }
181 }
182 
183 }  // namespace tensorflow
184