1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/distributed_runtime/session_mgr.h"
17
18 #include <utility>
19
20 #include "tensorflow/core/common_runtime/device_mgr.h"
21 #include "tensorflow/core/common_runtime/renamed_device.h"
22 #include "tensorflow/core/distributed_runtime/graph_mgr.h"
23 #include "tensorflow/core/distributed_runtime/remote_device.h"
24 #include "tensorflow/core/distributed_runtime/worker_cache_wrapper.h"
25 #include "tensorflow/core/lib/strings/strcat.h"
26 #include "tensorflow/core/protobuf/cluster.pb.h"
27 #include "tensorflow/core/protobuf/tensorflow_server.pb.h"
28 #include "tensorflow/core/util/ptr_util.h"
29
30 namespace tensorflow {
31
SessionMgr(WorkerEnv * worker_env,const string & default_worker_name,std::unique_ptr<WorkerCacheInterface> default_worker_cache,WorkerCacheFactory worker_cache_factory)32 SessionMgr::SessionMgr(
33 WorkerEnv* worker_env, const string& default_worker_name,
34 std::unique_ptr<WorkerCacheInterface> default_worker_cache,
35 WorkerCacheFactory worker_cache_factory)
36 : worker_env_(worker_env),
37 default_worker_cache_(std::move(default_worker_cache)),
38 legacy_session_(WorkerSession::CreateWithBorrowedDeviceMgr(
39 "", default_worker_name,
40 std::unique_ptr<WorkerCacheInterface>(
41 new WorkerCacheWrapper(default_worker_cache_.get())),
42 worker_env->device_mgr,
43 std::unique_ptr<GraphMgr>(
44 new GraphMgr(worker_env, worker_env->device_mgr)),
45 nullptr)),
46 worker_cache_factory_(std::move(worker_cache_factory)) {}
47
48 /* static */
WorkerNameFromServerDef(const ServerDef & server_def)49 string SessionMgr::WorkerNameFromServerDef(const ServerDef& server_def) {
50 return strings::StrCat("/job:", server_def.job_name(),
51 "/replica:0/task:", server_def.task_index());
52 }
53
CreateSession(const string & session,const ServerDef & server_def,bool isolate_session_state)54 Status SessionMgr::CreateSession(const string& session,
55 const ServerDef& server_def,
56 bool isolate_session_state) {
57 return CreateSession(session, server_def, {}, isolate_session_state);
58 }
59
CreateSession(const string & session,const ServerDef & server_def,const protobuf::RepeatedPtrField<DeviceAttributes> & cluster_device_attributes,bool isolate_session_state)60 Status SessionMgr::CreateSession(
61 const string& session, const ServerDef& server_def,
62 const protobuf::RepeatedPtrField<DeviceAttributes>&
63 cluster_device_attributes,
64 bool isolate_session_state) {
65 mutex_lock l(mu_);
66 if (session.empty()) {
67 return errors::InvalidArgument("Session must be non-empty.");
68 }
69
70 WorkerCacheInterface* worker_cache = nullptr;
71 string worker_name;
72 if (server_def.cluster().job().empty()) {
73 worker_cache = new WorkerCacheWrapper(default_worker_cache_.get());
74 worker_name = legacy_session_->worker_name();
75 } else {
76 TF_RETURN_IF_ERROR(worker_cache_factory_(server_def, &worker_cache));
77 worker_name = WorkerNameFromServerDef(server_def);
78 }
79
80 if (worker_cache != nullptr && default_worker_cache_ != nullptr) {
81 worker_cache->SetLogging(this->is_logging_active_);
82 }
83
84 CHECK(!worker_env_->local_devices.empty())
85 << "The WorkerEnv must have at least one device in `local_devices`.";
86
87 std::shared_ptr<WorkerSession> worker_session;
88 std::vector<std::unique_ptr<Device>> cluster_devices;
89
90 if (isolate_session_state || server_def.cluster().job_size()) {
91 if (server_def.cluster().job_size()) {
92 VLOG(1) << "ClusterSpec propagation is enabled.";
93 }
94 if (!isolate_session_state) {
95 VLOG(1) << "Session state isolation is disabled.";
96 }
97
98 // Create a private copy of the DeviceMgr for the WorkerSession.
99 std::vector<std::unique_ptr<Device>> renamed_devices;
100 for (Device* d : worker_env_->local_devices) {
101 renamed_devices.push_back(RenamedDevice::NewRenamedDevice(
102 worker_name, d, false, isolate_session_state));
103 }
104 auto device_mgr = MakeUnique<StaticDeviceMgr>(std::move(renamed_devices));
105 LookupLocalDevice cb = [&device_mgr](StringPiece name, Device** device) {
106 return device_mgr->LookupDevice(name, device);
107 };
108 AsRemoteDevices(worker_env_->env, cluster_device_attributes, cb,
109 &cluster_devices);
110 std::unique_ptr<DynamicDeviceMgr> remote_devices;
111 if (!cluster_device_attributes.empty()) {
112 remote_devices = MakeUnique<DynamicDeviceMgr>();
113 TF_RETURN_IF_ERROR(
114 remote_devices->AddDevices(std::move(cluster_devices)));
115 }
116
117 auto graph_mgr = MakeUnique<GraphMgr>(worker_env_, device_mgr.get());
118 worker_session.reset(
119 new WorkerSession(session, worker_name,
120 std::unique_ptr<WorkerCacheInterface>(worker_cache),
121 std::move(device_mgr), std::move(graph_mgr),
122 std::move(remote_devices)));
123 } else {
124 AsRemoteDevices(worker_env_->env, cluster_device_attributes, nullptr,
125 &cluster_devices);
126 std::unique_ptr<DynamicDeviceMgr> remote_devices;
127 if (!cluster_device_attributes.empty()) {
128 remote_devices = MakeUnique<DynamicDeviceMgr>();
129 TF_RETURN_IF_ERROR(
130 remote_devices->AddDevices(std::move(cluster_devices)));
131 }
132 // Borrow the WorkerEnv's DeviceMgr for the WorkerSession, so
133 // that resources using it can use its devices after the
134 // WorkerSession has been deleted.
135 auto graph_mgr = MakeUnique<GraphMgr>(worker_env_, worker_env_->device_mgr);
136 worker_session = WorkerSession::CreateWithBorrowedDeviceMgr(
137 session, worker_name,
138 std::unique_ptr<WorkerCacheInterface>(worker_cache),
139 worker_env_->device_mgr, std::move(graph_mgr),
140 std::move(remote_devices));
141 }
142
143 sessions_.insert(std::make_pair(session, std::move(worker_session)));
144 return Status::OK();
145 }
146
UpdateSession(const string & session,const ServerDef & server_def,const protobuf::RepeatedPtrField<DeviceAttributes> & cluster_device_attributes,bool isolate_session_state)147 Status SessionMgr::UpdateSession(
148 const string& session, const ServerDef& server_def,
149 const protobuf::RepeatedPtrField<DeviceAttributes>&
150 cluster_device_attributes,
151 bool isolate_session_state) {
152 mutex_lock l(mu_);
153 if (session.empty()) {
154 return errors::InvalidArgument("Session must be non-empty.");
155 }
156 auto it = sessions_.find(session);
157 if (it == sessions_.end()) {
158 return errors::InvalidArgument("Cannot update session ", session,
159 " because it does not exist.");
160 }
161 std::shared_ptr<WorkerSession> worker_session = it->second;
162
163 WorkerCacheInterface* worker_cache = nullptr;
164 if (server_def.cluster().job().empty()) {
165 worker_cache = new WorkerCacheWrapper(default_worker_cache_.get());
166 } else {
167 TF_RETURN_IF_ERROR(worker_cache_factory_(server_def, &worker_cache));
168 }
169 std::vector<string> updated_remote_workers;
170 worker_cache->ListWorkers(&updated_remote_workers);
171
172 std::vector<std::unique_ptr<Device>> cluster_devices;
173
174 DeviceMgr* local_device_mgr = worker_session->device_mgr();
175 DeviceMgr* remote_device_mgr = worker_session->remote_device_mgr();
176 std::vector<Device*> curr_remote_devices = remote_device_mgr->ListDevices();
177 std::vector<std::unique_ptr<Device>> added_remote_devices;
178 std::vector<Device*> removed_remote_devices;
179
180 std::vector<DeviceAttributes> added_cluster_device_attrs;
181 for (const auto& da : cluster_device_attributes) {
182 Device* device;
183 if (!local_device_mgr->LookupDevice(da.name(), &device).ok() &&
184 !remote_device_mgr->LookupDevice(da.name(), &device).ok()) {
185 added_cluster_device_attrs.emplace_back(da);
186 } else if (device != nullptr &&
187 device->attributes().incarnation() != da.incarnation()) {
188 removed_remote_devices.emplace_back(device);
189 added_cluster_device_attrs.emplace_back(da);
190 }
191 }
192 for (Device* device : curr_remote_devices) {
193 string task_name;
194 DeviceNameUtils::GetTaskName(device->parsed_name(), &task_name);
195 if (std::find(updated_remote_workers.begin(), updated_remote_workers.end(),
196 task_name) == updated_remote_workers.end()) {
197 removed_remote_devices.emplace_back(device);
198 }
199 }
200 protobuf::RepeatedPtrField<DeviceAttributes> added_cluster_device_attrs_pb(
201 added_cluster_device_attrs.begin(), added_cluster_device_attrs.end());
202 AsRemoteDevices(worker_env_->env, added_cluster_device_attrs_pb, nullptr,
203 &added_remote_devices);
204
205 TF_RETURN_IF_ERROR(worker_session->UpdateWorkerCacheAndDevices(
206 std::unique_ptr<WorkerCacheInterface>(worker_cache),
207 std::move(added_remote_devices), removed_remote_devices));
208 return Status::OK();
209 }
210
DeleteSession(const string & session)211 Status SessionMgr::DeleteSession(const string& session) {
212 mutex_lock l(mu_);
213 auto it = sessions_.find(session);
214 if (it != sessions_.end()) {
215 sessions_.erase(it);
216 }
217 return Status::OK();
218 }
219
WorkerSessionForSessionLocked(const string & session_handle,std::shared_ptr<WorkerSession> * out_session)220 Status SessionMgr::WorkerSessionForSessionLocked(
221 const string& session_handle, std::shared_ptr<WorkerSession>* out_session) {
222 if (session_handle.empty()) {
223 *out_session = legacy_session_;
224 } else {
225 auto it = sessions_.find(session_handle);
226 if (it == sessions_.end()) {
227 return errors::Aborted("Session handle is not found: ", session_handle,
228 ". Possibly this worker (\"",
229 legacy_session_->worker_name(),
230 "\") just restarted.");
231 } else {
232 *out_session = it->second;
233 }
234 }
235 return Status::OK();
236 }
237
WorkerSessionForSession(const string & session_handle,std::shared_ptr<WorkerSession> * out_session)238 Status SessionMgr::WorkerSessionForSession(
239 const string& session_handle, std::shared_ptr<WorkerSession>* out_session) {
240 mutex_lock l(mu_);
241 return WorkerSessionForSessionLocked(session_handle, out_session);
242 }
243
LegacySession()244 std::shared_ptr<WorkerSession> SessionMgr::LegacySession() {
245 return legacy_session_;
246 }
247
SetLogging(bool active)248 void SessionMgr::SetLogging(bool active) {
249 mutex_lock l(mu_);
250 this->is_logging_active_ = active;
251 // Legacy Session
252 if (legacy_session_) {
253 auto* worker_cache = legacy_session_->worker_cache();
254 if (worker_cache) {
255 worker_cache->SetLogging(active);
256 }
257 }
258
259 for (const auto& session_kv : sessions_) {
260 auto session = session_kv.second.get();
261 if (session) {
262 auto* worker_cache = session->worker_cache();
263 if (worker_cache) {
264 worker_cache->SetLogging(active);
265 }
266 }
267 }
268 }
269
RetrieveLogs(tensorflow::int64 step_id,LoggingResponse * response)270 void SessionMgr::RetrieveLogs(tensorflow::int64 step_id,
271 LoggingResponse* response) {
272 mutex_lock l(mu_);
273 // Legacy Session
274 if (legacy_session_) {
275 auto* worker_cache = legacy_session_->worker_cache();
276 if (worker_cache) {
277 auto step_stats = StepStats();
278 if (worker_cache->RetrieveLogs(step_id, &step_stats)) {
279 auto* labeled_step_stats = response->add_step();
280 labeled_step_stats->set_step_id(step_id);
281 labeled_step_stats->mutable_step_stats()->Swap(&step_stats);
282 }
283 }
284 }
285 for (const auto& session_kv : sessions_) {
286 auto session = session_kv.second.get();
287 if (session) {
288 auto* worker_cache = session->worker_cache();
289 if (worker_cache) {
290 auto step_stats = StepStats();
291 if (worker_cache->RetrieveLogs(step_id, &step_stats)) {
292 auto* labeled_step_stats = response->add_step();
293 labeled_step_stats->set_step_id(step_id);
294 labeled_step_stats->mutable_step_stats()->Swap(&step_stats);
295 }
296 }
297 }
298 }
299 }
300
ClearLogs()301 void SessionMgr::ClearLogs() {
302 mutex_lock l(mu_);
303 // Legacy Session
304 if (legacy_session_) {
305 auto* worker_cache = legacy_session_->worker_cache();
306 if (worker_cache) {
307 worker_cache->ClearLogs();
308 }
309 }
310
311 for (const auto& session_kv : sessions_) {
312 auto session = session_kv.second.get();
313 if (session) {
314 auto* worker_cache = session->worker_cache();
315 if (worker_cache) {
316 worker_cache->ClearLogs();
317 }
318 }
319 }
320 }
321 } // namespace tensorflow
322