• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/distributed_runtime/session_mgr.h"
17 
18 #include <utility>
19 
20 #include "tensorflow/core/common_runtime/device_mgr.h"
21 #include "tensorflow/core/common_runtime/renamed_device.h"
22 #include "tensorflow/core/distributed_runtime/graph_mgr.h"
23 #include "tensorflow/core/distributed_runtime/remote_device.h"
24 #include "tensorflow/core/distributed_runtime/worker_cache_wrapper.h"
25 #include "tensorflow/core/lib/strings/strcat.h"
26 #include "tensorflow/core/protobuf/cluster.pb.h"
27 #include "tensorflow/core/protobuf/tensorflow_server.pb.h"
28 #include "tensorflow/core/util/ptr_util.h"
29 
30 namespace tensorflow {
31 
SessionMgr(WorkerEnv * worker_env,const string & default_worker_name,std::unique_ptr<WorkerCacheInterface> default_worker_cache,WorkerCacheFactory worker_cache_factory)32 SessionMgr::SessionMgr(
33     WorkerEnv* worker_env, const string& default_worker_name,
34     std::unique_ptr<WorkerCacheInterface> default_worker_cache,
35     WorkerCacheFactory worker_cache_factory)
36     : worker_env_(worker_env),
37       default_worker_cache_(std::move(default_worker_cache)),
38       legacy_session_(WorkerSession::CreateWithBorrowedDeviceMgr(
39           "", default_worker_name,
40           std::unique_ptr<WorkerCacheInterface>(
41               new WorkerCacheWrapper(default_worker_cache_.get())),
42           worker_env->device_mgr,
43           std::unique_ptr<GraphMgr>(
44               new GraphMgr(worker_env, worker_env->device_mgr)),
45           nullptr)),
46       worker_cache_factory_(std::move(worker_cache_factory)) {}
47 
48 /* static */
WorkerNameFromServerDef(const ServerDef & server_def)49 string SessionMgr::WorkerNameFromServerDef(const ServerDef& server_def) {
50   return strings::StrCat("/job:", server_def.job_name(),
51                          "/replica:0/task:", server_def.task_index());
52 }
53 
CreateSession(const string & session,const ServerDef & server_def,bool isolate_session_state)54 Status SessionMgr::CreateSession(const string& session,
55                                  const ServerDef& server_def,
56                                  bool isolate_session_state) {
57   return CreateSession(session, server_def, {}, isolate_session_state);
58 }
59 
CreateSession(const string & session,const ServerDef & server_def,const protobuf::RepeatedPtrField<DeviceAttributes> & cluster_device_attributes,bool isolate_session_state)60 Status SessionMgr::CreateSession(
61     const string& session, const ServerDef& server_def,
62     const protobuf::RepeatedPtrField<DeviceAttributes>&
63         cluster_device_attributes,
64     bool isolate_session_state) {
65   return CreateSession(session, server_def, cluster_device_attributes,
66                        isolate_session_state, /*master_task=*/"",
67                        /*master_incarnation=*/0);
68 }
69 
CreateSession(const string & session,const ServerDef & server_def,const protobuf::RepeatedPtrField<DeviceAttributes> & cluster_device_attributes,bool isolate_session_state,string master_task,int64 master_incarnation)70 Status SessionMgr::CreateSession(
71     const string& session, const ServerDef& server_def,
72     const protobuf::RepeatedPtrField<DeviceAttributes>&
73         cluster_device_attributes,
74     bool isolate_session_state, string master_task, int64 master_incarnation) {
75   mutex_lock l(mu_);
76   if (session.empty()) {
77     return errors::InvalidArgument("Session must be non-empty.");
78   }
79 
80   // For given master task name, check if one or more `WorkerSession`s have been
81   // created previously on this worker, and if so garbage collect the expired
82   // `WorkerSession`s. This happens when the master fails before sending
83   // `DeleteSession` requests, which can cause `WorkerSession`s to be leaked.
84   if (!master_task.empty()) {
85     auto it_range = master_to_associated_sessions_.equal_range(master_task);
86     if (it_range.first != it_range.second &&
87         it_range.first->second.master_incarnation != master_incarnation) {
88       LOG(INFO) << "When creating WorkerSession for master task " << master_task
89                 << ", found old WorkerSessions created by the same master task "
90                 << "with a different incarnation. These sessions will "
91                 << "be garbage collected. Current WorkerSession count: "
92                 << sessions_.size();
93 
94       auto it = it_range.first;
95       while (it != it_range.second) {
96         auto session_it = sessions_.find(it->second.session_handle);
97         if (session_it != sessions_.end()) {
98           sessions_.erase(session_it);
99         }
100         it = master_to_associated_sessions_.erase(it);
101       }
102     }
103   }
104 
105   WorkerCacheInterface* worker_cache = nullptr;
106   string worker_name;
107   if (server_def.cluster().job().empty()) {
108     worker_cache = new WorkerCacheWrapper(default_worker_cache_.get());
109     worker_name = legacy_session_->worker_name();
110   } else {
111     TF_RETURN_IF_ERROR(worker_cache_factory_(server_def, &worker_cache));
112     worker_name = WorkerNameFromServerDef(server_def);
113   }
114 
115   if (worker_cache != nullptr && default_worker_cache_ != nullptr) {
116     worker_cache->SetLogging(this->is_logging_active_);
117   }
118 
119   CHECK(!worker_env_->local_devices.empty())
120       << "The WorkerEnv must have at least one device in `local_devices`.";
121 
122   std::shared_ptr<WorkerSession> worker_session;
123   std::vector<std::unique_ptr<Device>> cluster_devices;
124 
125   if (isolate_session_state || server_def.cluster().job_size()) {
126     if (server_def.cluster().job_size()) {
127       VLOG(1) << "ClusterSpec propagation is enabled.";
128     }
129     if (!isolate_session_state) {
130       VLOG(1) << "Session state isolation is disabled.";
131     }
132 
133     // Create a private copy of the DeviceMgr for the WorkerSession.
134     std::vector<std::unique_ptr<Device>> renamed_devices;
135     for (Device* d : worker_env_->local_devices) {
136       renamed_devices.push_back(RenamedDevice::NewRenamedDevice(
137           worker_name, d, false, isolate_session_state));
138     }
139     auto device_mgr = MakeUnique<StaticDeviceMgr>(std::move(renamed_devices));
140     LookupLocalDevice cb = [&device_mgr](StringPiece name, Device** device) {
141       return device_mgr->LookupDevice(name, device);
142     };
143     AsRemoteDevices(worker_env_->env, cluster_device_attributes, cb,
144                     &cluster_devices);
145     std::unique_ptr<DynamicDeviceMgr> remote_devices;
146     if (!cluster_device_attributes.empty()) {
147       remote_devices = MakeUnique<DynamicDeviceMgr>();
148       TF_RETURN_IF_ERROR(
149           remote_devices->AddDevices(std::move(cluster_devices)));
150     }
151 
152     auto graph_mgr = MakeUnique<GraphMgr>(worker_env_, device_mgr.get());
153     worker_session.reset(
154         new WorkerSession(session, worker_name,
155                           std::unique_ptr<WorkerCacheInterface>(worker_cache),
156                           std::move(device_mgr), std::move(graph_mgr),
157                           std::move(remote_devices)));
158   } else {
159     AsRemoteDevices(worker_env_->env, cluster_device_attributes, nullptr,
160                     &cluster_devices);
161     std::unique_ptr<DynamicDeviceMgr> remote_devices;
162     if (!cluster_device_attributes.empty()) {
163       remote_devices = MakeUnique<DynamicDeviceMgr>();
164       TF_RETURN_IF_ERROR(
165           remote_devices->AddDevices(std::move(cluster_devices)));
166     }
167     // Borrow the WorkerEnv's DeviceMgr for the WorkerSession, so
168     // that resources using it can use its devices after the
169     // WorkerSession has been deleted.
170     auto graph_mgr = MakeUnique<GraphMgr>(worker_env_, worker_env_->device_mgr);
171     worker_session = WorkerSession::CreateWithBorrowedDeviceMgr(
172         session, worker_name,
173         std::unique_ptr<WorkerCacheInterface>(worker_cache),
174         worker_env_->device_mgr, std::move(graph_mgr),
175         std::move(remote_devices));
176   }
177 
178   sessions_.insert(std::make_pair(session, std::move(worker_session)));
179   if (!master_task.empty()) {
180     MasterAssociatedSession s{master_incarnation, session};
181     master_to_associated_sessions_.emplace(master_task, s);
182   }
183   return Status::OK();
184 }
185 
ResetDefaultWorkerCache(WorkerCacheInterface * worker_cache)186 void SessionMgr::ResetDefaultWorkerCache(WorkerCacheInterface* worker_cache) {
187   default_worker_cache_.reset(worker_cache);
188 }
189 
UpdateSession(const string & session,const ServerDef & server_def,const protobuf::RepeatedPtrField<DeviceAttributes> & cluster_device_attributes,bool isolate_session_state)190 Status SessionMgr::UpdateSession(
191     const string& session, const ServerDef& server_def,
192     const protobuf::RepeatedPtrField<DeviceAttributes>&
193         cluster_device_attributes,
194     bool isolate_session_state) {
195   mutex_lock l(mu_);
196   if (session.empty()) {
197     return errors::InvalidArgument("Session must be non-empty.");
198   }
199   auto it = sessions_.find(session);
200   if (it == sessions_.end()) {
201     return errors::InvalidArgument("Cannot update session ", session,
202                                    " because it does not exist.");
203   }
204   std::shared_ptr<WorkerSession> worker_session = it->second;
205 
206   WorkerCacheInterface* worker_cache = nullptr;
207   if (server_def.cluster().job().empty()) {
208     worker_cache = new WorkerCacheWrapper(default_worker_cache_.get());
209   } else {
210     TF_RETURN_IF_ERROR(worker_cache_factory_(server_def, &worker_cache));
211   }
212   std::vector<string> updated_remote_workers;
213   worker_cache->ListWorkers(&updated_remote_workers);
214 
215   std::vector<std::unique_ptr<Device>> cluster_devices;
216 
217   const DeviceMgr* local_device_mgr = worker_session->device_mgr();
218   DeviceMgr* remote_device_mgr = worker_session->remote_device_mgr();
219   std::vector<Device*> curr_remote_devices = remote_device_mgr->ListDevices();
220   std::vector<std::unique_ptr<Device>> added_remote_devices;
221   std::vector<Device*> removed_remote_devices;
222 
223   std::vector<DeviceAttributes> added_cluster_device_attrs;
224   for (const auto& da : cluster_device_attributes) {
225     Device* device;
226     if (!local_device_mgr->LookupDevice(da.name(), &device).ok() &&
227         !remote_device_mgr->LookupDevice(da.name(), &device).ok()) {
228       added_cluster_device_attrs.emplace_back(da);
229     } else if (device != nullptr &&
230                device->attributes().incarnation() != da.incarnation()) {
231       removed_remote_devices.emplace_back(device);
232       added_cluster_device_attrs.emplace_back(da);
233     }
234   }
235   for (Device* device : curr_remote_devices) {
236     string task_name;
237     DeviceNameUtils::GetTaskName(device->parsed_name(), &task_name);
238     if (std::find(updated_remote_workers.begin(), updated_remote_workers.end(),
239                   task_name) == updated_remote_workers.end()) {
240       removed_remote_devices.emplace_back(device);
241     }
242   }
243   protobuf::RepeatedPtrField<DeviceAttributes> added_cluster_device_attrs_pb(
244       added_cluster_device_attrs.begin(), added_cluster_device_attrs.end());
245   AsRemoteDevices(worker_env_->env, added_cluster_device_attrs_pb, nullptr,
246                   &added_remote_devices);
247 
248   TF_RETURN_IF_ERROR(worker_session->UpdateWorkerCacheAndDevices(
249       std::unique_ptr<WorkerCacheInterface>(worker_cache),
250       std::move(added_remote_devices), removed_remote_devices));
251   return Status::OK();
252 }
253 
DeleteSession(const string & session)254 Status SessionMgr::DeleteSession(const string& session) {
255   mutex_lock l(mu_);
256   auto it = sessions_.find(session);
257   if (it != sessions_.end()) {
258     sessions_.erase(it);
259   }
260   return Status::OK();
261 }
262 
WorkerSessionForSessionLocked(const string & session_handle,std::shared_ptr<WorkerSession> * out_session)263 Status SessionMgr::WorkerSessionForSessionLocked(
264     const string& session_handle, std::shared_ptr<WorkerSession>* out_session) {
265   if (session_handle.empty()) {
266     *out_session = legacy_session_;
267   } else {
268     auto it = sessions_.find(session_handle);
269     if (it == sessions_.end()) {
270       return errors::Aborted("Session handle is not found: ", session_handle,
271                              ". Possibly this worker (\"",
272                              legacy_session_->worker_name(),
273                              "\") just restarted.");
274     } else {
275       *out_session = it->second;
276     }
277   }
278   return Status::OK();
279 }
280 
WorkerSessionForSession(const string & session_handle,std::shared_ptr<WorkerSession> * out_session)281 Status SessionMgr::WorkerSessionForSession(
282     const string& session_handle, std::shared_ptr<WorkerSession>* out_session) {
283   mutex_lock l(mu_);
284   return WorkerSessionForSessionLocked(session_handle, out_session);
285 }
286 
LegacySession()287 std::shared_ptr<WorkerSession> SessionMgr::LegacySession() {
288   return legacy_session_;
289 }
290 
SetLogging(bool active)291 void SessionMgr::SetLogging(bool active) {
292   mutex_lock l(mu_);
293   this->is_logging_active_ = active;
294   // Legacy Session
295   if (legacy_session_) {
296     auto* worker_cache = legacy_session_->worker_cache();
297     if (worker_cache) {
298       worker_cache->SetLogging(active);
299     }
300   }
301 
302   for (const auto& session_kv : sessions_) {
303     auto session = session_kv.second.get();
304     if (session) {
305       auto* worker_cache = session->worker_cache();
306       if (worker_cache) {
307         worker_cache->SetLogging(active);
308       }
309     }
310   }
311 }
312 
RetrieveLogs(tensorflow::int64 step_id,LoggingResponse * response)313 void SessionMgr::RetrieveLogs(tensorflow::int64 step_id,
314                               LoggingResponse* response) {
315   mutex_lock l(mu_);
316   // Legacy Session
317   if (legacy_session_) {
318     auto* worker_cache = legacy_session_->worker_cache();
319     if (worker_cache) {
320       auto step_stats = StepStats();
321       if (worker_cache->RetrieveLogs(step_id, &step_stats)) {
322         auto* labeled_step_stats = response->add_step();
323         labeled_step_stats->set_step_id(step_id);
324         labeled_step_stats->mutable_step_stats()->Swap(&step_stats);
325       }
326     }
327   }
328   for (const auto& session_kv : sessions_) {
329     auto session = session_kv.second.get();
330     if (session) {
331       auto* worker_cache = session->worker_cache();
332       if (worker_cache) {
333         auto step_stats = StepStats();
334         if (worker_cache->RetrieveLogs(step_id, &step_stats)) {
335           auto* labeled_step_stats = response->add_step();
336           labeled_step_stats->set_step_id(step_id);
337           labeled_step_stats->mutable_step_stats()->Swap(&step_stats);
338         }
339       }
340     }
341   }
342 }
343 
ClearLogs()344 void SessionMgr::ClearLogs() {
345   mutex_lock l(mu_);
346   // Legacy Session
347   if (legacy_session_) {
348     auto* worker_cache = legacy_session_->worker_cache();
349     if (worker_cache) {
350       worker_cache->ClearLogs();
351     }
352   }
353 
354   for (const auto& session_kv : sessions_) {
355     auto session = session_kv.second.get();
356     if (session) {
357       auto* worker_cache = session->worker_cache();
358       if (worker_cache) {
359         worker_cache->ClearLogs();
360       }
361     }
362   }
363 }
364 }  // namespace tensorflow
365