• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/distributed_runtime/worker_session.h"
16 
17 #include "tensorflow/core/lib/monitoring/gauge.h"
18 #include "tensorflow/core/platform/monitoring.h"
19 
20 namespace tensorflow {
21 
22 namespace {
23 
24 auto* worker_session_created =
25     monitoring::Gauge<bool, 0>::New("/tensorflow/core/worker_session_created",
26                                     "True if a worker session was created.");
27 
28 // A private cache that wraps worker_cache and allows reuse of
29 // WorkerInterface objects.
30 class WorkerFreeListCache : public WorkerCacheInterface {
31  public:
WorkerFreeListCache(std::unique_ptr<WorkerCacheInterface> w)32   explicit WorkerFreeListCache(std::unique_ptr<WorkerCacheInterface> w)
33       : wrapped_(std::move(w)) {}
34 
~WorkerFreeListCache()35   ~WorkerFreeListCache() final {
36     for (auto& p : workers_) {
37       wrapped_->ReleaseWorker(p.first, p.second.worker);
38     }
39   }
40 
ListWorkers(std::vector<string> * workers) const41   void ListWorkers(std::vector<string>* workers) const override {
42     wrapped_->ListWorkers(workers);
43   }
44 
ListWorkersInJob(const string & job_name,std::vector<string> * workers) const45   void ListWorkersInJob(const string& job_name,
46                         std::vector<string>* workers) const override {
47     wrapped_->ListWorkersInJob(job_name, workers);
48   }
49 
GetOrCreateWorker(const string & target)50   WorkerInterface* GetOrCreateWorker(const string& target) override {
51     mutex_lock l(mu_);
52     auto p = workers_.find(target);
53     if (p != workers_.end()) {
54       return p->second.worker;
55     }
56     WorkerState state;
57     state.worker = wrapped_->GetOrCreateWorker(target);
58     if (state.worker != nullptr) {
59       workers_.insert(std::make_pair(target, state));
60     }
61     return state.worker;
62   }
63 
GetEagerClientCache(std::unique_ptr<eager::EagerClientCache> * eager_client_cache)64   Status GetEagerClientCache(
65       std::unique_ptr<eager::EagerClientCache>* eager_client_cache) override {
66     return wrapped_->GetEagerClientCache(eager_client_cache);
67   }
68 
ReleaseWorker(const string & target,WorkerInterface * worker)69   void ReleaseWorker(const string& target, WorkerInterface* worker) override {
70     // TODO(jeff,sanjay): Should decrement ref-count when we implement eviction.
71   }
72 
GetDeviceLocalityNonBlocking(const string & device,DeviceLocality * locality)73   bool GetDeviceLocalityNonBlocking(const string& device,
74                                     DeviceLocality* locality) override {
75     return wrapped_->GetDeviceLocalityNonBlocking(device, locality);
76   }
77 
GetDeviceLocalityAsync(const string & device,DeviceLocality * locality,StatusCallback done)78   void GetDeviceLocalityAsync(const string& device, DeviceLocality* locality,
79                               StatusCallback done) override {
80     wrapped_->GetDeviceLocalityAsync(device, locality, done);
81   }
82 
SetLogging(bool active)83   void SetLogging(bool active) override { wrapped_->SetLogging(active); }
84 
ClearLogs()85   void ClearLogs() override { wrapped_->ClearLogs(); }
86 
RetrieveLogs(int64 step_id,StepStats * ss)87   bool RetrieveLogs(int64 step_id, StepStats* ss) override {
88     return wrapped_->RetrieveLogs(step_id, ss);
89   }
90 
91  private:
92   std::unique_ptr<WorkerCacheInterface> wrapped_;
93 
94   // Information kept per created WorkerInterface.
95   struct WorkerState {
96     WorkerInterface* worker;
97     // TODO(jeff,sanjay): Add reference count if we support eviction.
98   };
99 
100   // TODO(jeff,sanjay): Eviction when the map becomes too big.
101   mutex mu_;
102   std::unordered_map<string, WorkerState> workers_ GUARDED_BY(mu_);
103 };
104 
105 }  // namespace
106 
WorkerSession(const string & session_name,const string & worker_name,std::unique_ptr<WorkerCacheInterface> worker_cache,std::unique_ptr<DeviceMgr> device_mgr,std::unique_ptr<GraphMgr> graph_mgr,std::unique_ptr<DynamicDeviceMgr> remote_device_mgr)107 WorkerSession::WorkerSession(
108     const string& session_name, const string& worker_name,
109     std::unique_ptr<WorkerCacheInterface> worker_cache,
110     std::unique_ptr<DeviceMgr> device_mgr, std::unique_ptr<GraphMgr> graph_mgr,
111     std::unique_ptr<DynamicDeviceMgr> remote_device_mgr)
112     : session_name_(session_name),
113       worker_name_(worker_name),
114       worker_cache_(new WorkerFreeListCache(std::move(worker_cache))),
115       graph_mgr_(std::move(graph_mgr)),
116       cluster_flr_(new ClusterFunctionLibraryRuntime(
117           this, !session_name.empty(),
118           remote_device_mgr ? remote_device_mgr.get() : nullptr)),
119       device_mgr_(std::move(device_mgr)),
120       borrowed_device_mgr_(nullptr),
121       remote_device_mgr_(std::move(remote_device_mgr)) {
122   // Starts exporting metrics through a platform-specific monitoring API (if
123   // provided). For builds using "tensorflow/core/platform/default", this is
124   // currently a no-op.
125   worker_session_created->GetCell()->Set(true);
126   monitoring::StartExporter();
127 }
128 
UpdateWorkerCacheAndDevices(std::unique_ptr<WorkerCacheInterface> new_worker_cache,std::vector<std::unique_ptr<Device>> added_remote_devices,const std::vector<Device * > & removed_remote_devices)129 Status WorkerSession::UpdateWorkerCacheAndDevices(
130     std::unique_ptr<WorkerCacheInterface> new_worker_cache,
131     std::vector<std::unique_ptr<Device>> added_remote_devices,
132     const std::vector<Device*>& removed_remote_devices) {
133   worker_cache_ = std::unique_ptr<WorkerCacheInterface>(
134       new WorkerFreeListCache(std::move(new_worker_cache)));
135   TF_RETURN_IF_ERROR(remote_device_mgr_->RemoveDevices(removed_remote_devices));
136   TF_RETURN_IF_ERROR(
137       remote_device_mgr_->AddDevices(std::move(added_remote_devices)));
138   cluster_flr_ = std::unique_ptr<ClusterFunctionLibraryRuntime>(
139       new ClusterFunctionLibraryRuntime(this, !session_name_.empty(),
140                                         remote_device_mgr()));
141   return Status::OK();
142 }
143 
144 /* static */
CreateWithBorrowedDeviceMgr(const string & session_name,const string & worker_name,std::unique_ptr<WorkerCacheInterface> worker_cache,DeviceMgr * borrowed_device_mgr,std::unique_ptr<GraphMgr> graph_mgr,std::unique_ptr<DynamicDeviceMgr> remote_device_mgr)145 std::shared_ptr<WorkerSession> WorkerSession::CreateWithBorrowedDeviceMgr(
146     const string& session_name, const string& worker_name,
147     std::unique_ptr<WorkerCacheInterface> worker_cache,
148     DeviceMgr* borrowed_device_mgr, std::unique_ptr<GraphMgr> graph_mgr,
149     std::unique_ptr<DynamicDeviceMgr> remote_device_mgr) {
150   return std::shared_ptr<WorkerSession>(new WorkerSession(
151       session_name, worker_name, std::move(worker_cache), borrowed_device_mgr,
152       std::move(graph_mgr), std::move(remote_device_mgr)));
153 }
154 
WorkerSession(const string & session_name,const string & worker_name,std::unique_ptr<WorkerCacheInterface> worker_cache,DeviceMgr * borrowed_device_mgr,std::unique_ptr<GraphMgr> graph_mgr,std::unique_ptr<DynamicDeviceMgr> remote_device_mgr)155 WorkerSession::WorkerSession(
156     const string& session_name, const string& worker_name,
157     std::unique_ptr<WorkerCacheInterface> worker_cache,
158     DeviceMgr* borrowed_device_mgr, std::unique_ptr<GraphMgr> graph_mgr,
159     std::unique_ptr<DynamicDeviceMgr> remote_device_mgr)
160     : session_name_(session_name),
161       worker_name_(worker_name),
162       worker_cache_(new WorkerFreeListCache(std::move(worker_cache))),
163       graph_mgr_(std::move(graph_mgr)),
164       cluster_flr_(new ClusterFunctionLibraryRuntime(
165           this, !session_name.empty(), remote_device_mgr.get())),
166       device_mgr_(nullptr),
167       borrowed_device_mgr_(borrowed_device_mgr),
168       remote_device_mgr_(std::move(remote_device_mgr)) {
169   // Starts exporting metrics through a platform-specific monitoring API (if
170   // provided). For builds using "tensorflow/core/platform/default", this is
171   // currently a no-op.
172   worker_session_created->GetCell()->Set(true);
173   monitoring::StartExporter();
174 }
175 
~WorkerSession()176 WorkerSession::~WorkerSession() {
177   if (graph_mgr_) {
178     Status s = graph_mgr_->DeregisterAll();
179     if (!s.ok()) {
180       LOG(WARNING) << "Error during worker session deletion: " << s;
181     }
182   }
183 }
184 
185 }  // namespace tensorflow
186