1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/distributed_runtime/worker_session.h"
16
17 #include "tensorflow/core/lib/monitoring/gauge.h"
18 #include "tensorflow/core/platform/monitoring.h"
19
20 namespace tensorflow {
21
22 namespace {
23
24 auto* worker_session_created =
25 monitoring::Gauge<bool, 0>::New("/tensorflow/core/worker_session_created",
26 "True if a worker session was created.");
27
28 // A private cache that wraps worker_cache and allows reuse of
29 // WorkerInterface objects.
30 class WorkerFreeListCache : public WorkerCacheInterface {
31 public:
WorkerFreeListCache(std::unique_ptr<WorkerCacheInterface> w)32 explicit WorkerFreeListCache(std::unique_ptr<WorkerCacheInterface> w)
33 : wrapped_(std::move(w)) {}
34
~WorkerFreeListCache()35 ~WorkerFreeListCache() final {
36 for (auto& p : workers_) {
37 wrapped_->ReleaseWorker(p.first, p.second.worker);
38 }
39 }
40
ListWorkers(std::vector<string> * workers) const41 void ListWorkers(std::vector<string>* workers) const override {
42 wrapped_->ListWorkers(workers);
43 }
44
ListWorkersInJob(const string & job_name,std::vector<string> * workers) const45 void ListWorkersInJob(const string& job_name,
46 std::vector<string>* workers) const override {
47 wrapped_->ListWorkersInJob(job_name, workers);
48 }
49
GetOrCreateWorker(const string & target)50 WorkerInterface* GetOrCreateWorker(const string& target) override {
51 mutex_lock l(mu_);
52 auto p = workers_.find(target);
53 if (p != workers_.end()) {
54 return p->second.worker;
55 }
56 WorkerState state;
57 state.worker = wrapped_->GetOrCreateWorker(target);
58 if (state.worker != nullptr) {
59 workers_.insert(std::make_pair(target, state));
60 }
61 return state.worker;
62 }
63
GetEagerClientCache(std::unique_ptr<eager::EagerClientCache> * eager_client_cache)64 Status GetEagerClientCache(
65 std::unique_ptr<eager::EagerClientCache>* eager_client_cache) override {
66 return wrapped_->GetEagerClientCache(eager_client_cache);
67 }
68
ReleaseWorker(const string & target,WorkerInterface * worker)69 void ReleaseWorker(const string& target, WorkerInterface* worker) override {
70 // TODO(jeff,sanjay): Should decrement ref-count when we implement eviction.
71 }
72
GetDeviceLocalityNonBlocking(const string & device,DeviceLocality * locality)73 bool GetDeviceLocalityNonBlocking(const string& device,
74 DeviceLocality* locality) override {
75 return wrapped_->GetDeviceLocalityNonBlocking(device, locality);
76 }
77
GetDeviceLocalityAsync(const string & device,DeviceLocality * locality,StatusCallback done)78 void GetDeviceLocalityAsync(const string& device, DeviceLocality* locality,
79 StatusCallback done) override {
80 wrapped_->GetDeviceLocalityAsync(device, locality, done);
81 }
82
SetLogging(bool active)83 void SetLogging(bool active) override { wrapped_->SetLogging(active); }
84
ClearLogs()85 void ClearLogs() override { wrapped_->ClearLogs(); }
86
RetrieveLogs(int64 step_id,StepStats * ss)87 bool RetrieveLogs(int64 step_id, StepStats* ss) override {
88 return wrapped_->RetrieveLogs(step_id, ss);
89 }
90
91 private:
92 std::unique_ptr<WorkerCacheInterface> wrapped_;
93
94 // Information kept per created WorkerInterface.
95 struct WorkerState {
96 WorkerInterface* worker;
97 // TODO(jeff,sanjay): Add reference count if we support eviction.
98 };
99
100 // TODO(jeff,sanjay): Eviction when the map becomes too big.
101 mutex mu_;
102 std::unordered_map<string, WorkerState> workers_ GUARDED_BY(mu_);
103 };
104
105 } // namespace
106
WorkerSession(const string & session_name,const string & worker_name,std::unique_ptr<WorkerCacheInterface> worker_cache,std::unique_ptr<DeviceMgr> device_mgr,std::unique_ptr<GraphMgr> graph_mgr,std::unique_ptr<DynamicDeviceMgr> remote_device_mgr)107 WorkerSession::WorkerSession(
108 const string& session_name, const string& worker_name,
109 std::unique_ptr<WorkerCacheInterface> worker_cache,
110 std::unique_ptr<DeviceMgr> device_mgr, std::unique_ptr<GraphMgr> graph_mgr,
111 std::unique_ptr<DynamicDeviceMgr> remote_device_mgr)
112 : session_name_(session_name),
113 worker_name_(worker_name),
114 worker_cache_(new WorkerFreeListCache(std::move(worker_cache))),
115 graph_mgr_(std::move(graph_mgr)),
116 cluster_flr_(new ClusterFunctionLibraryRuntime(
117 this, !session_name.empty(),
118 remote_device_mgr ? remote_device_mgr.get() : nullptr)),
119 device_mgr_(std::move(device_mgr)),
120 borrowed_device_mgr_(nullptr),
121 remote_device_mgr_(std::move(remote_device_mgr)) {
122 // Starts exporting metrics through a platform-specific monitoring API (if
123 // provided). For builds using "tensorflow/core/platform/default", this is
124 // currently a no-op.
125 worker_session_created->GetCell()->Set(true);
126 monitoring::StartExporter();
127 }
128
UpdateWorkerCacheAndDevices(std::unique_ptr<WorkerCacheInterface> new_worker_cache,std::vector<std::unique_ptr<Device>> added_remote_devices,const std::vector<Device * > & removed_remote_devices)129 Status WorkerSession::UpdateWorkerCacheAndDevices(
130 std::unique_ptr<WorkerCacheInterface> new_worker_cache,
131 std::vector<std::unique_ptr<Device>> added_remote_devices,
132 const std::vector<Device*>& removed_remote_devices) {
133 worker_cache_ = std::unique_ptr<WorkerCacheInterface>(
134 new WorkerFreeListCache(std::move(new_worker_cache)));
135 TF_RETURN_IF_ERROR(remote_device_mgr_->RemoveDevices(removed_remote_devices));
136 TF_RETURN_IF_ERROR(
137 remote_device_mgr_->AddDevices(std::move(added_remote_devices)));
138 cluster_flr_ = std::unique_ptr<ClusterFunctionLibraryRuntime>(
139 new ClusterFunctionLibraryRuntime(this, !session_name_.empty(),
140 remote_device_mgr()));
141 return Status::OK();
142 }
143
144 /* static */
CreateWithBorrowedDeviceMgr(const string & session_name,const string & worker_name,std::unique_ptr<WorkerCacheInterface> worker_cache,DeviceMgr * borrowed_device_mgr,std::unique_ptr<GraphMgr> graph_mgr,std::unique_ptr<DynamicDeviceMgr> remote_device_mgr)145 std::shared_ptr<WorkerSession> WorkerSession::CreateWithBorrowedDeviceMgr(
146 const string& session_name, const string& worker_name,
147 std::unique_ptr<WorkerCacheInterface> worker_cache,
148 DeviceMgr* borrowed_device_mgr, std::unique_ptr<GraphMgr> graph_mgr,
149 std::unique_ptr<DynamicDeviceMgr> remote_device_mgr) {
150 return std::shared_ptr<WorkerSession>(new WorkerSession(
151 session_name, worker_name, std::move(worker_cache), borrowed_device_mgr,
152 std::move(graph_mgr), std::move(remote_device_mgr)));
153 }
154
WorkerSession(const string & session_name,const string & worker_name,std::unique_ptr<WorkerCacheInterface> worker_cache,DeviceMgr * borrowed_device_mgr,std::unique_ptr<GraphMgr> graph_mgr,std::unique_ptr<DynamicDeviceMgr> remote_device_mgr)155 WorkerSession::WorkerSession(
156 const string& session_name, const string& worker_name,
157 std::unique_ptr<WorkerCacheInterface> worker_cache,
158 DeviceMgr* borrowed_device_mgr, std::unique_ptr<GraphMgr> graph_mgr,
159 std::unique_ptr<DynamicDeviceMgr> remote_device_mgr)
160 : session_name_(session_name),
161 worker_name_(worker_name),
162 worker_cache_(new WorkerFreeListCache(std::move(worker_cache))),
163 graph_mgr_(std::move(graph_mgr)),
164 cluster_flr_(new ClusterFunctionLibraryRuntime(
165 this, !session_name.empty(), remote_device_mgr.get())),
166 device_mgr_(nullptr),
167 borrowed_device_mgr_(borrowed_device_mgr),
168 remote_device_mgr_(std::move(remote_device_mgr)) {
169 // Starts exporting metrics through a platform-specific monitoring API (if
170 // provided). For builds using "tensorflow/core/platform/default", this is
171 // currently a no-op.
172 worker_session_created->GetCell()->Set(true);
173 monitoring::StartExporter();
174 }
175
~WorkerSession()176 WorkerSession::~WorkerSession() {
177 if (graph_mgr_) {
178 Status s = graph_mgr_->DeregisterAll();
179 if (!s.ok()) {
180 LOG(WARNING) << "Error during worker session deletion: " << s;
181 }
182 }
183 }
184
185 } // namespace tensorflow
186