• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/distributed_runtime/worker_session.h"
16 
17 #include "tensorflow/core/lib/monitoring/collection_registry.h"
18 #include "tensorflow/core/lib/monitoring/gauge.h"
19 
20 namespace tensorflow {
21 
22 namespace {
23 
24 auto* worker_session_created =
25     monitoring::Gauge<bool, 0>::New("/tensorflow/core/worker_session_created",
26                                     "True if a worker session was created.");
27 
28 // A private cache that wraps worker_cache and allows reuse of
29 // WorkerInterface objects.
30 class WorkerFreeListCache : public WorkerCacheInterface {
31  public:
WorkerFreeListCache(std::unique_ptr<WorkerCacheInterface> w)32   explicit WorkerFreeListCache(std::unique_ptr<WorkerCacheInterface> w)
33       : wrapped_(std::move(w)) {}
34 
~WorkerFreeListCache()35   ~WorkerFreeListCache() final {
36     for (auto& p : workers_) {
37       wrapped_->ReleaseWorker(p.first, p.second.worker);
38     }
39   }
40 
ListWorkers(std::vector<string> * workers) const41   void ListWorkers(std::vector<string>* workers) const override {
42     wrapped_->ListWorkers(workers);
43   }
44 
ListWorkersInJob(const string & job_name,std::vector<string> * workers) const45   void ListWorkersInJob(const string& job_name,
46                         std::vector<string>* workers) const override {
47     wrapped_->ListWorkersInJob(job_name, workers);
48   }
49 
GetOrCreateWorker(const string & target)50   WorkerInterface* GetOrCreateWorker(const string& target) override {
51     mutex_lock l(mu_);
52     auto p = workers_.find(target);
53     if (p != workers_.end()) {
54       return p->second.worker;
55     }
56     WorkerState state;
57     state.worker = wrapped_->GetOrCreateWorker(target);
58     if (state.worker != nullptr) {
59       workers_.insert(std::make_pair(target, state));
60     }
61     return state.worker;
62   }
63 
GetEagerClientCache(std::unique_ptr<eager::EagerClientCache> * eager_client_cache)64   Status GetEagerClientCache(
65       std::unique_ptr<eager::EagerClientCache>* eager_client_cache) override {
66     return wrapped_->GetEagerClientCache(eager_client_cache);
67   }
68 
GetCoordinationClientCache(std::unique_ptr<CoordinationClientCache> * coordination_client_cache)69   Status GetCoordinationClientCache(std::unique_ptr<CoordinationClientCache>*
70                                         coordination_client_cache) override {
71     return wrapped_->GetCoordinationClientCache(coordination_client_cache);
72   }
73 
ReleaseWorker(const string & target,WorkerInterface * worker)74   void ReleaseWorker(const string& target, WorkerInterface* worker) override {
75     // TODO(jeff,sanjay): Should decrement ref-count when we implement eviction.
76   }
77 
GetDeviceLocalityNonBlocking(const string & device,DeviceLocality * locality)78   bool GetDeviceLocalityNonBlocking(const string& device,
79                                     DeviceLocality* locality) override {
80     return wrapped_->GetDeviceLocalityNonBlocking(device, locality);
81   }
82 
GetDeviceLocalityAsync(const string & device,DeviceLocality * locality,StatusCallback done)83   void GetDeviceLocalityAsync(const string& device, DeviceLocality* locality,
84                               StatusCallback done) override {
85     wrapped_->GetDeviceLocalityAsync(device, locality, done);
86   }
87 
SetLogging(bool active)88   void SetLogging(bool active) override { wrapped_->SetLogging(active); }
89 
ClearLogs()90   void ClearLogs() override { wrapped_->ClearLogs(); }
91 
RetrieveLogs(int64_t step_id,StepStats * ss)92   bool RetrieveLogs(int64_t step_id, StepStats* ss) override {
93     return wrapped_->RetrieveLogs(step_id, ss);
94   }
95 
96  private:
97   std::unique_ptr<WorkerCacheInterface> wrapped_;
98 
99   // Information kept per created WorkerInterface.
100   struct WorkerState {
101     WorkerInterface* worker;
102     // TODO(jeff,sanjay): Add reference count if we support eviction.
103   };
104 
105   // TODO(jeff,sanjay): Eviction when the map becomes too big.
106   mutex mu_;
107   std::unordered_map<string, WorkerState> workers_ TF_GUARDED_BY(mu_);
108 };
109 
110 }  // namespace
111 
WorkerSession(const string & session_name,const string & worker_name,std::unique_ptr<WorkerCacheInterface> worker_cache,std::unique_ptr<DeviceMgr> device_mgr,std::unique_ptr<GraphMgr> graph_mgr,std::unique_ptr<DynamicDeviceMgr> remote_device_mgr)112 WorkerSession::WorkerSession(
113     const string& session_name, const string& worker_name,
114     std::unique_ptr<WorkerCacheInterface> worker_cache,
115     std::unique_ptr<DeviceMgr> device_mgr, std::unique_ptr<GraphMgr> graph_mgr,
116     std::unique_ptr<DynamicDeviceMgr> remote_device_mgr)
117     : session_name_(session_name),
118       worker_name_(worker_name),
119       worker_cache_(new WorkerFreeListCache(std::move(worker_cache))),
120       graph_mgr_(std::move(graph_mgr)),
121       cluster_flr_(new ClusterFunctionLibraryRuntime(
122           this, !session_name.empty(),
123           remote_device_mgr ? remote_device_mgr.get() : nullptr)),
124       device_mgr_(std::move(device_mgr)),
125       borrowed_device_mgr_(nullptr),
126       remote_device_mgr_(std::move(remote_device_mgr)) {
127   // Starts exporting metrics through a platform-specific monitoring API (if
128   // provided). For builds using "tensorflow/core/platform/default", this is
129   // currently a no-op.
130   worker_session_created->GetCell()->Set(true);
131 }
132 
UpdateWorkerCacheAndDevices(std::unique_ptr<WorkerCacheInterface> new_worker_cache,std::vector<std::unique_ptr<Device>> added_remote_devices,const std::vector<Device * > & removed_remote_devices)133 Status WorkerSession::UpdateWorkerCacheAndDevices(
134     std::unique_ptr<WorkerCacheInterface> new_worker_cache,
135     std::vector<std::unique_ptr<Device>> added_remote_devices,
136     const std::vector<Device*>& removed_remote_devices) {
137   {
138     mutex_lock l(worker_session_state_mu_);
139     worker_cache_ = std::shared_ptr<WorkerCacheInterface>(
140         new WorkerFreeListCache(std::move(new_worker_cache)));
141   }
142   TF_RETURN_IF_ERROR(remote_device_mgr_->RemoveDevices(removed_remote_devices));
143   TF_RETURN_IF_ERROR(
144       remote_device_mgr_->AddDevices(std::move(added_remote_devices)));
145   return Status::OK();
146 }
147 
148 /* static */
CreateWithBorrowedDeviceMgr(const string & session_name,const string & worker_name,std::unique_ptr<WorkerCacheInterface> worker_cache,DeviceMgr * borrowed_device_mgr,std::unique_ptr<GraphMgr> graph_mgr,std::unique_ptr<DynamicDeviceMgr> remote_device_mgr)149 std::shared_ptr<WorkerSession> WorkerSession::CreateWithBorrowedDeviceMgr(
150     const string& session_name, const string& worker_name,
151     std::unique_ptr<WorkerCacheInterface> worker_cache,
152     DeviceMgr* borrowed_device_mgr, std::unique_ptr<GraphMgr> graph_mgr,
153     std::unique_ptr<DynamicDeviceMgr> remote_device_mgr) {
154   return std::shared_ptr<WorkerSession>(new WorkerSession(
155       session_name, worker_name, std::move(worker_cache), borrowed_device_mgr,
156       std::move(graph_mgr), std::move(remote_device_mgr)));
157 }
158 
WorkerSession(const string & session_name,const string & worker_name,std::unique_ptr<WorkerCacheInterface> worker_cache,DeviceMgr * borrowed_device_mgr,std::unique_ptr<GraphMgr> graph_mgr,std::unique_ptr<DynamicDeviceMgr> remote_device_mgr)159 WorkerSession::WorkerSession(
160     const string& session_name, const string& worker_name,
161     std::unique_ptr<WorkerCacheInterface> worker_cache,
162     DeviceMgr* borrowed_device_mgr, std::unique_ptr<GraphMgr> graph_mgr,
163     std::unique_ptr<DynamicDeviceMgr> remote_device_mgr)
164     : session_name_(session_name),
165       worker_name_(worker_name),
166       worker_cache_(new WorkerFreeListCache(std::move(worker_cache))),
167       graph_mgr_(std::move(graph_mgr)),
168       cluster_flr_(new ClusterFunctionLibraryRuntime(
169           this, !session_name.empty(), remote_device_mgr.get())),
170       device_mgr_(nullptr),
171       borrowed_device_mgr_(borrowed_device_mgr),
172       remote_device_mgr_(std::move(remote_device_mgr)) {
173   // Starts exporting metrics through a platform-specific monitoring API (if
174   // provided). For builds using "tensorflow/core/platform/default", this is
175   // currently a no-op.
176   worker_session_created->GetCell()->Set(true);
177 }
178 
~WorkerSession()179 WorkerSession::~WorkerSession() {
180   if (graph_mgr_) {
181     Status s = graph_mgr_->DeregisterAll();
182     if (!s.ok()) {
183       LOG(WARNING) << "Error during worker session deletion: " << s;
184     }
185   }
186 }
187 
188 }  // namespace tensorflow
189