1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/distributed_runtime/session_mgr.h"
17
18 #include <utility>
19
20 #include "tensorflow/core/common_runtime/device_mgr.h"
21 #include "tensorflow/core/common_runtime/renamed_device.h"
22 #include "tensorflow/core/distributed_runtime/graph_mgr.h"
23 #include "tensorflow/core/distributed_runtime/remote_device.h"
24 #include "tensorflow/core/distributed_runtime/worker_cache_wrapper.h"
25 #include "tensorflow/core/lib/strings/strcat.h"
26 #include "tensorflow/core/protobuf/cluster.pb.h"
27 #include "tensorflow/core/protobuf/tensorflow_server.pb.h"
28 #include "tensorflow/core/util/ptr_util.h"
29
30 namespace tensorflow {
31
SessionMgr(WorkerEnv * worker_env,const string & default_worker_name,std::unique_ptr<WorkerCacheInterface> default_worker_cache,WorkerCacheFactory worker_cache_factory)32 SessionMgr::SessionMgr(
33 WorkerEnv* worker_env, const string& default_worker_name,
34 std::unique_ptr<WorkerCacheInterface> default_worker_cache,
35 WorkerCacheFactory worker_cache_factory)
36 : worker_env_(worker_env),
37 default_worker_cache_(std::move(default_worker_cache)),
38 legacy_session_(WorkerSession::CreateWithBorrowedDeviceMgr(
39 "", default_worker_name,
40 std::unique_ptr<WorkerCacheInterface>(
41 new WorkerCacheWrapper(default_worker_cache_.get())),
42 worker_env->device_mgr,
43 std::unique_ptr<GraphMgr>(
44 new GraphMgr(worker_env, worker_env->device_mgr)),
45 nullptr)),
46 worker_cache_factory_(std::move(worker_cache_factory)) {}
47
48 /* static */
WorkerNameFromServerDef(const ServerDef & server_def)49 string SessionMgr::WorkerNameFromServerDef(const ServerDef& server_def) {
50 return strings::StrCat("/job:", server_def.job_name(),
51 "/replica:0/task:", server_def.task_index());
52 }
53
CreateSession(const string & session,const ServerDef & server_def,bool isolate_session_state)54 Status SessionMgr::CreateSession(const string& session,
55 const ServerDef& server_def,
56 bool isolate_session_state) {
57 return CreateSession(session, server_def, {}, isolate_session_state);
58 }
59
CreateSession(const string & session,const ServerDef & server_def,const protobuf::RepeatedPtrField<DeviceAttributes> & cluster_device_attributes,bool isolate_session_state)60 Status SessionMgr::CreateSession(
61 const string& session, const ServerDef& server_def,
62 const protobuf::RepeatedPtrField<DeviceAttributes>&
63 cluster_device_attributes,
64 bool isolate_session_state) {
65 return CreateSession(session, server_def, cluster_device_attributes,
66 isolate_session_state, /*master_task=*/"",
67 /*master_incarnation=*/0);
68 }
69
CreateSession(const string & session,const ServerDef & server_def,const protobuf::RepeatedPtrField<DeviceAttributes> & cluster_device_attributes,bool isolate_session_state,string master_task,int64 master_incarnation)70 Status SessionMgr::CreateSession(
71 const string& session, const ServerDef& server_def,
72 const protobuf::RepeatedPtrField<DeviceAttributes>&
73 cluster_device_attributes,
74 bool isolate_session_state, string master_task, int64 master_incarnation) {
75 mutex_lock l(mu_);
76 if (session.empty()) {
77 return errors::InvalidArgument("Session must be non-empty.");
78 }
79
80 // For given master task name, check if one or more `WorkerSession`s have been
81 // created previously on this worker, and if so garbage collect the expired
82 // `WorkerSession`s. This happens when the master fails before sending
83 // `DeleteSession` requests, which can cause `WorkerSession`s to be leaked.
84 if (!master_task.empty()) {
85 auto it_range = master_to_associated_sessions_.equal_range(master_task);
86 if (it_range.first != it_range.second &&
87 it_range.first->second.master_incarnation != master_incarnation) {
88 LOG(INFO) << "When creating WorkerSession for master task " << master_task
89 << ", found old WorkerSessions created by the same master task "
90 << "with a different incarnation. These sessions will "
91 << "be garbage collected. Current WorkerSession count: "
92 << sessions_.size();
93
94 auto it = it_range.first;
95 while (it != it_range.second) {
96 auto session_it = sessions_.find(it->second.session_handle);
97 if (session_it != sessions_.end()) {
98 sessions_.erase(session_it);
99 }
100 it = master_to_associated_sessions_.erase(it);
101 }
102 }
103 }
104
105 WorkerCacheInterface* worker_cache = nullptr;
106 string worker_name;
107 if (server_def.cluster().job().empty()) {
108 worker_cache = new WorkerCacheWrapper(default_worker_cache_.get());
109 worker_name = legacy_session_->worker_name();
110 } else {
111 TF_RETURN_IF_ERROR(worker_cache_factory_(server_def, &worker_cache));
112 worker_name = WorkerNameFromServerDef(server_def);
113 }
114
115 if (worker_cache != nullptr && default_worker_cache_ != nullptr) {
116 worker_cache->SetLogging(this->is_logging_active_);
117 }
118
119 CHECK(!worker_env_->local_devices.empty())
120 << "The WorkerEnv must have at least one device in `local_devices`.";
121
122 std::shared_ptr<WorkerSession> worker_session;
123 std::vector<std::unique_ptr<Device>> cluster_devices;
124
125 if (isolate_session_state || server_def.cluster().job_size()) {
126 if (server_def.cluster().job_size()) {
127 VLOG(1) << "ClusterSpec propagation is enabled.";
128 }
129 if (!isolate_session_state) {
130 VLOG(1) << "Session state isolation is disabled.";
131 }
132
133 // Create a private copy of the DeviceMgr for the WorkerSession.
134 std::vector<std::unique_ptr<Device>> renamed_devices;
135 for (Device* d : worker_env_->local_devices) {
136 renamed_devices.push_back(RenamedDevice::NewRenamedDevice(
137 worker_name, d, false, isolate_session_state));
138 }
139 auto device_mgr = MakeUnique<StaticDeviceMgr>(std::move(renamed_devices));
140 LookupLocalDevice cb = [&device_mgr](StringPiece name, Device** device) {
141 return device_mgr->LookupDevice(name, device);
142 };
143 AsRemoteDevices(worker_env_->env, cluster_device_attributes, cb,
144 &cluster_devices);
145 std::unique_ptr<DynamicDeviceMgr> remote_devices;
146 if (!cluster_device_attributes.empty()) {
147 remote_devices = MakeUnique<DynamicDeviceMgr>();
148 TF_RETURN_IF_ERROR(
149 remote_devices->AddDevices(std::move(cluster_devices)));
150 }
151
152 auto graph_mgr = MakeUnique<GraphMgr>(worker_env_, device_mgr.get());
153 worker_session.reset(
154 new WorkerSession(session, worker_name,
155 std::unique_ptr<WorkerCacheInterface>(worker_cache),
156 std::move(device_mgr), std::move(graph_mgr),
157 std::move(remote_devices)));
158 } else {
159 AsRemoteDevices(worker_env_->env, cluster_device_attributes, nullptr,
160 &cluster_devices);
161 std::unique_ptr<DynamicDeviceMgr> remote_devices;
162 if (!cluster_device_attributes.empty()) {
163 remote_devices = MakeUnique<DynamicDeviceMgr>();
164 TF_RETURN_IF_ERROR(
165 remote_devices->AddDevices(std::move(cluster_devices)));
166 }
167 // Borrow the WorkerEnv's DeviceMgr for the WorkerSession, so
168 // that resources using it can use its devices after the
169 // WorkerSession has been deleted.
170 auto graph_mgr = MakeUnique<GraphMgr>(worker_env_, worker_env_->device_mgr);
171 worker_session = WorkerSession::CreateWithBorrowedDeviceMgr(
172 session, worker_name,
173 std::unique_ptr<WorkerCacheInterface>(worker_cache),
174 worker_env_->device_mgr, std::move(graph_mgr),
175 std::move(remote_devices));
176 }
177
178 sessions_.insert(std::make_pair(session, std::move(worker_session)));
179 if (!master_task.empty()) {
180 MasterAssociatedSession s{master_incarnation, session};
181 master_to_associated_sessions_.emplace(master_task, s);
182 }
183 return Status::OK();
184 }
185
ResetDefaultWorkerCache(WorkerCacheInterface * worker_cache)186 void SessionMgr::ResetDefaultWorkerCache(WorkerCacheInterface* worker_cache) {
187 default_worker_cache_.reset(worker_cache);
188 }
189
UpdateSession(const string & session,const ServerDef & server_def,const protobuf::RepeatedPtrField<DeviceAttributes> & cluster_device_attributes,bool isolate_session_state)190 Status SessionMgr::UpdateSession(
191 const string& session, const ServerDef& server_def,
192 const protobuf::RepeatedPtrField<DeviceAttributes>&
193 cluster_device_attributes,
194 bool isolate_session_state) {
195 mutex_lock l(mu_);
196 if (session.empty()) {
197 return errors::InvalidArgument("Session must be non-empty.");
198 }
199 auto it = sessions_.find(session);
200 if (it == sessions_.end()) {
201 return errors::InvalidArgument("Cannot update session ", session,
202 " because it does not exist.");
203 }
204 std::shared_ptr<WorkerSession> worker_session = it->second;
205
206 WorkerCacheInterface* worker_cache = nullptr;
207 if (server_def.cluster().job().empty()) {
208 worker_cache = new WorkerCacheWrapper(default_worker_cache_.get());
209 } else {
210 TF_RETURN_IF_ERROR(worker_cache_factory_(server_def, &worker_cache));
211 }
212 std::vector<string> updated_remote_workers;
213 worker_cache->ListWorkers(&updated_remote_workers);
214
215 std::vector<std::unique_ptr<Device>> cluster_devices;
216
217 const DeviceMgr* local_device_mgr = worker_session->device_mgr();
218 DeviceMgr* remote_device_mgr = worker_session->remote_device_mgr();
219 std::vector<Device*> curr_remote_devices = remote_device_mgr->ListDevices();
220 std::vector<std::unique_ptr<Device>> added_remote_devices;
221 std::vector<Device*> removed_remote_devices;
222
223 std::vector<DeviceAttributes> added_cluster_device_attrs;
224 for (const auto& da : cluster_device_attributes) {
225 Device* device;
226 if (!local_device_mgr->LookupDevice(da.name(), &device).ok() &&
227 !remote_device_mgr->LookupDevice(da.name(), &device).ok()) {
228 added_cluster_device_attrs.emplace_back(da);
229 } else if (device != nullptr &&
230 device->attributes().incarnation() != da.incarnation()) {
231 removed_remote_devices.emplace_back(device);
232 added_cluster_device_attrs.emplace_back(da);
233 }
234 }
235 for (Device* device : curr_remote_devices) {
236 string task_name;
237 DeviceNameUtils::GetTaskName(device->parsed_name(), &task_name);
238 if (std::find(updated_remote_workers.begin(), updated_remote_workers.end(),
239 task_name) == updated_remote_workers.end()) {
240 removed_remote_devices.emplace_back(device);
241 }
242 }
243 protobuf::RepeatedPtrField<DeviceAttributes> added_cluster_device_attrs_pb(
244 added_cluster_device_attrs.begin(), added_cluster_device_attrs.end());
245 AsRemoteDevices(worker_env_->env, added_cluster_device_attrs_pb, nullptr,
246 &added_remote_devices);
247
248 TF_RETURN_IF_ERROR(worker_session->UpdateWorkerCacheAndDevices(
249 std::unique_ptr<WorkerCacheInterface>(worker_cache),
250 std::move(added_remote_devices), removed_remote_devices));
251 return Status::OK();
252 }
253
DeleteSession(const string & session)254 Status SessionMgr::DeleteSession(const string& session) {
255 mutex_lock l(mu_);
256 auto it = sessions_.find(session);
257 if (it != sessions_.end()) {
258 sessions_.erase(it);
259 }
260 return Status::OK();
261 }
262
WorkerSessionForSessionLocked(const string & session_handle,std::shared_ptr<WorkerSession> * out_session)263 Status SessionMgr::WorkerSessionForSessionLocked(
264 const string& session_handle, std::shared_ptr<WorkerSession>* out_session) {
265 if (session_handle.empty()) {
266 *out_session = legacy_session_;
267 } else {
268 auto it = sessions_.find(session_handle);
269 if (it == sessions_.end()) {
270 return errors::Aborted("Session handle is not found: ", session_handle,
271 ". Possibly this worker (\"",
272 legacy_session_->worker_name(),
273 "\") just restarted.");
274 } else {
275 *out_session = it->second;
276 }
277 }
278 return Status::OK();
279 }
280
WorkerSessionForSession(const string & session_handle,std::shared_ptr<WorkerSession> * out_session)281 Status SessionMgr::WorkerSessionForSession(
282 const string& session_handle, std::shared_ptr<WorkerSession>* out_session) {
283 mutex_lock l(mu_);
284 return WorkerSessionForSessionLocked(session_handle, out_session);
285 }
286
LegacySession()287 std::shared_ptr<WorkerSession> SessionMgr::LegacySession() {
288 return legacy_session_;
289 }
290
SetLogging(bool active)291 void SessionMgr::SetLogging(bool active) {
292 mutex_lock l(mu_);
293 this->is_logging_active_ = active;
294 // Legacy Session
295 if (legacy_session_) {
296 auto* worker_cache = legacy_session_->worker_cache();
297 if (worker_cache) {
298 worker_cache->SetLogging(active);
299 }
300 }
301
302 for (const auto& session_kv : sessions_) {
303 auto session = session_kv.second.get();
304 if (session) {
305 auto* worker_cache = session->worker_cache();
306 if (worker_cache) {
307 worker_cache->SetLogging(active);
308 }
309 }
310 }
311 }
312
RetrieveLogs(tensorflow::int64 step_id,LoggingResponse * response)313 void SessionMgr::RetrieveLogs(tensorflow::int64 step_id,
314 LoggingResponse* response) {
315 mutex_lock l(mu_);
316 // Legacy Session
317 if (legacy_session_) {
318 auto* worker_cache = legacy_session_->worker_cache();
319 if (worker_cache) {
320 auto step_stats = StepStats();
321 if (worker_cache->RetrieveLogs(step_id, &step_stats)) {
322 auto* labeled_step_stats = response->add_step();
323 labeled_step_stats->set_step_id(step_id);
324 labeled_step_stats->mutable_step_stats()->Swap(&step_stats);
325 }
326 }
327 }
328 for (const auto& session_kv : sessions_) {
329 auto session = session_kv.second.get();
330 if (session) {
331 auto* worker_cache = session->worker_cache();
332 if (worker_cache) {
333 auto step_stats = StepStats();
334 if (worker_cache->RetrieveLogs(step_id, &step_stats)) {
335 auto* labeled_step_stats = response->add_step();
336 labeled_step_stats->set_step_id(step_id);
337 labeled_step_stats->mutable_step_stats()->Swap(&step_stats);
338 }
339 }
340 }
341 }
342 }
343
ClearLogs()344 void SessionMgr::ClearLogs() {
345 mutex_lock l(mu_);
346 // Legacy Session
347 if (legacy_session_) {
348 auto* worker_cache = legacy_session_->worker_cache();
349 if (worker_cache) {
350 worker_cache->ClearLogs();
351 }
352 }
353
354 for (const auto& session_kv : sessions_) {
355 auto session = session_kv.second.get();
356 if (session) {
357 auto* worker_cache = session->worker_cache();
358 if (worker_cache) {
359 worker_cache->ClearLogs();
360 }
361 }
362 }
363 }
364 } // namespace tensorflow
365