• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/common_runtime/direct_session.h"
17 
18 #include <atomic>
19 #include <string>
20 #include <vector>
21 
22 #include "tensorflow/core/common_runtime/collective_executor_mgr.h"
23 #include "tensorflow/core/common_runtime/collective_param_resolver_local.h"
24 #include "tensorflow/core/common_runtime/constant_folding.h"
25 #include "tensorflow/core/common_runtime/debugger_state_interface.h"
26 #include "tensorflow/core/common_runtime/device_factory.h"
27 #include "tensorflow/core/common_runtime/device_resolver_local.h"
28 #include "tensorflow/core/common_runtime/executor.h"
29 #include "tensorflow/core/common_runtime/executor_factory.h"
30 #include "tensorflow/core/common_runtime/function.h"
31 #include "tensorflow/core/common_runtime/graph_optimizer.h"
32 #include "tensorflow/core/common_runtime/memory_types.h"
33 #include "tensorflow/core/common_runtime/metrics.h"
34 #include "tensorflow/core/common_runtime/optimization_registry.h"
35 #include "tensorflow/core/common_runtime/process_util.h"
36 #include "tensorflow/core/common_runtime/scoped_allocator_mgr.h"
37 #include "tensorflow/core/common_runtime/step_stats_collector.h"
38 #include "tensorflow/core/framework/function.h"
39 #include "tensorflow/core/framework/graph.pb_text.h"
40 #include "tensorflow/core/framework/graph.pb.h"
41 #include "tensorflow/core/framework/graph_def_util.h"
42 #include "tensorflow/core/framework/log_memory.h"
43 #include "tensorflow/core/framework/node_def.pb.h"
44 #include "tensorflow/core/framework/run_handler.h"
45 #include "tensorflow/core/framework/tensor.h"
46 #include "tensorflow/core/framework/versions.pb.h"
47 #include "tensorflow/core/graph/algorithm.h"
48 #include "tensorflow/core/graph/graph.h"
49 #include "tensorflow/core/graph/graph_constructor.h"
50 #include "tensorflow/core/graph/graph_partition.h"
51 #include "tensorflow/core/graph/subgraph.h"
52 #include "tensorflow/core/graph/tensor_id.h"
53 #include "tensorflow/core/lib/core/errors.h"
54 #include "tensorflow/core/lib/core/notification.h"
55 #include "tensorflow/core/lib/core/refcount.h"
56 #include "tensorflow/core/lib/core/status.h"
57 #include "tensorflow/core/lib/core/threadpool.h"
58 #include "tensorflow/core/lib/gtl/array_slice.h"
59 #include "tensorflow/core/lib/gtl/stl_util.h"
60 #include "tensorflow/core/lib/monitoring/counter.h"
61 #include "tensorflow/core/lib/random/random.h"
62 #include "tensorflow/core/lib/strings/numbers.h"
63 #include "tensorflow/core/lib/strings/str_util.h"
64 #include "tensorflow/core/lib/strings/strcat.h"
65 #include "tensorflow/core/platform/byte_order.h"
66 #include "tensorflow/core/platform/device_tracer.h"
67 #include "tensorflow/core/platform/logging.h"
68 #include "tensorflow/core/platform/mutex.h"
69 #include "tensorflow/core/platform/tracing.h"
70 #include "tensorflow/core/platform/types.h"
71 #include "tensorflow/core/util/device_name_utils.h"
72 #include "tensorflow/core/util/env_var.h"
73 
74 namespace tensorflow {
75 
76 namespace {
77 
78 auto* direct_session_runs = monitoring::Counter<0>::New(
79     "/tensorflow/core/direct_session_runs",
80     "The number of times DirectSession::Run() has been called.");
81 
NewThreadPoolFromThreadPoolOptions(const SessionOptions & options,const ThreadPoolOptionProto & thread_pool_options,int pool_number,thread::ThreadPool ** pool,bool * owned)82 Status NewThreadPoolFromThreadPoolOptions(
83     const SessionOptions& options,
84     const ThreadPoolOptionProto& thread_pool_options, int pool_number,
85     thread::ThreadPool** pool, bool* owned) {
86   int32 num_threads = thread_pool_options.num_threads();
87   if (num_threads == 0) {
88     num_threads = NumInterOpThreadsFromSessionOptions(options);
89   }
90   const string& name = thread_pool_options.global_name();
91   if (name.empty()) {
92     // Session-local threadpool.
93     VLOG(1) << "Direct session inter op parallelism threads for pool "
94             << pool_number << ": " << num_threads;
95     *pool = new thread::ThreadPool(
96         options.env, strings::StrCat("Compute", pool_number), num_threads);
97     *owned = true;
98     return Status::OK();
99   }
100 
101   // Global, named threadpool.
102   typedef std::pair<int32, thread::ThreadPool*> MapValue;
103   static std::map<string, MapValue>* global_pool_map =
104       new std::map<string, MapValue>;
105   static mutex* mu = new mutex();
106   mutex_lock l(*mu);
107   MapValue* mvalue = &(*global_pool_map)[name];
108   if (mvalue->second == nullptr) {
109     mvalue->first = thread_pool_options.num_threads();
110     mvalue->second = new thread::ThreadPool(
111         options.env, strings::StrCat("Compute", pool_number), num_threads);
112   } else {
113     if (mvalue->first != thread_pool_options.num_threads()) {
114       return errors::InvalidArgument(
115           "Pool ", name,
116           " configured previously with num_threads=", mvalue->first,
117           "; cannot re-configure with num_threads=",
118           thread_pool_options.num_threads());
119     }
120   }
121   *owned = false;
122   *pool = mvalue->second;
123   return Status::OK();
124 }
125 
GlobalThreadPool(const SessionOptions & options)126 thread::ThreadPool* GlobalThreadPool(const SessionOptions& options) {
127   static thread::ThreadPool* const thread_pool =
128       NewThreadPoolFromSessionOptions(options);
129   return thread_pool;
130 }
131 
132 // TODO(vrv): Figure out how to unify the many different functions
133 // that generate RendezvousKey, since many of them have to be
134 // consistent with each other.
GetRendezvousKey(const string & tensor_name,const DeviceAttributes & device_info,const FrameAndIter & frame_iter)135 string GetRendezvousKey(const string& tensor_name,
136                         const DeviceAttributes& device_info,
137                         const FrameAndIter& frame_iter) {
138   return strings::StrCat(device_info.name(), ";",
139                          strings::FpToString(device_info.incarnation()), ";",
140                          device_info.name(), ";", tensor_name, ";",
141                          frame_iter.frame_id, ":", frame_iter.iter_id);
142 }
143 
144 }  // namespace
145 
146 class DirectSessionFactory : public SessionFactory {
147  public:
DirectSessionFactory()148   DirectSessionFactory() {}
149 
AcceptsOptions(const SessionOptions & options)150   bool AcceptsOptions(const SessionOptions& options) override {
151     return options.target.empty();
152   }
153 
NewSession(const SessionOptions & options,Session ** out_session)154   Status NewSession(const SessionOptions& options,
155                     Session** out_session) override {
156     // Must do this before the CPU allocator is created.
157     if (options.config.graph_options().build_cost_model() > 0) {
158       EnableCPUAllocatorFullStats(true);
159     }
160     std::vector<std::unique_ptr<Device>> devices;
161     TF_RETURN_IF_ERROR(DeviceFactory::AddDevices(
162         options, "/job:localhost/replica:0/task:0", &devices));
163 
164     DirectSession* session =
165         new DirectSession(options, new DeviceMgr(std::move(devices)), this);
166     {
167       mutex_lock l(sessions_lock_);
168       sessions_.push_back(session);
169     }
170     *out_session = session;
171     return Status::OK();
172   }
173 
Reset(const SessionOptions & options,const std::vector<string> & containers)174   Status Reset(const SessionOptions& options,
175                const std::vector<string>& containers) override {
176     std::vector<DirectSession*> sessions_to_reset;
177     {
178       mutex_lock l(sessions_lock_);
179       // We create a copy to ensure that we don't have a deadlock when
180       // session->Close calls the DirectSessionFactory.Deregister, which
181       // acquires sessions_lock_.
182       std::swap(sessions_to_reset, sessions_);
183     }
184     Status s;
185     for (auto session : sessions_to_reset) {
186       s.Update(session->Reset(containers));
187     }
188     // TODO(suharshs): Change the Reset behavior of all SessionFactories so that
189     // it doesn't close the sessions?
190     for (auto session : sessions_to_reset) {
191       s.Update(session->Close());
192     }
193     return s;
194   }
195 
Deregister(const DirectSession * session)196   void Deregister(const DirectSession* session) {
197     mutex_lock l(sessions_lock_);
198     sessions_.erase(std::remove(sessions_.begin(), sessions_.end(), session),
199                     sessions_.end());
200   }
201 
202  private:
203   mutex sessions_lock_;
204   std::vector<DirectSession*> sessions_ GUARDED_BY(sessions_lock_);
205 };
206 
207 class DirectSessionRegistrar {
208  public:
DirectSessionRegistrar()209   DirectSessionRegistrar() {
210     SessionFactory::Register("DIRECT_SESSION", new DirectSessionFactory());
211   }
212 };
213 static DirectSessionRegistrar registrar;
214 
215 std::atomic_int_fast64_t DirectSession::step_id_counter_(1);
216 
217 // NOTE: On Android with a single device, there is never
218 // a risk of an OpKernel blocking indefinitely:
219 //
220 // 1) No operations do I/O that depends on other simultaneous kernels,
221 //
222 // 2) Recv nodes always complete immediately: The inputs are sent into
223 //    the local rendezvous before we start the executor, so the
224 //    corresponding recvs will not block.
225 //
226 // Based on these assumptions, we can use the same thread pool for
227 // both "non-blocking" and "blocking" OpKernels on Android.
228 //
229 // This may change down the road when we add support for multiple
230 // devices that run concurrently, in which case we will need to
231 // revisit this decision.
SchedClosure(thread::ThreadPool * pool,std::function<void ()> c)232 void DirectSession::SchedClosure(thread::ThreadPool* pool,
233                                  std::function<void()> c) {
234 // TODO(sanjay): Get rid of __ANDROID__ path
235 #ifdef __ANDROID__
236   // On Android, there is no implementation of ThreadPool that takes
237   // std::function, only Closure, which we cannot easily convert.
238   //
239   // Instead, we just run the function in-line, which is currently
240   // safe given the reasoning above.
241   c();
242 #else
243   if (pool != nullptr) {
244     pool->Schedule(std::move(c));
245   } else {
246     c();
247   }
248 #endif  // __ANDROID__
249 }
250 
GetOrCreateRunHandlerPool(const SessionOptions & options)251 static RunHandlerPool* GetOrCreateRunHandlerPool(
252     const SessionOptions& options) {
253   static RunHandlerPool* pool =
254       new RunHandlerPool(NumInterOpThreadsFromSessionOptions(options));
255   return pool;
256 }
257 
ShouldUseRunHandlerPool(const RunOptions & run_options) const258 bool DirectSession::ShouldUseRunHandlerPool(
259     const RunOptions& run_options) const {
260   if (options_.config.use_per_session_threads()) return false;
261   if (options_.config.session_inter_op_thread_pool_size() > 0 &&
262       run_options.inter_op_thread_pool() > 0)
263     return false;
264   // Only use RunHandlerPool when:
265   // a. Single global thread pool is used for inter-op parallelism.
266   // b. When multiple inter_op_thread_pool(s) are created, use it only while
267   // running sessions on the default inter_op_thread_pool=0. Typically,
268   // servo-team uses inter_op_thread_pool > 0 for model loading.
269   // TODO(crk): Revisit whether we'd want to create one (static) RunHandlerPool
270   // per entry in session_inter_op_thread_pool() in the future.
271   return true;
272 }
273 
DirectSession(const SessionOptions & options,const DeviceMgr * device_mgr,DirectSessionFactory * const factory)274 DirectSession::DirectSession(const SessionOptions& options,
275                              const DeviceMgr* device_mgr,
276                              DirectSessionFactory* const factory)
277     : options_(options),
278       device_mgr_(device_mgr),
279       factory_(factory),
280       cancellation_manager_(new CancellationManager()),
281       operation_timeout_in_ms_(options_.config.operation_timeout_in_ms()) {
282   const int thread_pool_size =
283       options_.config.session_inter_op_thread_pool_size();
284   if (thread_pool_size > 0) {
285     for (int i = 0; i < thread_pool_size; ++i) {
286       thread::ThreadPool* pool = nullptr;
287       bool owned = false;
288       init_error_.Update(NewThreadPoolFromThreadPoolOptions(
289           options_, options_.config.session_inter_op_thread_pool(i), i, &pool,
290           &owned));
291       thread_pools_.emplace_back(pool, owned);
292     }
293   } else if (options_.config.use_per_session_threads()) {
294     thread_pools_.emplace_back(NewThreadPoolFromSessionOptions(options_),
295                                true /* owned */);
296   } else {
297     thread_pools_.emplace_back(GlobalThreadPool(options), false /* owned */);
298   }
299   // The default value of sync_on_finish will be flipped soon and this
300   // environment variable will be removed as well.
301   const Status status =
302       ReadBoolFromEnvVar("TF_SYNC_ON_FINISH", true, &sync_on_finish_);
303   if (!status.ok()) {
304     LOG(ERROR) << status.error_message();
305   }
306   session_handle_ =
307       strings::StrCat("direct", strings::FpToString(random::New64()));
308   int devices_added = 0;
309   if (options.config.log_device_placement()) {
310     const string mapping_str = device_mgr_->DeviceMappingString();
311     if (mapping_str.empty()) {
312       printf("Device mapping: no known devices.\n");
313     } else {
314       printf("Device mapping:\n%s", mapping_str.c_str());
315     }
316     LOG(INFO) << "Device mapping:\n" << mapping_str;
317   }
318   for (auto d : device_mgr_->ListDevices()) {
319     devices_.push_back(d);
320     device_set_.AddDevice(d);
321     d->op_segment()->AddHold(session_handle_);
322 
323     // The first device added is special: it is the 'client device' (a
324     // CPU device) from which we feed and fetch Tensors.
325     if (devices_added == 0) {
326       device_set_.set_client_device(d);
327     }
328     ++devices_added;
329   }
330 }
331 
~DirectSession()332 DirectSession::~DirectSession() {
333   if (!closed_) Close().IgnoreError();
334   for (auto& it : partial_runs_) {
335     it.second.reset(nullptr);
336   }
337   for (auto& it : executors_) {
338     it.second.reset();
339   }
340   callables_.clear();
341   for (auto d : device_mgr_->ListDevices()) {
342     d->op_segment()->RemoveHold(session_handle_);
343   }
344   for (auto d : device_mgr_->ListDevices()) {
345     d->ClearResourceMgr();
346   }
347   functions_.clear();
348   delete cancellation_manager_;
349   for (const auto& p_and_owned : thread_pools_) {
350     if (p_and_owned.second) delete p_and_owned.first;
351   }
352 
353   execution_state_.reset(nullptr);
354   flib_def_.reset(nullptr);
355 }
356 
MaybeInitializeExecutionState(const GraphDef & graph,bool * out_already_initialized)357 Status DirectSession::MaybeInitializeExecutionState(
358     const GraphDef& graph, bool* out_already_initialized) {
359   // If already initialized, do nothing.
360   if (flib_def_ && execution_state_) {
361     *out_already_initialized = true;
362     return Status::OK();
363   }
364   // Set up the per-session execution state.
365   // NOTE(mrry): The function library created here will be used for
366   // all subsequent extensions of the graph.
367   flib_def_.reset(
368       new FunctionLibraryDefinition(OpRegistry::Global(), graph.library()));
369   GraphExecutionStateOptions options;
370   options.device_set = &device_set_;
371   options.session_options = &options_;
372   options.session_handle = session_handle_;
373   // TODO(mrry,suharshs): We explicitly copy `graph` so that
374   // `MakeForBaseGraph()` can take ownership of its
375   // contents. Previously this happened implicitly in calls to the
376   // `GraphExecutionState`. Other sessions call
377   // `MakeForBaseGraph` in such a way that we can destructively read
378   // the passed-in `GraphDef`. In principle we could do the same here,
379   // with a wider refactoring; we might revise the direct session so
380   // that it copies the graph fewer times.
381   GraphDef temp(graph);
382   TF_RETURN_IF_ERROR(
383       GraphExecutionState::MakeForBaseGraph(&temp, options, &execution_state_));
384   graph_created_ = true;
385   *out_already_initialized = false;
386   return Status::OK();
387 }
388 
Create(const GraphDef & graph)389 Status DirectSession::Create(const GraphDef& graph) {
390   TF_RETURN_IF_ERROR(init_error_);
391   if (graph.node_size() > 0) {
392     mutex_lock l(graph_state_lock_);
393     if (graph_created_) {
394       return errors::AlreadyExists(
395           "A Graph has already been created for this session.");
396     }
397     return ExtendLocked(graph);
398   }
399   return Status::OK();
400 }
401 
Extend(const GraphDef & graph)402 Status DirectSession::Extend(const GraphDef& graph) {
403   TF_RETURN_IF_ERROR(CheckNotClosed());
404   mutex_lock l(graph_state_lock_);
405   return ExtendLocked(graph);
406 }
407 
ExtendLocked(const GraphDef & graph)408 Status DirectSession::ExtendLocked(const GraphDef& graph) {
409   bool already_initialized;
410   // If this is the first call, we can initialize the execution state
411   // with `graph` and do not need to call `Extend()`.
412   TF_RETURN_IF_ERROR(
413       MaybeInitializeExecutionState(graph, &already_initialized));
414   if (already_initialized) {
415     TF_RETURN_IF_ERROR(flib_def_->AddLibrary(graph.library()));
416     std::unique_ptr<GraphExecutionState> state;
417     TF_RETURN_IF_ERROR(execution_state_->Extend(graph, &state));
418     execution_state_.swap(state);
419   }
420   return Status::OK();
421 }
422 
Run(const NamedTensorList & inputs,const std::vector<string> & output_names,const std::vector<string> & target_nodes,std::vector<Tensor> * outputs)423 Status DirectSession::Run(const NamedTensorList& inputs,
424                           const std::vector<string>& output_names,
425                           const std::vector<string>& target_nodes,
426                           std::vector<Tensor>* outputs) {
427   RunMetadata run_metadata;
428   return Run(RunOptions(), inputs, output_names, target_nodes, outputs,
429              &run_metadata);
430 }
431 
CreateDebuggerState(const CallableOptions & callable_options,int64 global_step,int64 session_run_index,int64 executor_step_index,std::unique_ptr<DebuggerStateInterface> * debugger_state)432 Status DirectSession::CreateDebuggerState(
433     const CallableOptions& callable_options, int64 global_step,
434     int64 session_run_index, int64 executor_step_index,
435     std::unique_ptr<DebuggerStateInterface>* debugger_state) {
436   TF_RETURN_IF_ERROR(DebuggerStateRegistry::CreateState(
437       callable_options.run_options().debug_options(), debugger_state));
438   std::vector<string> input_names(callable_options.feed().begin(),
439                                   callable_options.feed().end());
440   std::vector<string> output_names(callable_options.fetch().begin(),
441                                    callable_options.fetch().end());
442   std::vector<string> target_names(callable_options.target().begin(),
443                                    callable_options.target().end());
444 
445   TF_RETURN_IF_ERROR(debugger_state->get()->PublishDebugMetadata(
446       global_step, session_run_index, executor_step_index, input_names,
447       output_names, target_names));
448   return Status::OK();
449 }
450 
DecorateAndPublishGraphForDebug(const DebugOptions & debug_options,Graph * graph,Device * device)451 Status DirectSession::DecorateAndPublishGraphForDebug(
452     const DebugOptions& debug_options, Graph* graph, Device* device) {
453   std::unique_ptr<DebugGraphDecoratorInterface> decorator;
454   TF_RETURN_IF_ERROR(
455       DebugGraphDecoratorRegistry::CreateDecorator(debug_options, &decorator));
456 
457   TF_RETURN_IF_ERROR(decorator->DecorateGraph(graph, device));
458   TF_RETURN_IF_ERROR(decorator->PublishGraph(*graph, device->name()));
459   return Status::OK();
460 }
461 
RunInternal(int64 step_id,const RunOptions & run_options,CallFrameInterface * call_frame,ExecutorsAndKeys * executors_and_keys,RunMetadata * run_metadata)462 Status DirectSession::RunInternal(int64 step_id, const RunOptions& run_options,
463                                   CallFrameInterface* call_frame,
464                                   ExecutorsAndKeys* executors_and_keys,
465                                   RunMetadata* run_metadata) {
466   const uint64 start_time_usecs = Env::Default()->NowMicros();
467   string session_id_meta = strings::StrCat("SessionRun #id=", step_id, "#");
468   tracing::ScopedActivity activity(session_id_meta);
469 
470   const int64 executor_step_count = executors_and_keys->step_count.fetch_add(1);
471 
472   std::unique_ptr<DebuggerStateInterface> debugger_state;
473   if (!run_options.debug_options().debug_tensor_watch_opts().empty()) {
474     TF_RETURN_IF_ERROR(
475         CreateDebuggerState(executors_and_keys->callable_options,
476                             run_options.debug_options().global_step(), step_id,
477                             executor_step_count, &debugger_state));
478   }
479 
480   // Create a run state and start execution.
481   RunState run_state(step_id, &devices_);
482   run_state.rendez = new IntraProcessRendezvous(device_mgr_.get());
483 #ifndef __ANDROID__
484   // Set up for collectives if ExecutorsAndKeys declares a key.
485   if (executors_and_keys->collective_graph_key !=
486       BuildGraphOptions::kNoCollectiveGraphKey) {
487     if (run_options.experimental().collective_graph_key() !=
488         BuildGraphOptions::kNoCollectiveGraphKey) {
489       // If a collective_graph_key was specified in run_options, ensure that it
490       // matches what came out of GraphExecutionState::BuildGraph().
491       if (run_options.experimental().collective_graph_key() !=
492           executors_and_keys->collective_graph_key) {
493         return errors::Internal(
494             "collective_graph_key in RunOptions ",
495             run_options.experimental().collective_graph_key(),
496             " should match collective_graph_key from optimized graph ",
497             executors_and_keys->collective_graph_key);
498       }
499     }
500     if (!collective_executor_mgr_) {
501       std::unique_ptr<DeviceResolverInterface> drl(
502           new DeviceResolverLocal(device_mgr_.get()));
503       std::unique_ptr<ParamResolverInterface> cprl(
504           new CollectiveParamResolverLocal(options_.config, device_mgr_.get(),
505                                            drl.get(),
506                                            "/job:localhost/replica:0/task:0"));
507       collective_executor_mgr_.reset(new CollectiveExecutorMgr(
508           options_.config, device_mgr_.get(), std::move(drl), std::move(cprl)));
509     }
510     run_state.collective_executor.reset(new CollectiveExecutor::Handle(
511         collective_executor_mgr_->FindOrCreate(step_id), true /*inherit_ref*/));
512   }
513 #endif
514 
515   // Start parallel Executors.
516   const size_t num_executors = executors_and_keys->items.size();
517   ExecutorBarrier* barrier = new ExecutorBarrier(
518       num_executors, run_state.rendez, [&run_state](const Status& ret) {
519         {
520           mutex_lock l(run_state.mu_);
521           run_state.status.Update(ret);
522         }
523         run_state.executors_done.Notify();
524       });
525 
526   Executor::Args args;
527   args.step_id = step_id;
528   args.call_frame = call_frame;
529   args.rendezvous = run_state.rendez;
530   args.collective_executor =
531       (run_state.collective_executor ? run_state.collective_executor->get()
532                                      : nullptr);
533   CancellationManager step_cancellation_manager;
534   args.cancellation_manager = &step_cancellation_manager;
535   args.session_state = &session_state_;
536   args.session_handle = session_handle_;
537   args.tensor_store = &run_state.tensor_store;
538   args.step_container = &run_state.step_container;
539   args.sync_on_finish = sync_on_finish_;
540 
541   const bool do_trace = (run_options.trace_level() > RunOptions::NO_TRACE);
542 
543   bool update_cost_model = false;
544   if (options_.config.graph_options().build_cost_model() > 0) {
545     const int64 build_cost_model_every =
546         options_.config.graph_options().build_cost_model();
547     const int64 build_cost_model_after =
548         options_.config.graph_options().build_cost_model_after();
549     int64 measure_step_count = executor_step_count - build_cost_model_after;
550     if (measure_step_count >= 0) {
551       update_cost_model =
552           ((measure_step_count + 1) % build_cost_model_every == 0);
553     }
554   }
555   if (do_trace || update_cost_model ||
556       run_options.report_tensor_allocations_upon_oom()) {
557     run_state.collector.reset(
558         new StepStatsCollector(run_metadata->mutable_step_stats()));
559     args.stats_collector = run_state.collector.get();
560   }
561 
562   std::unique_ptr<DeviceTracer> tracer;
563   if (run_options.trace_level() >= RunOptions::HARDWARE_TRACE) {
564     tracer = CreateDeviceTracer();
565     // tracer may be NULL on platforms without accelerators.
566     if (tracer) {
567       Status s = tracer->Start();
568       if (!s.ok()) {
569         run_state.executors_done.Notify();
570         delete barrier;
571         return s;
572       }
573     }
574   }
575 
576   if (run_options.inter_op_thread_pool() < -1 ||
577       run_options.inter_op_thread_pool() >=
578           static_cast<int32>(thread_pools_.size())) {
579     run_state.executors_done.Notify();
580     delete barrier;
581     return errors::InvalidArgument("Invalid inter_op_thread_pool: ",
582                                    run_options.inter_op_thread_pool());
583   }
584 
585   // Register this step with session's cancellation manager, so that
586   // `Session::Close()` will cancel the step.
587   const CancellationToken cancellation_token =
588       cancellation_manager_->get_cancellation_token();
589   const bool already_cancelled = !cancellation_manager_->RegisterCallback(
590       cancellation_token, [&step_cancellation_manager]() {
591         step_cancellation_manager.StartCancel();
592       });
593   if (already_cancelled) {
594     // NOTE(mrry): If we don't explicitly notify
595     // `run_state.executors_done`, the RunState destructor would
596     // block on this notification.
597     run_state.executors_done.Notify();
598     delete barrier;
599     return errors::Cancelled("Run call was cancelled");
600   }
601 
602   thread::ThreadPool* pool =
603       run_options.inter_op_thread_pool() >= 0
604           ? thread_pools_[run_options.inter_op_thread_pool()].first
605           : nullptr;
606 
607   if (pool == nullptr) {
608     // We allow using the caller thread only when having a single executor
609     // specified.
610     if (executors_and_keys->items.size() > 1) {
611       pool = thread_pools_[0].first;
612     } else {
613       VLOG(1) << "Executing Session::Run() synchronously!";
614     }
615   }
616 
617   std::unique_ptr<RunHandler> handler;
618   if (ShouldUseRunHandlerPool(run_options) &&
619       run_options.experimental().use_run_handler_pool()) {
620     VLOG(1) << "Using RunHandler to scheduler inter-op closures.";
621     handler = GetOrCreateRunHandlerPool(options_)->Get();
622   }
623   auto* handler_ptr = handler.get();
624 
625   Executor::Args::Runner default_runner = nullptr;
626 
627   if (pool == nullptr) {
628     default_runner = [](Executor::Args::Closure c) { c(); };
629   } else if (handler_ptr != nullptr) {
630     default_runner = [handler_ptr](Executor::Args::Closure c) {
631       handler_ptr->ScheduleInterOpClosure(std::move(c));
632     };
633   } else {
634     default_runner = [this, pool](Executor::Args::Closure c) {
635       SchedClosure(pool, std::move(c));
636     };
637   }
638 
639   for (const auto& item : executors_and_keys->items) {
640     // TODO(azaks): support partial run.
641     // TODO(azaks): if the device picks its own threadpool, we need to assign
642     //     less threads to the main compute pool by default.
643     thread::ThreadPool* device_thread_pool =
644         item.device->tensorflow_device_thread_pool();
645     // TODO(crk): Investigate usage of RunHandlerPool when using device specific
646     // thread pool(s).
647     if (!device_thread_pool) {
648       args.runner = default_runner;
649     } else {
650       args.runner = [this, device_thread_pool](Executor::Args::Closure c) {
651         SchedClosure(device_thread_pool, std::move(c));
652       };
653     }
654     item.executor->RunAsync(args, barrier->Get());
655   }
656 
657   WaitForNotification(&run_state, &step_cancellation_manager,
658                       run_options.timeout_in_ms() > 0
659                           ? run_options.timeout_in_ms()
660                           : operation_timeout_in_ms_);
661 
662   if (!cancellation_manager_->DeregisterCallback(cancellation_token)) {
663     // The step has been cancelled: make sure we don't attempt to receive the
664     // outputs as this would make it block forever.
665     mutex_lock l(run_state.mu_);
666     run_state.status.Update(errors::Cancelled("Run call was cancelled"));
667   }
668 
669   if (tracer) {
670     TF_RETURN_IF_ERROR(tracer->Stop());
671     TF_RETURN_IF_ERROR(tracer->Collect(run_state.collector.get()));
672   }
673 
674   {
675     mutex_lock l(run_state.mu_);
676     TF_RETURN_IF_ERROR(run_state.status);
677   }
678 
679   // Save the output tensors of this run we choose to keep.
680   if (!run_state.tensor_store.empty()) {
681     TF_RETURN_IF_ERROR(run_state.tensor_store.SaveTensors(
682         {executors_and_keys->callable_options.fetch().begin(),
683          executors_and_keys->callable_options.fetch().end()},
684         &session_state_));
685   }
686 
687   if (run_state.collector) {
688     run_state.collector->Finalize();
689   }
690 
691   // Build and return the cost model as instructed.
692   if (update_cost_model) {
693     // Build the cost model
694     std::unordered_map<string, const Graph*> device_to_graph;
695     for (const PerPartitionExecutorsAndLib& partition :
696          executors_and_keys->items) {
697       const Graph* graph = partition.graph;
698       const string device = partition.flib->device()->name();
699       device_to_graph[device] = graph;
700     }
701 
702     mutex_lock l(executor_lock_);
703     run_state.collector->BuildCostModel(&cost_model_manager_, device_to_graph);
704 
705     // annotate stats onto cost graph.
706     CostGraphDef* cost_graph = run_metadata->mutable_cost_graph();
707     for (const auto& item : executors_and_keys->items) {
708       TF_RETURN_IF_ERROR(
709           cost_model_manager_.AddToCostGraphDef(item.graph, cost_graph));
710     }
711   }
712 
713   // If requested via RunOptions, output the partition graphs.
714   if (run_options.output_partition_graphs()) {
715     protobuf::RepeatedPtrField<GraphDef>* partition_graph_defs =
716         run_metadata->mutable_partition_graphs();
717     for (const PerPartitionExecutorsAndLib& exec_and_lib :
718          executors_and_keys->items) {
719       GraphDef* partition_graph_def = partition_graph_defs->Add();
720       exec_and_lib.graph->ToGraphDef(partition_graph_def);
721     }
722   }
723   metrics::UpdateGraphExecTime(Env::Default()->NowMicros() - start_time_usecs);
724 
725   return Status::OK();
726 }
727 
Run(const RunOptions & run_options,const NamedTensorList & inputs,const std::vector<string> & output_names,const std::vector<string> & target_nodes,std::vector<Tensor> * outputs,RunMetadata * run_metadata)728 Status DirectSession::Run(const RunOptions& run_options,
729                           const NamedTensorList& inputs,
730                           const std::vector<string>& output_names,
731                           const std::vector<string>& target_nodes,
732                           std::vector<Tensor>* outputs,
733                           RunMetadata* run_metadata) {
734   TF_RETURN_IF_ERROR(CheckNotClosed());
735   TF_RETURN_IF_ERROR(CheckGraphCreated("Run()"));
736   direct_session_runs->GetCell()->IncrementBy(1);
737 
738   // Extract the inputs names for this run of the session.
739   std::vector<string> input_tensor_names;
740   input_tensor_names.reserve(inputs.size());
741   for (const auto& it : inputs) {
742     input_tensor_names.push_back(it.first);
743   }
744 
745   // Check if we already have an executor for these arguments.
746   ExecutorsAndKeys* executors_and_keys;
747   RunStateArgs run_state_args(run_options.debug_options());
748   run_state_args.collective_graph_key =
749       run_options.experimental().collective_graph_key();
750 
751   TF_RETURN_IF_ERROR(GetOrCreateExecutors(input_tensor_names, output_names,
752                                           target_nodes, &executors_and_keys,
753                                           &run_state_args));
754   {
755     mutex_lock l(collective_graph_key_lock_);
756     collective_graph_key_ = executors_and_keys->collective_graph_key;
757   }
758 
759   // Configure a call frame for the step, which we use to feed and
760   // fetch values to and from the executors.
761   FunctionCallFrame call_frame(executors_and_keys->input_types,
762                                executors_and_keys->output_types);
763   gtl::InlinedVector<Tensor, 4> feed_args(inputs.size());
764   for (const auto& it : inputs) {
765     if (it.second.dtype() == DT_RESOURCE) {
766       Tensor tensor_from_handle;
767       TF_RETURN_IF_ERROR(
768           ResourceHandleToInputTensor(it.second, &tensor_from_handle));
769       feed_args[executors_and_keys->input_name_to_index[it.first]] =
770           tensor_from_handle;
771     } else {
772       feed_args[executors_and_keys->input_name_to_index[it.first]] = it.second;
773     }
774   }
775   const Status s = call_frame.SetArgs(feed_args);
776   if (errors::IsInternal(s)) {
777     return errors::InvalidArgument(s.error_message());
778   } else if (!s.ok()) {
779     return s;
780   }
781 
782   const int64 step_id = step_id_counter_.fetch_add(1);
783 
784   if (LogMemory::IsEnabled()) {
785     LogMemory::RecordStep(step_id, run_state_args.handle);
786   }
787 
788   TF_RETURN_IF_ERROR(RunInternal(step_id, run_options, &call_frame,
789                                  executors_and_keys, run_metadata));
790 
791   // Receive outputs.
792   if (outputs) {
793     std::vector<Tensor> sorted_outputs;
794     const Status s = call_frame.ConsumeRetvals(
795         &sorted_outputs, /* allow_dead_tensors = */ false);
796     if (errors::IsInternal(s)) {
797       return errors::InvalidArgument(s.error_message());
798     } else if (!s.ok()) {
799       return s;
800     }
801     const bool unique_outputs =
802         output_names.size() == executors_and_keys->output_name_to_index.size();
803     // first_indices[i] = j implies that j is the smallest value for which
804     // output_names[i] == output_names[j].
805     std::vector<int> first_indices;
806     if (!unique_outputs) {
807       first_indices.resize(output_names.size());
808       for (int i = 0; i < output_names.size(); ++i) {
809         for (int j = 0; j <= i; ++j) {
810           if (output_names[i] == output_names[j]) {
811             first_indices[i] = j;
812             break;
813           }
814         }
815       }
816     }
817     outputs->clear();
818     outputs->reserve(sorted_outputs.size());
819     for (int i = 0; i < output_names.size(); ++i) {
820       const string& output_name = output_names[i];
821       if (first_indices.empty() || first_indices[i] == i) {
822         outputs->emplace_back(
823             std::move(sorted_outputs[executors_and_keys
824                                          ->output_name_to_index[output_name]]));
825       } else {
826         outputs->push_back((*outputs)[first_indices[i]]);
827       }
828     }
829   }
830 
831   return Status::OK();
832 }
833 
PRunSetup(const std::vector<string> & input_names,const std::vector<string> & output_names,const std::vector<string> & target_nodes,string * handle)834 Status DirectSession::PRunSetup(const std::vector<string>& input_names,
835                                 const std::vector<string>& output_names,
836                                 const std::vector<string>& target_nodes,
837                                 string* handle) {
838   TF_RETURN_IF_ERROR(CheckNotClosed());
839   TF_RETURN_IF_ERROR(CheckGraphCreated("PRunSetup()"));
840 
841   // RunOptions is not available in PRunSetup, so use thread pool 0.
842   thread::ThreadPool* pool = thread_pools_[0].first;
843 
844   // Check if we already have an executor for these arguments.
845   ExecutorsAndKeys* executors_and_keys;
846   // TODO(cais): TFDBG support for partial runs.
847   DebugOptions debug_options;
848   RunStateArgs run_state_args(debug_options);
849   run_state_args.is_partial_run = true;
850   TF_RETURN_IF_ERROR(GetOrCreateExecutors(input_names, output_names,
851                                           target_nodes, &executors_and_keys,
852                                           &run_state_args));
853 
854   // Create the run state and save it for future PRun calls.
855   Executor::Args args;
856   args.step_id = step_id_counter_.fetch_add(1);
857   RunState* run_state =
858       new RunState(input_names, output_names, args.step_id, &devices_);
859   run_state->rendez = new IntraProcessRendezvous(device_mgr_.get());
860   {
861     mutex_lock l(executor_lock_);
862     if (!partial_runs_
863              .emplace(run_state_args.handle,
864                       std::unique_ptr<RunState>(run_state))
865              .second) {
866       return errors::Internal("The handle '", run_state_args.handle,
867                               "' created for this partial run is not unique.");
868     }
869   }
870 
871   // Start parallel Executors.
872   const size_t num_executors = executors_and_keys->items.size();
873   ExecutorBarrier* barrier = new ExecutorBarrier(
874       num_executors, run_state->rendez, [run_state](const Status& ret) {
875         if (!ret.ok()) {
876           mutex_lock l(run_state->mu_);
877           run_state->status.Update(ret);
878         }
879         run_state->executors_done.Notify();
880       });
881 
882   args.rendezvous = run_state->rendez;
883   args.cancellation_manager = cancellation_manager_;
884   // Note that Collectives are not supported in partial runs
885   // because RunOptions is not passed in so we can't know whether
886   // their use is intended.
887   args.collective_executor = nullptr;
888   args.runner = [this, pool](Executor::Args::Closure c) {
889     SchedClosure(pool, std::move(c));
890   };
891   args.session_state = &session_state_;
892   args.session_handle = session_handle_;
893   args.tensor_store = &run_state->tensor_store;
894   args.step_container = &run_state->step_container;
895   if (LogMemory::IsEnabled()) {
896     LogMemory::RecordStep(args.step_id, run_state_args.handle);
897   }
898   args.sync_on_finish = sync_on_finish_;
899 
900   if (options_.config.graph_options().build_cost_model()) {
901     run_state->collector.reset(new StepStatsCollector(nullptr));
902     args.stats_collector = run_state->collector.get();
903   }
904 
905   for (auto& item : executors_and_keys->items) {
906     item.executor->RunAsync(args, barrier->Get());
907   }
908 
909   *handle = run_state_args.handle;
910   return Status::OK();
911 }
912 
PRun(const string & handle,const NamedTensorList & inputs,const std::vector<string> & output_names,std::vector<Tensor> * outputs)913 Status DirectSession::PRun(const string& handle, const NamedTensorList& inputs,
914                            const std::vector<string>& output_names,
915                            std::vector<Tensor>* outputs) {
916   TF_RETURN_IF_ERROR(CheckNotClosed());
917   std::vector<string> parts = str_util::Split(handle, ';');
918   const string& key = parts[0];
919   // Get the executors for this partial run.
920   ExecutorsAndKeys* executors_and_keys;
921   RunState* run_state;
922   {
923     mutex_lock l(executor_lock_);  // could use reader lock
924     auto exc_it = executors_.find(key);
925     if (exc_it == executors_.end()) {
926       return errors::InvalidArgument(
927           "Must run 'setup' before performing partial runs!");
928     }
929     executors_and_keys = exc_it->second.get();
930 
931     auto prun_it = partial_runs_.find(handle);
932     if (prun_it == partial_runs_.end()) {
933       return errors::InvalidArgument(
934           "Must run 'setup' before performing partial runs!");
935     }
936     run_state = prun_it->second.get();
937 
938     // Make sure that this is a new set of feeds that are still pending.
939     for (const auto& input : inputs) {
940       auto it = run_state->pending_inputs.find(input.first);
941       if (it == run_state->pending_inputs.end()) {
942         return errors::InvalidArgument(
943             "The feed ", input.first,
944             " was not specified in partial_run_setup.");
945       } else if (it->second) {
946         return errors::InvalidArgument("The feed ", input.first,
947                                        " has already been fed.");
948       }
949     }
950     // Check that this is a new set of fetches that are still pending.
951     for (const auto& output : output_names) {
952       auto it = run_state->pending_outputs.find(output);
953       if (it == run_state->pending_outputs.end()) {
954         return errors::InvalidArgument(
955             "The fetch ", output, " was not specified in partial_run_setup.");
956       } else if (it->second) {
957         return errors::InvalidArgument("The fetch ", output,
958                                        " has already been fetched.");
959       }
960     }
961   }
962 
963   // Check that this new set of fetches can be computed from all the
964   // feeds we have supplied.
965   TF_RETURN_IF_ERROR(
966       CheckFetch(inputs, output_names, executors_and_keys, run_state));
967 
968   // Send inputs.
969   Status s = SendPRunInputs(inputs, executors_and_keys, run_state->rendez);
970 
971   // Receive outputs.
972   if (s.ok()) {
973     s = RecvPRunOutputs(output_names, executors_and_keys, run_state, outputs);
974   }
975 
976   // Save the output tensors of this run we choose to keep.
977   if (s.ok()) {
978     s = run_state->tensor_store.SaveTensors(output_names, &session_state_);
979   }
980 
981   {
982     mutex_lock l(executor_lock_);
983     // Delete the run state if there is an error or all fetches are done.
984     bool done = true;
985     if (s.ok()) {
986       {
987         mutex_lock l(run_state->mu_);
988         if (!run_state->status.ok()) {
989           LOG(WARNING) << "An error unrelated to this prun has been detected. "
990                        << run_state->status;
991         }
992       }
993       for (const auto& input : inputs) {
994         auto it = run_state->pending_inputs.find(input.first);
995         it->second = true;
996       }
997       for (const auto& name : output_names) {
998         auto it = run_state->pending_outputs.find(name);
999         it->second = true;
1000       }
1001       done = run_state->PendingDone();
1002     }
1003     if (done) {
1004       WaitForNotification(run_state, cancellation_manager_,
1005                           operation_timeout_in_ms_);
1006       partial_runs_.erase(handle);
1007     }
1008   }
1009 
1010   return s;
1011 }
1012 
ResourceHandleToInputTensor(const Tensor & resource_tensor,Tensor * retrieved_tensor)1013 Status DirectSession::ResourceHandleToInputTensor(const Tensor& resource_tensor,
1014                                                   Tensor* retrieved_tensor) {
1015   if (resource_tensor.dtype() != DT_RESOURCE) {
1016     return errors::InvalidArgument(strings::StrCat(
1017         "ResourceHandleToInputTensor() received non-DT_RESOURCE Tensor: ",
1018         resource_tensor.dtype()));
1019   }
1020 
1021   const ResourceHandle& resource_handle =
1022       resource_tensor.scalar<ResourceHandle>()();
1023 
1024   if (resource_handle.container() ==
1025       SessionState::kTensorHandleResourceTypeName) {
1026     return session_state_.GetTensor(resource_handle.name(), retrieved_tensor);
1027   } else {
1028     return errors::InvalidArgument(strings::StrCat(
1029         "Invalid resource type hash code: ", resource_handle.hash_code(),
1030         "(name: ", resource_handle.name(),
1031         " type: ", resource_handle.maybe_type_name(),
1032         "). Perhaps a resource tensor was being provided as a feed? That is "
1033         "not currently allowed. Please file an issue at "
1034         "https://github.com/tensorflow/tensorflow/issues/new, ideally with a "
1035         "short code snippet that leads to this error message."));
1036   }
1037 }
1038 
SendPRunInputs(const NamedTensorList & inputs,const ExecutorsAndKeys * executors_and_keys,IntraProcessRendezvous * rendez)1039 Status DirectSession::SendPRunInputs(const NamedTensorList& inputs,
1040                                      const ExecutorsAndKeys* executors_and_keys,
1041                                      IntraProcessRendezvous* rendez) {
1042   Status s;
1043   Rendezvous::ParsedKey parsed;
1044   // Insert the input tensors into the local rendezvous by their
1045   // rendezvous key.
1046   for (const auto& input : inputs) {
1047     auto it =
1048         executors_and_keys->input_name_to_rendezvous_key.find(input.first);
1049     if (it == executors_and_keys->input_name_to_rendezvous_key.end()) {
1050       return errors::Internal("'", input.first, "' is not a pre-defined feed.");
1051     }
1052     const string& input_key = it->second;
1053 
1054     s = Rendezvous::ParseKey(input_key, &parsed);
1055     if (!s.ok()) {
1056       rendez->StartAbort(s);
1057       return s;
1058     }
1059 
1060     if (input.second.dtype() == DT_RESOURCE) {
1061       Tensor tensor_from_handle;
1062       s = ResourceHandleToInputTensor(input.second, &tensor_from_handle);
1063       if (s.ok()) {
1064         s = rendez->Send(parsed, Rendezvous::Args(), tensor_from_handle, false);
1065       }
1066     } else {
1067       s = rendez->Send(parsed, Rendezvous::Args(), input.second, false);
1068     }
1069 
1070     if (!s.ok()) {
1071       rendez->StartAbort(s);
1072       return s;
1073     }
1074   }
1075   return Status::OK();
1076 }
1077 
RecvPRunOutputs(const std::vector<string> & output_names,const ExecutorsAndKeys * executors_and_keys,RunState * run_state,std::vector<Tensor> * outputs)1078 Status DirectSession::RecvPRunOutputs(
1079     const std::vector<string>& output_names,
1080     const ExecutorsAndKeys* executors_and_keys, RunState* run_state,
1081     std::vector<Tensor>* outputs) {
1082   Status s;
1083   if (!output_names.empty()) {
1084     outputs->resize(output_names.size());
1085   }
1086 
1087   Rendezvous::ParsedKey parsed;
1088   // Get the outputs from the rendezvous
1089   for (size_t output_offset = 0; output_offset < output_names.size();
1090        ++output_offset) {
1091     const string& output_name = output_names[output_offset];
1092     auto it =
1093         executors_and_keys->output_name_to_rendezvous_key.find(output_name);
1094     if (it == executors_and_keys->output_name_to_rendezvous_key.end()) {
1095       return errors::Internal("'", output_name,
1096                               "' is not a pre-defined fetch.");
1097     }
1098     const string& output_key = it->second;
1099     Tensor output_tensor;
1100     bool is_dead;
1101     IntraProcessRendezvous* rendez = run_state->rendez;
1102 
1103     s = Rendezvous::ParseKey(output_key, &parsed);
1104     if (s.ok()) {
1105       // Fetch data from the Rendezvous.
1106       s = rendez->Recv(parsed, Rendezvous::Args(), &output_tensor, &is_dead,
1107                        operation_timeout_in_ms_);
1108       if (is_dead && s.ok()) {
1109         s = errors::InvalidArgument("The tensor returned for ", output_name,
1110                                     " was not valid.");
1111       }
1112     }
1113     if (!s.ok()) {
1114       rendez->StartAbort(s);
1115       outputs->clear();
1116       return s;
1117     }
1118 
1119     (*outputs)[output_offset] = output_tensor;
1120   }
1121   return Status::OK();
1122 }
1123 
CheckFetch(const NamedTensorList & feeds,const std::vector<string> & fetches,const ExecutorsAndKeys * executors_and_keys,const RunState * run_state)1124 Status DirectSession::CheckFetch(const NamedTensorList& feeds,
1125                                  const std::vector<string>& fetches,
1126                                  const ExecutorsAndKeys* executors_and_keys,
1127                                  const RunState* run_state) {
1128   const Graph* graph = executors_and_keys->graph.get();
1129   const NameNodeMap* name_to_node = &executors_and_keys->name_to_node;
1130 
1131   // Build the set of pending feeds that we haven't seen.
1132   std::unordered_set<TensorId, TensorId::Hasher> pending_feeds;
1133   {
1134     mutex_lock l(executor_lock_);
1135     for (const auto& input : run_state->pending_inputs) {
1136       // Skip if the feed has already been fed.
1137       if (input.second) continue;
1138       TensorId id(ParseTensorName(input.first));
1139       auto it = name_to_node->find(id.first);
1140       if (it == name_to_node->end()) {
1141         return errors::NotFound("Feed ", input.first, ": not found");
1142       }
1143       pending_feeds.insert(id);
1144     }
1145   }
1146   for (const auto& it : feeds) {
1147     TensorId id(ParseTensorName(it.first));
1148     pending_feeds.erase(id);
1149   }
1150 
1151   // Initialize the stack with the fetch nodes.
1152   std::vector<const Node*> stack;
1153   for (const string& fetch : fetches) {
1154     TensorId id(ParseTensorName(fetch));
1155     auto it = name_to_node->find(id.first);
1156     if (it == name_to_node->end()) {
1157       return errors::NotFound("Fetch ", fetch, ": not found");
1158     }
1159     stack.push_back(it->second);
1160   }
1161 
1162   // Any tensor needed for fetches can't be in pending_feeds.
1163   std::vector<bool> visited(graph->num_node_ids(), false);
1164   while (!stack.empty()) {
1165     const Node* n = stack.back();
1166     stack.pop_back();
1167 
1168     for (const Edge* in_edge : n->in_edges()) {
1169       const Node* in_node = in_edge->src();
1170       if (pending_feeds.count({in_node->name(), in_edge->src_output()}) > 0) {
1171         return errors::InvalidArgument("Fetch ", in_node->name(), ":",
1172                                        in_edge->src_output(),
1173                                        " can't be computed from the feeds"
1174                                        " that have been fed so far.");
1175       }
1176       if (!visited[in_node->id()]) {
1177         visited[in_node->id()] = true;
1178         stack.push_back(in_node);
1179       }
1180     }
1181   }
1182   return Status::OK();
1183 }
1184 
CreateExecutors(const CallableOptions & callable_options,std::unique_ptr<ExecutorsAndKeys> * out_executors_and_keys,std::unique_ptr<FunctionInfo> * out_func_info,RunStateArgs * run_state_args)1185 Status DirectSession::CreateExecutors(
1186     const CallableOptions& callable_options,
1187     std::unique_ptr<ExecutorsAndKeys>* out_executors_and_keys,
1188     std::unique_ptr<FunctionInfo>* out_func_info,
1189     RunStateArgs* run_state_args) {
1190   BuildGraphOptions options;
1191   options.callable_options = callable_options;
1192   options.use_function_convention = !run_state_args->is_partial_run;
1193   options.collective_graph_key =
1194       callable_options.run_options().experimental().collective_graph_key();
1195   if (options_.config.experimental()
1196           .collective_deterministic_sequential_execution()) {
1197     options.collective_order = GraphCollectiveOrder::kEdges;
1198   } else if (options_.config.experimental().collective_nccl()) {
1199     options.collective_order = GraphCollectiveOrder::kAttrs;
1200   }
1201 
1202   std::unique_ptr<FunctionInfo> func_info(new FunctionInfo);
1203   std::unique_ptr<ExecutorsAndKeys> ek(new ExecutorsAndKeys);
1204 
1205   ek->callable_options = callable_options;
1206 
1207   std::unordered_map<string, std::unique_ptr<Graph>> graphs;
1208   TF_RETURN_IF_ERROR(CreateGraphs(
1209       options, &graphs, &func_info->flib_def, run_state_args, &ek->input_types,
1210       &ek->output_types, &ek->collective_graph_key));
1211 
1212   if (run_state_args->is_partial_run) {
1213     ek->graph = std::move(run_state_args->graph);
1214     std::unordered_set<StringPiece, StringPieceHasher> names;
1215     for (const string& input : callable_options.feed()) {
1216       TensorId id(ParseTensorName(input));
1217       names.emplace(id.first);
1218     }
1219     for (const string& output : callable_options.fetch()) {
1220       TensorId id(ParseTensorName(output));
1221       names.emplace(id.first);
1222     }
1223     for (Node* n : ek->graph->nodes()) {
1224       if (names.count(n->name()) > 0) {
1225         ek->name_to_node.insert({n->name(), n});
1226       }
1227     }
1228   }
1229   ek->items.reserve(graphs.size());
1230   const auto& optimizer_opts =
1231       options_.config.graph_options().optimizer_options();
1232 
1233   int graph_def_version;
1234   {
1235     mutex_lock l(graph_state_lock_);
1236     graph_def_version =
1237         execution_state_->original_graph_def().versions().producer();
1238   }
1239   func_info->proc_flr.reset(new ProcessFunctionLibraryRuntime(
1240       device_mgr_.get(), options_.env, graph_def_version,
1241       func_info->flib_def.get(), optimizer_opts, thread_pools_[0].first));
1242 
1243   GraphOptimizer optimizer(optimizer_opts);
1244   for (auto iter = graphs.begin(); iter != graphs.end(); ++iter) {
1245     const string& partition_name = iter->first;
1246     std::unique_ptr<Graph>& partition_graph = iter->second;
1247 
1248     Device* device;
1249     TF_RETURN_IF_ERROR(device_mgr_->LookupDevice(partition_name, &device));
1250 
1251     ek->items.resize(ek->items.size() + 1);
1252     auto* item = &(ek->items.back());
1253     auto lib = func_info->proc_flr->GetFLR(partition_name);
1254     if (lib == nullptr) {
1255       return errors::Internal("Could not find device: ", partition_name);
1256     }
1257     item->flib = lib;
1258 
1259     LocalExecutorParams params;
1260     params.device = device;
1261     params.function_library = lib;
1262     auto opseg = device->op_segment();
1263     params.create_kernel = [this, lib, opseg](const NodeDef& ndef,
1264                                               OpKernel** kernel) {
1265       // NOTE(mrry): We must not share function kernels (implemented
1266       // using `CallOp`) between subgraphs, because `CallOp::handle_`
1267       // is tied to a particular subgraph. Even if the function itself
1268       // is stateful, the `CallOp` that invokes it is not.
1269       if (!OpSegment::ShouldOwnKernel(lib, ndef.op())) {
1270         return lib->CreateKernel(ndef, kernel);
1271       }
1272       auto create_fn = [lib, &ndef](OpKernel** kernel) {
1273         return lib->CreateKernel(ndef, kernel);
1274       };
1275       // Kernels created for subgraph nodes need to be cached.  On
1276       // cache miss, create_fn() is invoked to create a kernel based
1277       // on the function library here + global op registry.
1278       return opseg->FindOrCreate(session_handle_, ndef.name(), kernel,
1279                                  create_fn);
1280     };
1281     params.delete_kernel = [lib](OpKernel* kernel) {
1282       if (kernel && !OpSegment::ShouldOwnKernel(lib, kernel->type_string()))
1283         delete kernel;
1284     };
1285 
1286     optimizer.Optimize(lib, options_.env, device, &partition_graph,
1287                        /*shape_map=*/nullptr);
1288 
1289     // TensorFlow Debugger (tfdbg) inserts debug nodes in the graph.
1290     const DebugOptions& debug_options =
1291         options.callable_options.run_options().debug_options();
1292     if (!debug_options.debug_tensor_watch_opts().empty()) {
1293       TF_RETURN_IF_ERROR(DecorateAndPublishGraphForDebug(
1294           debug_options, partition_graph.get(), params.device));
1295     }
1296 
1297     TF_RETURN_IF_ERROR(EnsureMemoryTypes(DeviceType(device->device_type()),
1298                                          device->name(),
1299                                          partition_graph.get()));
1300     // NewLocalExecutor takes ownership of partition_graph.
1301     item->graph = partition_graph.get();
1302     item->executor = nullptr;
1303     item->device = device;
1304     auto executor_type = options_.config.experimental().executor_type();
1305     TF_RETURN_IF_ERROR(NewExecutor(
1306         executor_type, params, std::move(partition_graph), &item->executor));
1307   }
1308 
1309   // Cache the mapping from input/output names to graph elements to
1310   // avoid recomputing it every time.
1311   if (!run_state_args->is_partial_run) {
1312     // For regular `Run()`, we use the function calling convention, and so
1313     // maintain a mapping from input/output names to
1314     // argument/return-value ordinal index.
1315     for (int i = 0; i < callable_options.feed().size(); ++i) {
1316       const string& input = callable_options.feed(i);
1317       ek->input_name_to_index[input] = i;
1318     }
1319     for (int i = 0; i < callable_options.fetch().size(); ++i) {
1320       const string& output = callable_options.fetch(i);
1321       ek->output_name_to_index[output] = i;
1322     }
1323   } else {
1324     // For `PRun()`, we use the rendezvous calling convention, and so
1325     // maintain a mapping from input/output names to rendezvous keys.
1326     //
1327     // We always use the first device as the device name portion of the
1328     // key, even if we're feeding another graph.
1329     for (int i = 0; i < callable_options.feed().size(); ++i) {
1330       const string& input = callable_options.feed(i);
1331       ek->input_name_to_rendezvous_key[input] = GetRendezvousKey(
1332           input, device_set_.client_device()->attributes(), FrameAndIter(0, 0));
1333     }
1334     for (int i = 0; i < callable_options.fetch().size(); ++i) {
1335       const string& output = callable_options.fetch(i);
1336       ek->output_name_to_rendezvous_key[output] =
1337           GetRendezvousKey(output, device_set_.client_device()->attributes(),
1338                            FrameAndIter(0, 0));
1339     }
1340   }
1341 
1342   *out_executors_and_keys = std::move(ek);
1343   *out_func_info = std::move(func_info);
1344   return Status::OK();
1345 }
1346 
GetOrCreateExecutors(gtl::ArraySlice<string> inputs,gtl::ArraySlice<string> outputs,gtl::ArraySlice<string> target_nodes,ExecutorsAndKeys ** executors_and_keys,RunStateArgs * run_state_args)1347 Status DirectSession::GetOrCreateExecutors(
1348     gtl::ArraySlice<string> inputs, gtl::ArraySlice<string> outputs,
1349     gtl::ArraySlice<string> target_nodes, ExecutorsAndKeys** executors_and_keys,
1350     RunStateArgs* run_state_args) {
1351   int64 handle_name_counter_value = -1;
1352   if (LogMemory::IsEnabled() || run_state_args->is_partial_run) {
1353     handle_name_counter_value = handle_name_counter_.fetch_add(1);
1354   }
1355 
1356   string debug_tensor_watches_summary;
1357   if (!run_state_args->debug_options.debug_tensor_watch_opts().empty()) {
1358     debug_tensor_watches_summary = SummarizeDebugTensorWatches(
1359         run_state_args->debug_options.debug_tensor_watch_opts());
1360   }
1361 
1362   // Fast lookup path, no sorting.
1363   const string key = strings::StrCat(
1364       str_util::Join(inputs, ","), "->", str_util::Join(outputs, ","), "/",
1365       str_util::Join(target_nodes, ","), "/", run_state_args->is_partial_run,
1366       "/", debug_tensor_watches_summary);
1367   // Set the handle, if it's needed to log memory or for partial run.
1368   if (handle_name_counter_value >= 0) {
1369     run_state_args->handle =
1370         strings::StrCat(key, ";", handle_name_counter_value);
1371   }
1372 
1373   // See if we already have the executors for this run.
1374   {
1375     mutex_lock l(executor_lock_);  // could use reader lock
1376     auto it = executors_.find(key);
1377     if (it != executors_.end()) {
1378       *executors_and_keys = it->second.get();
1379       return Status::OK();
1380     }
1381   }
1382 
1383   // Slow lookup path, the unsorted key missed the cache.
1384   // Sort the inputs and outputs, and look up with the sorted key in case an
1385   // earlier call used a different order of inputs and outputs.
1386   //
1387   // We could consider some other signature instead of sorting that
1388   // preserves the same property to avoid the sort in the future.
1389   std::vector<string> inputs_sorted(inputs.begin(), inputs.end());
1390   std::sort(inputs_sorted.begin(), inputs_sorted.end());
1391   std::vector<string> outputs_sorted(outputs.begin(), outputs.end());
1392   std::sort(outputs_sorted.begin(), outputs_sorted.end());
1393   std::vector<string> tn_sorted(target_nodes.begin(), target_nodes.end());
1394   std::sort(tn_sorted.begin(), tn_sorted.end());
1395 
1396   const string sorted_key = strings::StrCat(
1397       str_util::Join(inputs_sorted, ","), "->",
1398       str_util::Join(outputs_sorted, ","), "/", str_util::Join(tn_sorted, ","),
1399       "/", run_state_args->is_partial_run, "/", debug_tensor_watches_summary);
1400   // Set the handle, if its needed to log memory or for partial run.
1401   if (handle_name_counter_value >= 0) {
1402     run_state_args->handle =
1403         strings::StrCat(sorted_key, ";", handle_name_counter_value);
1404   }
1405 
1406   // See if we already have the executors for this run.
1407   {
1408     mutex_lock l(executor_lock_);
1409     auto it = executors_.find(sorted_key);
1410     if (it != executors_.end()) {
1411       *executors_and_keys = it->second.get();
1412       // Insert this under the original key.
1413       executors_.emplace(key, it->second);
1414       return Status::OK();
1415     }
1416   }
1417 
1418   // Nothing found, so create the executors and store in the cache.
1419   // The executor_lock_ is intentionally released while executors are
1420   // being created.
1421   CallableOptions callable_options;
1422   for (const string& input : inputs_sorted) {
1423     callable_options.add_feed(input);
1424   }
1425   for (const string& output : outputs_sorted) {
1426     callable_options.add_fetch(output);
1427   }
1428   for (const string& target : tn_sorted) {
1429     callable_options.add_target(target);
1430   }
1431   *callable_options.mutable_run_options()->mutable_debug_options() =
1432       run_state_args->debug_options;
1433   callable_options.mutable_run_options()
1434       ->mutable_experimental()
1435       ->set_collective_graph_key(run_state_args->collective_graph_key);
1436   std::unique_ptr<ExecutorsAndKeys> ek;
1437   std::unique_ptr<FunctionInfo> func_info;
1438   TF_RETURN_IF_ERROR(
1439       CreateExecutors(callable_options, &ek, &func_info, run_state_args));
1440 
1441   // Reacquire the lock, try to insert into the map.
1442   mutex_lock l(executor_lock_);
1443   functions_.push_back(std::move(func_info));
1444 
1445   // Another thread may have created the entry before us, in which case we will
1446   // reuse the already created one.
1447   auto insert_result = executors_.emplace(
1448       sorted_key, std::shared_ptr<ExecutorsAndKeys>(std::move(ek)));
1449   // Insert the value under the original key, so the fast path lookup will work
1450   // if the user uses the same order of inputs, outputs, and targets again.
1451   executors_.emplace(key, insert_result.first->second);
1452   *executors_and_keys = insert_result.first->second.get();
1453 
1454   return Status::OK();
1455 }
1456 
CreateGraphs(const BuildGraphOptions & subgraph_options,std::unordered_map<string,std::unique_ptr<Graph>> * outputs,std::unique_ptr<FunctionLibraryDefinition> * flib_def,RunStateArgs * run_state_args,DataTypeVector * input_types,DataTypeVector * output_types,int64 * collective_graph_key)1457 Status DirectSession::CreateGraphs(
1458     const BuildGraphOptions& subgraph_options,
1459     std::unordered_map<string, std::unique_ptr<Graph>>* outputs,
1460     std::unique_ptr<FunctionLibraryDefinition>* flib_def,
1461     RunStateArgs* run_state_args, DataTypeVector* input_types,
1462     DataTypeVector* output_types, int64* collective_graph_key) {
1463   mutex_lock l(graph_state_lock_);
1464   std::unique_ptr<ClientGraph> client_graph;
1465 
1466   std::unique_ptr<GraphExecutionState> temp_exec_state_holder;
1467   GraphExecutionState* execution_state = nullptr;
1468   if (options_.config.graph_options().place_pruned_graph()) {
1469     // Because we are placing pruned graphs, we need to create a
1470     // new GraphExecutionState for every new unseen graph,
1471     // and then place it.
1472     GraphExecutionStateOptions prune_options;
1473     prune_options.device_set = &device_set_;
1474     prune_options.session_options = &options_;
1475     prune_options.stateful_placements = stateful_placements_;
1476     prune_options.session_handle = session_handle_;
1477     TF_RETURN_IF_ERROR(GraphExecutionState::MakeForPrunedGraph(
1478         execution_state_->original_graph_def().library(), prune_options,
1479         execution_state_->original_graph_def(), subgraph_options,
1480         &temp_exec_state_holder, &client_graph));
1481     execution_state = temp_exec_state_holder.get();
1482   } else {
1483     execution_state = execution_state_.get();
1484     TF_RETURN_IF_ERROR(
1485         execution_state->BuildGraph(subgraph_options, &client_graph));
1486   }
1487   *collective_graph_key = client_graph->collective_graph_key;
1488 
1489   if (subgraph_options.callable_options.feed_size() !=
1490       client_graph->feed_types.size()) {
1491     return errors::Internal(
1492         "Graph pruning failed: requested number of feed endpoints = ",
1493         subgraph_options.callable_options.feed_size(),
1494         " versus number of pruned feed endpoints = ",
1495         client_graph->feed_types.size());
1496   }
1497   if (subgraph_options.callable_options.fetch_size() !=
1498       client_graph->fetch_types.size()) {
1499     return errors::Internal(
1500         "Graph pruning failed: requested number of fetch endpoints = ",
1501         subgraph_options.callable_options.fetch_size(),
1502         " versus number of pruned fetch endpoints = ",
1503         client_graph->fetch_types.size());
1504   }
1505 
1506   auto current_stateful_placements = execution_state->GetStatefulPlacements();
1507   // Update our current state based on the execution_state's
1508   // placements.  If there are any mismatches for a node,
1509   // we should fail, as this should never happen.
1510   for (auto placement_pair : current_stateful_placements) {
1511     const string& node_name = placement_pair.first;
1512     const string& placement = placement_pair.second;
1513     auto iter = stateful_placements_.find(node_name);
1514     if (iter == stateful_placements_.end()) {
1515       stateful_placements_.insert(std::make_pair(node_name, placement));
1516     } else if (iter->second != placement) {
1517       return errors::Internal(
1518           "Stateful placement mismatch. "
1519           "Current assignment of ",
1520           node_name, " to ", iter->second, " does not match ", placement);
1521     }
1522   }
1523 
1524   stateful_placements_ = execution_state->GetStatefulPlacements();
1525 
1526   // Remember the graph in run state if this is a partial run.
1527   if (run_state_args->is_partial_run) {
1528     run_state_args->graph.reset(new Graph(flib_def_.get()));
1529     CopyGraph(*execution_state->full_graph(), run_state_args->graph.get());
1530   }
1531 
1532   // Partition the graph across devices.
1533   PartitionOptions popts;
1534   popts.node_to_loc = [](const Node* node) {
1535     return node->assigned_device_name();
1536   };
1537   popts.new_name = [this](const string& prefix) {
1538     return strings::StrCat(prefix, "/_", edge_name_counter_.fetch_add(1));
1539   };
1540   popts.get_incarnation = [](const string& name) {
1541     // The direct session does not have changing incarnation numbers.
1542     // Just return '1'.
1543     return 1;
1544   };
1545   popts.flib_def = &client_graph->graph.flib_def();
1546   popts.control_flow_added = false;
1547 
1548   std::unordered_map<string, GraphDef> partitions;
1549   TF_RETURN_IF_ERROR(Partition(popts, &client_graph->graph, &partitions));
1550 
1551   std::vector<string> device_names;
1552   for (auto device : devices_) {
1553     // Extract the LocalName from the device.
1554     device_names.push_back(DeviceNameUtils::LocalName(device->name()));
1555   }
1556 
1557   // Check for valid partitions.
1558   for (const auto& partition : partitions) {
1559     const string local_partition_name =
1560         DeviceNameUtils::LocalName(partition.first);
1561     if (std::count(device_names.begin(), device_names.end(),
1562                    local_partition_name) == 0) {
1563       return errors::InvalidArgument(
1564           "Creating a partition for ", local_partition_name,
1565           " which doesn't exist in the list of available devices. Available "
1566           "devices: ",
1567           str_util::Join(device_names, ","));
1568     }
1569   }
1570 
1571   for (const auto& partition : partitions) {
1572     std::unique_ptr<Graph> device_graph(
1573         new Graph(client_graph->flib_def.get()));
1574     GraphConstructorOptions device_opts;
1575     // There are internal operations (e.g., send/recv) that we now allow.
1576     device_opts.allow_internal_ops = true;
1577     device_opts.expect_device_spec = true;
1578     TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(device_opts, partition.second,
1579                                               device_graph.get()));
1580     outputs->emplace(partition.first, std::move(device_graph));
1581   }
1582 
1583   GraphOptimizationPassOptions optimization_options;
1584   optimization_options.session_options = &options_;
1585   optimization_options.flib_def = client_graph->flib_def.get();
1586   optimization_options.partition_graphs = outputs;
1587   TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
1588       OptimizationPassRegistry::POST_PARTITIONING, optimization_options));
1589 
1590   Status s;
1591   for (auto& partition : *outputs) {
1592     const string& partition_name = partition.first;
1593     std::unique_ptr<Graph>* graph = &partition.second;
1594 
1595     VLOG(2) << "Created " << DebugString(graph->get()) << " for "
1596             << partition_name;
1597 
1598     // Give the device an opportunity to rewrite its subgraph.
1599     Device* d;
1600     s = device_mgr_->LookupDevice(partition_name, &d);
1601     if (!s.ok()) break;
1602     s = d->MaybeRewriteGraph(graph);
1603     if (!s.ok()) {
1604       break;
1605     }
1606   }
1607   *flib_def = std::move(client_graph->flib_def);
1608   std::swap(*input_types, client_graph->feed_types);
1609   std::swap(*output_types, client_graph->fetch_types);
1610   return s;
1611 }
1612 
ListDevices(std::vector<DeviceAttributes> * response)1613 ::tensorflow::Status DirectSession::ListDevices(
1614     std::vector<DeviceAttributes>* response) {
1615   response->clear();
1616   response->reserve(devices_.size());
1617   for (Device* d : devices_) {
1618     const DeviceAttributes& attrs = d->attributes();
1619     response->emplace_back(attrs);
1620   }
1621   return ::tensorflow::Status::OK();
1622 }
1623 
Reset(const std::vector<string> & containers)1624 ::tensorflow::Status DirectSession::Reset(
1625     const std::vector<string>& containers) {
1626   device_mgr_->ClearContainers(containers);
1627   return ::tensorflow::Status::OK();
1628 }
1629 
Close()1630 ::tensorflow::Status DirectSession::Close() {
1631   cancellation_manager_->StartCancel();
1632   {
1633     mutex_lock l(closed_lock_);
1634     if (closed_) return ::tensorflow::Status::OK();
1635     closed_ = true;
1636   }
1637   if (factory_ != nullptr) factory_->Deregister(this);
1638   return ::tensorflow::Status::OK();
1639 }
1640 
RunState(const std::vector<string> & pending_input_names,const std::vector<string> & pending_output_names,int64 step_id,const std::vector<Device * > * devices)1641 DirectSession::RunState::RunState(
1642     const std::vector<string>& pending_input_names,
1643     const std::vector<string>& pending_output_names, int64 step_id,
1644     const std::vector<Device*>* devices)
1645     : step_container(step_id, [devices, step_id](const string& name) {
1646         for (auto d : *devices) {
1647           if (!d->resource_manager()->Cleanup(name).ok()) {
1648             // Do nothing...
1649           }
1650           ScopedAllocatorMgr* sam = d->GetScopedAllocatorMgr();
1651           if (sam) sam->Cleanup(step_id);
1652         }
1653       }) {
1654   // Initially all the feeds and fetches are pending.
1655   for (auto& name : pending_input_names) {
1656     pending_inputs[name] = false;
1657   }
1658   for (auto& name : pending_output_names) {
1659     pending_outputs[name] = false;
1660   }
1661 }
1662 
RunState(int64 step_id,const std::vector<Device * > * devices)1663 DirectSession::RunState::RunState(int64 step_id,
1664                                   const std::vector<Device*>* devices)
1665     : RunState({}, {}, step_id, devices) {}
1666 
~RunState()1667 DirectSession::RunState::~RunState() {
1668   if (rendez != nullptr) {
1669     if (!executors_done.HasBeenNotified()) {
1670       rendez->StartAbort(errors::Cancelled("PRun cancellation"));
1671       executors_done.WaitForNotification();
1672     }
1673     rendez->Unref();
1674   }
1675 }
1676 
PendingDone() const1677 bool DirectSession::RunState::PendingDone() const {
1678   for (const auto& it : pending_inputs) {
1679     if (!it.second) return false;
1680   }
1681   for (const auto& it : pending_outputs) {
1682     if (!it.second) return false;
1683   }
1684   return true;
1685 }
1686 
WaitForNotification(RunState * run_state,CancellationManager * cm,int64 timeout_in_ms)1687 void DirectSession::WaitForNotification(RunState* run_state,
1688                                         CancellationManager* cm,
1689                                         int64 timeout_in_ms) {
1690   const Status status =
1691       WaitForNotification(&run_state->executors_done, timeout_in_ms);
1692   if (!status.ok()) {
1693     {
1694       mutex_lock l(run_state->mu_);
1695       run_state->status.Update(status);
1696     }
1697     cm->StartCancel();
1698     // We must wait for the executors to complete, because they have borrowed
1699     // references to `cm` and other per-step state. After this notification, it
1700     // is safe to clean up the step.
1701     run_state->executors_done.WaitForNotification();
1702   }
1703 }
1704 
WaitForNotification(Notification * notification,int64 timeout_in_ms)1705 ::tensorflow::Status DirectSession::WaitForNotification(
1706     Notification* notification, int64 timeout_in_ms) {
1707   if (timeout_in_ms > 0) {
1708     const int64 timeout_in_us = timeout_in_ms * 1000;
1709     const bool notified =
1710         WaitForNotificationWithTimeout(notification, timeout_in_us);
1711     if (!notified) {
1712       return Status(error::DEADLINE_EXCEEDED,
1713                     "Timed out waiting for notification");
1714     }
1715   } else {
1716     notification->WaitForNotification();
1717   }
1718   return Status::OK();
1719 }
1720 
MakeCallable(const CallableOptions & callable_options,CallableHandle * out_handle)1721 Status DirectSession::MakeCallable(const CallableOptions& callable_options,
1722                                    CallableHandle* out_handle) {
1723   TF_RETURN_IF_ERROR(CheckNotClosed());
1724   TF_RETURN_IF_ERROR(CheckGraphCreated("MakeCallable()"));
1725 
1726   std::unique_ptr<ExecutorsAndKeys> ek;
1727   std::unique_ptr<FunctionInfo> func_info;
1728   RunStateArgs run_state_args(callable_options.run_options().debug_options());
1729   TF_RETURN_IF_ERROR(
1730       CreateExecutors(callable_options, &ek, &func_info, &run_state_args));
1731   {
1732     mutex_lock l(callables_lock_);
1733     *out_handle = next_callable_handle_++;
1734     callables_[*out_handle] = {std::move(ek), std::move(func_info)};
1735   }
1736   return Status::OK();
1737 }
1738 
1739 class DirectSession::RunCallableCallFrame : public CallFrameInterface {
1740  public:
RunCallableCallFrame(DirectSession * session,ExecutorsAndKeys * executors_and_keys,const std::vector<Tensor> * feed_tensors,std::vector<Tensor> * fetch_tensors)1741   RunCallableCallFrame(DirectSession* session,
1742                        ExecutorsAndKeys* executors_and_keys,
1743                        const std::vector<Tensor>* feed_tensors,
1744                        std::vector<Tensor>* fetch_tensors)
1745       : session_(session),
1746         executors_and_keys_(executors_and_keys),
1747         feed_tensors_(feed_tensors),
1748         fetch_tensors_(fetch_tensors) {}
1749 
num_args() const1750   size_t num_args() const override {
1751     return executors_and_keys_->input_types.size();
1752   }
num_retvals() const1753   size_t num_retvals() const override {
1754     return executors_and_keys_->output_types.size();
1755   }
1756 
GetArg(int index,Tensor * val) const1757   Status GetArg(int index, Tensor* val) const override {
1758     if (index > feed_tensors_->size()) {
1759       return errors::Internal("Args index out of bounds: ", index);
1760     } else if (executors_and_keys_->input_types[index] == DT_RESOURCE) {
1761       TF_RETURN_IF_ERROR(
1762           session_->ResourceHandleToInputTensor((*feed_tensors_)[index], val));
1763     } else {
1764       *val = (*feed_tensors_)[index];
1765     }
1766     return Status::OK();
1767   }
1768 
SetRetval(int index,const Tensor & val)1769   Status SetRetval(int index, const Tensor& val) override {
1770     if (index > fetch_tensors_->size()) {
1771       return errors::Internal("RetVal index out of bounds: ", index);
1772     }
1773     (*fetch_tensors_)[index] = val;
1774     return Status::OK();
1775   }
1776 
1777  private:
1778   DirectSession* const session_;                   // Not owned.
1779   ExecutorsAndKeys* const executors_and_keys_;     // Not owned.
1780   const std::vector<Tensor>* const feed_tensors_;  // Not owned.
1781   std::vector<Tensor>* const fetch_tensors_;       // Not owned.
1782 };
1783 
RunCallable(CallableHandle handle,const std::vector<Tensor> & feed_tensors,std::vector<Tensor> * fetch_tensors,RunMetadata * run_metadata)1784 ::tensorflow::Status DirectSession::RunCallable(
1785     CallableHandle handle, const std::vector<Tensor>& feed_tensors,
1786     std::vector<Tensor>* fetch_tensors, RunMetadata* run_metadata) {
1787   TF_RETURN_IF_ERROR(CheckNotClosed());
1788   TF_RETURN_IF_ERROR(CheckGraphCreated("RunCallable()"));
1789   direct_session_runs->GetCell()->IncrementBy(1);
1790 
1791   // Check if we already have an executor for these arguments.
1792   std::shared_ptr<ExecutorsAndKeys> executors_and_keys;
1793   const int64 step_id = step_id_counter_.fetch_add(1);
1794 
1795   {
1796     tf_shared_lock l(callables_lock_);
1797     if (handle >= next_callable_handle_) {
1798       return errors::InvalidArgument("No such callable handle: ", handle);
1799     }
1800     executors_and_keys = callables_[handle].executors_and_keys;
1801   }
1802 
1803   if (!executors_and_keys) {
1804     return errors::InvalidArgument(
1805         "Attempted to run callable after handle was released: ", handle);
1806   }
1807 
1808   // NOTE(mrry): Debug options are not currently supported in the
1809   // callable interface.
1810   DebugOptions debug_options;
1811   RunStateArgs run_state_args(debug_options);
1812 
1813   // Configure a call frame for the step, which we use to feed and
1814   // fetch values to and from the executors.
1815   if (feed_tensors.size() != executors_and_keys->input_types.size()) {
1816     return errors::InvalidArgument(
1817         "Expected ", executors_and_keys->input_types.size(),
1818         " feed tensors, but got ", feed_tensors.size());
1819   }
1820   if (fetch_tensors != nullptr) {
1821     fetch_tensors->resize(executors_and_keys->output_types.size());
1822   } else if (!executors_and_keys->output_types.empty()) {
1823     return errors::InvalidArgument(
1824         "`fetch_tensors` must be provided when the callable has one or more "
1825         "outputs.");
1826   }
1827 
1828   // A specialized CallFrame implementation that takes advantage of the
1829   // optimized RunCallable interface.
1830 
1831   RunCallableCallFrame call_frame(this, executors_and_keys.get(), &feed_tensors,
1832                                   fetch_tensors);
1833 
1834   if (LogMemory::IsEnabled()) {
1835     LogMemory::RecordStep(step_id, run_state_args.handle);
1836   }
1837 
1838   TF_RETURN_IF_ERROR(
1839       RunInternal(step_id, executors_and_keys->callable_options.run_options(),
1840                   &call_frame, executors_and_keys.get(), run_metadata));
1841 
1842   return Status::OK();
1843 }
1844 
ReleaseCallable(CallableHandle handle)1845 ::tensorflow::Status DirectSession::ReleaseCallable(CallableHandle handle) {
1846   mutex_lock l(callables_lock_);
1847   if (handle >= next_callable_handle_) {
1848     return errors::InvalidArgument("No such callable handle: ", handle);
1849   }
1850   callables_.erase(handle);
1851   return Status::OK();
1852 }
1853 
~Callable()1854 DirectSession::Callable::~Callable() {
1855   // We must delete the fields in this order, because the destructor
1856   // of `executors_and_keys` will call into an object owned by
1857   // `function_info` (in particular, when deleting a kernel, it relies
1858   // on the `FunctionLibraryRuntime` to know if the kernel is stateful
1859   // or not).
1860   executors_and_keys.reset();
1861   function_info.reset();
1862 }
1863 
1864 }  // namespace tensorflow
1865