• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/python/client/session_ref.h"
16 
17 #include <stdlib.h>
18 #include <memory>
19 #include <utility>
20 
21 #include "tensorflow/core/lib/io/path.h"
22 #include "tensorflow/core/lib/io/record_writer.h"
23 #include "tensorflow/core/lib/strings/stringprintf.h"
24 #include "tensorflow/core/protobuf/master.pb.h"
25 #include "tensorflow/core/protobuf/named_tensor.pb.h"
26 #include "tensorflow/core/protobuf/replay_log.pb.h"
27 
28 namespace tensorflow {
29 
30 namespace {
31 
32 // Scope helper to track active calls and manage session lifetime.
33 // SessionRef blocks closing until all active calls complete or are cancelled.
34 struct RunCounter {
35   std::shared_ptr<Session> session;
36   uint64* value;
37   mutex* m;
38   condition_variable* cv;
39 
RunCountertensorflow::__anone594e64f0111::RunCounter40   explicit RunCounter(std::shared_ptr<Session> s, uint64* v, mutex* m,
41                       condition_variable* cv)
42       : session(std::move(s)), value(v), m(m), cv(cv) {
43     mutex_lock l(*m);
44     ++*value;
45   }
46 
~RunCountertensorflow::__anone594e64f0111::RunCounter47   ~RunCounter() {
48     mutex_lock l(*m);
49     if (--*value == 0) {
50       cv->notify_all();
51     }
52   }
53 };
54 
SessionToHandle(Session * session)55 std::string SessionToHandle(Session* session) {
56   return strings::Printf("%llu", static_cast<unsigned long long>(
57                                      reinterpret_cast<uintptr_t>(session)));
58 }
59 
60 // The Session interface has many methods of the form:
61 //
62 // X(a, b);
63 // X(RunOptions, a, b);
64 //
65 // Not all sessions support the second case (with an empty RunOptions()).
66 // We use this variable as a sentinel to dispatch to the correct call.
kEmptyRunOptions()67 RunOptions* kEmptyRunOptions() {
68   static RunOptions* options = new RunOptions();
69   return options;
70 }
71 
72 }  // namespace
73 
74 // Run the given session operation, recording start and end timestamps.
75 // If the operation returns a bad status, return after flushing the current
76 // log request.  This should be run _after_ all request information has been
77 // added to the current op.
78 #define RUN_WITH_TIMESTAMP(OpName, ...)              \
79   op.set_start_time_us(Env::Default()->NowMicros()); \
80   Status status = session->OpName(__VA_ARGS__);      \
81   op.set_end_time_us(Env::Default()->NowMicros());   \
82   if (!status.ok()) {                                \
83     Flush(op).IgnoreError();                         \
84     return status;                                   \
85   }
86 
87 // Records requests (and optionally responses) performed against a session.
88 // The resulting replay log can be used with the `tf_replay` tool to replicate
89 // the operations against a simulated environment, without requiring the
90 // original code or cluster setup.
91 //
92 // Session logging by setting the TF_REPLAY_LOG_FILE environment variable.
93 class SessionLogger {
94  public:
SessionLogger()95   SessionLogger() {
96     std::string log_name = getenv("TF_REPLAY_LOG_FILE");
97     LOG(INFO) << "Constructing new session logger for " << log_name;
98     TF_CHECK_OK(
99         Env::Default()->RecursivelyCreateDir(string(io::Dirname(log_name))));
100     Env::Default()->DeleteFile(log_name).IgnoreError();
101 
102     TF_CHECK_OK(Env::Default()->NewWritableFile(log_name, &log_file_));
103     log_writer_ = absl::make_unique<io::RecordWriter>(log_file_.get());
104   }
105 
~SessionLogger()106   ~SessionLogger() {
107     log_writer_->Close().IgnoreError();
108     log_writer_.release();
109     log_file_->Close().IgnoreError();
110   }
111 
RecordNewSession(Session * session)112   Status RecordNewSession(Session* session) {
113     ReplayOp op;
114     NewReplaySession* req = op.mutable_new_replay_session();
115     req->set_session_handle(SessionToHandle(session));
116     return Flush(op);
117   }
118 
RecordRun(Session * session,const std::vector<std::pair<string,Tensor>> & inputs,const std::vector<string> & output_tensor_names,const std::vector<string> & target_node_names,std::vector<Tensor> * outputs)119   Status RecordRun(Session* session,
120                    const std::vector<std::pair<string, Tensor> >& inputs,
121                    const std::vector<string>& output_tensor_names,
122                    const std::vector<string>& target_node_names,
123                    std::vector<Tensor>* outputs) {
124     return RecordRun(session, *kEmptyRunOptions(), inputs, output_tensor_names,
125                      target_node_names, outputs, nullptr);
126   }
127 
RecordRun(Session * session,const RunOptions & run_options,const std::vector<std::pair<string,Tensor>> & inputs,const std::vector<string> & output_tensor_names,const std::vector<string> & target_node_names,std::vector<Tensor> * outputs,RunMetadata * run_metadata)128   Status RecordRun(Session* session, const RunOptions& run_options,
129                    const std::vector<std::pair<string, Tensor> >& inputs,
130                    const std::vector<string>& output_tensor_names,
131                    const std::vector<string>& target_node_names,
132                    std::vector<Tensor>* outputs, RunMetadata* run_metadata) {
133     ReplayOp op;
134     RunStepRequest* req = op.mutable_run_step();
135     RunStepResponse* resp = op.mutable_run_step_response();
136 
137     req->set_session_handle(SessionToHandle(session));
138     *req->mutable_options() = run_options;
139 
140     for (const auto& it : inputs) {
141       NamedTensorProto* feed = req->add_feed();
142       feed->set_name(it.first);
143       it.second.AsProtoField(feed->mutable_tensor());
144     }
145 
146     // Build an index from fetch tensor name to first index in
147     // output_tensor_names.
148     std::unordered_map<string, int> output_name_to_offset;
149     for (int i = 0, end = output_tensor_names.size(); i < end; ++i) {
150       const string& name = output_tensor_names[i];
151       if (output_name_to_offset.insert(std::make_pair(name, i)).second) {
152         req->add_fetch(name);
153       }
154     }
155     for (const string& target : target_node_names) {
156       req->add_target(target);
157     }
158 
159     if (&run_options == kEmptyRunOptions()) {
160       RUN_WITH_TIMESTAMP(Run, inputs, output_tensor_names, target_node_names,
161                          outputs);
162     } else {
163       RUN_WITH_TIMESTAMP(Run, run_options, inputs, output_tensor_names,
164                          target_node_names, outputs, run_metadata);
165     }
166 
167     for (size_t i = 0; i < outputs->size(); ++i) {
168       const Tensor& tensor = (*outputs)[i];
169       NamedTensorProto* tproto = resp->add_tensor();
170       tensor.AsProtoField(tproto->mutable_tensor());
171       tproto->set_name(output_tensor_names[i]);
172     }
173 
174     if (run_metadata) {
175       *resp->mutable_metadata() = *run_metadata;
176     }
177 
178     return Flush(op);
179   }
180 
RecordCreate(Session * session,const GraphDef & graph)181   Status RecordCreate(Session* session, const GraphDef& graph) {
182     return RecordCreate(session, *kEmptyRunOptions(), graph);
183   }
184 
185   // N.B. RunOptions is not stored (it has no entry in CreateRequest)
RecordCreate(Session * session,const RunOptions & run_options,const GraphDef & graph)186   Status RecordCreate(Session* session, const RunOptions& run_options,
187                       const GraphDef& graph) {
188     ReplayOp op;
189     CreateSessionRequest* req = op.mutable_create_session();
190     *req->mutable_graph_def() = graph;
191 
192     CreateSessionResponse* resp = op.mutable_create_session_response();
193     if (&run_options == kEmptyRunOptions()) {
194       RUN_WITH_TIMESTAMP(Create, graph);
195     } else {
196       RUN_WITH_TIMESTAMP(Create, run_options, graph);
197     }
198     resp->set_session_handle(SessionToHandle(session));
199     return Flush(op);
200   }
201 
RecordExtend(Session * session,const GraphDef & graph)202   Status RecordExtend(Session* session, const GraphDef& graph) {
203     return RecordExtend(session, *kEmptyRunOptions(), graph);
204   }
205 
206   // N.B. RunOptions is not stored (it has no entry in ExtendRequest)
RecordExtend(Session * session,const RunOptions & run_options,const GraphDef & graph)207   Status RecordExtend(Session* session, const RunOptions& run_options,
208                       const GraphDef& graph) {
209     ReplayOp op;
210     ExtendSessionRequest* req = op.mutable_extend_session();
211     op.mutable_extend_session_response();
212     req->set_session_handle(SessionToHandle(session));
213     *req->mutable_graph_def() = graph;
214     if (&run_options == kEmptyRunOptions()) {
215       RUN_WITH_TIMESTAMP(Extend, graph);
216     } else {
217       RUN_WITH_TIMESTAMP(Extend, run_options, graph);
218     }
219 
220     return Flush(op);
221   }
222 
RecordClose(Session * session)223   Status RecordClose(Session* session) {
224     return RecordClose(session, *kEmptyRunOptions());
225   }
226 
227   // N.B. RunOptions is not stored (it has no entry in CloseRequest)
RecordClose(Session * session,const RunOptions & run_options)228   Status RecordClose(Session* session, const RunOptions& run_options) {
229     ReplayOp op;
230     CloseSessionRequest* req = op.mutable_close_session();
231     req->set_session_handle(SessionToHandle(session));
232     op.mutable_close_session_response();
233     if (&run_options == kEmptyRunOptions()) {
234       RUN_WITH_TIMESTAMP(Close);
235     } else {
236       RUN_WITH_TIMESTAMP(Close, run_options);
237     }
238     return Flush(op);
239   }
240 
RecordListDevices(Session * session,std::vector<DeviceAttributes> * response)241   Status RecordListDevices(Session* session,
242                            std::vector<DeviceAttributes>* response) {
243     ReplayOp op;
244     ListDevicesRequest* req = op.mutable_list_devices();
245     ListDevicesResponse* resp = op.mutable_list_devices_response();
246     req->set_session_handle(SessionToHandle(session));
247     RUN_WITH_TIMESTAMP(ListDevices, response);
248 
249     // TODO(power) -- local vs remote device distinction is lost here!
250     *resp->mutable_local_device() = {response->begin(), response->end()};
251     return Flush(op);
252   }
253 
RecordPRunSetup(Session * session,const std::vector<string> & input_names,const std::vector<string> & output_names,const std::vector<string> & target_nodes,string * handle)254   Status RecordPRunSetup(Session* session,
255                          const std::vector<string>& input_names,
256                          const std::vector<string>& output_names,
257                          const std::vector<string>& target_nodes,
258                          string* handle) {
259     ReplayOp op;
260     PartialRunSetupRequest* req = op.mutable_partial_run_setup();
261     req->set_session_handle(SessionToHandle(session));
262     for (auto& input : input_names) {
263       req->add_feed(input);
264     }
265     for (auto& output : output_names) {
266       req->add_fetch(output);
267     }
268     for (auto& target : target_nodes) {
269       req->add_target(target);
270     }
271     RUN_WITH_TIMESTAMP(PRunSetup, input_names, output_names, target_nodes,
272                        handle);
273     op.mutable_partial_run_setup_response()->set_partial_run_handle(*handle);
274     return Flush(op);
275   }
276 
RecordPRun(Session * session,const string & handle,const std::vector<std::pair<string,Tensor>> & inputs,const std::vector<string> & output_names,std::vector<Tensor> * outputs)277   Status RecordPRun(Session* session, const string& handle,
278                     const std::vector<std::pair<string, Tensor> >& inputs,
279                     const std::vector<string>& output_names,
280                     std::vector<Tensor>* outputs) {
281     ReplayOp op;
282     RunStepRequest* req = op.mutable_run_step();
283     RunStepResponse* resp = op.mutable_run_step_response();
284     req->set_session_handle(SessionToHandle(session));
285 
286     // Mark this step as a partial run for replay.
287     req->set_partial_run_handle(handle);
288     for (auto& input : inputs) {
289       auto* feed = req->add_feed();
290       feed->set_name(input.first);
291       input.second.AsProtoField(feed->mutable_tensor());
292     }
293 
294     for (auto& output : output_names) {
295       req->add_fetch(output);
296     }
297 
298     RUN_WITH_TIMESTAMP(PRun, handle, inputs, output_names, outputs);
299 
300     for (size_t i = 0; i < outputs->size(); ++i) {
301       const Tensor& tensor = (*outputs)[i];
302       NamedTensorProto* tproto = resp->add_tensor();
303       tensor.AsProtoField(tproto->mutable_tensor());
304       tproto->set_name(output_names[i]);
305     }
306 
307     return Flush(op);
308   }
309 
RecordMakeCallable(Session * session,const CallableOptions & callable_options,Session::CallableHandle * handle)310   Status RecordMakeCallable(Session* session,
311                             const CallableOptions& callable_options,
312                             Session::CallableHandle* handle) {
313     ReplayOp op;
314     MakeCallableRequest* req = op.mutable_make_callable();
315     req->set_session_handle(SessionToHandle(session));
316     *req->mutable_options() = callable_options;
317 
318     RUN_WITH_TIMESTAMP(MakeCallable, callable_options, handle);
319 
320     MakeCallableResponse* resp = op.mutable_make_callable_response();
321     resp->set_handle(*handle);
322 
323     return Flush(op);
324   }
325 
RecordRunCallable(Session * session,Session::CallableHandle handle,const std::vector<Tensor> & feed_tensors,std::vector<Tensor> * fetch_tensors,RunMetadata * run_metadata)326   Status RecordRunCallable(Session* session, Session::CallableHandle handle,
327                            const std::vector<Tensor>& feed_tensors,
328                            std::vector<Tensor>* fetch_tensors,
329                            RunMetadata* run_metadata) {
330     ReplayOp op;
331     RunCallableRequest* req = op.mutable_run_callable();
332     req->set_session_handle(SessionToHandle(session));
333     req->set_handle(handle);
334     for (auto& tensor : feed_tensors) {
335       tensor.AsProtoField(req->add_feed());
336     }
337     RUN_WITH_TIMESTAMP(RunCallable, handle, feed_tensors, fetch_tensors,
338                        run_metadata);
339 
340     RunCallableResponse* resp = op.mutable_run_callable_response();
341     if (run_metadata) {
342       *resp->mutable_metadata() = *run_metadata;
343     }
344     for (const Tensor& tensor : *fetch_tensors) {
345       tensor.AsProtoTensorContent(resp->add_fetch());
346     }
347     return Flush(op);
348   }
349 
RecordReleaseCallable(Session * session,Session::CallableHandle handle)350   Status RecordReleaseCallable(Session* session,
351                                Session::CallableHandle handle) {
352     ReplayOp op;
353     ReleaseCallableRequest* req = op.mutable_release_callable();
354     req->set_session_handle(SessionToHandle(session));
355     req->set_handle(handle);
356     RUN_WITH_TIMESTAMP(ReleaseCallable, handle);
357     return Flush(op);
358   }
359 
360  private:
Flush(const ReplayOp & op)361   Status Flush(const ReplayOp& op) {
362     mutex_lock l(log_mutex_);
363 
364     string buf;
365     op.SerializeToString(&buf);
366     TF_RETURN_IF_ERROR(log_writer_->WriteRecord(buf));
367 
368     // TODO(b/116624106): Not all file-systems respect calls to `Sync()`
369     return log_file_->Sync();
370   }
371 
372   std::unique_ptr<WritableFile> log_file_;
373   std::unique_ptr<io::RecordWriter> log_writer_;
374   mutex log_mutex_;
375 };
376 
global_session_logger()377 static SessionLogger* global_session_logger() {
378   static SessionLogger* logger = new SessionLogger();
379   return logger;
380 }
381 
SessionRef(Session * session)382 SessionRef::SessionRef(Session* session) : session_(session) {
383   if (getenv("TF_REPLAY_LOG_FILE") != nullptr) {
384     logger_ = global_session_logger();
385     logger_->RecordNewSession(this->session_.get()).IgnoreError();
386   } else {
387     logger_ = nullptr;
388   }
389 }
390 
391 SessionRef::~SessionRef() = default;
392 
CheckNotClosed()393 Status SessionRef::CheckNotClosed() {
394   mutex_lock l(run_lock_);
395   if (session_ == nullptr) return errors::Cancelled("Session has been closed.");
396   return ::tensorflow::Status::OK();
397 }
398 
399 // If logging is active, log the start and end time of the operation along with
400 // the request and response.
401 #define LOG_AND_RUN_OPERATION(OpName, ...)                          \
402   TF_RETURN_IF_ERROR(CheckNotClosed());                             \
403   RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_); \
404   if (!logger_) {                                                   \
405     return rc.session->OpName(__VA_ARGS__);                         \
406   }                                                                 \
407   return logger_->Record##OpName(rc.session.get(), __VA_ARGS__);
408 
Run(const RunOptions & run_options,const std::vector<std::pair<string,Tensor>> & inputs,const std::vector<string> & output_tensor_names,const std::vector<string> & target_node_names,std::vector<Tensor> * outputs,RunMetadata * run_metadata)409 Status SessionRef::Run(const RunOptions& run_options,
410                        const std::vector<std::pair<string, Tensor> >& inputs,
411                        const std::vector<string>& output_tensor_names,
412                        const std::vector<string>& target_node_names,
413                        std::vector<Tensor>* outputs,
414                        RunMetadata* run_metadata) {
415   LOG_AND_RUN_OPERATION(Run, run_options, inputs, output_tensor_names,
416                         target_node_names, outputs, run_metadata);
417 }
418 
Run(const std::vector<std::pair<string,Tensor>> & inputs,const std::vector<string> & output_tensor_names,const std::vector<string> & target_node_names,std::vector<Tensor> * outputs)419 Status SessionRef::Run(const std::vector<std::pair<string, Tensor> >& inputs,
420                        const std::vector<string>& output_tensor_names,
421                        const std::vector<string>& target_node_names,
422                        std::vector<Tensor>* outputs) {
423   LOG_AND_RUN_OPERATION(Run, inputs, output_tensor_names, target_node_names,
424                         outputs);
425 }
426 
Create(const GraphDef & graph)427 Status SessionRef::Create(const GraphDef& graph) {
428   LOG_AND_RUN_OPERATION(Create, graph);
429 }
430 
Create(const RunOptions & run_options,const GraphDef & graph)431 Status SessionRef::Create(const RunOptions& run_options,
432                           const GraphDef& graph) {
433   LOG_AND_RUN_OPERATION(Create, run_options, graph);
434 }
435 
Extend(const RunOptions & run_options,const GraphDef & graph)436 Status SessionRef::Extend(const RunOptions& run_options,
437                           const GraphDef& graph) {
438   LOG_AND_RUN_OPERATION(Extend, run_options, graph);
439 }
440 
Extend(const GraphDef & graph)441 Status SessionRef::Extend(const GraphDef& graph) {
442   LOG_AND_RUN_OPERATION(Extend, graph);
443 }
444 
ListDevices(std::vector<DeviceAttributes> * response)445 Status SessionRef::ListDevices(std::vector<DeviceAttributes>* response) {
446   LOG_AND_RUN_OPERATION(ListDevices, response);
447 }
448 
PRunSetup(const std::vector<string> & input_names,const std::vector<string> & output_names,const std::vector<string> & target_nodes,string * handle)449 Status SessionRef::PRunSetup(const std::vector<string>& input_names,
450                              const std::vector<string>& output_names,
451                              const std::vector<string>& target_nodes,
452                              string* handle) {
453   LOG_AND_RUN_OPERATION(PRunSetup, input_names, output_names, target_nodes,
454                         handle);
455 }
456 
PRun(const string & handle,const std::vector<std::pair<string,Tensor>> & inputs,const std::vector<string> & output_names,std::vector<Tensor> * outputs)457 Status SessionRef::PRun(const string& handle,
458                         const std::vector<std::pair<string, Tensor> >& inputs,
459                         const std::vector<string>& output_names,
460                         std::vector<Tensor>* outputs) {
461   LOG_AND_RUN_OPERATION(PRun, handle, inputs, output_names, outputs);
462 }
463 
MakeCallable(const CallableOptions & callable_options,CallableHandle * out_handle)464 Status SessionRef::MakeCallable(const CallableOptions& callable_options,
465                                 CallableHandle* out_handle) {
466   LOG_AND_RUN_OPERATION(MakeCallable, callable_options, out_handle);
467 }
468 
RunCallable(CallableHandle handle,const std::vector<Tensor> & feed_tensors,std::vector<Tensor> * fetch_tensors,RunMetadata * run_metadata)469 Status SessionRef::RunCallable(CallableHandle handle,
470                                const std::vector<Tensor>& feed_tensors,
471                                std::vector<Tensor>* fetch_tensors,
472                                RunMetadata* run_metadata) {
473   LOG_AND_RUN_OPERATION(RunCallable, handle, feed_tensors, fetch_tensors,
474                         run_metadata);
475 }
476 
ReleaseCallable(CallableHandle handle)477 Status SessionRef::ReleaseCallable(CallableHandle handle) {
478   {
479     mutex_lock l(run_lock_);
480     if (session_ == nullptr) {
481       // Session already closed. Do nothing.
482       return Status::OK();
483     }
484   }
485   LOG_AND_RUN_OPERATION(ReleaseCallable, handle);
486 }
487 
Close(const RunOptions & run_options)488 Status SessionRef::Close(const RunOptions& run_options) {
489   TF_RETURN_IF_ERROR(CheckNotClosed());
490   mutex_lock l(run_lock_);
491   Status status;
492   if (logger_) {
493     status = logger_->RecordClose(session_.get(), run_options);
494   } else {
495     status = session_->Close(run_options);
496   }
497   session_.reset();
498   while (run_count_ > 0) {
499     run_finished_.wait(l);
500   }
501   return status;
502 }
503 
Close()504 Status SessionRef::Close() {
505   TF_RETURN_IF_ERROR(CheckNotClosed());
506   mutex_lock l(run_lock_);
507   Status status;
508   if (logger_) {
509     status = logger_->RecordClose(session_.get());
510   } else {
511     status = session_->Close();
512   }
513   session_.reset();
514   while (run_count_ > 0) {
515     run_finished_.wait(l);
516   }
517   return status;
518 }
519 
520 }  // namespace tensorflow
521