• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/common_runtime/direct_session.h"
17 
18 #include <algorithm>
19 #include <atomic>
20 #include <string>
21 #include <vector>
22 
23 #include "absl/container/flat_hash_set.h"
24 #include "tensorflow/core/common_runtime/collective_executor_mgr.h"
25 #include "tensorflow/core/common_runtime/collective_param_resolver_local.h"
26 #include "tensorflow/core/common_runtime/constant_folding.h"
27 #include "tensorflow/core/common_runtime/debugger_state_interface.h"
28 #include "tensorflow/core/common_runtime/device_factory.h"
29 #include "tensorflow/core/common_runtime/device_resolver_local.h"
30 #include "tensorflow/core/common_runtime/executor.h"
31 #include "tensorflow/core/common_runtime/executor_factory.h"
32 #include "tensorflow/core/common_runtime/function.h"
33 #include "tensorflow/core/common_runtime/graph_constructor.h"
34 #include "tensorflow/core/common_runtime/graph_optimizer.h"
35 #include "tensorflow/core/common_runtime/memory_types.h"
36 #include "tensorflow/core/common_runtime/metrics.h"
37 #include "tensorflow/core/common_runtime/optimization_registry.h"
38 #include "tensorflow/core/common_runtime/process_util.h"
39 #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
40 #include "tensorflow/core/common_runtime/scoped_allocator_mgr.h"
41 #include "tensorflow/core/common_runtime/step_stats_collector.h"
42 #include "tensorflow/core/framework/function.h"
43 #include "tensorflow/core/framework/graph.pb.h"
44 #include "tensorflow/core/framework/graph_def_util.h"
45 #include "tensorflow/core/framework/log_memory.h"
46 #include "tensorflow/core/framework/logging.h"
47 #include "tensorflow/core/framework/node_def.pb.h"
48 #include "tensorflow/core/framework/run_handler.h"
49 #include "tensorflow/core/framework/tensor.h"
50 #include "tensorflow/core/framework/versions.pb.h"
51 #include "tensorflow/core/graph/algorithm.h"
52 #include "tensorflow/core/graph/graph.h"
53 #include "tensorflow/core/graph/graph_partition.h"
54 #include "tensorflow/core/graph/subgraph.h"
55 #include "tensorflow/core/graph/tensor_id.h"
56 #include "tensorflow/core/lib/core/errors.h"
57 #include "tensorflow/core/lib/core/notification.h"
58 #include "tensorflow/core/lib/core/refcount.h"
59 #include "tensorflow/core/lib/core/status.h"
60 #include "tensorflow/core/lib/core/threadpool.h"
61 #include "tensorflow/core/lib/core/threadpool_options.h"
62 #include "tensorflow/core/lib/gtl/array_slice.h"
63 #include "tensorflow/core/lib/monitoring/counter.h"
64 #include "tensorflow/core/lib/random/random.h"
65 #include "tensorflow/core/lib/strings/numbers.h"
66 #include "tensorflow/core/lib/strings/str_util.h"
67 #include "tensorflow/core/lib/strings/strcat.h"
68 #include "tensorflow/core/nccl/collective_communicator.h"
69 #include "tensorflow/core/platform/byte_order.h"
70 #include "tensorflow/core/platform/cpu_info.h"
71 #include "tensorflow/core/platform/logging.h"
72 #include "tensorflow/core/platform/mutex.h"
73 #include "tensorflow/core/platform/tracing.h"
74 #include "tensorflow/core/platform/types.h"
75 #include "tensorflow/core/profiler/lib/connected_traceme.h"
76 #include "tensorflow/core/profiler/lib/device_profiler_session.h"
77 #include "tensorflow/core/profiler/lib/traceme_encode.h"
78 #include "tensorflow/core/protobuf/config.pb.h"
79 #include "tensorflow/core/util/device_name_utils.h"
80 #include "tensorflow/core/util/env_var.h"
81 
82 namespace tensorflow {
83 
84 namespace {
85 
86 auto* direct_session_runs = monitoring::Counter<0>::New(
87     "/tensorflow/core/direct_session_runs",
88     "The number of times DirectSession::Run() has been called.");
89 
NewThreadPoolFromThreadPoolOptions(const SessionOptions & options,const ThreadPoolOptionProto & thread_pool_options,int pool_number,thread::ThreadPool ** pool,bool * owned)90 Status NewThreadPoolFromThreadPoolOptions(
91     const SessionOptions& options,
92     const ThreadPoolOptionProto& thread_pool_options, int pool_number,
93     thread::ThreadPool** pool, bool* owned) {
94   int32_t num_threads = thread_pool_options.num_threads();
95   if (num_threads == 0) {
96     num_threads = NumInterOpThreadsFromSessionOptions(options);
97   }
98   const string& name = thread_pool_options.global_name();
99   if (name.empty()) {
100     // Session-local threadpool.
101     VLOG(1) << "Direct session inter op parallelism threads for pool "
102             << pool_number << ": " << num_threads;
103     *pool = new thread::ThreadPool(
104         options.env, ThreadOptions(), strings::StrCat("Compute", pool_number),
105         num_threads, !options.config.experimental().disable_thread_spinning(),
106         /*allocator=*/nullptr);
107     *owned = true;
108     return Status::OK();
109   }
110 
111   // Global, named threadpool.
112   typedef std::pair<int32, thread::ThreadPool*> MapValue;
113   static std::map<string, MapValue>* global_pool_map =
114       new std::map<string, MapValue>;
115   static mutex* mu = new mutex();
116   mutex_lock l(*mu);
117   MapValue* mvalue = &(*global_pool_map)[name];
118   if (mvalue->second == nullptr) {
119     mvalue->first = thread_pool_options.num_threads();
120     mvalue->second = new thread::ThreadPool(
121         options.env, ThreadOptions(), strings::StrCat("Compute", pool_number),
122         num_threads, !options.config.experimental().disable_thread_spinning(),
123         /*allocator=*/nullptr);
124   } else {
125     if (mvalue->first != thread_pool_options.num_threads()) {
126       return errors::InvalidArgument(
127           "Pool ", name,
128           " configured previously with num_threads=", mvalue->first,
129           "; cannot re-configure with num_threads=",
130           thread_pool_options.num_threads());
131     }
132   }
133   *owned = false;
134   *pool = mvalue->second;
135   return Status::OK();
136 }
137 
GlobalThreadPool(const SessionOptions & options)138 thread::ThreadPool* GlobalThreadPool(const SessionOptions& options) {
139   static thread::ThreadPool* const thread_pool =
140       NewThreadPoolFromSessionOptions(options);
141   return thread_pool;
142 }
143 
144 // TODO(vrv): Figure out how to unify the many different functions
145 // that generate RendezvousKey, since many of them have to be
146 // consistent with each other.
GetRendezvousKey(const string & tensor_name,const DeviceAttributes & device_info,const FrameAndIter & frame_iter)147 string GetRendezvousKey(const string& tensor_name,
148                         const DeviceAttributes& device_info,
149                         const FrameAndIter& frame_iter) {
150   return strings::StrCat(device_info.name(), ";",
151                          strings::FpToString(device_info.incarnation()), ";",
152                          device_info.name(), ";", tensor_name, ";",
153                          frame_iter.frame_id, ":", frame_iter.iter_id);
154 }
155 
156 }  // namespace
157 
158 class DirectSessionFactory : public SessionFactory {
159  public:
DirectSessionFactory()160   DirectSessionFactory() {}
161 
AcceptsOptions(const SessionOptions & options)162   bool AcceptsOptions(const SessionOptions& options) override {
163     return options.target.empty();
164   }
165 
NewSession(const SessionOptions & options,Session ** out_session)166   Status NewSession(const SessionOptions& options,
167                     Session** out_session) override {
168     const auto& experimental_config = options.config.experimental();
169     if (experimental_config.has_session_metadata()) {
170       if (experimental_config.session_metadata().version() < 0) {
171         return errors::InvalidArgument(
172             "Session version shouldn't be negative: ",
173             experimental_config.session_metadata().DebugString());
174       }
175       const string key = GetMetadataKey(experimental_config.session_metadata());
176       mutex_lock l(sessions_lock_);
177       if (!session_metadata_keys_.insert(key).second) {
178         return errors::InvalidArgument(
179             "A session with the same name and version has already been "
180             "created: ",
181             experimental_config.session_metadata().DebugString());
182       }
183     }
184 
185     // Must do this before the CPU allocator is created.
186     if (options.config.graph_options().build_cost_model() > 0) {
187       EnableCPUAllocatorFullStats();
188     }
189     std::vector<std::unique_ptr<Device>> devices;
190     TF_RETURN_IF_ERROR(DeviceFactory::AddDevices(
191         options, "/job:localhost/replica:0/task:0", &devices));
192 
193     DirectSession* session = new DirectSession(
194         options, new StaticDeviceMgr(std::move(devices)), this);
195     {
196       mutex_lock l(sessions_lock_);
197       sessions_.push_back(session);
198     }
199     *out_session = session;
200     return Status::OK();
201   }
202 
Reset(const SessionOptions & options,const std::vector<string> & containers)203   Status Reset(const SessionOptions& options,
204                const std::vector<string>& containers) override {
205     std::vector<DirectSession*> sessions_to_reset;
206     {
207       mutex_lock l(sessions_lock_);
208       // We create a copy to ensure that we don't have a deadlock when
209       // session->Close calls the DirectSessionFactory.Deregister, which
210       // acquires sessions_lock_.
211       std::swap(sessions_to_reset, sessions_);
212     }
213     Status s;
214     for (auto session : sessions_to_reset) {
215       s.Update(session->Reset(containers));
216     }
217     // TODO(suharshs): Change the Reset behavior of all SessionFactories so that
218     // it doesn't close the sessions?
219     for (auto session : sessions_to_reset) {
220       s.Update(session->Close());
221     }
222     return s;
223   }
224 
Deregister(const DirectSession * session)225   void Deregister(const DirectSession* session) {
226     mutex_lock l(sessions_lock_);
227     sessions_.erase(std::remove(sessions_.begin(), sessions_.end(), session),
228                     sessions_.end());
229     if (session->options().config.experimental().has_session_metadata()) {
230       session_metadata_keys_.erase(GetMetadataKey(
231           session->options().config.experimental().session_metadata()));
232     }
233   }
234 
235  private:
GetMetadataKey(const SessionMetadata & metadata)236   static string GetMetadataKey(const SessionMetadata& metadata) {
237     return absl::StrCat(metadata.name(), "/", metadata.version());
238   }
239 
240   mutex sessions_lock_;
241   std::vector<DirectSession*> sessions_ TF_GUARDED_BY(sessions_lock_);
242   absl::flat_hash_set<string> session_metadata_keys_
243       TF_GUARDED_BY(sessions_lock_);
244 };
245 
246 class DirectSessionRegistrar {
247  public:
DirectSessionRegistrar()248   DirectSessionRegistrar() {
249     SessionFactory::Register("DIRECT_SESSION", new DirectSessionFactory());
250   }
251 };
252 static DirectSessionRegistrar registrar;
253 
254 std::atomic_int_fast64_t DirectSession::step_id_counter_(1);
255 
GetOrCreateRunHandlerPool(const SessionOptions & options)256 static RunHandlerPool* GetOrCreateRunHandlerPool(
257     const SessionOptions& options) {
258   int num_inter_threads = 0;
259   int num_intra_threads = 0;
260   static const int env_num_inter_threads = NumInterOpThreadsFromEnvironment();
261   static const int env_num_intra_threads = NumIntraOpThreadsFromEnvironment();
262   if (env_num_inter_threads > 0) {
263     num_inter_threads = env_num_inter_threads;
264   }
265   if (env_num_intra_threads > 0) {
266     num_intra_threads = env_num_intra_threads;
267   }
268 
269   if (num_inter_threads == 0) {
270     if (options.config.session_inter_op_thread_pool_size() > 0) {
271       // Note due to ShouldUseRunHandler we are guaranteed that
272       // run_options.inter_op_thread_pool() == 0
273       num_inter_threads =
274           options.config.session_inter_op_thread_pool(0).num_threads();
275     }
276     if (num_inter_threads == 0) {
277       num_inter_threads = NumInterOpThreadsFromSessionOptions(options);
278     }
279   }
280 
281   if (num_intra_threads == 0) {
282     num_intra_threads = options.config.intra_op_parallelism_threads();
283     if (num_intra_threads == 0) {
284       num_intra_threads = port::MaxParallelism();
285     }
286   }
287 
288   static RunHandlerPool* pool =
289       new RunHandlerPool(num_inter_threads, num_intra_threads);
290   return pool;
291 }
292 
ShouldUseRunHandlerPool(const RunOptions & run_options) const293 bool DirectSession::ShouldUseRunHandlerPool(
294     const RunOptions& run_options) const {
295   if (options_.config.use_per_session_threads()) return false;
296   if (options_.config.session_inter_op_thread_pool_size() > 0 &&
297       run_options.inter_op_thread_pool() > 0)
298     return false;
299   // Only use RunHandlerPool when:
300   // a. Single global thread pool is used for inter-op parallelism.
301   // b. When multiple inter_op_thread_pool(s) are created, use it only while
302   // running sessions on the default inter_op_thread_pool=0. Typically,
303   // servo-team uses inter_op_thread_pool > 0 for model loading.
304   // TODO(crk): Revisit whether we'd want to create one (static) RunHandlerPool
305   // per entry in session_inter_op_thread_pool() in the future.
306   return true;
307 }
308 
DirectSession(const SessionOptions & options,const DeviceMgr * device_mgr,DirectSessionFactory * const factory)309 DirectSession::DirectSession(const SessionOptions& options,
310                              const DeviceMgr* device_mgr,
311                              DirectSessionFactory* const factory)
312     : options_(options),
313       device_mgr_(device_mgr),
314       factory_(factory),
315       cancellation_manager_(new CancellationManager()),
316       operation_timeout_in_ms_(options_.config.operation_timeout_in_ms()) {
317   const int thread_pool_size =
318       options_.config.session_inter_op_thread_pool_size();
319   if (thread_pool_size > 0) {
320     for (int i = 0; i < thread_pool_size; ++i) {
321       thread::ThreadPool* pool = nullptr;
322       bool owned = false;
323       init_error_.Update(NewThreadPoolFromThreadPoolOptions(
324           options_, options_.config.session_inter_op_thread_pool(i), i, &pool,
325           &owned));
326       thread_pools_.emplace_back(pool, owned);
327     }
328   } else if (options_.config.use_per_session_threads()) {
329     thread_pools_.emplace_back(NewThreadPoolFromSessionOptions(options_),
330                                true /* owned */);
331   } else {
332     thread_pools_.emplace_back(GlobalThreadPool(options), false /* owned */);
333     // Run locally if environment value of TF_NUM_INTEROP_THREADS is negative
334     // and config.inter_op_parallelism_threads is unspecified or negative.
335     static const int env_num_threads = NumInterOpThreadsFromEnvironment();
336     if (options_.config.inter_op_parallelism_threads() < 0 ||
337         (options_.config.inter_op_parallelism_threads() == 0 &&
338          env_num_threads < 0)) {
339       run_in_caller_thread_ = true;
340     }
341   }
342   // The default value of sync_on_finish will be flipped soon and this
343   // environment variable will be removed as well.
344   const Status status =
345       ReadBoolFromEnvVar("TF_SYNC_ON_FINISH", true, &sync_on_finish_);
346   if (!status.ok()) {
347     LOG(ERROR) << status.error_message();
348   }
349   session_handle_ =
350       strings::StrCat("direct", strings::FpToString(random::New64()));
351   int devices_added = 0;
352   if (options.config.log_device_placement()) {
353     const string mapping_str = device_mgr_->DeviceMappingString();
354     string msg;
355     if (mapping_str.empty()) {
356       msg = "Device mapping: no known devices.";
357     } else {
358       msg = strings::StrCat("Device mapping:\n", mapping_str);
359     }
360     if (!logging::LogToListeners(msg)) {
361       LOG(INFO) << msg;
362     }
363   }
364   for (auto d : device_mgr_->ListDevices()) {
365     devices_.push_back(d);
366     device_set_.AddDevice(d);
367     d->op_segment()->AddHold(session_handle_);
368 
369     // The first device added is special: it is the 'client device' (a
370     // CPU device) from which we feed and fetch Tensors.
371     if (devices_added == 0) {
372       device_set_.set_client_device(d);
373     }
374     ++devices_added;
375   }
376 }
377 
~DirectSession()378 DirectSession::~DirectSession() {
379   if (!closed_) Close().IgnoreError();
380   for (auto& it : partial_runs_) {
381     it.second.reset(nullptr);
382   }
383   for (auto& it : executors_) {
384     it.second.reset();
385   }
386   callables_.clear();
387   for (auto d : device_mgr_->ListDevices()) {
388     d->op_segment()->RemoveHold(session_handle_);
389   }
390   functions_.clear();
391   delete cancellation_manager_;
392   for (const auto& p_and_owned : thread_pools_) {
393     if (p_and_owned.second) delete p_and_owned.first;
394   }
395 
396   execution_state_.reset(nullptr);
397   flib_def_.reset(nullptr);
398 }
399 
Create(const GraphDef & graph)400 Status DirectSession::Create(const GraphDef& graph) {
401   return Create(GraphDef(graph));
402 }
403 
Create(GraphDef && graph)404 Status DirectSession::Create(GraphDef&& graph) {
405   TF_RETURN_IF_ERROR(init_error_);
406   if (graph.node_size() > 0) {
407     mutex_lock l(graph_state_lock_);
408     if (graph_created_) {
409       return errors::AlreadyExists(
410           "A Graph has already been created for this session.");
411     }
412     return ExtendLocked(std::move(graph));
413   }
414   return Status::OK();
415 }
416 
Extend(const GraphDef & graph)417 Status DirectSession::Extend(const GraphDef& graph) {
418   return Extend(GraphDef(graph));
419 }
420 
Extend(GraphDef && graph)421 Status DirectSession::Extend(GraphDef&& graph) {
422   TF_RETURN_IF_ERROR(CheckNotClosed());
423   mutex_lock l(graph_state_lock_);
424   return ExtendLocked(std::move(graph));
425 }
426 
ExtendLocked(GraphDef graph)427 Status DirectSession::ExtendLocked(GraphDef graph) {
428   if (finalized_) {
429     return errors::FailedPrecondition("Session has been finalized.");
430   }
431   if (!(flib_def_ && execution_state_)) {
432     // If this is the first call, we can initialize the execution state
433     // with `graph` and do not need to call `Extend()`.
434     // NOTE(mrry): The function library created here will be used for
435     // all subsequent extensions of the graph.
436     flib_def_.reset(
437         new FunctionLibraryDefinition(OpRegistry::Global(), graph.library()));
438     GraphExecutionStateOptions options;
439     options.device_set = &device_set_;
440     options.session_options = &options_;
441     options.session_handle = session_handle_;
442     TF_RETURN_IF_ERROR(GraphExecutionState::MakeForBaseGraph(
443         std::move(graph), options, &execution_state_));
444     graph_created_ = true;
445   } else {
446     TF_RETURN_IF_ERROR(flib_def_->AddLibrary(graph.library()));
447     std::unique_ptr<GraphExecutionState> state;
448     // TODO(mrry): Rewrite GraphExecutionState::Extend() to take `graph` by
449     // value and move `graph` in here.
450     TF_RETURN_IF_ERROR(execution_state_->Extend(graph, &state));
451     execution_state_.swap(state);
452   }
453   return Status::OK();
454 }
455 
Run(const NamedTensorList & inputs,const std::vector<string> & output_names,const std::vector<string> & target_nodes,std::vector<Tensor> * outputs)456 Status DirectSession::Run(const NamedTensorList& inputs,
457                           const std::vector<string>& output_names,
458                           const std::vector<string>& target_nodes,
459                           std::vector<Tensor>* outputs) {
460   RunMetadata run_metadata;
461   return Run(RunOptions(), inputs, output_names, target_nodes, outputs,
462              &run_metadata);
463 }
464 
CreateDebuggerState(const CallableOptions & callable_options,int64_t global_step,int64_t session_run_index,int64_t executor_step_index,std::unique_ptr<DebuggerStateInterface> * debugger_state)465 Status DirectSession::CreateDebuggerState(
466     const CallableOptions& callable_options, int64_t global_step,
467     int64_t session_run_index, int64_t executor_step_index,
468     std::unique_ptr<DebuggerStateInterface>* debugger_state) {
469   TF_RETURN_IF_ERROR(DebuggerStateRegistry::CreateState(
470       callable_options.run_options().debug_options(), debugger_state));
471   std::vector<string> input_names(callable_options.feed().begin(),
472                                   callable_options.feed().end());
473   std::vector<string> output_names(callable_options.fetch().begin(),
474                                    callable_options.fetch().end());
475   std::vector<string> target_names(callable_options.target().begin(),
476                                    callable_options.target().end());
477 
478   TF_RETURN_IF_ERROR(debugger_state->get()->PublishDebugMetadata(
479       global_step, session_run_index, executor_step_index, input_names,
480       output_names, target_names));
481   return Status::OK();
482 }
483 
DecorateAndPublishGraphForDebug(const DebugOptions & debug_options,Graph * graph,Device * device)484 Status DirectSession::DecorateAndPublishGraphForDebug(
485     const DebugOptions& debug_options, Graph* graph, Device* device) {
486   std::unique_ptr<DebugGraphDecoratorInterface> decorator;
487   TF_RETURN_IF_ERROR(
488       DebugGraphDecoratorRegistry::CreateDecorator(debug_options, &decorator));
489 
490   TF_RETURN_IF_ERROR(decorator->DecorateGraph(graph, device));
491   TF_RETURN_IF_ERROR(decorator->PublishGraph(*graph, device->name()));
492   return Status::OK();
493 }
494 
RunInternal(int64_t step_id,const RunOptions & run_options,CallFrameInterface * call_frame,ExecutorsAndKeys * executors_and_keys,RunMetadata * run_metadata,const thread::ThreadPoolOptions & threadpool_options)495 Status DirectSession::RunInternal(
496     int64_t step_id, const RunOptions& run_options,
497     CallFrameInterface* call_frame, ExecutorsAndKeys* executors_and_keys,
498     RunMetadata* run_metadata,
499     const thread::ThreadPoolOptions& threadpool_options) {
500   const uint64 start_time_usecs = options_.env->NowMicros();
501   const int64_t executor_step_count =
502       executors_and_keys->step_count.fetch_add(1);
503   RunState run_state(step_id, &devices_);
504   const size_t num_executors = executors_and_keys->items.size();
505 
506   profiler::TraceMeProducer activity(
507       // To TraceMeConsumers in ExecutorState::Process/Finish.
508       [&] {
509         if (options_.config.experimental().has_session_metadata()) {
510           const auto& model_metadata =
511               options_.config.experimental().session_metadata();
512           string model_id = strings::StrCat(model_metadata.name(), ":",
513                                             model_metadata.version());
514           return profiler::TraceMeEncode("SessionRun",
515                                          {{"id", step_id},
516                                           {"_r", 1} /*root_event*/,
517                                           {"model_id", model_id}});
518         } else {
519           return profiler::TraceMeEncode(
520               "SessionRun", {{"id", step_id}, {"_r", 1} /*root_event*/});
521         }
522       },
523       profiler::ContextType::kTfExecutor, step_id,
524       profiler::TraceMeLevel::kInfo);
525 
526   std::unique_ptr<DebuggerStateInterface> debugger_state;
527   if (!run_options.debug_options().debug_tensor_watch_opts().empty()) {
528     TF_RETURN_IF_ERROR(
529         CreateDebuggerState(executors_and_keys->callable_options,
530                             run_options.debug_options().global_step(), step_id,
531                             executor_step_count, &debugger_state));
532   }
533 
534 #ifndef __ANDROID__
535   // Set up for collectives if ExecutorsAndKeys declares a key.
536   if (executors_and_keys->collective_graph_key !=
537       BuildGraphOptions::kNoCollectiveGraphKey) {
538     if (run_options.experimental().collective_graph_key() !=
539         BuildGraphOptions::kNoCollectiveGraphKey) {
540       // If a collective_graph_key was specified in run_options, ensure that it
541       // matches what came out of GraphExecutionState::BuildGraph().
542       if (run_options.experimental().collective_graph_key() !=
543           executors_and_keys->collective_graph_key) {
544         return errors::Internal(
545             "collective_graph_key in RunOptions ",
546             run_options.experimental().collective_graph_key(),
547             " should match collective_graph_key from optimized graph ",
548             executors_and_keys->collective_graph_key);
549       }
550     }
551     if (!collective_executor_mgr_) {
552       collective_executor_mgr_ = CreateProdLocalCollectiveExecutorMgr(
553           options_.config, device_mgr_.get(),
554           MaybeCreateNcclCommunicator(options_.config));
555     }
556     run_state.collective_executor.reset(new CollectiveExecutor::Handle(
557         collective_executor_mgr_->FindOrCreate(step_id), true /*inherit_ref*/));
558   }
559 #endif
560 
561   thread::ThreadPool* pool;
562   // Use std::unique_ptr to ensure garbage collection
563   std::unique_ptr<thread::ThreadPool> threadpool_wrapper;
564 
565   const bool inline_execution_requested =
566       run_in_caller_thread_ || run_options.inter_op_thread_pool() == -1;
567 
568   if (inline_execution_requested) {
569     // We allow using the caller thread only when having a single executor
570     // specified.
571     if (executors_and_keys->items.size() > 1) {
572       pool = thread_pools_[0].first;
573     } else {
574       VLOG(1) << "Executing Session::Run() synchronously!";
575       pool = nullptr;
576     }
577   } else if (threadpool_options.inter_op_threadpool != nullptr) {
578     threadpool_wrapper = absl::make_unique<thread::ThreadPool>(
579         threadpool_options.inter_op_threadpool);
580     pool = threadpool_wrapper.get();
581   } else {
582     if (run_options.inter_op_thread_pool() < -1 ||
583         run_options.inter_op_thread_pool() >=
584             static_cast<int32>(thread_pools_.size())) {
585       return errors::InvalidArgument("Invalid inter_op_thread_pool: ",
586                                      run_options.inter_op_thread_pool());
587     }
588 
589     pool = thread_pools_[run_options.inter_op_thread_pool()].first;
590   }
591 
592   const int64_t call_timeout = run_options.timeout_in_ms() > 0
593                                    ? run_options.timeout_in_ms()
594                                    : operation_timeout_in_ms_;
595 
596   std::unique_ptr<RunHandler> handler;
597   if (ShouldUseRunHandlerPool(run_options) &&
598       run_options.experimental().use_run_handler_pool()) {
599     VLOG(1) << "Using RunHandler to scheduler inter-op closures.";
600     handler = GetOrCreateRunHandlerPool(options_)->Get(
601         step_id, call_timeout,
602         run_options.experimental().run_handler_pool_options());
603     if (!handler) {
604       return errors::DeadlineExceeded(
605           "Could not obtain RunHandler for request after waiting for ",
606           call_timeout, "ms.");
607     }
608   }
609   auto* handler_ptr = handler.get();
610 
611   Executor::Args::Runner default_runner = nullptr;
612 
613   if (pool == nullptr) {
614     default_runner = [](const Executor::Args::Closure& c) { c(); };
615   } else if (handler_ptr != nullptr) {
616     default_runner = [handler_ptr](Executor::Args::Closure c) {
617       handler_ptr->ScheduleInterOpClosure(std::move(c));
618     };
619   } else {
620     default_runner = [pool](Executor::Args::Closure c) {
621       pool->Schedule(std::move(c));
622     };
623   }
624 
625   // Start parallel Executors.
626 
627   // We can execute this step synchronously on the calling thread whenever
628   // there is a single device and the timeout mechanism is not used.
629   //
630   // When timeouts are used, we must execute the graph(s) asynchronously, in
631   // order to invoke the cancellation manager on the calling thread if the
632   // timeout expires.
633   const bool can_execute_synchronously =
634       executors_and_keys->items.size() == 1 && call_timeout == 0;
635 
636   Executor::Args args;
637   args.step_id = step_id;
638   args.call_frame = call_frame;
639   args.collective_executor =
640       (run_state.collective_executor ? run_state.collective_executor->get()
641                                      : nullptr);
642   args.session_state = &session_state_;
643   args.session_handle = session_handle_;
644   args.tensor_store = &run_state.tensor_store;
645   args.step_container = &run_state.step_container;
646   args.sync_on_finish = sync_on_finish_;
647   args.user_intra_op_threadpool = threadpool_options.intra_op_threadpool;
648   args.run_all_kernels_inline = pool == nullptr;
649   args.start_time_usecs = start_time_usecs;
650 
651   const bool do_trace = (run_options.trace_level() > RunOptions::NO_TRACE);
652 
653   bool update_cost_model = false;
654   if (options_.config.graph_options().build_cost_model() > 0) {
655     const int64_t build_cost_model_every =
656         options_.config.graph_options().build_cost_model();
657     const int64_t build_cost_model_after =
658         options_.config.graph_options().build_cost_model_after();
659     int64_t measure_step_count = executor_step_count - build_cost_model_after;
660     if (measure_step_count >= 0) {
661       update_cost_model =
662           ((measure_step_count + 1) % build_cost_model_every == 0);
663     }
664   }
665   if (do_trace || update_cost_model ||
666       run_options.report_tensor_allocations_upon_oom()) {
667     run_state.collector.reset(
668         new StepStatsCollector(run_metadata->mutable_step_stats()));
669     args.stats_collector = run_state.collector.get();
670   }
671 
672   std::unique_ptr<DeviceProfilerSession> device_profiler_session;
673   if (run_options.trace_level() >= RunOptions::HARDWARE_TRACE) {
674     device_profiler_session = DeviceProfilerSession::Create();
675   }
676 
677   // Register this step with session's cancellation manager, so that
678   // `Session::Close()` will cancel the step.
679   CancellationManager step_cancellation_manager(cancellation_manager_);
680   if (step_cancellation_manager.IsCancelled()) {
681     return errors::Cancelled("Run call was cancelled");
682   }
683   args.cancellation_manager = &step_cancellation_manager;
684 
685   Status run_status;
686 
687   auto set_threadpool_args_for_item =
688       [&default_runner, &handler](const PerPartitionExecutorsAndLib& item,
689                                   Executor::Args* args) {
690         // TODO(azaks): support partial run.
691         // TODO(azaks): if the device picks its own threadpool, we need to
692         // assign
693         //     less threads to the main compute pool by default.
694         thread::ThreadPool* device_thread_pool =
695             item.device->tensorflow_device_thread_pool();
696         // TODO(crk): Investigate usage of RunHandlerPool when using device
697         // specific thread pool(s).
698         if (!device_thread_pool) {
699           args->runner = default_runner;
700         } else {
701           args->runner = [device_thread_pool](Executor::Args::Closure c) {
702             device_thread_pool->Schedule(std::move(c));
703           };
704         }
705         if (handler != nullptr) {
706           args->user_intra_op_threadpool =
707               handler->AsIntraThreadPoolInterface();
708         }
709       };
710 
711   if (can_execute_synchronously) {
712     PrivateIntraProcessRendezvous rendezvous(device_mgr_.get());
713     args.rendezvous = &rendezvous;
714 
715     const auto& item = executors_and_keys->items[0];
716     set_threadpool_args_for_item(item, &args);
717     run_status = item.executor->Run(args);
718   } else {
719     core::RefCountPtr<RefCountedIntraProcessRendezvous> rendezvous(
720         new RefCountedIntraProcessRendezvous(device_mgr_.get()));
721     args.rendezvous = rendezvous.get();
722 
723     // `barrier` will delete itself after the final executor finishes.
724     Notification executors_done;
725     ExecutorBarrier* barrier =
726         new ExecutorBarrier(num_executors, rendezvous.get(),
727                             [&run_state, &executors_done](const Status& ret) {
728                               {
729                                 mutex_lock l(run_state.mu);
730                                 run_state.status.Update(ret);
731                               }
732                               executors_done.Notify();
733                             });
734 
735     for (const auto& item : executors_and_keys->items) {
736       set_threadpool_args_for_item(item, &args);
737       item.executor->RunAsync(args, barrier->Get());
738     }
739 
740     WaitForNotification(&executors_done, &run_state, &step_cancellation_manager,
741                         call_timeout);
742     {
743       tf_shared_lock l(run_state.mu);
744       run_status = run_state.status;
745     }
746   }
747 
748   if (step_cancellation_manager.IsCancelled()) {
749     run_status.Update(errors::Cancelled("Run call was cancelled"));
750   }
751 
752   if (device_profiler_session) {
753     TF_RETURN_IF_ERROR(device_profiler_session->CollectData(
754         run_metadata->mutable_step_stats()));
755   }
756 
757   TF_RETURN_IF_ERROR(run_status);
758 
759   // Save the output tensors of this run we choose to keep.
760   if (!run_state.tensor_store.empty()) {
761     TF_RETURN_IF_ERROR(run_state.tensor_store.SaveTensors(
762         {executors_and_keys->callable_options.fetch().begin(),
763          executors_and_keys->callable_options.fetch().end()},
764         &session_state_));
765   }
766 
767   if (run_state.collector) {
768     run_state.collector->Finalize();
769   }
770 
771   // Build and return the cost model as instructed.
772   if (update_cost_model) {
773     // Build the cost model
774     std::unordered_map<string, const Graph*> device_to_graph;
775     for (const PerPartitionExecutorsAndLib& partition :
776          executors_and_keys->items) {
777       const Graph* graph = partition.graph.get();
778       const string& device = partition.flib->device()->name();
779       device_to_graph[device] = graph;
780     }
781 
782     mutex_lock l(executor_lock_);
783     run_state.collector->BuildCostModel(&cost_model_manager_, device_to_graph);
784 
785     // annotate stats onto cost graph.
786     CostGraphDef* cost_graph = run_metadata->mutable_cost_graph();
787     for (const auto& item : executors_and_keys->items) {
788       TF_RETURN_IF_ERROR(
789           cost_model_manager_.AddToCostGraphDef(item.graph.get(), cost_graph));
790     }
791   }
792 
793   // If requested via RunOptions, output the partition graphs.
794   if (run_options.output_partition_graphs()) {
795     if (options_.config.experimental().disable_output_partition_graphs()) {
796       return errors::InvalidArgument(
797           "RunOptions.output_partition_graphs() is not supported when "
798           "disable_output_partition_graphs is true.");
799     } else {
800       protobuf::RepeatedPtrField<GraphDef>* partition_graph_defs =
801           run_metadata->mutable_partition_graphs();
802       for (const PerPartitionExecutorsAndLib& exec_and_lib :
803            executors_and_keys->items) {
804         GraphDef* partition_graph_def = partition_graph_defs->Add();
805         exec_and_lib.graph->ToGraphDef(partition_graph_def);
806       }
807     }
808   }
809   metrics::UpdateGraphExecTime(options_.env->NowMicros() - start_time_usecs);
810 
811   return Status::OK();
812 }
813 
Run(const RunOptions & run_options,const NamedTensorList & inputs,const std::vector<string> & output_names,const std::vector<string> & target_nodes,std::vector<Tensor> * outputs,RunMetadata * run_metadata)814 Status DirectSession::Run(const RunOptions& run_options,
815                           const NamedTensorList& inputs,
816                           const std::vector<string>& output_names,
817                           const std::vector<string>& target_nodes,
818                           std::vector<Tensor>* outputs,
819                           RunMetadata* run_metadata) {
820   return Run(run_options, inputs, output_names, target_nodes, outputs,
821              run_metadata, thread::ThreadPoolOptions());
822 }
823 
Run(const RunOptions & run_options,const NamedTensorList & inputs,const std::vector<string> & output_names,const std::vector<string> & target_nodes,std::vector<Tensor> * outputs,RunMetadata * run_metadata,const thread::ThreadPoolOptions & threadpool_options)824 Status DirectSession::Run(const RunOptions& run_options,
825                           const NamedTensorList& inputs,
826                           const std::vector<string>& output_names,
827                           const std::vector<string>& target_nodes,
828                           std::vector<Tensor>* outputs,
829                           RunMetadata* run_metadata,
830                           const thread::ThreadPoolOptions& threadpool_options) {
831   TF_RETURN_IF_ERROR(CheckNotClosed());
832   TF_RETURN_IF_ERROR(CheckGraphCreated("Run()"));
833   direct_session_runs->GetCell()->IncrementBy(1);
834 
835   // Extract the inputs names for this run of the session.
836   std::vector<string> input_tensor_names;
837   input_tensor_names.reserve(inputs.size());
838   size_t input_size = 0;
839   for (const auto& it : inputs) {
840     input_tensor_names.push_back(it.first);
841     input_size += it.second.AllocatedBytes();
842   }
843   metrics::RecordGraphInputTensors(input_size);
844 
845   // Check if we already have an executor for these arguments.
846   ExecutorsAndKeys* executors_and_keys;
847   RunStateArgs run_state_args(run_options.debug_options());
848   run_state_args.collective_graph_key =
849       run_options.experimental().collective_graph_key();
850 
851   TF_RETURN_IF_ERROR(GetOrCreateExecutors(input_tensor_names, output_names,
852                                           target_nodes, &executors_and_keys,
853                                           &run_state_args));
854   {
855     mutex_lock l(collective_graph_key_lock_);
856     collective_graph_key_ = executors_and_keys->collective_graph_key;
857   }
858 
859   // Configure a call frame for the step, which we use to feed and
860   // fetch values to and from the executors.
861   FunctionCallFrame call_frame(executors_and_keys->input_types,
862                                executors_and_keys->output_types);
863   gtl::InlinedVector<Tensor, 4> feed_args(inputs.size());
864   for (const auto& it : inputs) {
865     if (it.second.dtype() == DT_RESOURCE) {
866       Tensor tensor_from_handle;
867       TF_RETURN_IF_ERROR(
868           ResourceHandleToInputTensor(it.second, &tensor_from_handle));
869       feed_args[executors_and_keys->input_name_to_index[it.first]] =
870           tensor_from_handle;
871     } else {
872       feed_args[executors_and_keys->input_name_to_index[it.first]] = it.second;
873     }
874   }
875   const Status s = call_frame.SetArgs(feed_args);
876   if (errors::IsInternal(s)) {
877     return errors::InvalidArgument(s.error_message());
878   } else if (!s.ok()) {
879     return s;
880   }
881 
882   const int64_t step_id = step_id_counter_.fetch_add(1);
883 
884   if (LogMemory::IsEnabled()) {
885     LogMemory::RecordStep(step_id, run_state_args.handle);
886   }
887 
888   TF_RETURN_IF_ERROR(RunInternal(step_id, run_options, &call_frame,
889                                  executors_and_keys, run_metadata,
890                                  threadpool_options));
891 
892   // Receive outputs.
893   if (outputs) {
894     std::vector<Tensor> sorted_outputs;
895     const Status s = call_frame.ConsumeRetvals(
896         &sorted_outputs, /* allow_dead_tensors = */ false);
897     if (errors::IsInternal(s)) {
898       return errors::InvalidArgument(s.error_message());
899     } else if (!s.ok()) {
900       return s;
901     }
902     const bool unique_outputs =
903         output_names.size() == executors_and_keys->output_name_to_index.size();
904     // first_indices[i] = j implies that j is the smallest value for which
905     // output_names[i] == output_names[j].
906     std::vector<int> first_indices;
907     if (!unique_outputs) {
908       first_indices.reserve(output_names.size());
909       for (const auto& name : output_names) {
910         first_indices.push_back(
911             std::find(output_names.begin(), output_names.end(), name) -
912             output_names.begin());
913       }
914     }
915     outputs->clear();
916     size_t output_size = 0;
917     outputs->reserve(sorted_outputs.size());
918     for (int i = 0; i < output_names.size(); ++i) {
919       const string& output_name = output_names[i];
920       if (first_indices.empty() || first_indices[i] == i) {
921         outputs->emplace_back(
922             std::move(sorted_outputs[executors_and_keys
923                                          ->output_name_to_index[output_name]]));
924       } else {
925         outputs->push_back((*outputs)[first_indices[i]]);
926       }
927       output_size += outputs->back().AllocatedBytes();
928     }
929     metrics::RecordGraphOutputTensors(output_size);
930   }
931 
932   return Status::OK();
933 }
934 
PRunSetup(const std::vector<string> & input_names,const std::vector<string> & output_names,const std::vector<string> & target_nodes,string * handle)935 Status DirectSession::PRunSetup(const std::vector<string>& input_names,
936                                 const std::vector<string>& output_names,
937                                 const std::vector<string>& target_nodes,
938                                 string* handle) {
939   TF_RETURN_IF_ERROR(CheckNotClosed());
940   TF_RETURN_IF_ERROR(CheckGraphCreated("PRunSetup()"));
941 
942   // RunOptions is not available in PRunSetup, so use thread pool 0.
943   thread::ThreadPool* pool = thread_pools_[0].first;
944 
945   // Check if we already have an executor for these arguments.
946   ExecutorsAndKeys* executors_and_keys;
947   // TODO(cais): TFDBG support for partial runs.
948   DebugOptions debug_options;
949   RunStateArgs run_state_args(debug_options);
950   run_state_args.is_partial_run = true;
951   TF_RETURN_IF_ERROR(GetOrCreateExecutors(input_names, output_names,
952                                           target_nodes, &executors_and_keys,
953                                           &run_state_args));
954 
955   // Create the run state and save it for future PRun calls.
956   Executor::Args args;
957   args.step_id = step_id_counter_.fetch_add(1);
958   PartialRunState* run_state =
959       new PartialRunState(input_names, output_names, args.step_id, &devices_);
960   run_state->rendez.reset(new IntraProcessRendezvous(device_mgr_.get()));
961   {
962     mutex_lock l(executor_lock_);
963     if (!partial_runs_
964              .emplace(run_state_args.handle,
965                       std::unique_ptr<PartialRunState>(run_state))
966              .second) {
967       return errors::Internal("The handle '", run_state_args.handle,
968                               "' created for this partial run is not unique.");
969     }
970   }
971 
972   // Start parallel Executors.
973   const size_t num_executors = executors_and_keys->items.size();
974   ExecutorBarrier* barrier = new ExecutorBarrier(
975       num_executors, run_state->rendez.get(), [run_state](const Status& ret) {
976         if (!ret.ok()) {
977           mutex_lock l(run_state->mu);
978           run_state->status.Update(ret);
979         }
980         run_state->executors_done.Notify();
981       });
982 
983   args.rendezvous = run_state->rendez.get();
984   args.cancellation_manager = cancellation_manager_;
985   // Note that Collectives are not supported in partial runs
986   // because RunOptions is not passed in so we can't know whether
987   // their use is intended.
988   args.collective_executor = nullptr;
989   args.runner = [this, pool](Executor::Args::Closure c) {
990     pool->Schedule(std::move(c));
991   };
992   args.session_state = &session_state_;
993   args.session_handle = session_handle_;
994   args.tensor_store = &run_state->tensor_store;
995   args.step_container = &run_state->step_container;
996   if (LogMemory::IsEnabled()) {
997     LogMemory::RecordStep(args.step_id, run_state_args.handle);
998   }
999   args.sync_on_finish = sync_on_finish_;
1000 
1001   if (options_.config.graph_options().build_cost_model()) {
1002     run_state->collector.reset(new StepStatsCollector(nullptr));
1003     args.stats_collector = run_state->collector.get();
1004   }
1005 
1006   for (auto& item : executors_and_keys->items) {
1007     item.executor->RunAsync(args, barrier->Get());
1008   }
1009 
1010   *handle = run_state_args.handle;
1011   return Status::OK();
1012 }
1013 
PRun(const string & handle,const NamedTensorList & inputs,const std::vector<string> & output_names,std::vector<Tensor> * outputs)1014 Status DirectSession::PRun(const string& handle, const NamedTensorList& inputs,
1015                            const std::vector<string>& output_names,
1016                            std::vector<Tensor>* outputs) {
1017   TF_RETURN_IF_ERROR(CheckNotClosed());
1018   std::vector<string> parts = str_util::Split(handle, ';');
1019   const string& key = parts[0];
1020   // Get the executors for this partial run.
1021   ExecutorsAndKeys* executors_and_keys;
1022   PartialRunState* run_state;
1023   {
1024     mutex_lock l(executor_lock_);  // could use reader lock
1025     auto exc_it = executors_.find(key);
1026     if (exc_it == executors_.end()) {
1027       return errors::InvalidArgument(
1028           "Must run 'setup' before performing partial runs!");
1029     }
1030     executors_and_keys = exc_it->second.get();
1031 
1032     auto prun_it = partial_runs_.find(handle);
1033     if (prun_it == partial_runs_.end()) {
1034       return errors::InvalidArgument(
1035           "Must run 'setup' before performing partial runs!");
1036     }
1037     run_state = prun_it->second.get();
1038 
1039     // Make sure that this is a new set of feeds that are still pending.
1040     for (const auto& input : inputs) {
1041       auto it = run_state->pending_inputs.find(input.first);
1042       if (it == run_state->pending_inputs.end()) {
1043         return errors::InvalidArgument(
1044             "The feed ", input.first,
1045             " was not specified in partial_run_setup.");
1046       } else if (it->second) {
1047         return errors::InvalidArgument("The feed ", input.first,
1048                                        " has already been fed.");
1049       }
1050     }
1051     // Check that this is a new set of fetches that are still pending.
1052     for (const auto& output : output_names) {
1053       auto it = run_state->pending_outputs.find(output);
1054       if (it == run_state->pending_outputs.end()) {
1055         return errors::InvalidArgument(
1056             "The fetch ", output, " was not specified in partial_run_setup.");
1057       } else if (it->second) {
1058         return errors::InvalidArgument("The fetch ", output,
1059                                        " has already been fetched.");
1060       }
1061     }
1062   }
1063 
1064   // Check that this new set of fetches can be computed from all the
1065   // feeds we have supplied.
1066   TF_RETURN_IF_ERROR(
1067       CheckFetch(inputs, output_names, executors_and_keys, run_state));
1068 
1069   // Send inputs.
1070   Status s =
1071       SendPRunInputs(inputs, executors_and_keys, run_state->rendez.get());
1072 
1073   // Receive outputs.
1074   if (s.ok()) {
1075     s = RecvPRunOutputs(output_names, executors_and_keys, run_state, outputs);
1076   }
1077 
1078   // Save the output tensors of this run we choose to keep.
1079   if (s.ok()) {
1080     s = run_state->tensor_store.SaveTensors(output_names, &session_state_);
1081   }
1082 
1083   {
1084     mutex_lock l(executor_lock_);
1085     // Delete the run state if there is an error or all fetches are done.
1086     bool done = true;
1087     if (s.ok()) {
1088       {
1089         mutex_lock l(run_state->mu);
1090         if (!run_state->status.ok()) {
1091           LOG(WARNING) << "An error unrelated to this prun has been detected. "
1092                        << run_state->status;
1093         }
1094       }
1095       for (const auto& input : inputs) {
1096         auto it = run_state->pending_inputs.find(input.first);
1097         it->second = true;
1098       }
1099       for (const auto& name : output_names) {
1100         auto it = run_state->pending_outputs.find(name);
1101         it->second = true;
1102       }
1103       done = run_state->PendingDone();
1104     }
1105     if (done) {
1106       WaitForNotification(&run_state->executors_done, run_state,
1107                           cancellation_manager_, operation_timeout_in_ms_);
1108       partial_runs_.erase(handle);
1109     }
1110   }
1111 
1112   return s;
1113 }
1114 
ResourceHandleToInputTensor(const Tensor & resource_tensor,Tensor * retrieved_tensor)1115 Status DirectSession::ResourceHandleToInputTensor(const Tensor& resource_tensor,
1116                                                   Tensor* retrieved_tensor) {
1117   if (resource_tensor.dtype() != DT_RESOURCE) {
1118     return errors::InvalidArgument(strings::StrCat(
1119         "ResourceHandleToInputTensor() received non-DT_RESOURCE Tensor: ",
1120         resource_tensor.dtype()));
1121   }
1122 
1123   const ResourceHandle& resource_handle =
1124       resource_tensor.scalar<ResourceHandle>()();
1125 
1126   if (resource_handle.container() ==
1127       SessionState::kTensorHandleResourceTypeName) {
1128     return session_state_.GetTensor(resource_handle.name(), retrieved_tensor);
1129   } else {
1130     return errors::InvalidArgument(strings::StrCat(
1131         "Invalid resource type hash code: ", resource_handle.hash_code(),
1132         "(name: ", resource_handle.name(),
1133         " type: ", resource_handle.maybe_type_name(),
1134         "). Perhaps a resource tensor was being provided as a feed? That is "
1135         "not currently allowed. Please file an issue at "
1136         "https://github.com/tensorflow/tensorflow/issues/new, ideally with a "
1137         "short code snippet that leads to this error message."));
1138   }
1139 }
1140 
SendPRunInputs(const NamedTensorList & inputs,const ExecutorsAndKeys * executors_and_keys,IntraProcessRendezvous * rendez)1141 Status DirectSession::SendPRunInputs(const NamedTensorList& inputs,
1142                                      const ExecutorsAndKeys* executors_and_keys,
1143                                      IntraProcessRendezvous* rendez) {
1144   Status s;
1145   Rendezvous::ParsedKey parsed;
1146   // Insert the input tensors into the local rendezvous by their
1147   // rendezvous key.
1148   for (const auto& input : inputs) {
1149     auto it =
1150         executors_and_keys->input_name_to_rendezvous_key.find(input.first);
1151     if (it == executors_and_keys->input_name_to_rendezvous_key.end()) {
1152       return errors::Internal("'", input.first, "' is not a pre-defined feed.");
1153     }
1154     const string& input_key = it->second;
1155 
1156     s = Rendezvous::ParseKey(input_key, &parsed);
1157     if (!s.ok()) {
1158       rendez->StartAbort(s);
1159       return s;
1160     }
1161 
1162     if (input.second.dtype() == DT_RESOURCE) {
1163       Tensor tensor_from_handle;
1164       s = ResourceHandleToInputTensor(input.second, &tensor_from_handle);
1165       if (s.ok()) {
1166         s = rendez->Send(parsed, Rendezvous::Args(), tensor_from_handle, false);
1167       }
1168     } else {
1169       s = rendez->Send(parsed, Rendezvous::Args(), input.second, false);
1170     }
1171 
1172     if (!s.ok()) {
1173       rendez->StartAbort(s);
1174       return s;
1175     }
1176   }
1177   return Status::OK();
1178 }
1179 
RecvPRunOutputs(const std::vector<string> & output_names,const ExecutorsAndKeys * executors_and_keys,PartialRunState * run_state,std::vector<Tensor> * outputs)1180 Status DirectSession::RecvPRunOutputs(
1181     const std::vector<string>& output_names,
1182     const ExecutorsAndKeys* executors_and_keys, PartialRunState* run_state,
1183     std::vector<Tensor>* outputs) {
1184   Status s;
1185   if (!output_names.empty()) {
1186     outputs->resize(output_names.size());
1187   }
1188 
1189   Rendezvous::ParsedKey parsed;
1190   // Get the outputs from the rendezvous
1191   for (size_t output_offset = 0; output_offset < output_names.size();
1192        ++output_offset) {
1193     const string& output_name = output_names[output_offset];
1194     auto it =
1195         executors_and_keys->output_name_to_rendezvous_key.find(output_name);
1196     if (it == executors_and_keys->output_name_to_rendezvous_key.end()) {
1197       return errors::Internal("'", output_name,
1198                               "' is not a pre-defined fetch.");
1199     }
1200     const string& output_key = it->second;
1201     Tensor output_tensor;
1202     bool is_dead;
1203 
1204     s = Rendezvous::ParseKey(output_key, &parsed);
1205     if (s.ok()) {
1206       // Fetch data from the Rendezvous.
1207       s = run_state->rendez->Recv(parsed, Rendezvous::Args(), &output_tensor,
1208                                   &is_dead, operation_timeout_in_ms_);
1209       if (is_dead && s.ok()) {
1210         s = errors::InvalidArgument("The tensor returned for ", output_name,
1211                                     " was not valid.");
1212       }
1213     }
1214     if (!s.ok()) {
1215       run_state->rendez->StartAbort(s);
1216       outputs->clear();
1217       return s;
1218     }
1219 
1220     (*outputs)[output_offset] = output_tensor;
1221   }
1222   return Status::OK();
1223 }
1224 
CheckFetch(const NamedTensorList & feeds,const std::vector<string> & fetches,const ExecutorsAndKeys * executors_and_keys,const PartialRunState * run_state)1225 Status DirectSession::CheckFetch(const NamedTensorList& feeds,
1226                                  const std::vector<string>& fetches,
1227                                  const ExecutorsAndKeys* executors_and_keys,
1228                                  const PartialRunState* run_state) {
1229   const Graph* graph = executors_and_keys->graph.get();
1230   const NameNodeMap* name_to_node = &executors_and_keys->name_to_node;
1231 
1232   // Build the set of pending feeds that we haven't seen.
1233   std::unordered_set<TensorId, TensorId::Hasher> pending_feeds;
1234   {
1235     mutex_lock l(executor_lock_);
1236     for (const auto& input : run_state->pending_inputs) {
1237       // Skip if the feed has already been fed.
1238       if (input.second) continue;
1239       TensorId id(ParseTensorName(input.first));
1240       auto it = name_to_node->find(id.first);
1241       if (it == name_to_node->end()) {
1242         return errors::NotFound("Feed ", input.first, ": not found");
1243       }
1244       pending_feeds.insert(id);
1245     }
1246   }
1247   for (const auto& it : feeds) {
1248     TensorId id(ParseTensorName(it.first));
1249     pending_feeds.erase(id);
1250   }
1251 
1252   // Initialize the stack with the fetch nodes.
1253   std::vector<const Node*> stack;
1254   for (const string& fetch : fetches) {
1255     TensorId id(ParseTensorName(fetch));
1256     auto it = name_to_node->find(id.first);
1257     if (it == name_to_node->end()) {
1258       return errors::NotFound("Fetch ", fetch, ": not found");
1259     }
1260     stack.push_back(it->second);
1261   }
1262 
1263   // Any tensor needed for fetches can't be in pending_feeds.
1264   std::vector<bool> visited(graph->num_node_ids(), false);
1265   while (!stack.empty()) {
1266     const Node* n = stack.back();
1267     stack.pop_back();
1268 
1269     for (const Edge* in_edge : n->in_edges()) {
1270       const Node* in_node = in_edge->src();
1271       if (pending_feeds.count({in_node->name(), in_edge->src_output()}) > 0) {
1272         return errors::InvalidArgument("Fetch ", in_node->name(), ":",
1273                                        in_edge->src_output(),
1274                                        " can't be computed from the feeds"
1275                                        " that have been fed so far.");
1276       }
1277       if (!visited[in_node->id()]) {
1278         visited[in_node->id()] = true;
1279         stack.push_back(in_node);
1280       }
1281     }
1282   }
1283   return Status::OK();
1284 }
1285 
CreateExecutors(const CallableOptions & callable_options,std::unique_ptr<ExecutorsAndKeys> * out_executors_and_keys,std::unique_ptr<FunctionInfo> * out_func_info,RunStateArgs * run_state_args)1286 Status DirectSession::CreateExecutors(
1287     const CallableOptions& callable_options,
1288     std::unique_ptr<ExecutorsAndKeys>* out_executors_and_keys,
1289     std::unique_ptr<FunctionInfo>* out_func_info,
1290     RunStateArgs* run_state_args) {
1291   BuildGraphOptions options;
1292   options.callable_options = callable_options;
1293   options.use_function_convention = !run_state_args->is_partial_run;
1294   options.collective_graph_key =
1295       callable_options.run_options().experimental().collective_graph_key();
1296   if (options_.config.experimental()
1297           .collective_deterministic_sequential_execution()) {
1298     options.collective_order = GraphCollectiveOrder::kEdges;
1299   } else if (options_.config.experimental().collective_nccl()) {
1300     options.collective_order = GraphCollectiveOrder::kAttrs;
1301   }
1302 
1303   std::unique_ptr<FunctionInfo> func_info(new FunctionInfo);
1304   std::unique_ptr<ExecutorsAndKeys> ek(new ExecutorsAndKeys);
1305 
1306   ek->callable_options = callable_options;
1307 
1308   std::unordered_map<string, std::unique_ptr<Graph>> graphs;
1309   TF_RETURN_IF_ERROR(CreateGraphs(
1310       options, &graphs, &func_info->flib_def, run_state_args, &ek->input_types,
1311       &ek->output_types, &ek->collective_graph_key));
1312 
1313   if (run_state_args->is_partial_run) {
1314     ek->graph = std::move(run_state_args->graph);
1315     std::unordered_set<StringPiece, StringPieceHasher> names;
1316     for (const string& input : callable_options.feed()) {
1317       TensorId id(ParseTensorName(input));
1318       names.emplace(id.first);
1319     }
1320     for (const string& output : callable_options.fetch()) {
1321       TensorId id(ParseTensorName(output));
1322       names.emplace(id.first);
1323     }
1324     for (Node* n : ek->graph->nodes()) {
1325       if (names.count(n->name()) > 0) {
1326         ek->name_to_node.insert({n->name(), n});
1327       }
1328     }
1329   }
1330   ek->items.reserve(graphs.size());
1331   const auto& optimizer_opts =
1332       options_.config.graph_options().optimizer_options();
1333 
1334   int graph_def_version = graphs.begin()->second->versions().producer();
1335 
1336   const auto* session_metadata =
1337       options_.config.experimental().has_session_metadata()
1338           ? &options_.config.experimental().session_metadata()
1339           : nullptr;
1340   func_info->proc_flr.reset(new ProcessFunctionLibraryRuntime(
1341       device_mgr_.get(), options_.env, &options_.config, graph_def_version,
1342       func_info->flib_def.get(), optimizer_opts, thread_pools_[0].first,
1343       /*parent=*/nullptr, session_metadata,
1344       Rendezvous::Factory{
1345           [](const int64_t, const DeviceMgr* device_mgr, Rendezvous** r) {
1346             *r = new IntraProcessRendezvous(device_mgr);
1347             return Status::OK();
1348           }}));
1349 
1350   GraphOptimizer optimizer(optimizer_opts);
1351   for (auto iter = graphs.begin(); iter != graphs.end(); ++iter) {
1352     const string& partition_name = iter->first;
1353     std::unique_ptr<Graph>& partition_graph = iter->second;
1354 
1355     Device* device;
1356     TF_RETURN_IF_ERROR(device_mgr_->LookupDevice(partition_name, &device));
1357 
1358     ek->items.resize(ek->items.size() + 1);
1359     auto* item = &(ek->items.back());
1360     auto lib = func_info->proc_flr->GetFLR(partition_name);
1361     if (lib == nullptr) {
1362       return errors::Internal("Could not find device: ", partition_name);
1363     }
1364     item->flib = lib;
1365 
1366     LocalExecutorParams params;
1367     params.device = device;
1368     params.session_metadata = session_metadata;
1369     params.function_library = lib;
1370     auto opseg = device->op_segment();
1371     params.create_kernel =
1372         [this, lib, opseg](const std::shared_ptr<const NodeProperties>& props,
1373                            OpKernel** kernel) {
1374           // NOTE(mrry): We must not share function kernels (implemented
1375           // using `CallOp`) between subgraphs, because `CallOp::handle_`
1376           // is tied to a particular subgraph. Even if the function itself
1377           // is stateful, the `CallOp` that invokes it is not.
1378           if (!OpSegment::ShouldOwnKernel(lib, props->node_def.op())) {
1379             return lib->CreateKernel(props, kernel);
1380           }
1381           auto create_fn = [lib, &props](OpKernel** kernel) {
1382             return lib->CreateKernel(props, kernel);
1383           };
1384           // Kernels created for subgraph nodes need to be cached.  On
1385           // cache miss, create_fn() is invoked to create a kernel based
1386           // on the function library here + global op registry.
1387           return opseg->FindOrCreate(session_handle_, props->node_def.name(),
1388                                      kernel, create_fn);
1389         };
1390     params.delete_kernel = [lib](OpKernel* kernel) {
1391       if (kernel && !OpSegment::ShouldOwnKernel(lib, kernel->type_string()))
1392         delete kernel;
1393     };
1394 
1395     optimizer.Optimize(lib, options_.env, device, &partition_graph,
1396                        /*shape_map=*/nullptr);
1397 
1398     // TensorFlow Debugger (tfdbg) inserts debug nodes in the graph.
1399     const DebugOptions& debug_options =
1400         options.callable_options.run_options().debug_options();
1401     if (!debug_options.debug_tensor_watch_opts().empty()) {
1402       TF_RETURN_IF_ERROR(DecorateAndPublishGraphForDebug(
1403           debug_options, partition_graph.get(), params.device));
1404     }
1405 
1406     TF_RETURN_IF_ERROR(EnsureMemoryTypes(DeviceType(device->device_type()),
1407                                          device->name(),
1408                                          partition_graph.get()));
1409 
1410     item->executor = nullptr;
1411     item->device = device;
1412     auto executor_type = options_.config.experimental().executor_type();
1413     TF_RETURN_IF_ERROR(
1414         NewExecutor(executor_type, params, *partition_graph, &item->executor));
1415     if (!options_.config.experimental().disable_output_partition_graphs() ||
1416         options_.config.graph_options().build_cost_model() > 0) {
1417       item->graph = std::move(partition_graph);
1418     }
1419   }
1420 
1421   // Cache the mapping from input/output names to graph elements to
1422   // avoid recomputing it every time.
1423   if (!run_state_args->is_partial_run) {
1424     // For regular `Run()`, we use the function calling convention, and so
1425     // maintain a mapping from input/output names to
1426     // argument/return-value ordinal index.
1427     for (int i = 0; i < callable_options.feed().size(); ++i) {
1428       const string& input = callable_options.feed(i);
1429       ek->input_name_to_index[input] = i;
1430     }
1431     for (int i = 0; i < callable_options.fetch().size(); ++i) {
1432       const string& output = callable_options.fetch(i);
1433       ek->output_name_to_index[output] = i;
1434     }
1435   } else {
1436     // For `PRun()`, we use the rendezvous calling convention, and so
1437     // maintain a mapping from input/output names to rendezvous keys.
1438     //
1439     // We always use the first device as the device name portion of the
1440     // key, even if we're feeding another graph.
1441     for (int i = 0; i < callable_options.feed().size(); ++i) {
1442       const string& input = callable_options.feed(i);
1443       ek->input_name_to_rendezvous_key[input] = GetRendezvousKey(
1444           input, device_set_.client_device()->attributes(), FrameAndIter(0, 0));
1445     }
1446     for (int i = 0; i < callable_options.fetch().size(); ++i) {
1447       const string& output = callable_options.fetch(i);
1448       ek->output_name_to_rendezvous_key[output] =
1449           GetRendezvousKey(output, device_set_.client_device()->attributes(),
1450                            FrameAndIter(0, 0));
1451     }
1452   }
1453 
1454   *out_executors_and_keys = std::move(ek);
1455   *out_func_info = std::move(func_info);
1456   return Status::OK();
1457 }
1458 
GetOrCreateExecutors(gtl::ArraySlice<string> inputs,gtl::ArraySlice<string> outputs,gtl::ArraySlice<string> target_nodes,ExecutorsAndKeys ** executors_and_keys,RunStateArgs * run_state_args)1459 Status DirectSession::GetOrCreateExecutors(
1460     gtl::ArraySlice<string> inputs, gtl::ArraySlice<string> outputs,
1461     gtl::ArraySlice<string> target_nodes, ExecutorsAndKeys** executors_and_keys,
1462     RunStateArgs* run_state_args) {
1463   int64_t handle_name_counter_value = -1;
1464   if (LogMemory::IsEnabled() || run_state_args->is_partial_run) {
1465     handle_name_counter_value = handle_name_counter_.fetch_add(1);
1466   }
1467 
1468   string debug_tensor_watches_summary;
1469   if (!run_state_args->debug_options.debug_tensor_watch_opts().empty()) {
1470     debug_tensor_watches_summary = SummarizeDebugTensorWatches(
1471         run_state_args->debug_options.debug_tensor_watch_opts());
1472   }
1473 
1474   // Fast lookup path, no sorting.
1475   const string key = strings::StrCat(
1476       absl::StrJoin(inputs, ","), "->", absl::StrJoin(outputs, ","), "/",
1477       absl::StrJoin(target_nodes, ","), "/", run_state_args->is_partial_run,
1478       "/", debug_tensor_watches_summary);
1479   // Set the handle, if it's needed to log memory or for partial run.
1480   if (handle_name_counter_value >= 0) {
1481     run_state_args->handle =
1482         strings::StrCat(key, ";", handle_name_counter_value);
1483   }
1484 
1485   // See if we already have the executors for this run.
1486   {
1487     mutex_lock l(executor_lock_);  // could use reader lock
1488     auto it = executors_.find(key);
1489     if (it != executors_.end()) {
1490       *executors_and_keys = it->second.get();
1491       return Status::OK();
1492     }
1493   }
1494 
1495   // Slow lookup path, the unsorted key missed the cache.
1496   // Sort the inputs and outputs, and look up with the sorted key in case an
1497   // earlier call used a different order of inputs and outputs.
1498   //
1499   // We could consider some other signature instead of sorting that
1500   // preserves the same property to avoid the sort in the future.
1501   std::vector<string> inputs_sorted(inputs.begin(), inputs.end());
1502   std::sort(inputs_sorted.begin(), inputs_sorted.end());
1503   std::vector<string> outputs_sorted(outputs.begin(), outputs.end());
1504   std::sort(outputs_sorted.begin(), outputs_sorted.end());
1505   std::vector<string> tn_sorted(target_nodes.begin(), target_nodes.end());
1506   std::sort(tn_sorted.begin(), tn_sorted.end());
1507 
1508   const string sorted_key = strings::StrCat(
1509       absl::StrJoin(inputs_sorted, ","), "->",
1510       absl::StrJoin(outputs_sorted, ","), "/", absl::StrJoin(tn_sorted, ","),
1511       "/", run_state_args->is_partial_run, "/", debug_tensor_watches_summary);
1512   // Set the handle, if its needed to log memory or for partial run.
1513   if (handle_name_counter_value >= 0) {
1514     run_state_args->handle =
1515         strings::StrCat(sorted_key, ";", handle_name_counter_value);
1516   }
1517 
1518   // See if we already have the executors for this run.
1519   {
1520     mutex_lock l(executor_lock_);
1521     auto it = executors_.find(sorted_key);
1522     if (it != executors_.end()) {
1523       *executors_and_keys = it->second.get();
1524       return Status::OK();
1525     }
1526   }
1527 
1528   // Nothing found, so create the executors and store in the cache.
1529   // The executor_lock_ is intentionally released while executors are
1530   // being created.
1531   CallableOptions callable_options;
1532   callable_options.mutable_feed()->Reserve(inputs_sorted.size());
1533   for (const string& input : inputs_sorted) {
1534     callable_options.add_feed(input);
1535   }
1536   callable_options.mutable_fetch()->Reserve(outputs_sorted.size());
1537   for (const string& output : outputs_sorted) {
1538     callable_options.add_fetch(output);
1539   }
1540   callable_options.mutable_target()->Reserve(tn_sorted.size());
1541   for (const string& target : tn_sorted) {
1542     callable_options.add_target(target);
1543   }
1544   *callable_options.mutable_run_options()->mutable_debug_options() =
1545       run_state_args->debug_options;
1546   callable_options.mutable_run_options()
1547       ->mutable_experimental()
1548       ->set_collective_graph_key(run_state_args->collective_graph_key);
1549   std::unique_ptr<ExecutorsAndKeys> ek;
1550   std::unique_ptr<FunctionInfo> func_info;
1551   TF_RETURN_IF_ERROR(
1552       CreateExecutors(callable_options, &ek, &func_info, run_state_args));
1553 
1554   // Reacquire the lock, try to insert into the map.
1555   mutex_lock l(executor_lock_);
1556 
1557   // Another thread may have created the entry before us, in which case we will
1558   // reuse the already created one.
1559   auto insert_result = executors_.emplace(
1560       sorted_key, std::shared_ptr<ExecutorsAndKeys>(std::move(ek)));
1561   if (insert_result.second) {
1562     functions_.push_back(std::move(func_info));
1563   }
1564 
1565   // Insert the value under the original key, so the fast path lookup will work
1566   // if the user uses the same order of inputs, outputs, and targets again.
1567   executors_.emplace(key, insert_result.first->second);
1568   *executors_and_keys = insert_result.first->second.get();
1569 
1570   return Status::OK();
1571 }
1572 
CreateGraphs(const BuildGraphOptions & subgraph_options,std::unordered_map<string,std::unique_ptr<Graph>> * outputs,std::unique_ptr<FunctionLibraryDefinition> * flib_def,RunStateArgs * run_state_args,DataTypeVector * input_types,DataTypeVector * output_types,int64 * collective_graph_key)1573 Status DirectSession::CreateGraphs(
1574     const BuildGraphOptions& subgraph_options,
1575     std::unordered_map<string, std::unique_ptr<Graph>>* outputs,
1576     std::unique_ptr<FunctionLibraryDefinition>* flib_def,
1577     RunStateArgs* run_state_args, DataTypeVector* input_types,
1578     DataTypeVector* output_types, int64* collective_graph_key) {
1579   mutex_lock l(graph_state_lock_);
1580   if (finalized_) {
1581     return errors::FailedPrecondition("Session has been finalized.");
1582   }
1583 
1584   std::unique_ptr<ClientGraph> client_graph;
1585 
1586   std::unique_ptr<GraphExecutionState> temp_exec_state_holder;
1587   GraphExecutionState* execution_state = nullptr;
1588   if (options_.config.graph_options().place_pruned_graph()) {
1589     // Because we are placing pruned graphs, we need to create a
1590     // new GraphExecutionState for every new unseen graph,
1591     // and then place it.
1592     GraphExecutionStateOptions prune_options;
1593     prune_options.device_set = &device_set_;
1594     prune_options.session_options = &options_;
1595     prune_options.stateful_placements = stateful_placements_;
1596     prune_options.session_handle = session_handle_;
1597     TF_RETURN_IF_ERROR(GraphExecutionState::MakeForPrunedGraph(
1598         *execution_state_, prune_options, subgraph_options,
1599         &temp_exec_state_holder, &client_graph));
1600     execution_state = temp_exec_state_holder.get();
1601   } else {
1602     execution_state = execution_state_.get();
1603     TF_RETURN_IF_ERROR(
1604         execution_state->BuildGraph(subgraph_options, &client_graph));
1605   }
1606   *collective_graph_key = client_graph->collective_graph_key;
1607 
1608   if (subgraph_options.callable_options.feed_size() !=
1609       client_graph->feed_types.size()) {
1610     return errors::Internal(
1611         "Graph pruning failed: requested number of feed endpoints = ",
1612         subgraph_options.callable_options.feed_size(),
1613         " versus number of pruned feed endpoints = ",
1614         client_graph->feed_types.size());
1615   }
1616   if (subgraph_options.callable_options.fetch_size() !=
1617       client_graph->fetch_types.size()) {
1618     return errors::Internal(
1619         "Graph pruning failed: requested number of fetch endpoints = ",
1620         subgraph_options.callable_options.fetch_size(),
1621         " versus number of pruned fetch endpoints = ",
1622         client_graph->fetch_types.size());
1623   }
1624 
1625   auto current_stateful_placements = execution_state->GetStatefulPlacements();
1626   // Update our current state based on the execution_state's
1627   // placements.  If there are any mismatches for a node,
1628   // we should fail, as this should never happen.
1629   for (const auto& placement_pair : current_stateful_placements) {
1630     const string& node_name = placement_pair.first;
1631     const string& placement = placement_pair.second;
1632     auto iter = stateful_placements_.find(node_name);
1633     if (iter == stateful_placements_.end()) {
1634       stateful_placements_.insert(std::make_pair(node_name, placement));
1635     } else if (iter->second != placement) {
1636       return errors::Internal(
1637           "Stateful placement mismatch. "
1638           "Current assignment of ",
1639           node_name, " to ", iter->second, " does not match ", placement);
1640     }
1641   }
1642 
1643   stateful_placements_ = execution_state->GetStatefulPlacements();
1644 
1645   // Remember the graph in run state if this is a partial run.
1646   if (run_state_args->is_partial_run) {
1647     run_state_args->graph.reset(new Graph(flib_def_.get()));
1648     CopyGraph(*execution_state->full_graph(), run_state_args->graph.get());
1649   }
1650 
1651   // Partition the graph across devices.
1652   PartitionOptions popts;
1653   popts.node_to_loc = [](const Node* node) {
1654     return node->assigned_device_name();
1655   };
1656   popts.new_name = [this](const string& prefix) {
1657     return strings::StrCat(prefix, "/_", edge_name_counter_.fetch_add(1));
1658   };
1659   popts.get_incarnation = [](const string& name) {
1660     // The direct session does not have changing incarnation numbers.
1661     // Just return '1'.
1662     return 1;
1663   };
1664   popts.flib_def = flib_def->get();
1665   popts.control_flow_added = false;
1666 
1667   std::unordered_map<string, GraphDef> partitions;
1668   TF_RETURN_IF_ERROR(Partition(popts, &client_graph->graph, &partitions));
1669 
1670   std::vector<string> device_names;
1671   for (auto device : devices_) {
1672     // Extract the LocalName from the device.
1673     device_names.push_back(DeviceNameUtils::LocalName(device->name()));
1674   }
1675 
1676   // Check for valid partitions.
1677   for (const auto& partition : partitions) {
1678     const string local_partition_name =
1679         DeviceNameUtils::LocalName(partition.first);
1680     if (std::count(device_names.begin(), device_names.end(),
1681                    local_partition_name) == 0) {
1682       return errors::InvalidArgument(
1683           "Creating a partition for ", local_partition_name,
1684           " which doesn't exist in the list of available devices. Available "
1685           "devices: ",
1686           absl::StrJoin(device_names, ","));
1687     }
1688   }
1689 
1690   for (auto& partition : partitions) {
1691     std::unique_ptr<Graph> device_graph(
1692         new Graph(client_graph->flib_def.get()));
1693     device_graph->SetConstructionContext(ConstructionContext::kDirectSession);
1694     GraphConstructorOptions device_opts;
1695     // There are internal operations (e.g., send/recv) that we now allow.
1696     device_opts.allow_internal_ops = true;
1697     device_opts.expect_device_spec = true;
1698     TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(
1699         device_opts, std::move(partition.second), device_graph.get()));
1700     outputs->emplace(partition.first, std::move(device_graph));
1701   }
1702 
1703   GraphOptimizationPassOptions optimization_options;
1704   optimization_options.session_options = &options_;
1705   optimization_options.flib_def = client_graph->flib_def.get();
1706   optimization_options.partition_graphs = outputs;
1707   TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
1708       OptimizationPassRegistry::POST_PARTITIONING, optimization_options));
1709 
1710   Status s;
1711   for (auto& partition : *outputs) {
1712     const string& partition_name = partition.first;
1713     std::unique_ptr<Graph>* graph = &partition.second;
1714 
1715     VLOG(2) << "Created " << DebugString(graph->get()) << " for "
1716             << partition_name;
1717 
1718     // Give the device an opportunity to rewrite its subgraph.
1719     Device* d;
1720     s = device_mgr_->LookupDevice(partition_name, &d);
1721     if (!s.ok()) break;
1722     s = d->MaybeRewriteGraph(graph);
1723     if (!s.ok()) {
1724       break;
1725     }
1726   }
1727   *flib_def = std::move(client_graph->flib_def);
1728   std::swap(*input_types, client_graph->feed_types);
1729   std::swap(*output_types, client_graph->fetch_types);
1730   return s;
1731 }
1732 
ListDevices(std::vector<DeviceAttributes> * response)1733 ::tensorflow::Status DirectSession::ListDevices(
1734     std::vector<DeviceAttributes>* response) {
1735   response->clear();
1736   response->reserve(devices_.size());
1737   for (Device* d : devices_) {
1738     const DeviceAttributes& attrs = d->attributes();
1739     response->emplace_back(attrs);
1740   }
1741   return ::tensorflow::Status::OK();
1742 }
1743 
Reset(const std::vector<string> & containers)1744 ::tensorflow::Status DirectSession::Reset(
1745     const std::vector<string>& containers) {
1746   device_mgr_->ClearContainers(containers);
1747   return ::tensorflow::Status::OK();
1748 }
1749 
Close()1750 ::tensorflow::Status DirectSession::Close() {
1751   cancellation_manager_->StartCancel();
1752   {
1753     mutex_lock l(closed_lock_);
1754     if (closed_) return ::tensorflow::Status::OK();
1755     closed_ = true;
1756   }
1757   if (factory_ != nullptr) factory_->Deregister(this);
1758   return ::tensorflow::Status::OK();
1759 }
1760 
RunState(int64_t step_id,const std::vector<Device * > * devices)1761 DirectSession::RunState::RunState(int64_t step_id,
1762                                   const std::vector<Device*>* devices)
1763     : step_container(step_id, [devices, step_id](const string& name) {
1764         for (auto d : *devices) {
1765           if (!d->resource_manager()->Cleanup(name).ok()) {
1766             // Do nothing...
1767           }
1768           ScopedAllocatorMgr* sam = d->GetScopedAllocatorMgr();
1769           if (sam) sam->Cleanup(step_id);
1770         }
1771       }) {}
1772 
PartialRunState(const std::vector<string> & pending_input_names,const std::vector<string> & pending_output_names,int64_t step_id,const std::vector<Device * > * devices)1773 DirectSession::PartialRunState::PartialRunState(
1774     const std::vector<string>& pending_input_names,
1775     const std::vector<string>& pending_output_names, int64_t step_id,
1776     const std::vector<Device*>* devices)
1777     : RunState(step_id, devices) {
1778   // Initially all the feeds and fetches are pending.
1779   for (auto& name : pending_input_names) {
1780     pending_inputs[name] = false;
1781   }
1782   for (auto& name : pending_output_names) {
1783     pending_outputs[name] = false;
1784   }
1785 }
1786 
~PartialRunState()1787 DirectSession::PartialRunState::~PartialRunState() {
1788   if (rendez != nullptr) {
1789     rendez->StartAbort(errors::Cancelled("PRun cancellation"));
1790     executors_done.WaitForNotification();
1791   }
1792 }
1793 
PendingDone() const1794 bool DirectSession::PartialRunState::PendingDone() const {
1795   for (const auto& it : pending_inputs) {
1796     if (!it.second) return false;
1797   }
1798   for (const auto& it : pending_outputs) {
1799     if (!it.second) return false;
1800   }
1801   return true;
1802 }
1803 
WaitForNotification(Notification * n,RunState * run_state,CancellationManager * cm,int64_t timeout_in_ms)1804 void DirectSession::WaitForNotification(Notification* n, RunState* run_state,
1805                                         CancellationManager* cm,
1806                                         int64_t timeout_in_ms) {
1807   const Status status = WaitForNotification(n, timeout_in_ms);
1808   if (!status.ok()) {
1809     {
1810       mutex_lock l(run_state->mu);
1811       run_state->status.Update(status);
1812     }
1813     cm->StartCancel();
1814     // We must wait for the executors to complete, because they have borrowed
1815     // references to `cm` and other per-step state. After this notification, it
1816     // is safe to clean up the step.
1817     n->WaitForNotification();
1818   }
1819 }
1820 
WaitForNotification(Notification * notification,int64_t timeout_in_ms)1821 ::tensorflow::Status DirectSession::WaitForNotification(
1822     Notification* notification, int64_t timeout_in_ms) {
1823   if (timeout_in_ms > 0) {
1824     const int64_t timeout_in_us = timeout_in_ms * 1000;
1825     const bool notified =
1826         WaitForNotificationWithTimeout(notification, timeout_in_us);
1827     if (!notified) {
1828       return Status(error::DEADLINE_EXCEEDED,
1829                     "Timed out waiting for notification");
1830     }
1831   } else {
1832     notification->WaitForNotification();
1833   }
1834   return Status::OK();
1835 }
1836 
MakeCallable(const CallableOptions & callable_options,CallableHandle * out_handle)1837 Status DirectSession::MakeCallable(const CallableOptions& callable_options,
1838                                    CallableHandle* out_handle) {
1839   TF_RETURN_IF_ERROR(CheckNotClosed());
1840   TF_RETURN_IF_ERROR(CheckGraphCreated("MakeCallable()"));
1841 
1842   std::unique_ptr<ExecutorsAndKeys> ek;
1843   std::unique_ptr<FunctionInfo> func_info;
1844   RunStateArgs run_state_args(callable_options.run_options().debug_options());
1845   TF_RETURN_IF_ERROR(
1846       CreateExecutors(callable_options, &ek, &func_info, &run_state_args));
1847   {
1848     mutex_lock l(callables_lock_);
1849     *out_handle = next_callable_handle_++;
1850     callables_[*out_handle] = {std::move(ek), std::move(func_info)};
1851   }
1852   return Status::OK();
1853 }
1854 
1855 class DirectSession::RunCallableCallFrame : public CallFrameInterface {
1856  public:
RunCallableCallFrame(DirectSession * session,ExecutorsAndKeys * executors_and_keys,const std::vector<Tensor> * feed_tensors,std::vector<Tensor> * fetch_tensors)1857   RunCallableCallFrame(DirectSession* session,
1858                        ExecutorsAndKeys* executors_and_keys,
1859                        const std::vector<Tensor>* feed_tensors,
1860                        std::vector<Tensor>* fetch_tensors)
1861       : session_(session),
1862         executors_and_keys_(executors_and_keys),
1863         feed_tensors_(feed_tensors),
1864         fetch_tensors_(fetch_tensors) {}
1865 
num_args() const1866   size_t num_args() const override {
1867     return executors_and_keys_->input_types.size();
1868   }
num_retvals() const1869   size_t num_retvals() const override {
1870     return executors_and_keys_->output_types.size();
1871   }
1872 
GetArg(int index,const Tensor ** val)1873   Status GetArg(int index, const Tensor** val) override {
1874     if (TF_PREDICT_FALSE(index > feed_tensors_->size())) {
1875       return errors::Internal("Args index out of bounds: ", index);
1876     } else {
1877       *val = &(*feed_tensors_)[index];
1878     }
1879     return Status::OK();
1880   }
1881 
SetRetval(int index,const Tensor & val)1882   Status SetRetval(int index, const Tensor& val) override {
1883     if (index > fetch_tensors_->size()) {
1884       return errors::Internal("RetVal index out of bounds: ", index);
1885     }
1886     (*fetch_tensors_)[index] = val;
1887     return Status::OK();
1888   }
1889 
1890  private:
1891   DirectSession* const session_;                   // Not owned.
1892   ExecutorsAndKeys* const executors_and_keys_;     // Not owned.
1893   const std::vector<Tensor>* const feed_tensors_;  // Not owned.
1894   std::vector<Tensor>* const fetch_tensors_;       // Not owned.
1895 };
1896 
RunCallable(CallableHandle handle,const std::vector<Tensor> & feed_tensors,std::vector<Tensor> * fetch_tensors,RunMetadata * run_metadata)1897 ::tensorflow::Status DirectSession::RunCallable(
1898     CallableHandle handle, const std::vector<Tensor>& feed_tensors,
1899     std::vector<Tensor>* fetch_tensors, RunMetadata* run_metadata) {
1900   return RunCallable(handle, feed_tensors, fetch_tensors, run_metadata,
1901                      thread::ThreadPoolOptions());
1902 }
1903 
RunCallable(CallableHandle handle,const std::vector<Tensor> & feed_tensors,std::vector<Tensor> * fetch_tensors,RunMetadata * run_metadata,const thread::ThreadPoolOptions & threadpool_options)1904 ::tensorflow::Status DirectSession::RunCallable(
1905     CallableHandle handle, const std::vector<Tensor>& feed_tensors,
1906     std::vector<Tensor>* fetch_tensors, RunMetadata* run_metadata,
1907     const thread::ThreadPoolOptions& threadpool_options) {
1908   TF_RETURN_IF_ERROR(CheckNotClosed());
1909   TF_RETURN_IF_ERROR(CheckGraphCreated("RunCallable()"));
1910   direct_session_runs->GetCell()->IncrementBy(1);
1911 
1912   // Check if we already have an executor for these arguments.
1913   std::shared_ptr<ExecutorsAndKeys> executors_and_keys;
1914   const int64_t step_id = step_id_counter_.fetch_add(1);
1915 
1916   {
1917     tf_shared_lock l(callables_lock_);
1918     if (handle >= next_callable_handle_) {
1919       return errors::InvalidArgument("No such callable handle: ", handle);
1920     }
1921     executors_and_keys = callables_[handle].executors_and_keys;
1922   }
1923 
1924   if (!executors_and_keys) {
1925     return errors::InvalidArgument(
1926         "Attempted to run callable after handle was released: ", handle);
1927   }
1928 
1929   // NOTE(mrry): Debug options are not currently supported in the
1930   // callable interface.
1931   DebugOptions debug_options;
1932   RunStateArgs run_state_args(debug_options);
1933 
1934   // Configure a call frame for the step, which we use to feed and
1935   // fetch values to and from the executors.
1936   if (feed_tensors.size() != executors_and_keys->input_types.size()) {
1937     return errors::InvalidArgument(
1938         "Expected ", executors_and_keys->input_types.size(),
1939         " feed tensors, but got ", feed_tensors.size());
1940   }
1941   if (fetch_tensors != nullptr) {
1942     fetch_tensors->resize(executors_and_keys->output_types.size());
1943   } else if (!executors_and_keys->output_types.empty()) {
1944     return errors::InvalidArgument(
1945         "`fetch_tensors` must be provided when the callable has one or more "
1946         "outputs.");
1947   }
1948 
1949   size_t input_size = 0;
1950   bool any_resource_feeds = false;
1951   for (auto& tensor : feed_tensors) {
1952     input_size += tensor.AllocatedBytes();
1953     any_resource_feeds = any_resource_feeds || tensor.dtype() == DT_RESOURCE;
1954   }
1955   metrics::RecordGraphInputTensors(input_size);
1956 
1957   std::unique_ptr<std::vector<Tensor>> converted_feed_tensors;
1958   const std::vector<Tensor>* actual_feed_tensors;
1959 
1960   if (TF_PREDICT_FALSE(any_resource_feeds)) {
1961     converted_feed_tensors = absl::make_unique<std::vector<Tensor>>();
1962     converted_feed_tensors->reserve(feed_tensors.size());
1963     for (const Tensor& t : feed_tensors) {
1964       if (t.dtype() == DT_RESOURCE) {
1965         converted_feed_tensors->emplace_back();
1966         Tensor* tensor_from_handle = &converted_feed_tensors->back();
1967         TF_RETURN_IF_ERROR(ResourceHandleToInputTensor(t, tensor_from_handle));
1968       } else {
1969         converted_feed_tensors->emplace_back(t);
1970       }
1971     }
1972     actual_feed_tensors = converted_feed_tensors.get();
1973   } else {
1974     actual_feed_tensors = &feed_tensors;
1975   }
1976 
1977   // A specialized CallFrame implementation that takes advantage of the
1978   // optimized RunCallable interface.
1979   RunCallableCallFrame call_frame(this, executors_and_keys.get(),
1980                                   actual_feed_tensors, fetch_tensors);
1981 
1982   if (LogMemory::IsEnabled()) {
1983     LogMemory::RecordStep(step_id, run_state_args.handle);
1984   }
1985 
1986   TF_RETURN_IF_ERROR(RunInternal(
1987       step_id, executors_and_keys->callable_options.run_options(), &call_frame,
1988       executors_and_keys.get(), run_metadata, threadpool_options));
1989 
1990   if (fetch_tensors != nullptr) {
1991     size_t output_size = 0;
1992     for (auto& tensor : *fetch_tensors) {
1993       output_size += tensor.AllocatedBytes();
1994     }
1995     metrics::RecordGraphOutputTensors(output_size);
1996   }
1997 
1998   return Status::OK();
1999 }
2000 
ReleaseCallable(CallableHandle handle)2001 ::tensorflow::Status DirectSession::ReleaseCallable(CallableHandle handle) {
2002   mutex_lock l(callables_lock_);
2003   if (handle >= next_callable_handle_) {
2004     return errors::InvalidArgument("No such callable handle: ", handle);
2005   }
2006   callables_.erase(handle);
2007   return Status::OK();
2008 }
2009 
Finalize()2010 Status DirectSession::Finalize() {
2011   mutex_lock l(graph_state_lock_);
2012   if (finalized_) {
2013     return errors::FailedPrecondition("Session already finalized.");
2014   }
2015   if (!graph_created_) {
2016     return errors::FailedPrecondition("Session not yet created.");
2017   }
2018   execution_state_.reset();
2019   flib_def_.reset();
2020   finalized_ = true;
2021   return Status::OK();
2022 }
2023 
~Callable()2024 DirectSession::Callable::~Callable() {
2025   // We must delete the fields in this order, because the destructor
2026   // of `executors_and_keys` will call into an object owned by
2027   // `function_info` (in particular, when deleting a kernel, it relies
2028   // on the `FunctionLibraryRuntime` to know if the kernel is stateful
2029   // or not).
2030   executors_and_keys.reset();
2031   function_info.reset();
2032 }
2033 
2034 }  // namespace tensorflow
2035