1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/distributed_runtime/worker_session.h"
16
17 #include "tensorflow/core/lib/monitoring/collection_registry.h"
18 #include "tensorflow/core/lib/monitoring/gauge.h"
19
20 namespace tensorflow {
21
22 namespace {
23
24 auto* worker_session_created =
25 monitoring::Gauge<bool, 0>::New("/tensorflow/core/worker_session_created",
26 "True if a worker session was created.");
27
28 // A private cache that wraps worker_cache and allows reuse of
29 // WorkerInterface objects.
30 class WorkerFreeListCache : public WorkerCacheInterface {
31 public:
WorkerFreeListCache(std::unique_ptr<WorkerCacheInterface> w)32 explicit WorkerFreeListCache(std::unique_ptr<WorkerCacheInterface> w)
33 : wrapped_(std::move(w)) {}
34
~WorkerFreeListCache()35 ~WorkerFreeListCache() final {
36 for (auto& p : workers_) {
37 wrapped_->ReleaseWorker(p.first, p.second.worker);
38 }
39 }
40
ListWorkers(std::vector<string> * workers) const41 void ListWorkers(std::vector<string>* workers) const override {
42 wrapped_->ListWorkers(workers);
43 }
44
ListWorkersInJob(const string & job_name,std::vector<string> * workers) const45 void ListWorkersInJob(const string& job_name,
46 std::vector<string>* workers) const override {
47 wrapped_->ListWorkersInJob(job_name, workers);
48 }
49
GetOrCreateWorker(const string & target)50 WorkerInterface* GetOrCreateWorker(const string& target) override {
51 mutex_lock l(mu_);
52 auto p = workers_.find(target);
53 if (p != workers_.end()) {
54 return p->second.worker;
55 }
56 WorkerState state;
57 state.worker = wrapped_->GetOrCreateWorker(target);
58 if (state.worker != nullptr) {
59 workers_.insert(std::make_pair(target, state));
60 }
61 return state.worker;
62 }
63
GetEagerClientCache(std::unique_ptr<eager::EagerClientCache> * eager_client_cache)64 Status GetEagerClientCache(
65 std::unique_ptr<eager::EagerClientCache>* eager_client_cache) override {
66 return wrapped_->GetEagerClientCache(eager_client_cache);
67 }
68
ReleaseWorker(const string & target,WorkerInterface * worker)69 void ReleaseWorker(const string& target, WorkerInterface* worker) override {
70 // TODO(jeff,sanjay): Should decrement ref-count when we implement eviction.
71 }
72
GetDeviceLocalityNonBlocking(const string & device,DeviceLocality * locality)73 bool GetDeviceLocalityNonBlocking(const string& device,
74 DeviceLocality* locality) override {
75 return wrapped_->GetDeviceLocalityNonBlocking(device, locality);
76 }
77
GetDeviceLocalityAsync(const string & device,DeviceLocality * locality,StatusCallback done)78 void GetDeviceLocalityAsync(const string& device, DeviceLocality* locality,
79 StatusCallback done) override {
80 wrapped_->GetDeviceLocalityAsync(device, locality, done);
81 }
82
SetLogging(bool active)83 void SetLogging(bool active) override { wrapped_->SetLogging(active); }
84
ClearLogs()85 void ClearLogs() override { wrapped_->ClearLogs(); }
86
RetrieveLogs(int64 step_id,StepStats * ss)87 bool RetrieveLogs(int64 step_id, StepStats* ss) override {
88 return wrapped_->RetrieveLogs(step_id, ss);
89 }
90
91 private:
92 std::unique_ptr<WorkerCacheInterface> wrapped_;
93
94 // Information kept per created WorkerInterface.
95 struct WorkerState {
96 WorkerInterface* worker;
97 // TODO(jeff,sanjay): Add reference count if we support eviction.
98 };
99
100 // TODO(jeff,sanjay): Eviction when the map becomes too big.
101 mutex mu_;
102 std::unordered_map<string, WorkerState> workers_ TF_GUARDED_BY(mu_);
103 };
104
105 } // namespace
106
WorkerSession(const string & session_name,const string & worker_name,std::unique_ptr<WorkerCacheInterface> worker_cache,std::unique_ptr<DeviceMgr> device_mgr,std::unique_ptr<GraphMgr> graph_mgr,std::unique_ptr<DynamicDeviceMgr> remote_device_mgr)107 WorkerSession::WorkerSession(
108 const string& session_name, const string& worker_name,
109 std::unique_ptr<WorkerCacheInterface> worker_cache,
110 std::unique_ptr<DeviceMgr> device_mgr, std::unique_ptr<GraphMgr> graph_mgr,
111 std::unique_ptr<DynamicDeviceMgr> remote_device_mgr)
112 : session_name_(session_name),
113 worker_name_(worker_name),
114 worker_cache_(new WorkerFreeListCache(std::move(worker_cache))),
115 graph_mgr_(std::move(graph_mgr)),
116 cluster_flr_(new ClusterFunctionLibraryRuntime(
117 this, !session_name.empty(),
118 remote_device_mgr ? remote_device_mgr.get() : nullptr)),
119 device_mgr_(std::move(device_mgr)),
120 borrowed_device_mgr_(nullptr),
121 remote_device_mgr_(std::move(remote_device_mgr)) {
122 // Starts exporting metrics through a platform-specific monitoring API (if
123 // provided). For builds using "tensorflow/core/platform/default", this is
124 // currently a no-op.
125 worker_session_created->GetCell()->Set(true);
126 }
127
UpdateWorkerCacheAndDevices(std::unique_ptr<WorkerCacheInterface> new_worker_cache,std::vector<std::unique_ptr<Device>> added_remote_devices,const std::vector<Device * > & removed_remote_devices)128 Status WorkerSession::UpdateWorkerCacheAndDevices(
129 std::unique_ptr<WorkerCacheInterface> new_worker_cache,
130 std::vector<std::unique_ptr<Device>> added_remote_devices,
131 const std::vector<Device*>& removed_remote_devices) {
132 {
133 mutex_lock l(worker_session_state_mu_);
134 worker_cache_ = std::shared_ptr<WorkerCacheInterface>(
135 new WorkerFreeListCache(std::move(new_worker_cache)));
136 }
137 TF_RETURN_IF_ERROR(remote_device_mgr_->RemoveDevices(removed_remote_devices));
138 TF_RETURN_IF_ERROR(
139 remote_device_mgr_->AddDevices(std::move(added_remote_devices)));
140 return Status::OK();
141 }
142
143 /* static */
CreateWithBorrowedDeviceMgr(const string & session_name,const string & worker_name,std::unique_ptr<WorkerCacheInterface> worker_cache,const DeviceMgr * borrowed_device_mgr,std::unique_ptr<GraphMgr> graph_mgr,std::unique_ptr<DynamicDeviceMgr> remote_device_mgr)144 std::shared_ptr<WorkerSession> WorkerSession::CreateWithBorrowedDeviceMgr(
145 const string& session_name, const string& worker_name,
146 std::unique_ptr<WorkerCacheInterface> worker_cache,
147 const DeviceMgr* borrowed_device_mgr, std::unique_ptr<GraphMgr> graph_mgr,
148 std::unique_ptr<DynamicDeviceMgr> remote_device_mgr) {
149 return std::shared_ptr<WorkerSession>(new WorkerSession(
150 session_name, worker_name, std::move(worker_cache), borrowed_device_mgr,
151 std::move(graph_mgr), std::move(remote_device_mgr)));
152 }
153
WorkerSession(const string & session_name,const string & worker_name,std::unique_ptr<WorkerCacheInterface> worker_cache,const DeviceMgr * borrowed_device_mgr,std::unique_ptr<GraphMgr> graph_mgr,std::unique_ptr<DynamicDeviceMgr> remote_device_mgr)154 WorkerSession::WorkerSession(
155 const string& session_name, const string& worker_name,
156 std::unique_ptr<WorkerCacheInterface> worker_cache,
157 const DeviceMgr* borrowed_device_mgr, std::unique_ptr<GraphMgr> graph_mgr,
158 std::unique_ptr<DynamicDeviceMgr> remote_device_mgr)
159 : session_name_(session_name),
160 worker_name_(worker_name),
161 worker_cache_(new WorkerFreeListCache(std::move(worker_cache))),
162 graph_mgr_(std::move(graph_mgr)),
163 cluster_flr_(new ClusterFunctionLibraryRuntime(
164 this, !session_name.empty(), remote_device_mgr.get())),
165 device_mgr_(nullptr),
166 borrowed_device_mgr_(borrowed_device_mgr),
167 remote_device_mgr_(std::move(remote_device_mgr)) {
168 // Starts exporting metrics through a platform-specific monitoring API (if
169 // provided). For builds using "tensorflow/core/platform/default", this is
170 // currently a no-op.
171 worker_session_created->GetCell()->Set(true);
172 }
173
~WorkerSession()174 WorkerSession::~WorkerSession() {
175 if (graph_mgr_) {
176 Status s = graph_mgr_->DeregisterAll();
177 if (!s.ok()) {
178 LOG(WARNING) << "Error during worker session deletion: " << s;
179 }
180 }
181 }
182
183 } // namespace tensorflow
184