1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/python/client/session_ref.h"
16
17 #include <stdlib.h>
18 #include <memory>
19 #include <utility>
20
21 #include "tensorflow/core/lib/io/path.h"
22 #include "tensorflow/core/lib/io/record_writer.h"
23 #include "tensorflow/core/lib/strings/stringprintf.h"
24 #include "tensorflow/core/protobuf/master.pb.h"
25 #include "tensorflow/core/protobuf/named_tensor.pb.h"
26 #include "tensorflow/core/protobuf/replay_log.pb.h"
27
28 namespace tensorflow {
29
30 namespace {
31
32 // Scope helper to track active calls and manage session lifetime.
33 // SessionRef blocks closing until all active calls complete or are cancelled.
34 struct RunCounter {
35 std::shared_ptr<Session> session;
36 uint64* value;
37 mutex* m;
38 condition_variable* cv;
39
RunCountertensorflow::__anone594e64f0111::RunCounter40 explicit RunCounter(std::shared_ptr<Session> s, uint64* v, mutex* m,
41 condition_variable* cv)
42 : session(std::move(s)), value(v), m(m), cv(cv) {
43 mutex_lock l(*m);
44 ++*value;
45 }
46
~RunCountertensorflow::__anone594e64f0111::RunCounter47 ~RunCounter() {
48 mutex_lock l(*m);
49 if (--*value == 0) {
50 cv->notify_all();
51 }
52 }
53 };
54
SessionToHandle(Session * session)55 std::string SessionToHandle(Session* session) {
56 return strings::Printf("%llu", static_cast<unsigned long long>(
57 reinterpret_cast<uintptr_t>(session)));
58 }
59
60 // The Session interface has many methods of the form:
61 //
62 // X(a, b);
63 // X(RunOptions, a, b);
64 //
65 // Not all sessions support the second case (with an empty RunOptions()).
66 // We use this variable as a sentinel to dispatch to the correct call.
kEmptyRunOptions()67 RunOptions* kEmptyRunOptions() {
68 static RunOptions* options = new RunOptions();
69 return options;
70 }
71
72 } // namespace
73
74 // Run the given session operation, recording start and end timestamps.
75 // If the operation returns a bad status, return after flushing the current
76 // log request. This should be run _after_ all request information has been
77 // added to the current op.
78 #define RUN_WITH_TIMESTAMP(OpName, ...) \
79 op.set_start_time_us(Env::Default()->NowMicros()); \
80 Status status = session->OpName(__VA_ARGS__); \
81 op.set_end_time_us(Env::Default()->NowMicros()); \
82 if (!status.ok()) { \
83 Flush(op).IgnoreError(); \
84 return status; \
85 }
86
87 // Records requests (and optionally responses) performed against a session.
88 // The resulting replay log can be used with the `tf_replay` tool to replicate
89 // the operations against a simulated environment, without requiring the
90 // original code or cluster setup.
91 //
92 // Session logging by setting the TF_REPLAY_LOG_FILE environment variable.
93 class SessionLogger {
94 public:
SessionLogger()95 SessionLogger() {
96 std::string log_name = getenv("TF_REPLAY_LOG_FILE");
97 LOG(INFO) << "Constructing new session logger for " << log_name;
98 TF_CHECK_OK(
99 Env::Default()->RecursivelyCreateDir(string(io::Dirname(log_name))));
100 Env::Default()->DeleteFile(log_name).IgnoreError();
101
102 TF_CHECK_OK(Env::Default()->NewWritableFile(log_name, &log_file_));
103 log_writer_ = absl::make_unique<io::RecordWriter>(log_file_.get());
104 }
105
~SessionLogger()106 ~SessionLogger() {
107 log_writer_->Close().IgnoreError();
108 log_writer_.release();
109 log_file_->Close().IgnoreError();
110 }
111
RecordNewSession(Session * session)112 Status RecordNewSession(Session* session) {
113 ReplayOp op;
114 NewReplaySession* req = op.mutable_new_replay_session();
115 req->set_session_handle(SessionToHandle(session));
116 return Flush(op);
117 }
118
RecordRun(Session * session,const std::vector<std::pair<string,Tensor>> & inputs,const std::vector<string> & output_tensor_names,const std::vector<string> & target_node_names,std::vector<Tensor> * outputs)119 Status RecordRun(Session* session,
120 const std::vector<std::pair<string, Tensor> >& inputs,
121 const std::vector<string>& output_tensor_names,
122 const std::vector<string>& target_node_names,
123 std::vector<Tensor>* outputs) {
124 return RecordRun(session, *kEmptyRunOptions(), inputs, output_tensor_names,
125 target_node_names, outputs, nullptr);
126 }
127
RecordRun(Session * session,const RunOptions & run_options,const std::vector<std::pair<string,Tensor>> & inputs,const std::vector<string> & output_tensor_names,const std::vector<string> & target_node_names,std::vector<Tensor> * outputs,RunMetadata * run_metadata)128 Status RecordRun(Session* session, const RunOptions& run_options,
129 const std::vector<std::pair<string, Tensor> >& inputs,
130 const std::vector<string>& output_tensor_names,
131 const std::vector<string>& target_node_names,
132 std::vector<Tensor>* outputs, RunMetadata* run_metadata) {
133 ReplayOp op;
134 RunStepRequest* req = op.mutable_run_step();
135 RunStepResponse* resp = op.mutable_run_step_response();
136
137 req->set_session_handle(SessionToHandle(session));
138 *req->mutable_options() = run_options;
139
140 for (const auto& it : inputs) {
141 NamedTensorProto* feed = req->add_feed();
142 feed->set_name(it.first);
143 it.second.AsProtoField(feed->mutable_tensor());
144 }
145
146 // Build an index from fetch tensor name to first index in
147 // output_tensor_names.
148 std::unordered_map<string, int> output_name_to_offset;
149 for (int i = 0, end = output_tensor_names.size(); i < end; ++i) {
150 const string& name = output_tensor_names[i];
151 if (output_name_to_offset.insert(std::make_pair(name, i)).second) {
152 req->add_fetch(name);
153 }
154 }
155 for (const string& target : target_node_names) {
156 req->add_target(target);
157 }
158
159 if (&run_options == kEmptyRunOptions()) {
160 RUN_WITH_TIMESTAMP(Run, inputs, output_tensor_names, target_node_names,
161 outputs);
162 } else {
163 RUN_WITH_TIMESTAMP(Run, run_options, inputs, output_tensor_names,
164 target_node_names, outputs, run_metadata);
165 }
166
167 for (size_t i = 0; i < outputs->size(); ++i) {
168 const Tensor& tensor = (*outputs)[i];
169 NamedTensorProto* tproto = resp->add_tensor();
170 tensor.AsProtoField(tproto->mutable_tensor());
171 tproto->set_name(output_tensor_names[i]);
172 }
173
174 if (run_metadata) {
175 *resp->mutable_metadata() = *run_metadata;
176 }
177
178 return Flush(op);
179 }
180
RecordCreate(Session * session,const GraphDef & graph)181 Status RecordCreate(Session* session, const GraphDef& graph) {
182 return RecordCreate(session, *kEmptyRunOptions(), graph);
183 }
184
185 // N.B. RunOptions is not stored (it has no entry in CreateRequest)
RecordCreate(Session * session,const RunOptions & run_options,const GraphDef & graph)186 Status RecordCreate(Session* session, const RunOptions& run_options,
187 const GraphDef& graph) {
188 ReplayOp op;
189 CreateSessionRequest* req = op.mutable_create_session();
190 *req->mutable_graph_def() = graph;
191
192 CreateSessionResponse* resp = op.mutable_create_session_response();
193 if (&run_options == kEmptyRunOptions()) {
194 RUN_WITH_TIMESTAMP(Create, graph);
195 } else {
196 RUN_WITH_TIMESTAMP(Create, run_options, graph);
197 }
198 resp->set_session_handle(SessionToHandle(session));
199 return Flush(op);
200 }
201
RecordExtend(Session * session,const GraphDef & graph)202 Status RecordExtend(Session* session, const GraphDef& graph) {
203 return RecordExtend(session, *kEmptyRunOptions(), graph);
204 }
205
206 // N.B. RunOptions is not stored (it has no entry in ExtendRequest)
RecordExtend(Session * session,const RunOptions & run_options,const GraphDef & graph)207 Status RecordExtend(Session* session, const RunOptions& run_options,
208 const GraphDef& graph) {
209 ReplayOp op;
210 ExtendSessionRequest* req = op.mutable_extend_session();
211 op.mutable_extend_session_response();
212 req->set_session_handle(SessionToHandle(session));
213 *req->mutable_graph_def() = graph;
214 if (&run_options == kEmptyRunOptions()) {
215 RUN_WITH_TIMESTAMP(Extend, graph);
216 } else {
217 RUN_WITH_TIMESTAMP(Extend, run_options, graph);
218 }
219
220 return Flush(op);
221 }
222
RecordClose(Session * session)223 Status RecordClose(Session* session) {
224 return RecordClose(session, *kEmptyRunOptions());
225 }
226
227 // N.B. RunOptions is not stored (it has no entry in CloseRequest)
RecordClose(Session * session,const RunOptions & run_options)228 Status RecordClose(Session* session, const RunOptions& run_options) {
229 ReplayOp op;
230 CloseSessionRequest* req = op.mutable_close_session();
231 req->set_session_handle(SessionToHandle(session));
232 op.mutable_close_session_response();
233 if (&run_options == kEmptyRunOptions()) {
234 RUN_WITH_TIMESTAMP(Close);
235 } else {
236 RUN_WITH_TIMESTAMP(Close, run_options);
237 }
238 return Flush(op);
239 }
240
RecordListDevices(Session * session,std::vector<DeviceAttributes> * response)241 Status RecordListDevices(Session* session,
242 std::vector<DeviceAttributes>* response) {
243 ReplayOp op;
244 ListDevicesRequest* req = op.mutable_list_devices();
245 ListDevicesResponse* resp = op.mutable_list_devices_response();
246 req->set_session_handle(SessionToHandle(session));
247 RUN_WITH_TIMESTAMP(ListDevices, response);
248
249 // TODO(power) -- local vs remote device distinction is lost here!
250 *resp->mutable_local_device() = {response->begin(), response->end()};
251 return Flush(op);
252 }
253
RecordPRunSetup(Session * session,const std::vector<string> & input_names,const std::vector<string> & output_names,const std::vector<string> & target_nodes,string * handle)254 Status RecordPRunSetup(Session* session,
255 const std::vector<string>& input_names,
256 const std::vector<string>& output_names,
257 const std::vector<string>& target_nodes,
258 string* handle) {
259 ReplayOp op;
260 PartialRunSetupRequest* req = op.mutable_partial_run_setup();
261 req->set_session_handle(SessionToHandle(session));
262 for (auto& input : input_names) {
263 req->add_feed(input);
264 }
265 for (auto& output : output_names) {
266 req->add_fetch(output);
267 }
268 for (auto& target : target_nodes) {
269 req->add_target(target);
270 }
271 RUN_WITH_TIMESTAMP(PRunSetup, input_names, output_names, target_nodes,
272 handle);
273 op.mutable_partial_run_setup_response()->set_partial_run_handle(*handle);
274 return Flush(op);
275 }
276
RecordPRun(Session * session,const string & handle,const std::vector<std::pair<string,Tensor>> & inputs,const std::vector<string> & output_names,std::vector<Tensor> * outputs)277 Status RecordPRun(Session* session, const string& handle,
278 const std::vector<std::pair<string, Tensor> >& inputs,
279 const std::vector<string>& output_names,
280 std::vector<Tensor>* outputs) {
281 ReplayOp op;
282 RunStepRequest* req = op.mutable_run_step();
283 RunStepResponse* resp = op.mutable_run_step_response();
284 req->set_session_handle(SessionToHandle(session));
285
286 // Mark this step as a partial run for replay.
287 req->set_partial_run_handle(handle);
288 for (auto& input : inputs) {
289 auto* feed = req->add_feed();
290 feed->set_name(input.first);
291 input.second.AsProtoField(feed->mutable_tensor());
292 }
293
294 for (auto& output : output_names) {
295 req->add_fetch(output);
296 }
297
298 RUN_WITH_TIMESTAMP(PRun, handle, inputs, output_names, outputs);
299
300 for (size_t i = 0; i < outputs->size(); ++i) {
301 const Tensor& tensor = (*outputs)[i];
302 NamedTensorProto* tproto = resp->add_tensor();
303 tensor.AsProtoField(tproto->mutable_tensor());
304 tproto->set_name(output_names[i]);
305 }
306
307 return Flush(op);
308 }
309
RecordMakeCallable(Session * session,const CallableOptions & callable_options,Session::CallableHandle * handle)310 Status RecordMakeCallable(Session* session,
311 const CallableOptions& callable_options,
312 Session::CallableHandle* handle) {
313 ReplayOp op;
314 MakeCallableRequest* req = op.mutable_make_callable();
315 req->set_session_handle(SessionToHandle(session));
316 *req->mutable_options() = callable_options;
317
318 RUN_WITH_TIMESTAMP(MakeCallable, callable_options, handle);
319
320 MakeCallableResponse* resp = op.mutable_make_callable_response();
321 resp->set_handle(*handle);
322
323 return Flush(op);
324 }
325
RecordRunCallable(Session * session,Session::CallableHandle handle,const std::vector<Tensor> & feed_tensors,std::vector<Tensor> * fetch_tensors,RunMetadata * run_metadata)326 Status RecordRunCallable(Session* session, Session::CallableHandle handle,
327 const std::vector<Tensor>& feed_tensors,
328 std::vector<Tensor>* fetch_tensors,
329 RunMetadata* run_metadata) {
330 ReplayOp op;
331 RunCallableRequest* req = op.mutable_run_callable();
332 req->set_session_handle(SessionToHandle(session));
333 req->set_handle(handle);
334 for (auto& tensor : feed_tensors) {
335 tensor.AsProtoField(req->add_feed());
336 }
337 RUN_WITH_TIMESTAMP(RunCallable, handle, feed_tensors, fetch_tensors,
338 run_metadata);
339
340 RunCallableResponse* resp = op.mutable_run_callable_response();
341 if (run_metadata) {
342 *resp->mutable_metadata() = *run_metadata;
343 }
344 for (const Tensor& tensor : *fetch_tensors) {
345 tensor.AsProtoTensorContent(resp->add_fetch());
346 }
347 return Flush(op);
348 }
349
RecordReleaseCallable(Session * session,Session::CallableHandle handle)350 Status RecordReleaseCallable(Session* session,
351 Session::CallableHandle handle) {
352 ReplayOp op;
353 ReleaseCallableRequest* req = op.mutable_release_callable();
354 req->set_session_handle(SessionToHandle(session));
355 req->set_handle(handle);
356 RUN_WITH_TIMESTAMP(ReleaseCallable, handle);
357 return Flush(op);
358 }
359
360 private:
Flush(const ReplayOp & op)361 Status Flush(const ReplayOp& op) {
362 mutex_lock l(log_mutex_);
363
364 string buf;
365 op.SerializeToString(&buf);
366 TF_RETURN_IF_ERROR(log_writer_->WriteRecord(buf));
367
368 // TODO(b/116624106): Not all file-systems respect calls to `Sync()`
369 return log_file_->Sync();
370 }
371
372 std::unique_ptr<WritableFile> log_file_;
373 std::unique_ptr<io::RecordWriter> log_writer_;
374 mutex log_mutex_;
375 };
376
global_session_logger()377 static SessionLogger* global_session_logger() {
378 static SessionLogger* logger = new SessionLogger();
379 return logger;
380 }
381
SessionRef(Session * session)382 SessionRef::SessionRef(Session* session) : session_(session) {
383 if (getenv("TF_REPLAY_LOG_FILE") != nullptr) {
384 logger_ = global_session_logger();
385 logger_->RecordNewSession(this->session_.get()).IgnoreError();
386 } else {
387 logger_ = nullptr;
388 }
389 }
390
391 SessionRef::~SessionRef() = default;
392
CheckNotClosed()393 Status SessionRef::CheckNotClosed() {
394 mutex_lock l(run_lock_);
395 if (session_ == nullptr) return errors::Cancelled("Session has been closed.");
396 return ::tensorflow::Status::OK();
397 }
398
399 // If logging is active, log the start and end time of the operation along with
400 // the request and response.
401 #define LOG_AND_RUN_OPERATION(OpName, ...) \
402 TF_RETURN_IF_ERROR(CheckNotClosed()); \
403 RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_); \
404 if (!logger_) { \
405 return rc.session->OpName(__VA_ARGS__); \
406 } \
407 return logger_->Record##OpName(rc.session.get(), __VA_ARGS__);
408
Run(const RunOptions & run_options,const std::vector<std::pair<string,Tensor>> & inputs,const std::vector<string> & output_tensor_names,const std::vector<string> & target_node_names,std::vector<Tensor> * outputs,RunMetadata * run_metadata)409 Status SessionRef::Run(const RunOptions& run_options,
410 const std::vector<std::pair<string, Tensor> >& inputs,
411 const std::vector<string>& output_tensor_names,
412 const std::vector<string>& target_node_names,
413 std::vector<Tensor>* outputs,
414 RunMetadata* run_metadata) {
415 LOG_AND_RUN_OPERATION(Run, run_options, inputs, output_tensor_names,
416 target_node_names, outputs, run_metadata);
417 }
418
Run(const std::vector<std::pair<string,Tensor>> & inputs,const std::vector<string> & output_tensor_names,const std::vector<string> & target_node_names,std::vector<Tensor> * outputs)419 Status SessionRef::Run(const std::vector<std::pair<string, Tensor> >& inputs,
420 const std::vector<string>& output_tensor_names,
421 const std::vector<string>& target_node_names,
422 std::vector<Tensor>* outputs) {
423 LOG_AND_RUN_OPERATION(Run, inputs, output_tensor_names, target_node_names,
424 outputs);
425 }
426
Create(const GraphDef & graph)427 Status SessionRef::Create(const GraphDef& graph) {
428 LOG_AND_RUN_OPERATION(Create, graph);
429 }
430
Create(const RunOptions & run_options,const GraphDef & graph)431 Status SessionRef::Create(const RunOptions& run_options,
432 const GraphDef& graph) {
433 LOG_AND_RUN_OPERATION(Create, run_options, graph);
434 }
435
Extend(const RunOptions & run_options,const GraphDef & graph)436 Status SessionRef::Extend(const RunOptions& run_options,
437 const GraphDef& graph) {
438 LOG_AND_RUN_OPERATION(Extend, run_options, graph);
439 }
440
Extend(const GraphDef & graph)441 Status SessionRef::Extend(const GraphDef& graph) {
442 LOG_AND_RUN_OPERATION(Extend, graph);
443 }
444
ListDevices(std::vector<DeviceAttributes> * response)445 Status SessionRef::ListDevices(std::vector<DeviceAttributes>* response) {
446 LOG_AND_RUN_OPERATION(ListDevices, response);
447 }
448
PRunSetup(const std::vector<string> & input_names,const std::vector<string> & output_names,const std::vector<string> & target_nodes,string * handle)449 Status SessionRef::PRunSetup(const std::vector<string>& input_names,
450 const std::vector<string>& output_names,
451 const std::vector<string>& target_nodes,
452 string* handle) {
453 LOG_AND_RUN_OPERATION(PRunSetup, input_names, output_names, target_nodes,
454 handle);
455 }
456
PRun(const string & handle,const std::vector<std::pair<string,Tensor>> & inputs,const std::vector<string> & output_names,std::vector<Tensor> * outputs)457 Status SessionRef::PRun(const string& handle,
458 const std::vector<std::pair<string, Tensor> >& inputs,
459 const std::vector<string>& output_names,
460 std::vector<Tensor>* outputs) {
461 LOG_AND_RUN_OPERATION(PRun, handle, inputs, output_names, outputs);
462 }
463
MakeCallable(const CallableOptions & callable_options,CallableHandle * out_handle)464 Status SessionRef::MakeCallable(const CallableOptions& callable_options,
465 CallableHandle* out_handle) {
466 LOG_AND_RUN_OPERATION(MakeCallable, callable_options, out_handle);
467 }
468
RunCallable(CallableHandle handle,const std::vector<Tensor> & feed_tensors,std::vector<Tensor> * fetch_tensors,RunMetadata * run_metadata)469 Status SessionRef::RunCallable(CallableHandle handle,
470 const std::vector<Tensor>& feed_tensors,
471 std::vector<Tensor>* fetch_tensors,
472 RunMetadata* run_metadata) {
473 LOG_AND_RUN_OPERATION(RunCallable, handle, feed_tensors, fetch_tensors,
474 run_metadata);
475 }
476
ReleaseCallable(CallableHandle handle)477 Status SessionRef::ReleaseCallable(CallableHandle handle) {
478 {
479 mutex_lock l(run_lock_);
480 if (session_ == nullptr) {
481 // Session already closed. Do nothing.
482 return Status::OK();
483 }
484 }
485 LOG_AND_RUN_OPERATION(ReleaseCallable, handle);
486 }
487
Close(const RunOptions & run_options)488 Status SessionRef::Close(const RunOptions& run_options) {
489 TF_RETURN_IF_ERROR(CheckNotClosed());
490 mutex_lock l(run_lock_);
491 Status status;
492 if (logger_) {
493 status = logger_->RecordClose(session_.get(), run_options);
494 } else {
495 status = session_->Close(run_options);
496 }
497 session_.reset();
498 while (run_count_ > 0) {
499 run_finished_.wait(l);
500 }
501 return status;
502 }
503
Close()504 Status SessionRef::Close() {
505 TF_RETURN_IF_ERROR(CheckNotClosed());
506 mutex_lock l(run_lock_);
507 Status status;
508 if (logger_) {
509 status = logger_->RecordClose(session_.get());
510 } else {
511 status = session_->Close();
512 }
513 session_.reset();
514 while (run_count_ > 0) {
515 run_finished_.wait(l);
516 }
517 return status;
518 }
519
520 } // namespace tensorflow
521