1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/distributed_runtime/worker_session.h"
16
17 #include "tensorflow/core/lib/monitoring/collection_registry.h"
18 #include "tensorflow/core/lib/monitoring/gauge.h"
19
20 namespace tensorflow {
21
22 namespace {
23
24 auto* worker_session_created =
25 monitoring::Gauge<bool, 0>::New("/tensorflow/core/worker_session_created",
26 "True if a worker session was created.");
27
28 // A private cache that wraps worker_cache and allows reuse of
29 // WorkerInterface objects.
30 class WorkerFreeListCache : public WorkerCacheInterface {
31 public:
WorkerFreeListCache(std::unique_ptr<WorkerCacheInterface> w)32 explicit WorkerFreeListCache(std::unique_ptr<WorkerCacheInterface> w)
33 : wrapped_(std::move(w)) {}
34
~WorkerFreeListCache()35 ~WorkerFreeListCache() final {
36 for (auto& p : workers_) {
37 wrapped_->ReleaseWorker(p.first, p.second.worker);
38 }
39 }
40
ListWorkers(std::vector<string> * workers) const41 void ListWorkers(std::vector<string>* workers) const override {
42 wrapped_->ListWorkers(workers);
43 }
44
ListWorkersInJob(const string & job_name,std::vector<string> * workers) const45 void ListWorkersInJob(const string& job_name,
46 std::vector<string>* workers) const override {
47 wrapped_->ListWorkersInJob(job_name, workers);
48 }
49
GetOrCreateWorker(const string & target)50 WorkerInterface* GetOrCreateWorker(const string& target) override {
51 mutex_lock l(mu_);
52 auto p = workers_.find(target);
53 if (p != workers_.end()) {
54 return p->second.worker;
55 }
56 WorkerState state;
57 state.worker = wrapped_->GetOrCreateWorker(target);
58 if (state.worker != nullptr) {
59 workers_.insert(std::make_pair(target, state));
60 }
61 return state.worker;
62 }
63
GetEagerClientCache(std::unique_ptr<eager::EagerClientCache> * eager_client_cache)64 Status GetEagerClientCache(
65 std::unique_ptr<eager::EagerClientCache>* eager_client_cache) override {
66 return wrapped_->GetEagerClientCache(eager_client_cache);
67 }
68
GetCoordinationClientCache(std::unique_ptr<CoordinationClientCache> * coordination_client_cache)69 Status GetCoordinationClientCache(std::unique_ptr<CoordinationClientCache>*
70 coordination_client_cache) override {
71 return wrapped_->GetCoordinationClientCache(coordination_client_cache);
72 }
73
ReleaseWorker(const string & target,WorkerInterface * worker)74 void ReleaseWorker(const string& target, WorkerInterface* worker) override {
75 // TODO(jeff,sanjay): Should decrement ref-count when we implement eviction.
76 }
77
GetDeviceLocalityNonBlocking(const string & device,DeviceLocality * locality)78 bool GetDeviceLocalityNonBlocking(const string& device,
79 DeviceLocality* locality) override {
80 return wrapped_->GetDeviceLocalityNonBlocking(device, locality);
81 }
82
GetDeviceLocalityAsync(const string & device,DeviceLocality * locality,StatusCallback done)83 void GetDeviceLocalityAsync(const string& device, DeviceLocality* locality,
84 StatusCallback done) override {
85 wrapped_->GetDeviceLocalityAsync(device, locality, done);
86 }
87
SetLogging(bool active)88 void SetLogging(bool active) override { wrapped_->SetLogging(active); }
89
ClearLogs()90 void ClearLogs() override { wrapped_->ClearLogs(); }
91
RetrieveLogs(int64_t step_id,StepStats * ss)92 bool RetrieveLogs(int64_t step_id, StepStats* ss) override {
93 return wrapped_->RetrieveLogs(step_id, ss);
94 }
95
96 private:
97 std::unique_ptr<WorkerCacheInterface> wrapped_;
98
99 // Information kept per created WorkerInterface.
100 struct WorkerState {
101 WorkerInterface* worker;
102 // TODO(jeff,sanjay): Add reference count if we support eviction.
103 };
104
105 // TODO(jeff,sanjay): Eviction when the map becomes too big.
106 mutex mu_;
107 std::unordered_map<string, WorkerState> workers_ TF_GUARDED_BY(mu_);
108 };
109
110 } // namespace
111
WorkerSession(const string & session_name,const string & worker_name,std::unique_ptr<WorkerCacheInterface> worker_cache,std::unique_ptr<DeviceMgr> device_mgr,std::unique_ptr<GraphMgr> graph_mgr,std::unique_ptr<DynamicDeviceMgr> remote_device_mgr)112 WorkerSession::WorkerSession(
113 const string& session_name, const string& worker_name,
114 std::unique_ptr<WorkerCacheInterface> worker_cache,
115 std::unique_ptr<DeviceMgr> device_mgr, std::unique_ptr<GraphMgr> graph_mgr,
116 std::unique_ptr<DynamicDeviceMgr> remote_device_mgr)
117 : session_name_(session_name),
118 worker_name_(worker_name),
119 worker_cache_(new WorkerFreeListCache(std::move(worker_cache))),
120 graph_mgr_(std::move(graph_mgr)),
121 cluster_flr_(new ClusterFunctionLibraryRuntime(
122 this, !session_name.empty(),
123 remote_device_mgr ? remote_device_mgr.get() : nullptr)),
124 device_mgr_(std::move(device_mgr)),
125 borrowed_device_mgr_(nullptr),
126 remote_device_mgr_(std::move(remote_device_mgr)) {
127 // Starts exporting metrics through a platform-specific monitoring API (if
128 // provided). For builds using "tensorflow/core/platform/default", this is
129 // currently a no-op.
130 worker_session_created->GetCell()->Set(true);
131 }
132
UpdateWorkerCacheAndDevices(std::unique_ptr<WorkerCacheInterface> new_worker_cache,std::vector<std::unique_ptr<Device>> added_remote_devices,const std::vector<Device * > & removed_remote_devices)133 Status WorkerSession::UpdateWorkerCacheAndDevices(
134 std::unique_ptr<WorkerCacheInterface> new_worker_cache,
135 std::vector<std::unique_ptr<Device>> added_remote_devices,
136 const std::vector<Device*>& removed_remote_devices) {
137 {
138 mutex_lock l(worker_session_state_mu_);
139 worker_cache_ = std::shared_ptr<WorkerCacheInterface>(
140 new WorkerFreeListCache(std::move(new_worker_cache)));
141 }
142 TF_RETURN_IF_ERROR(remote_device_mgr_->RemoveDevices(removed_remote_devices));
143 TF_RETURN_IF_ERROR(
144 remote_device_mgr_->AddDevices(std::move(added_remote_devices)));
145 return Status::OK();
146 }
147
148 /* static */
CreateWithBorrowedDeviceMgr(const string & session_name,const string & worker_name,std::unique_ptr<WorkerCacheInterface> worker_cache,DeviceMgr * borrowed_device_mgr,std::unique_ptr<GraphMgr> graph_mgr,std::unique_ptr<DynamicDeviceMgr> remote_device_mgr)149 std::shared_ptr<WorkerSession> WorkerSession::CreateWithBorrowedDeviceMgr(
150 const string& session_name, const string& worker_name,
151 std::unique_ptr<WorkerCacheInterface> worker_cache,
152 DeviceMgr* borrowed_device_mgr, std::unique_ptr<GraphMgr> graph_mgr,
153 std::unique_ptr<DynamicDeviceMgr> remote_device_mgr) {
154 return std::shared_ptr<WorkerSession>(new WorkerSession(
155 session_name, worker_name, std::move(worker_cache), borrowed_device_mgr,
156 std::move(graph_mgr), std::move(remote_device_mgr)));
157 }
158
WorkerSession(const string & session_name,const string & worker_name,std::unique_ptr<WorkerCacheInterface> worker_cache,DeviceMgr * borrowed_device_mgr,std::unique_ptr<GraphMgr> graph_mgr,std::unique_ptr<DynamicDeviceMgr> remote_device_mgr)159 WorkerSession::WorkerSession(
160 const string& session_name, const string& worker_name,
161 std::unique_ptr<WorkerCacheInterface> worker_cache,
162 DeviceMgr* borrowed_device_mgr, std::unique_ptr<GraphMgr> graph_mgr,
163 std::unique_ptr<DynamicDeviceMgr> remote_device_mgr)
164 : session_name_(session_name),
165 worker_name_(worker_name),
166 worker_cache_(new WorkerFreeListCache(std::move(worker_cache))),
167 graph_mgr_(std::move(graph_mgr)),
168 cluster_flr_(new ClusterFunctionLibraryRuntime(
169 this, !session_name.empty(), remote_device_mgr.get())),
170 device_mgr_(nullptr),
171 borrowed_device_mgr_(borrowed_device_mgr),
172 remote_device_mgr_(std::move(remote_device_mgr)) {
173 // Starts exporting metrics through a platform-specific monitoring API (if
174 // provided). For builds using "tensorflow/core/platform/default", this is
175 // currently a no-op.
176 worker_session_created->GetCell()->Set(true);
177 }
178
~WorkerSession()179 WorkerSession::~WorkerSession() {
180 if (graph_mgr_) {
181 Status s = graph_mgr_->DeregisterAll();
182 if (!s.ok()) {
183 LOG(WARNING) << "Error during worker session deletion: " << s;
184 }
185 }
186 }
187
188 } // namespace tensorflow
189