• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/common_runtime/direct_session.h"
17 
18 #include <algorithm>
19 #include <atomic>
20 #include <string>
21 #include <vector>
22 
23 #include "absl/container/flat_hash_set.h"
24 #include "tensorflow/core/common_runtime/collective_executor_mgr.h"
25 #include "tensorflow/core/common_runtime/collective_param_resolver_local.h"
26 #include "tensorflow/core/common_runtime/constant_folding.h"
27 #include "tensorflow/core/common_runtime/debugger_state_interface.h"
28 #include "tensorflow/core/common_runtime/device_factory.h"
29 #include "tensorflow/core/common_runtime/device_resolver_local.h"
30 #include "tensorflow/core/common_runtime/executor.h"
31 #include "tensorflow/core/common_runtime/executor_factory.h"
32 #include "tensorflow/core/common_runtime/function.h"
33 #include "tensorflow/core/common_runtime/graph_constructor.h"
34 #include "tensorflow/core/common_runtime/graph_optimizer.h"
35 #include "tensorflow/core/common_runtime/memory_types.h"
36 #include "tensorflow/core/common_runtime/metrics.h"
37 #include "tensorflow/core/common_runtime/optimization_registry.h"
38 #include "tensorflow/core/common_runtime/process_util.h"
39 #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
40 #include "tensorflow/core/common_runtime/scoped_allocator_mgr.h"
41 #include "tensorflow/core/common_runtime/step_stats_collector.h"
42 #include "tensorflow/core/framework/function.h"
43 #include "tensorflow/core/framework/graph.pb.h"
44 #include "tensorflow/core/framework/graph_def_util.h"
45 #include "tensorflow/core/framework/log_memory.h"
46 #include "tensorflow/core/framework/logging.h"
47 #include "tensorflow/core/framework/node_def.pb.h"
48 #include "tensorflow/core/framework/run_handler.h"
49 #include "tensorflow/core/framework/tensor.h"
50 #include "tensorflow/core/framework/versions.pb.h"
51 #include "tensorflow/core/graph/algorithm.h"
52 #include "tensorflow/core/graph/graph.h"
53 #include "tensorflow/core/graph/graph_partition.h"
54 #include "tensorflow/core/graph/subgraph.h"
55 #include "tensorflow/core/graph/tensor_id.h"
56 #include "tensorflow/core/lib/core/errors.h"
57 #include "tensorflow/core/lib/core/notification.h"
58 #include "tensorflow/core/lib/core/refcount.h"
59 #include "tensorflow/core/lib/core/status.h"
60 #include "tensorflow/core/lib/core/threadpool.h"
61 #include "tensorflow/core/lib/core/threadpool_options.h"
62 #include "tensorflow/core/lib/gtl/array_slice.h"
63 #include "tensorflow/core/lib/monitoring/counter.h"
64 #include "tensorflow/core/lib/random/random.h"
65 #include "tensorflow/core/lib/strings/numbers.h"
66 #include "tensorflow/core/lib/strings/str_util.h"
67 #include "tensorflow/core/lib/strings/strcat.h"
68 #include "tensorflow/core/nccl/collective_communicator.h"
69 #include "tensorflow/core/platform/byte_order.h"
70 #include "tensorflow/core/platform/cpu_info.h"
71 #include "tensorflow/core/platform/logging.h"
72 #include "tensorflow/core/platform/mutex.h"
73 #include "tensorflow/core/platform/tracing.h"
74 #include "tensorflow/core/platform/types.h"
75 #include "tensorflow/core/profiler/lib/connected_traceme.h"
76 #include "tensorflow/core/profiler/lib/profiler_session.h"
77 #include "tensorflow/core/profiler/lib/traceme_encode.h"
78 #include "tensorflow/core/protobuf/config.pb.h"
79 #include "tensorflow/core/util/device_name_utils.h"
80 #include "tensorflow/core/util/env_var.h"
81 
82 namespace tensorflow {
83 
84 namespace {
85 
86 auto* direct_session_runs = monitoring::Counter<0>::New(
87     "/tensorflow/core/direct_session_runs",
88     "The number of times DirectSession::Run() has been called.");
89 
NewThreadPoolFromThreadPoolOptions(const SessionOptions & options,const ThreadPoolOptionProto & thread_pool_options,int pool_number,thread::ThreadPool ** pool,bool * owned)90 Status NewThreadPoolFromThreadPoolOptions(
91     const SessionOptions& options,
92     const ThreadPoolOptionProto& thread_pool_options, int pool_number,
93     thread::ThreadPool** pool, bool* owned) {
94   int32 num_threads = thread_pool_options.num_threads();
95   if (num_threads == 0) {
96     num_threads = NumInterOpThreadsFromSessionOptions(options);
97   }
98   const string& name = thread_pool_options.global_name();
99   if (name.empty()) {
100     // Session-local threadpool.
101     VLOG(1) << "Direct session inter op parallelism threads for pool "
102             << pool_number << ": " << num_threads;
103     *pool = new thread::ThreadPool(
104         options.env, ThreadOptions(), strings::StrCat("Compute", pool_number),
105         num_threads, !options.config.experimental().disable_thread_spinning(),
106         /*allocator=*/nullptr);
107     *owned = true;
108     return Status::OK();
109   }
110 
111   // Global, named threadpool.
112   typedef std::pair<int32, thread::ThreadPool*> MapValue;
113   static std::map<string, MapValue>* global_pool_map =
114       new std::map<string, MapValue>;
115   static mutex* mu = new mutex();
116   mutex_lock l(*mu);
117   MapValue* mvalue = &(*global_pool_map)[name];
118   if (mvalue->second == nullptr) {
119     mvalue->first = thread_pool_options.num_threads();
120     mvalue->second = new thread::ThreadPool(
121         options.env, ThreadOptions(), strings::StrCat("Compute", pool_number),
122         num_threads, !options.config.experimental().disable_thread_spinning(),
123         /*allocator=*/nullptr);
124   } else {
125     if (mvalue->first != thread_pool_options.num_threads()) {
126       return errors::InvalidArgument(
127           "Pool ", name,
128           " configured previously with num_threads=", mvalue->first,
129           "; cannot re-configure with num_threads=",
130           thread_pool_options.num_threads());
131     }
132   }
133   *owned = false;
134   *pool = mvalue->second;
135   return Status::OK();
136 }
137 
GlobalThreadPool(const SessionOptions & options)138 thread::ThreadPool* GlobalThreadPool(const SessionOptions& options) {
139   static thread::ThreadPool* const thread_pool =
140       NewThreadPoolFromSessionOptions(options);
141   return thread_pool;
142 }
143 
144 // TODO(vrv): Figure out how to unify the many different functions
145 // that generate RendezvousKey, since many of them have to be
146 // consistent with each other.
GetRendezvousKey(const string & tensor_name,const DeviceAttributes & device_info,const FrameAndIter & frame_iter)147 string GetRendezvousKey(const string& tensor_name,
148                         const DeviceAttributes& device_info,
149                         const FrameAndIter& frame_iter) {
150   return strings::StrCat(device_info.name(), ";",
151                          strings::FpToString(device_info.incarnation()), ";",
152                          device_info.name(), ";", tensor_name, ";",
153                          frame_iter.frame_id, ":", frame_iter.iter_id);
154 }
155 
156 }  // namespace
157 
158 class DirectSessionFactory : public SessionFactory {
159  public:
DirectSessionFactory()160   DirectSessionFactory() {}
161 
AcceptsOptions(const SessionOptions & options)162   bool AcceptsOptions(const SessionOptions& options) override {
163     return options.target.empty();
164   }
165 
NewSession(const SessionOptions & options,Session ** out_session)166   Status NewSession(const SessionOptions& options,
167                     Session** out_session) override {
168     const auto& experimental_config = options.config.experimental();
169     if (experimental_config.has_session_metadata()) {
170       if (experimental_config.session_metadata().version() < 0) {
171         return errors::InvalidArgument(
172             "Session version shouldn't be negative: ",
173             experimental_config.session_metadata().DebugString());
174       }
175       const string key = GetMetadataKey(experimental_config.session_metadata());
176       mutex_lock l(sessions_lock_);
177       if (!session_metadata_keys_.insert(key).second) {
178         return errors::InvalidArgument(
179             "A session with the same name and version has already been "
180             "created: ",
181             experimental_config.session_metadata().DebugString());
182       }
183     }
184 
185     // Must do this before the CPU allocator is created.
186     if (options.config.graph_options().build_cost_model() > 0) {
187       EnableCPUAllocatorFullStats();
188     }
189     std::vector<std::unique_ptr<Device>> devices;
190     TF_RETURN_IF_ERROR(DeviceFactory::AddDevices(
191         options, "/job:localhost/replica:0/task:0", &devices));
192 
193     DirectSession* session = new DirectSession(
194         options, new StaticDeviceMgr(std::move(devices)), this);
195     {
196       mutex_lock l(sessions_lock_);
197       sessions_.push_back(session);
198     }
199     *out_session = session;
200     return Status::OK();
201   }
202 
Reset(const SessionOptions & options,const std::vector<string> & containers)203   Status Reset(const SessionOptions& options,
204                const std::vector<string>& containers) override {
205     std::vector<DirectSession*> sessions_to_reset;
206     {
207       mutex_lock l(sessions_lock_);
208       // We create a copy to ensure that we don't have a deadlock when
209       // session->Close calls the DirectSessionFactory.Deregister, which
210       // acquires sessions_lock_.
211       std::swap(sessions_to_reset, sessions_);
212     }
213     Status s;
214     for (auto session : sessions_to_reset) {
215       s.Update(session->Reset(containers));
216     }
217     // TODO(suharshs): Change the Reset behavior of all SessionFactories so that
218     // it doesn't close the sessions?
219     for (auto session : sessions_to_reset) {
220       s.Update(session->Close());
221     }
222     return s;
223   }
224 
Deregister(const DirectSession * session)225   void Deregister(const DirectSession* session) {
226     mutex_lock l(sessions_lock_);
227     sessions_.erase(std::remove(sessions_.begin(), sessions_.end(), session),
228                     sessions_.end());
229     if (session->options().config.experimental().has_session_metadata()) {
230       session_metadata_keys_.erase(GetMetadataKey(
231           session->options().config.experimental().session_metadata()));
232     }
233   }
234 
235  private:
GetMetadataKey(const SessionMetadata & metadata)236   static string GetMetadataKey(const SessionMetadata& metadata) {
237     return absl::StrCat(metadata.name(), "/", metadata.version());
238   }
239 
240   mutex sessions_lock_;
241   std::vector<DirectSession*> sessions_ TF_GUARDED_BY(sessions_lock_);
242   absl::flat_hash_set<string> session_metadata_keys_
243       TF_GUARDED_BY(sessions_lock_);
244 };
245 
246 class DirectSessionRegistrar {
247  public:
DirectSessionRegistrar()248   DirectSessionRegistrar() {
249     SessionFactory::Register("DIRECT_SESSION", new DirectSessionFactory());
250   }
251 };
252 static DirectSessionRegistrar registrar;
253 
254 std::atomic_int_fast64_t DirectSession::step_id_counter_(1);
255 
GetOrCreateRunHandlerPool(const SessionOptions & options)256 static RunHandlerPool* GetOrCreateRunHandlerPool(
257     const SessionOptions& options) {
258   int num_inter_threads = 0;
259   int num_intra_threads = 0;
260   static const int env_num_inter_threads = NumInterOpThreadsFromEnvironment();
261   static const int env_num_intra_threads = NumIntraOpThreadsFromEnvironment();
262   if (env_num_inter_threads > 0) {
263     num_inter_threads = env_num_inter_threads;
264   }
265   if (env_num_intra_threads > 0) {
266     num_intra_threads = env_num_intra_threads;
267   }
268 
269   if (num_inter_threads == 0) {
270     if (options.config.session_inter_op_thread_pool_size() > 0) {
271       // Note due to ShouldUseRunHandler we are guaranteed that
272       // run_options.inter_op_thread_pool() == 0
273       num_inter_threads =
274           options.config.session_inter_op_thread_pool(0).num_threads();
275     }
276     if (num_inter_threads == 0) {
277       num_inter_threads = NumInterOpThreadsFromSessionOptions(options);
278     }
279   }
280 
281   if (num_intra_threads == 0) {
282     num_intra_threads = options.config.intra_op_parallelism_threads();
283     if (num_intra_threads == 0) {
284       num_intra_threads = port::MaxParallelism();
285     }
286   }
287 
288   static RunHandlerPool* pool =
289       new RunHandlerPool(num_inter_threads, num_intra_threads);
290   return pool;
291 }
292 
ShouldUseRunHandlerPool(const RunOptions & run_options) const293 bool DirectSession::ShouldUseRunHandlerPool(
294     const RunOptions& run_options) const {
295   if (options_.config.use_per_session_threads()) return false;
296   if (options_.config.session_inter_op_thread_pool_size() > 0 &&
297       run_options.inter_op_thread_pool() > 0)
298     return false;
299   // Only use RunHandlerPool when:
300   // a. Single global thread pool is used for inter-op parallelism.
301   // b. When multiple inter_op_thread_pool(s) are created, use it only while
302   // running sessions on the default inter_op_thread_pool=0. Typically,
303   // servo-team uses inter_op_thread_pool > 0 for model loading.
304   // TODO(crk): Revisit whether we'd want to create one (static) RunHandlerPool
305   // per entry in session_inter_op_thread_pool() in the future.
306   return true;
307 }
308 
DirectSession(const SessionOptions & options,const DeviceMgr * device_mgr,DirectSessionFactory * const factory)309 DirectSession::DirectSession(const SessionOptions& options,
310                              const DeviceMgr* device_mgr,
311                              DirectSessionFactory* const factory)
312     : options_(options),
313       device_mgr_(device_mgr),
314       factory_(factory),
315       cancellation_manager_(new CancellationManager()),
316       operation_timeout_in_ms_(options_.config.operation_timeout_in_ms()) {
317   const int thread_pool_size =
318       options_.config.session_inter_op_thread_pool_size();
319   if (thread_pool_size > 0) {
320     for (int i = 0; i < thread_pool_size; ++i) {
321       thread::ThreadPool* pool = nullptr;
322       bool owned = false;
323       init_error_.Update(NewThreadPoolFromThreadPoolOptions(
324           options_, options_.config.session_inter_op_thread_pool(i), i, &pool,
325           &owned));
326       thread_pools_.emplace_back(pool, owned);
327     }
328   } else if (options_.config.use_per_session_threads()) {
329     thread_pools_.emplace_back(NewThreadPoolFromSessionOptions(options_),
330                                true /* owned */);
331   } else {
332     thread_pools_.emplace_back(GlobalThreadPool(options), false /* owned */);
333     // Run locally if environment value of TF_NUM_INTEROP_THREADS is negative
334     // and config.inter_op_parallelism_threads is unspecified or negative.
335     static const int env_num_threads = NumInterOpThreadsFromEnvironment();
336     if (options_.config.inter_op_parallelism_threads() < 0 ||
337         (options_.config.inter_op_parallelism_threads() == 0 &&
338          env_num_threads < 0)) {
339       run_in_caller_thread_ = true;
340     }
341   }
342   // The default value of sync_on_finish will be flipped soon and this
343   // environment variable will be removed as well.
344   const Status status =
345       ReadBoolFromEnvVar("TF_SYNC_ON_FINISH", true, &sync_on_finish_);
346   if (!status.ok()) {
347     LOG(ERROR) << status.error_message();
348   }
349   session_handle_ =
350       strings::StrCat("direct", strings::FpToString(random::New64()));
351   int devices_added = 0;
352   if (options.config.log_device_placement()) {
353     const string mapping_str = device_mgr_->DeviceMappingString();
354     string msg;
355     if (mapping_str.empty()) {
356       msg = "Device mapping: no known devices.";
357     } else {
358       msg = strings::StrCat("Device mapping:\n", mapping_str);
359     }
360     if (!logging::LogToListeners(msg)) {
361       LOG(INFO) << msg;
362     }
363   }
364   for (auto d : device_mgr_->ListDevices()) {
365     devices_.push_back(d);
366     device_set_.AddDevice(d);
367     d->op_segment()->AddHold(session_handle_);
368 
369     // The first device added is special: it is the 'client device' (a
370     // CPU device) from which we feed and fetch Tensors.
371     if (devices_added == 0) {
372       device_set_.set_client_device(d);
373     }
374     ++devices_added;
375   }
376 }
377 
~DirectSession()378 DirectSession::~DirectSession() {
379   if (!closed_) Close().IgnoreError();
380   for (auto& it : partial_runs_) {
381     it.second.reset(nullptr);
382   }
383   for (auto& it : executors_) {
384     it.second.reset();
385   }
386   callables_.clear();
387   for (auto d : device_mgr_->ListDevices()) {
388     d->op_segment()->RemoveHold(session_handle_);
389   }
390   functions_.clear();
391   delete cancellation_manager_;
392   for (const auto& p_and_owned : thread_pools_) {
393     if (p_and_owned.second) delete p_and_owned.first;
394   }
395 
396   execution_state_.reset(nullptr);
397   flib_def_.reset(nullptr);
398 }
399 
Create(const GraphDef & graph)400 Status DirectSession::Create(const GraphDef& graph) {
401   return Create(GraphDef(graph));
402 }
403 
Create(GraphDef && graph)404 Status DirectSession::Create(GraphDef&& graph) {
405   TF_RETURN_IF_ERROR(init_error_);
406   if (graph.node_size() > 0) {
407     mutex_lock l(graph_state_lock_);
408     if (graph_created_) {
409       return errors::AlreadyExists(
410           "A Graph has already been created for this session.");
411     }
412     return ExtendLocked(std::move(graph));
413   }
414   return Status::OK();
415 }
416 
Extend(const GraphDef & graph)417 Status DirectSession::Extend(const GraphDef& graph) {
418   return Extend(GraphDef(graph));
419 }
420 
Extend(GraphDef && graph)421 Status DirectSession::Extend(GraphDef&& graph) {
422   TF_RETURN_IF_ERROR(CheckNotClosed());
423   mutex_lock l(graph_state_lock_);
424   return ExtendLocked(std::move(graph));
425 }
426 
ExtendLocked(GraphDef graph)427 Status DirectSession::ExtendLocked(GraphDef graph) {
428   if (finalized_) {
429     return errors::FailedPrecondition("Session has been finalized.");
430   }
431   if (!(flib_def_ && execution_state_)) {
432     // If this is the first call, we can initialize the execution state
433     // with `graph` and do not need to call `Extend()`.
434     // NOTE(mrry): The function library created here will be used for
435     // all subsequent extensions of the graph.
436     flib_def_.reset(
437         new FunctionLibraryDefinition(OpRegistry::Global(), graph.library()));
438     GraphExecutionStateOptions options;
439     options.device_set = &device_set_;
440     options.session_options = &options_;
441     options.session_handle = session_handle_;
442     TF_RETURN_IF_ERROR(GraphExecutionState::MakeForBaseGraph(
443         std::move(graph), options, &execution_state_));
444     graph_created_ = true;
445   } else {
446     TF_RETURN_IF_ERROR(flib_def_->AddLibrary(graph.library()));
447     std::unique_ptr<GraphExecutionState> state;
448     // TODO(mrry): Rewrite GraphExecutionState::Extend() to take `graph` by
449     // value and move `graph` in here.
450     TF_RETURN_IF_ERROR(execution_state_->Extend(graph, &state));
451     execution_state_.swap(state);
452   }
453   return Status::OK();
454 }
455 
Run(const NamedTensorList & inputs,const std::vector<string> & output_names,const std::vector<string> & target_nodes,std::vector<Tensor> * outputs)456 Status DirectSession::Run(const NamedTensorList& inputs,
457                           const std::vector<string>& output_names,
458                           const std::vector<string>& target_nodes,
459                           std::vector<Tensor>* outputs) {
460   RunMetadata run_metadata;
461   return Run(RunOptions(), inputs, output_names, target_nodes, outputs,
462              &run_metadata);
463 }
464 
CreateDebuggerState(const CallableOptions & callable_options,int64 global_step,int64 session_run_index,int64 executor_step_index,std::unique_ptr<DebuggerStateInterface> * debugger_state)465 Status DirectSession::CreateDebuggerState(
466     const CallableOptions& callable_options, int64 global_step,
467     int64 session_run_index, int64 executor_step_index,
468     std::unique_ptr<DebuggerStateInterface>* debugger_state) {
469   TF_RETURN_IF_ERROR(DebuggerStateRegistry::CreateState(
470       callable_options.run_options().debug_options(), debugger_state));
471   std::vector<string> input_names(callable_options.feed().begin(),
472                                   callable_options.feed().end());
473   std::vector<string> output_names(callable_options.fetch().begin(),
474                                    callable_options.fetch().end());
475   std::vector<string> target_names(callable_options.target().begin(),
476                                    callable_options.target().end());
477 
478   TF_RETURN_IF_ERROR(debugger_state->get()->PublishDebugMetadata(
479       global_step, session_run_index, executor_step_index, input_names,
480       output_names, target_names));
481   return Status::OK();
482 }
483 
DecorateAndPublishGraphForDebug(const DebugOptions & debug_options,Graph * graph,Device * device)484 Status DirectSession::DecorateAndPublishGraphForDebug(
485     const DebugOptions& debug_options, Graph* graph, Device* device) {
486   std::unique_ptr<DebugGraphDecoratorInterface> decorator;
487   TF_RETURN_IF_ERROR(
488       DebugGraphDecoratorRegistry::CreateDecorator(debug_options, &decorator));
489 
490   TF_RETURN_IF_ERROR(decorator->DecorateGraph(graph, device));
491   TF_RETURN_IF_ERROR(decorator->PublishGraph(*graph, device->name()));
492   return Status::OK();
493 }
494 
RunInternal(int64 step_id,const RunOptions & run_options,CallFrameInterface * call_frame,ExecutorsAndKeys * executors_and_keys,RunMetadata * run_metadata,const thread::ThreadPoolOptions & threadpool_options)495 Status DirectSession::RunInternal(
496     int64 step_id, const RunOptions& run_options,
497     CallFrameInterface* call_frame, ExecutorsAndKeys* executors_and_keys,
498     RunMetadata* run_metadata,
499     const thread::ThreadPoolOptions& threadpool_options) {
500   const uint64 start_time_usecs = options_.env->NowMicros();
501   const int64 executor_step_count = executors_and_keys->step_count.fetch_add(1);
502   RunState run_state(step_id, &devices_);
503   const size_t num_executors = executors_and_keys->items.size();
504 
505   profiler::TraceMeProducer activity(
506       // To TraceMeConsumers in ExecutorState::Process/Finish.
507       [&] {
508         if (options_.config.experimental().has_session_metadata()) {
509           const auto& model_metadata =
510               options_.config.experimental().session_metadata();
511           string model_id = strings::StrCat(model_metadata.name(), ":",
512                                             model_metadata.version());
513           return profiler::TraceMeEncode("SessionRun",
514                                          {{"id", step_id},
515                                           {"_r", 1} /*root_event*/,
516                                           {"model_id", model_id}});
517         } else {
518           return profiler::TraceMeEncode(
519               "SessionRun", {{"id", step_id}, {"_r", 1} /*root_event*/});
520         }
521       },
522       profiler::ContextType::kTfExecutor, step_id,
523       profiler::TraceMeLevel::kInfo);
524 
525   std::unique_ptr<DebuggerStateInterface> debugger_state;
526   if (!run_options.debug_options().debug_tensor_watch_opts().empty()) {
527     TF_RETURN_IF_ERROR(
528         CreateDebuggerState(executors_and_keys->callable_options,
529                             run_options.debug_options().global_step(), step_id,
530                             executor_step_count, &debugger_state));
531   }
532 
533 #ifndef __ANDROID__
534   // Set up for collectives if ExecutorsAndKeys declares a key.
535   if (executors_and_keys->collective_graph_key !=
536       BuildGraphOptions::kNoCollectiveGraphKey) {
537     if (run_options.experimental().collective_graph_key() !=
538         BuildGraphOptions::kNoCollectiveGraphKey) {
539       // If a collective_graph_key was specified in run_options, ensure that it
540       // matches what came out of GraphExecutionState::BuildGraph().
541       if (run_options.experimental().collective_graph_key() !=
542           executors_and_keys->collective_graph_key) {
543         return errors::Internal(
544             "collective_graph_key in RunOptions ",
545             run_options.experimental().collective_graph_key(),
546             " should match collective_graph_key from optimized graph ",
547             executors_and_keys->collective_graph_key);
548       }
549     }
550     if (!collective_executor_mgr_) {
551       std::unique_ptr<DeviceResolverInterface> drl(
552           new DeviceResolverLocal(device_mgr_.get()));
553       std::unique_ptr<ParamResolverInterface> cprl(
554           new CollectiveParamResolverLocal(options_.config, device_mgr_.get(),
555                                            drl.get(),
556                                            "/job:localhost/replica:0/task:0"));
557       collective_executor_mgr_.reset(new CollectiveExecutorMgr(
558           options_.config, device_mgr_.get(), std::move(drl), std::move(cprl),
559           MaybeCreateNcclCommunicator()));
560     }
561     run_state.collective_executor.reset(new CollectiveExecutor::Handle(
562         collective_executor_mgr_->FindOrCreate(step_id), true /*inherit_ref*/));
563   }
564 #endif
565 
566   thread::ThreadPool* pool;
567   // Use std::unique_ptr to ensure garbage collection
568   std::unique_ptr<thread::ThreadPool> threadpool_wrapper;
569 
570   const bool inline_execution_requested =
571       run_in_caller_thread_ || run_options.inter_op_thread_pool() == -1;
572 
573   if (inline_execution_requested) {
574     // We allow using the caller thread only when having a single executor
575     // specified.
576     if (executors_and_keys->items.size() > 1) {
577       pool = thread_pools_[0].first;
578     } else {
579       VLOG(1) << "Executing Session::Run() synchronously!";
580       pool = nullptr;
581     }
582   } else if (threadpool_options.inter_op_threadpool != nullptr) {
583     threadpool_wrapper = absl::make_unique<thread::ThreadPool>(
584         threadpool_options.inter_op_threadpool);
585     pool = threadpool_wrapper.get();
586   } else {
587     if (run_options.inter_op_thread_pool() < -1 ||
588         run_options.inter_op_thread_pool() >=
589             static_cast<int32>(thread_pools_.size())) {
590       return errors::InvalidArgument("Invalid inter_op_thread_pool: ",
591                                      run_options.inter_op_thread_pool());
592     }
593 
594     pool = thread_pools_[run_options.inter_op_thread_pool()].first;
595   }
596 
597   const int64 call_timeout = run_options.timeout_in_ms() > 0
598                                  ? run_options.timeout_in_ms()
599                                  : operation_timeout_in_ms_;
600 
601   std::unique_ptr<RunHandler> handler;
602   if (ShouldUseRunHandlerPool(run_options) &&
603       run_options.experimental().use_run_handler_pool()) {
604     VLOG(1) << "Using RunHandler to scheduler inter-op closures.";
605     handler = GetOrCreateRunHandlerPool(options_)->Get(
606         step_id, call_timeout,
607         run_options.experimental().run_handler_pool_options());
608     if (!handler) {
609       return errors::DeadlineExceeded(
610           "Could not obtain RunHandler for request after waiting for ",
611           call_timeout, "ms.");
612     }
613   }
614   auto* handler_ptr = handler.get();
615 
616   Executor::Args::Runner default_runner = nullptr;
617 
618   if (pool == nullptr) {
619     default_runner = [](const Executor::Args::Closure& c) { c(); };
620   } else if (handler_ptr != nullptr) {
621     default_runner = [handler_ptr](Executor::Args::Closure c) {
622       handler_ptr->ScheduleInterOpClosure(std::move(c));
623     };
624   } else {
625     default_runner = [pool](Executor::Args::Closure c) {
626       pool->Schedule(std::move(c));
627     };
628   }
629 
630   // Start parallel Executors.
631 
632   // We can execute this step synchronously on the calling thread whenever
633   // there is a single device and the timeout mechanism is not used.
634   //
635   // When timeouts are used, we must execute the graph(s) asynchronously, in
636   // order to invoke the cancellation manager on the calling thread if the
637   // timeout expires.
638   const bool can_execute_synchronously =
639       executors_and_keys->items.size() == 1 && call_timeout == 0;
640 
641   Executor::Args args;
642   args.step_id = step_id;
643   args.call_frame = call_frame;
644   args.collective_executor =
645       (run_state.collective_executor ? run_state.collective_executor->get()
646                                      : nullptr);
647   args.session_state = &session_state_;
648   args.session_handle = session_handle_;
649   args.tensor_store = &run_state.tensor_store;
650   args.step_container = &run_state.step_container;
651   args.sync_on_finish = sync_on_finish_;
652   args.user_intra_op_threadpool = threadpool_options.intra_op_threadpool;
653   args.run_all_kernels_inline = pool == nullptr;
654 
655   const bool do_trace = (run_options.trace_level() > RunOptions::NO_TRACE);
656 
657   bool update_cost_model = false;
658   if (options_.config.graph_options().build_cost_model() > 0) {
659     const int64 build_cost_model_every =
660         options_.config.graph_options().build_cost_model();
661     const int64 build_cost_model_after =
662         options_.config.graph_options().build_cost_model_after();
663     int64 measure_step_count = executor_step_count - build_cost_model_after;
664     if (measure_step_count >= 0) {
665       update_cost_model =
666           ((measure_step_count + 1) % build_cost_model_every == 0);
667     }
668   }
669   if (do_trace || update_cost_model ||
670       run_options.report_tensor_allocations_upon_oom()) {
671     run_state.collector.reset(
672         new StepStatsCollector(run_metadata->mutable_step_stats()));
673     args.stats_collector = run_state.collector.get();
674   }
675 
676   std::unique_ptr<ProfilerSession> profiler_session;
677   if (run_options.trace_level() >= RunOptions::HARDWARE_TRACE) {
678     ProfileOptions options = ProfilerSession::DefaultOptions();
679     options.set_host_tracer_level(0);
680     profiler_session = ProfilerSession::Create(options);
681   }
682 
683   // Register this step with session's cancellation manager, so that
684   // `Session::Close()` will cancel the step.
685   CancellationManager step_cancellation_manager(cancellation_manager_);
686   if (step_cancellation_manager.IsCancelled()) {
687     return errors::Cancelled("Run call was cancelled");
688   }
689   args.cancellation_manager = &step_cancellation_manager;
690 
691   Status run_status;
692 
693   auto set_threadpool_args_for_item =
694       [&default_runner, &handler](const PerPartitionExecutorsAndLib& item,
695                                   Executor::Args* args) {
696         // TODO(azaks): support partial run.
697         // TODO(azaks): if the device picks its own threadpool, we need to
698         // assign
699         //     less threads to the main compute pool by default.
700         thread::ThreadPool* device_thread_pool =
701             item.device->tensorflow_device_thread_pool();
702         // TODO(crk): Investigate usage of RunHandlerPool when using device
703         // specific thread pool(s).
704         if (!device_thread_pool) {
705           args->runner = default_runner;
706         } else {
707           args->runner = [device_thread_pool](Executor::Args::Closure c) {
708             device_thread_pool->Schedule(std::move(c));
709           };
710         }
711         if (handler != nullptr) {
712           args->user_intra_op_threadpool =
713               handler->AsIntraThreadPoolInterface();
714         }
715       };
716 
717   if (can_execute_synchronously) {
718     PrivateIntraProcessRendezvous rendezvous(device_mgr_.get());
719     args.rendezvous = &rendezvous;
720 
721     const auto& item = executors_and_keys->items[0];
722     set_threadpool_args_for_item(item, &args);
723     run_status = item.executor->Run(args);
724   } else {
725     core::RefCountPtr<RefCountedIntraProcessRendezvous> rendezvous(
726         new RefCountedIntraProcessRendezvous(device_mgr_.get()));
727     args.rendezvous = rendezvous.get();
728 
729     // `barrier` will delete itself after the final executor finishes.
730     Notification executors_done;
731     ExecutorBarrier* barrier =
732         new ExecutorBarrier(num_executors, rendezvous.get(),
733                             [&run_state, &executors_done](const Status& ret) {
734                               {
735                                 mutex_lock l(run_state.mu);
736                                 run_state.status.Update(ret);
737                               }
738                               executors_done.Notify();
739                             });
740 
741     for (const auto& item : executors_and_keys->items) {
742       set_threadpool_args_for_item(item, &args);
743       item.executor->RunAsync(args, barrier->Get());
744     }
745 
746     WaitForNotification(&executors_done, &run_state, &step_cancellation_manager,
747                         call_timeout);
748     {
749       tf_shared_lock l(run_state.mu);
750       run_status = run_state.status;
751     }
752   }
753 
754   if (step_cancellation_manager.IsCancelled()) {
755     run_status.Update(errors::Cancelled("Run call was cancelled"));
756   }
757 
758   if (profiler_session) {
759     TF_RETURN_IF_ERROR(profiler_session->CollectData(run_metadata));
760   }
761 
762   TF_RETURN_IF_ERROR(run_status);
763 
764   // Save the output tensors of this run we choose to keep.
765   if (!run_state.tensor_store.empty()) {
766     TF_RETURN_IF_ERROR(run_state.tensor_store.SaveTensors(
767         {executors_and_keys->callable_options.fetch().begin(),
768          executors_and_keys->callable_options.fetch().end()},
769         &session_state_));
770   }
771 
772   if (run_state.collector) {
773     run_state.collector->Finalize();
774   }
775 
776   // Build and return the cost model as instructed.
777   if (update_cost_model) {
778     // Build the cost model
779     std::unordered_map<string, const Graph*> device_to_graph;
780     for (const PerPartitionExecutorsAndLib& partition :
781          executors_and_keys->items) {
782       const Graph* graph = partition.graph.get();
783       const string& device = partition.flib->device()->name();
784       device_to_graph[device] = graph;
785     }
786 
787     mutex_lock l(executor_lock_);
788     run_state.collector->BuildCostModel(&cost_model_manager_, device_to_graph);
789 
790     // annotate stats onto cost graph.
791     CostGraphDef* cost_graph = run_metadata->mutable_cost_graph();
792     for (const auto& item : executors_and_keys->items) {
793       TF_RETURN_IF_ERROR(
794           cost_model_manager_.AddToCostGraphDef(item.graph.get(), cost_graph));
795     }
796   }
797 
798   // If requested via RunOptions, output the partition graphs.
799   if (run_options.output_partition_graphs()) {
800     if (options_.config.experimental().disable_output_partition_graphs()) {
801       return errors::InvalidArgument(
802           "RunOptions.output_partition_graphs() is not supported when "
803           "disable_output_partition_graphs is true.");
804     } else {
805       protobuf::RepeatedPtrField<GraphDef>* partition_graph_defs =
806           run_metadata->mutable_partition_graphs();
807       for (const PerPartitionExecutorsAndLib& exec_and_lib :
808            executors_and_keys->items) {
809         GraphDef* partition_graph_def = partition_graph_defs->Add();
810         exec_and_lib.graph->ToGraphDef(partition_graph_def);
811       }
812     }
813   }
814   metrics::UpdateGraphExecTime(options_.env->NowMicros() - start_time_usecs);
815 
816   return Status::OK();
817 }
818 
Run(const RunOptions & run_options,const NamedTensorList & inputs,const std::vector<string> & output_names,const std::vector<string> & target_nodes,std::vector<Tensor> * outputs,RunMetadata * run_metadata)819 Status DirectSession::Run(const RunOptions& run_options,
820                           const NamedTensorList& inputs,
821                           const std::vector<string>& output_names,
822                           const std::vector<string>& target_nodes,
823                           std::vector<Tensor>* outputs,
824                           RunMetadata* run_metadata) {
825   return Run(run_options, inputs, output_names, target_nodes, outputs,
826              run_metadata, thread::ThreadPoolOptions());
827 }
828 
Run(const RunOptions & run_options,const NamedTensorList & inputs,const std::vector<string> & output_names,const std::vector<string> & target_nodes,std::vector<Tensor> * outputs,RunMetadata * run_metadata,const thread::ThreadPoolOptions & threadpool_options)829 Status DirectSession::Run(const RunOptions& run_options,
830                           const NamedTensorList& inputs,
831                           const std::vector<string>& output_names,
832                           const std::vector<string>& target_nodes,
833                           std::vector<Tensor>* outputs,
834                           RunMetadata* run_metadata,
835                           const thread::ThreadPoolOptions& threadpool_options) {
836   TF_RETURN_IF_ERROR(CheckNotClosed());
837   TF_RETURN_IF_ERROR(CheckGraphCreated("Run()"));
838   direct_session_runs->GetCell()->IncrementBy(1);
839 
840   // Extract the inputs names for this run of the session.
841   std::vector<string> input_tensor_names;
842   input_tensor_names.reserve(inputs.size());
843   size_t input_size = 0;
844   for (const auto& it : inputs) {
845     input_tensor_names.push_back(it.first);
846     input_size += it.second.AllocatedBytes();
847   }
848   metrics::RecordGraphInputTensors(input_size);
849 
850   // Check if we already have an executor for these arguments.
851   ExecutorsAndKeys* executors_and_keys;
852   RunStateArgs run_state_args(run_options.debug_options());
853   run_state_args.collective_graph_key =
854       run_options.experimental().collective_graph_key();
855 
856   TF_RETURN_IF_ERROR(GetOrCreateExecutors(input_tensor_names, output_names,
857                                           target_nodes, &executors_and_keys,
858                                           &run_state_args));
859   {
860     mutex_lock l(collective_graph_key_lock_);
861     collective_graph_key_ = executors_and_keys->collective_graph_key;
862   }
863 
864   // Configure a call frame for the step, which we use to feed and
865   // fetch values to and from the executors.
866   FunctionCallFrame call_frame(executors_and_keys->input_types,
867                                executors_and_keys->output_types);
868   gtl::InlinedVector<Tensor, 4> feed_args(inputs.size());
869   for (const auto& it : inputs) {
870     if (it.second.dtype() == DT_RESOURCE) {
871       Tensor tensor_from_handle;
872       TF_RETURN_IF_ERROR(
873           ResourceHandleToInputTensor(it.second, &tensor_from_handle));
874       feed_args[executors_and_keys->input_name_to_index[it.first]] =
875           tensor_from_handle;
876     } else {
877       feed_args[executors_and_keys->input_name_to_index[it.first]] = it.second;
878     }
879   }
880   const Status s = call_frame.SetArgs(feed_args);
881   if (errors::IsInternal(s)) {
882     return errors::InvalidArgument(s.error_message());
883   } else if (!s.ok()) {
884     return s;
885   }
886 
887   const int64 step_id = step_id_counter_.fetch_add(1);
888 
889   if (LogMemory::IsEnabled()) {
890     LogMemory::RecordStep(step_id, run_state_args.handle);
891   }
892 
893   TF_RETURN_IF_ERROR(RunInternal(step_id, run_options, &call_frame,
894                                  executors_and_keys, run_metadata,
895                                  threadpool_options));
896 
897   // Receive outputs.
898   if (outputs) {
899     std::vector<Tensor> sorted_outputs;
900     const Status s = call_frame.ConsumeRetvals(
901         &sorted_outputs, /* allow_dead_tensors = */ false);
902     if (errors::IsInternal(s)) {
903       return errors::InvalidArgument(s.error_message());
904     } else if (!s.ok()) {
905       return s;
906     }
907     const bool unique_outputs =
908         output_names.size() == executors_and_keys->output_name_to_index.size();
909     // first_indices[i] = j implies that j is the smallest value for which
910     // output_names[i] == output_names[j].
911     std::vector<int> first_indices;
912     if (!unique_outputs) {
913       first_indices.reserve(output_names.size());
914       for (const auto& name : output_names) {
915         first_indices.push_back(
916             std::find(output_names.begin(), output_names.end(), name) -
917             output_names.begin());
918       }
919     }
920     outputs->clear();
921     size_t output_size = 0;
922     outputs->reserve(sorted_outputs.size());
923     for (int i = 0; i < output_names.size(); ++i) {
924       const string& output_name = output_names[i];
925       if (first_indices.empty() || first_indices[i] == i) {
926         outputs->emplace_back(
927             std::move(sorted_outputs[executors_and_keys
928                                          ->output_name_to_index[output_name]]));
929       } else {
930         outputs->push_back((*outputs)[first_indices[i]]);
931       }
932       output_size += outputs->back().AllocatedBytes();
933     }
934     metrics::RecordGraphOutputTensors(output_size);
935   }
936 
937   return Status::OK();
938 }
939 
PRunSetup(const std::vector<string> & input_names,const std::vector<string> & output_names,const std::vector<string> & target_nodes,string * handle)940 Status DirectSession::PRunSetup(const std::vector<string>& input_names,
941                                 const std::vector<string>& output_names,
942                                 const std::vector<string>& target_nodes,
943                                 string* handle) {
944   TF_RETURN_IF_ERROR(CheckNotClosed());
945   TF_RETURN_IF_ERROR(CheckGraphCreated("PRunSetup()"));
946 
947   // RunOptions is not available in PRunSetup, so use thread pool 0.
948   thread::ThreadPool* pool = thread_pools_[0].first;
949 
950   // Check if we already have an executor for these arguments.
951   ExecutorsAndKeys* executors_and_keys;
952   // TODO(cais): TFDBG support for partial runs.
953   DebugOptions debug_options;
954   RunStateArgs run_state_args(debug_options);
955   run_state_args.is_partial_run = true;
956   TF_RETURN_IF_ERROR(GetOrCreateExecutors(input_names, output_names,
957                                           target_nodes, &executors_and_keys,
958                                           &run_state_args));
959 
960   // Create the run state and save it for future PRun calls.
961   Executor::Args args;
962   args.step_id = step_id_counter_.fetch_add(1);
963   PartialRunState* run_state =
964       new PartialRunState(input_names, output_names, args.step_id, &devices_);
965   run_state->rendez.reset(new IntraProcessRendezvous(device_mgr_.get()));
966   {
967     mutex_lock l(executor_lock_);
968     if (!partial_runs_
969              .emplace(run_state_args.handle,
970                       std::unique_ptr<PartialRunState>(run_state))
971              .second) {
972       return errors::Internal("The handle '", run_state_args.handle,
973                               "' created for this partial run is not unique.");
974     }
975   }
976 
977   // Start parallel Executors.
978   const size_t num_executors = executors_and_keys->items.size();
979   ExecutorBarrier* barrier = new ExecutorBarrier(
980       num_executors, run_state->rendez.get(), [run_state](const Status& ret) {
981         if (!ret.ok()) {
982           mutex_lock l(run_state->mu);
983           run_state->status.Update(ret);
984         }
985         run_state->executors_done.Notify();
986       });
987 
988   args.rendezvous = run_state->rendez.get();
989   args.cancellation_manager = cancellation_manager_;
990   // Note that Collectives are not supported in partial runs
991   // because RunOptions is not passed in so we can't know whether
992   // their use is intended.
993   args.collective_executor = nullptr;
994   args.runner = [this, pool](Executor::Args::Closure c) {
995     pool->Schedule(std::move(c));
996   };
997   args.session_state = &session_state_;
998   args.session_handle = session_handle_;
999   args.tensor_store = &run_state->tensor_store;
1000   args.step_container = &run_state->step_container;
1001   if (LogMemory::IsEnabled()) {
1002     LogMemory::RecordStep(args.step_id, run_state_args.handle);
1003   }
1004   args.sync_on_finish = sync_on_finish_;
1005 
1006   if (options_.config.graph_options().build_cost_model()) {
1007     run_state->collector.reset(new StepStatsCollector(nullptr));
1008     args.stats_collector = run_state->collector.get();
1009   }
1010 
1011   for (auto& item : executors_and_keys->items) {
1012     item.executor->RunAsync(args, barrier->Get());
1013   }
1014 
1015   *handle = run_state_args.handle;
1016   return Status::OK();
1017 }
1018 
PRun(const string & handle,const NamedTensorList & inputs,const std::vector<string> & output_names,std::vector<Tensor> * outputs)1019 Status DirectSession::PRun(const string& handle, const NamedTensorList& inputs,
1020                            const std::vector<string>& output_names,
1021                            std::vector<Tensor>* outputs) {
1022   TF_RETURN_IF_ERROR(CheckNotClosed());
1023   std::vector<string> parts = str_util::Split(handle, ';');
1024   const string& key = parts[0];
1025   // Get the executors for this partial run.
1026   ExecutorsAndKeys* executors_and_keys;
1027   PartialRunState* run_state;
1028   {
1029     mutex_lock l(executor_lock_);  // could use reader lock
1030     auto exc_it = executors_.find(key);
1031     if (exc_it == executors_.end()) {
1032       return errors::InvalidArgument(
1033           "Must run 'setup' before performing partial runs!");
1034     }
1035     executors_and_keys = exc_it->second.get();
1036 
1037     auto prun_it = partial_runs_.find(handle);
1038     if (prun_it == partial_runs_.end()) {
1039       return errors::InvalidArgument(
1040           "Must run 'setup' before performing partial runs!");
1041     }
1042     run_state = prun_it->second.get();
1043 
1044     // Make sure that this is a new set of feeds that are still pending.
1045     for (const auto& input : inputs) {
1046       auto it = run_state->pending_inputs.find(input.first);
1047       if (it == run_state->pending_inputs.end()) {
1048         return errors::InvalidArgument(
1049             "The feed ", input.first,
1050             " was not specified in partial_run_setup.");
1051       } else if (it->second) {
1052         return errors::InvalidArgument("The feed ", input.first,
1053                                        " has already been fed.");
1054       }
1055     }
1056     // Check that this is a new set of fetches that are still pending.
1057     for (const auto& output : output_names) {
1058       auto it = run_state->pending_outputs.find(output);
1059       if (it == run_state->pending_outputs.end()) {
1060         return errors::InvalidArgument(
1061             "The fetch ", output, " was not specified in partial_run_setup.");
1062       } else if (it->second) {
1063         return errors::InvalidArgument("The fetch ", output,
1064                                        " has already been fetched.");
1065       }
1066     }
1067   }
1068 
1069   // Check that this new set of fetches can be computed from all the
1070   // feeds we have supplied.
1071   TF_RETURN_IF_ERROR(
1072       CheckFetch(inputs, output_names, executors_and_keys, run_state));
1073 
1074   // Send inputs.
1075   Status s =
1076       SendPRunInputs(inputs, executors_and_keys, run_state->rendez.get());
1077 
1078   // Receive outputs.
1079   if (s.ok()) {
1080     s = RecvPRunOutputs(output_names, executors_and_keys, run_state, outputs);
1081   }
1082 
1083   // Save the output tensors of this run we choose to keep.
1084   if (s.ok()) {
1085     s = run_state->tensor_store.SaveTensors(output_names, &session_state_);
1086   }
1087 
1088   {
1089     mutex_lock l(executor_lock_);
1090     // Delete the run state if there is an error or all fetches are done.
1091     bool done = true;
1092     if (s.ok()) {
1093       {
1094         mutex_lock l(run_state->mu);
1095         if (!run_state->status.ok()) {
1096           LOG(WARNING) << "An error unrelated to this prun has been detected. "
1097                        << run_state->status;
1098         }
1099       }
1100       for (const auto& input : inputs) {
1101         auto it = run_state->pending_inputs.find(input.first);
1102         it->second = true;
1103       }
1104       for (const auto& name : output_names) {
1105         auto it = run_state->pending_outputs.find(name);
1106         it->second = true;
1107       }
1108       done = run_state->PendingDone();
1109     }
1110     if (done) {
1111       WaitForNotification(&run_state->executors_done, run_state,
1112                           cancellation_manager_, operation_timeout_in_ms_);
1113       partial_runs_.erase(handle);
1114     }
1115   }
1116 
1117   return s;
1118 }
1119 
ResourceHandleToInputTensor(const Tensor & resource_tensor,Tensor * retrieved_tensor)1120 Status DirectSession::ResourceHandleToInputTensor(const Tensor& resource_tensor,
1121                                                   Tensor* retrieved_tensor) {
1122   if (resource_tensor.dtype() != DT_RESOURCE) {
1123     return errors::InvalidArgument(strings::StrCat(
1124         "ResourceHandleToInputTensor() received non-DT_RESOURCE Tensor: ",
1125         resource_tensor.dtype()));
1126   }
1127 
1128   const ResourceHandle& resource_handle =
1129       resource_tensor.scalar<ResourceHandle>()();
1130 
1131   if (resource_handle.container() ==
1132       SessionState::kTensorHandleResourceTypeName) {
1133     return session_state_.GetTensor(resource_handle.name(), retrieved_tensor);
1134   } else {
1135     return errors::InvalidArgument(strings::StrCat(
1136         "Invalid resource type hash code: ", resource_handle.hash_code(),
1137         "(name: ", resource_handle.name(),
1138         " type: ", resource_handle.maybe_type_name(),
1139         "). Perhaps a resource tensor was being provided as a feed? That is "
1140         "not currently allowed. Please file an issue at "
1141         "https://github.com/tensorflow/tensorflow/issues/new, ideally with a "
1142         "short code snippet that leads to this error message."));
1143   }
1144 }
1145 
SendPRunInputs(const NamedTensorList & inputs,const ExecutorsAndKeys * executors_and_keys,IntraProcessRendezvous * rendez)1146 Status DirectSession::SendPRunInputs(const NamedTensorList& inputs,
1147                                      const ExecutorsAndKeys* executors_and_keys,
1148                                      IntraProcessRendezvous* rendez) {
1149   Status s;
1150   Rendezvous::ParsedKey parsed;
1151   // Insert the input tensors into the local rendezvous by their
1152   // rendezvous key.
1153   for (const auto& input : inputs) {
1154     auto it =
1155         executors_and_keys->input_name_to_rendezvous_key.find(input.first);
1156     if (it == executors_and_keys->input_name_to_rendezvous_key.end()) {
1157       return errors::Internal("'", input.first, "' is not a pre-defined feed.");
1158     }
1159     const string& input_key = it->second;
1160 
1161     s = Rendezvous::ParseKey(input_key, &parsed);
1162     if (!s.ok()) {
1163       rendez->StartAbort(s);
1164       return s;
1165     }
1166 
1167     if (input.second.dtype() == DT_RESOURCE) {
1168       Tensor tensor_from_handle;
1169       s = ResourceHandleToInputTensor(input.second, &tensor_from_handle);
1170       if (s.ok()) {
1171         s = rendez->Send(parsed, Rendezvous::Args(), tensor_from_handle, false);
1172       }
1173     } else {
1174       s = rendez->Send(parsed, Rendezvous::Args(), input.second, false);
1175     }
1176 
1177     if (!s.ok()) {
1178       rendez->StartAbort(s);
1179       return s;
1180     }
1181   }
1182   return Status::OK();
1183 }
1184 
RecvPRunOutputs(const std::vector<string> & output_names,const ExecutorsAndKeys * executors_and_keys,PartialRunState * run_state,std::vector<Tensor> * outputs)1185 Status DirectSession::RecvPRunOutputs(
1186     const std::vector<string>& output_names,
1187     const ExecutorsAndKeys* executors_and_keys, PartialRunState* run_state,
1188     std::vector<Tensor>* outputs) {
1189   Status s;
1190   if (!output_names.empty()) {
1191     outputs->resize(output_names.size());
1192   }
1193 
1194   Rendezvous::ParsedKey parsed;
1195   // Get the outputs from the rendezvous
1196   for (size_t output_offset = 0; output_offset < output_names.size();
1197        ++output_offset) {
1198     const string& output_name = output_names[output_offset];
1199     auto it =
1200         executors_and_keys->output_name_to_rendezvous_key.find(output_name);
1201     if (it == executors_and_keys->output_name_to_rendezvous_key.end()) {
1202       return errors::Internal("'", output_name,
1203                               "' is not a pre-defined fetch.");
1204     }
1205     const string& output_key = it->second;
1206     Tensor output_tensor;
1207     bool is_dead;
1208 
1209     s = Rendezvous::ParseKey(output_key, &parsed);
1210     if (s.ok()) {
1211       // Fetch data from the Rendezvous.
1212       s = run_state->rendez->Recv(parsed, Rendezvous::Args(), &output_tensor,
1213                                   &is_dead, operation_timeout_in_ms_);
1214       if (is_dead && s.ok()) {
1215         s = errors::InvalidArgument("The tensor returned for ", output_name,
1216                                     " was not valid.");
1217       }
1218     }
1219     if (!s.ok()) {
1220       run_state->rendez->StartAbort(s);
1221       outputs->clear();
1222       return s;
1223     }
1224 
1225     (*outputs)[output_offset] = output_tensor;
1226   }
1227   return Status::OK();
1228 }
1229 
CheckFetch(const NamedTensorList & feeds,const std::vector<string> & fetches,const ExecutorsAndKeys * executors_and_keys,const PartialRunState * run_state)1230 Status DirectSession::CheckFetch(const NamedTensorList& feeds,
1231                                  const std::vector<string>& fetches,
1232                                  const ExecutorsAndKeys* executors_and_keys,
1233                                  const PartialRunState* run_state) {
1234   const Graph* graph = executors_and_keys->graph.get();
1235   const NameNodeMap* name_to_node = &executors_and_keys->name_to_node;
1236 
1237   // Build the set of pending feeds that we haven't seen.
1238   std::unordered_set<TensorId, TensorId::Hasher> pending_feeds;
1239   {
1240     mutex_lock l(executor_lock_);
1241     for (const auto& input : run_state->pending_inputs) {
1242       // Skip if the feed has already been fed.
1243       if (input.second) continue;
1244       TensorId id(ParseTensorName(input.first));
1245       auto it = name_to_node->find(id.first);
1246       if (it == name_to_node->end()) {
1247         return errors::NotFound("Feed ", input.first, ": not found");
1248       }
1249       pending_feeds.insert(id);
1250     }
1251   }
1252   for (const auto& it : feeds) {
1253     TensorId id(ParseTensorName(it.first));
1254     pending_feeds.erase(id);
1255   }
1256 
1257   // Initialize the stack with the fetch nodes.
1258   std::vector<const Node*> stack;
1259   for (const string& fetch : fetches) {
1260     TensorId id(ParseTensorName(fetch));
1261     auto it = name_to_node->find(id.first);
1262     if (it == name_to_node->end()) {
1263       return errors::NotFound("Fetch ", fetch, ": not found");
1264     }
1265     stack.push_back(it->second);
1266   }
1267 
1268   // Any tensor needed for fetches can't be in pending_feeds.
1269   std::vector<bool> visited(graph->num_node_ids(), false);
1270   while (!stack.empty()) {
1271     const Node* n = stack.back();
1272     stack.pop_back();
1273 
1274     for (const Edge* in_edge : n->in_edges()) {
1275       const Node* in_node = in_edge->src();
1276       if (pending_feeds.count({in_node->name(), in_edge->src_output()}) > 0) {
1277         return errors::InvalidArgument("Fetch ", in_node->name(), ":",
1278                                        in_edge->src_output(),
1279                                        " can't be computed from the feeds"
1280                                        " that have been fed so far.");
1281       }
1282       if (!visited[in_node->id()]) {
1283         visited[in_node->id()] = true;
1284         stack.push_back(in_node);
1285       }
1286     }
1287   }
1288   return Status::OK();
1289 }
1290 
CreateExecutors(const CallableOptions & callable_options,std::unique_ptr<ExecutorsAndKeys> * out_executors_and_keys,std::unique_ptr<FunctionInfo> * out_func_info,RunStateArgs * run_state_args)1291 Status DirectSession::CreateExecutors(
1292     const CallableOptions& callable_options,
1293     std::unique_ptr<ExecutorsAndKeys>* out_executors_and_keys,
1294     std::unique_ptr<FunctionInfo>* out_func_info,
1295     RunStateArgs* run_state_args) {
1296   BuildGraphOptions options;
1297   options.callable_options = callable_options;
1298   options.use_function_convention = !run_state_args->is_partial_run;
1299   options.collective_graph_key =
1300       callable_options.run_options().experimental().collective_graph_key();
1301   if (options_.config.experimental()
1302           .collective_deterministic_sequential_execution()) {
1303     options.collective_order = GraphCollectiveOrder::kEdges;
1304   } else if (options_.config.experimental().collective_nccl()) {
1305     options.collective_order = GraphCollectiveOrder::kAttrs;
1306   }
1307 
1308   std::unique_ptr<FunctionInfo> func_info(new FunctionInfo);
1309   std::unique_ptr<ExecutorsAndKeys> ek(new ExecutorsAndKeys);
1310 
1311   ek->callable_options = callable_options;
1312 
1313   std::unordered_map<string, std::unique_ptr<Graph>> graphs;
1314   TF_RETURN_IF_ERROR(CreateGraphs(
1315       options, &graphs, &func_info->flib_def, run_state_args, &ek->input_types,
1316       &ek->output_types, &ek->collective_graph_key));
1317 
1318   if (run_state_args->is_partial_run) {
1319     ek->graph = std::move(run_state_args->graph);
1320     std::unordered_set<StringPiece, StringPieceHasher> names;
1321     for (const string& input : callable_options.feed()) {
1322       TensorId id(ParseTensorName(input));
1323       names.emplace(id.first);
1324     }
1325     for (const string& output : callable_options.fetch()) {
1326       TensorId id(ParseTensorName(output));
1327       names.emplace(id.first);
1328     }
1329     for (Node* n : ek->graph->nodes()) {
1330       if (names.count(n->name()) > 0) {
1331         ek->name_to_node.insert({n->name(), n});
1332       }
1333     }
1334   }
1335   ek->items.reserve(graphs.size());
1336   const auto& optimizer_opts =
1337       options_.config.graph_options().optimizer_options();
1338 
1339   int graph_def_version = graphs.begin()->second->versions().producer();
1340 
1341   const auto* session_metadata =
1342       options_.config.experimental().has_session_metadata()
1343           ? &options_.config.experimental().session_metadata()
1344           : nullptr;
1345   func_info->proc_flr.reset(new ProcessFunctionLibraryRuntime(
1346       device_mgr_.get(), options_.env, &options_.config, graph_def_version,
1347       func_info->flib_def.get(), optimizer_opts, thread_pools_[0].first,
1348       /*parent=*/nullptr, session_metadata,
1349       Rendezvous::Factory{
1350           [](const int64, const DeviceMgr* device_mgr, Rendezvous** r) {
1351             *r = new IntraProcessRendezvous(device_mgr);
1352             return Status::OK();
1353           }}));
1354 
1355   GraphOptimizer optimizer(optimizer_opts);
1356   for (auto iter = graphs.begin(); iter != graphs.end(); ++iter) {
1357     const string& partition_name = iter->first;
1358     std::unique_ptr<Graph>& partition_graph = iter->second;
1359 
1360     Device* device;
1361     TF_RETURN_IF_ERROR(device_mgr_->LookupDevice(partition_name, &device));
1362 
1363     ek->items.resize(ek->items.size() + 1);
1364     auto* item = &(ek->items.back());
1365     auto lib = func_info->proc_flr->GetFLR(partition_name);
1366     if (lib == nullptr) {
1367       return errors::Internal("Could not find device: ", partition_name);
1368     }
1369     item->flib = lib;
1370 
1371     LocalExecutorParams params;
1372     params.device = device;
1373     params.session_metadata = session_metadata;
1374     params.function_library = lib;
1375     auto opseg = device->op_segment();
1376     params.create_kernel =
1377         [this, lib, opseg](const std::shared_ptr<const NodeProperties>& props,
1378                            OpKernel** kernel) {
1379           // NOTE(mrry): We must not share function kernels (implemented
1380           // using `CallOp`) between subgraphs, because `CallOp::handle_`
1381           // is tied to a particular subgraph. Even if the function itself
1382           // is stateful, the `CallOp` that invokes it is not.
1383           if (!OpSegment::ShouldOwnKernel(lib, props->node_def.op())) {
1384             return lib->CreateKernel(props, kernel);
1385           }
1386           auto create_fn = [lib, &props](OpKernel** kernel) {
1387             return lib->CreateKernel(props, kernel);
1388           };
1389           // Kernels created for subgraph nodes need to be cached.  On
1390           // cache miss, create_fn() is invoked to create a kernel based
1391           // on the function library here + global op registry.
1392           return opseg->FindOrCreate(session_handle_, props->node_def.name(),
1393                                      kernel, create_fn);
1394         };
1395     params.delete_kernel = [lib](OpKernel* kernel) {
1396       if (kernel && !OpSegment::ShouldOwnKernel(lib, kernel->type_string()))
1397         delete kernel;
1398     };
1399 
1400     optimizer.Optimize(lib, options_.env, device, &partition_graph,
1401                        /*shape_map=*/nullptr);
1402 
1403     // TensorFlow Debugger (tfdbg) inserts debug nodes in the graph.
1404     const DebugOptions& debug_options =
1405         options.callable_options.run_options().debug_options();
1406     if (!debug_options.debug_tensor_watch_opts().empty()) {
1407       TF_RETURN_IF_ERROR(DecorateAndPublishGraphForDebug(
1408           debug_options, partition_graph.get(), params.device));
1409     }
1410 
1411     TF_RETURN_IF_ERROR(EnsureMemoryTypes(DeviceType(device->device_type()),
1412                                          device->name(),
1413                                          partition_graph.get()));
1414 
1415     item->executor = nullptr;
1416     item->device = device;
1417     auto executor_type = options_.config.experimental().executor_type();
1418     TF_RETURN_IF_ERROR(
1419         NewExecutor(executor_type, params, *partition_graph, &item->executor));
1420     if (!options_.config.experimental().disable_output_partition_graphs() ||
1421         options_.config.graph_options().build_cost_model() > 0) {
1422       item->graph = std::move(partition_graph);
1423     }
1424   }
1425 
1426   // Cache the mapping from input/output names to graph elements to
1427   // avoid recomputing it every time.
1428   if (!run_state_args->is_partial_run) {
1429     // For regular `Run()`, we use the function calling convention, and so
1430     // maintain a mapping from input/output names to
1431     // argument/return-value ordinal index.
1432     for (int i = 0; i < callable_options.feed().size(); ++i) {
1433       const string& input = callable_options.feed(i);
1434       ek->input_name_to_index[input] = i;
1435     }
1436     for (int i = 0; i < callable_options.fetch().size(); ++i) {
1437       const string& output = callable_options.fetch(i);
1438       ek->output_name_to_index[output] = i;
1439     }
1440   } else {
1441     // For `PRun()`, we use the rendezvous calling convention, and so
1442     // maintain a mapping from input/output names to rendezvous keys.
1443     //
1444     // We always use the first device as the device name portion of the
1445     // key, even if we're feeding another graph.
1446     for (int i = 0; i < callable_options.feed().size(); ++i) {
1447       const string& input = callable_options.feed(i);
1448       ek->input_name_to_rendezvous_key[input] = GetRendezvousKey(
1449           input, device_set_.client_device()->attributes(), FrameAndIter(0, 0));
1450     }
1451     for (int i = 0; i < callable_options.fetch().size(); ++i) {
1452       const string& output = callable_options.fetch(i);
1453       ek->output_name_to_rendezvous_key[output] =
1454           GetRendezvousKey(output, device_set_.client_device()->attributes(),
1455                            FrameAndIter(0, 0));
1456     }
1457   }
1458 
1459   *out_executors_and_keys = std::move(ek);
1460   *out_func_info = std::move(func_info);
1461   return Status::OK();
1462 }
1463 
GetOrCreateExecutors(gtl::ArraySlice<string> inputs,gtl::ArraySlice<string> outputs,gtl::ArraySlice<string> target_nodes,ExecutorsAndKeys ** executors_and_keys,RunStateArgs * run_state_args)1464 Status DirectSession::GetOrCreateExecutors(
1465     gtl::ArraySlice<string> inputs, gtl::ArraySlice<string> outputs,
1466     gtl::ArraySlice<string> target_nodes, ExecutorsAndKeys** executors_and_keys,
1467     RunStateArgs* run_state_args) {
1468   int64 handle_name_counter_value = -1;
1469   if (LogMemory::IsEnabled() || run_state_args->is_partial_run) {
1470     handle_name_counter_value = handle_name_counter_.fetch_add(1);
1471   }
1472 
1473   string debug_tensor_watches_summary;
1474   if (!run_state_args->debug_options.debug_tensor_watch_opts().empty()) {
1475     debug_tensor_watches_summary = SummarizeDebugTensorWatches(
1476         run_state_args->debug_options.debug_tensor_watch_opts());
1477   }
1478 
1479   // Fast lookup path, no sorting.
1480   const string key = strings::StrCat(
1481       absl::StrJoin(inputs, ","), "->", absl::StrJoin(outputs, ","), "/",
1482       absl::StrJoin(target_nodes, ","), "/", run_state_args->is_partial_run,
1483       "/", debug_tensor_watches_summary);
1484   // Set the handle, if it's needed to log memory or for partial run.
1485   if (handle_name_counter_value >= 0) {
1486     run_state_args->handle =
1487         strings::StrCat(key, ";", handle_name_counter_value);
1488   }
1489 
1490   // See if we already have the executors for this run.
1491   {
1492     mutex_lock l(executor_lock_);  // could use reader lock
1493     auto it = executors_.find(key);
1494     if (it != executors_.end()) {
1495       *executors_and_keys = it->second.get();
1496       return Status::OK();
1497     }
1498   }
1499 
1500   // Slow lookup path, the unsorted key missed the cache.
1501   // Sort the inputs and outputs, and look up with the sorted key in case an
1502   // earlier call used a different order of inputs and outputs.
1503   //
1504   // We could consider some other signature instead of sorting that
1505   // preserves the same property to avoid the sort in the future.
1506   std::vector<string> inputs_sorted(inputs.begin(), inputs.end());
1507   std::sort(inputs_sorted.begin(), inputs_sorted.end());
1508   std::vector<string> outputs_sorted(outputs.begin(), outputs.end());
1509   std::sort(outputs_sorted.begin(), outputs_sorted.end());
1510   std::vector<string> tn_sorted(target_nodes.begin(), target_nodes.end());
1511   std::sort(tn_sorted.begin(), tn_sorted.end());
1512 
1513   const string sorted_key = strings::StrCat(
1514       absl::StrJoin(inputs_sorted, ","), "->",
1515       absl::StrJoin(outputs_sorted, ","), "/", absl::StrJoin(tn_sorted, ","),
1516       "/", run_state_args->is_partial_run, "/", debug_tensor_watches_summary);
1517   // Set the handle, if its needed to log memory or for partial run.
1518   if (handle_name_counter_value >= 0) {
1519     run_state_args->handle =
1520         strings::StrCat(sorted_key, ";", handle_name_counter_value);
1521   }
1522 
1523   // See if we already have the executors for this run.
1524   {
1525     mutex_lock l(executor_lock_);
1526     auto it = executors_.find(sorted_key);
1527     if (it != executors_.end()) {
1528       *executors_and_keys = it->second.get();
1529       return Status::OK();
1530     }
1531   }
1532 
1533   // Nothing found, so create the executors and store in the cache.
1534   // The executor_lock_ is intentionally released while executors are
1535   // being created.
1536   CallableOptions callable_options;
1537   callable_options.mutable_feed()->Reserve(inputs_sorted.size());
1538   for (const string& input : inputs_sorted) {
1539     callable_options.add_feed(input);
1540   }
1541   callable_options.mutable_fetch()->Reserve(outputs_sorted.size());
1542   for (const string& output : outputs_sorted) {
1543     callable_options.add_fetch(output);
1544   }
1545   callable_options.mutable_target()->Reserve(tn_sorted.size());
1546   for (const string& target : tn_sorted) {
1547     callable_options.add_target(target);
1548   }
1549   *callable_options.mutable_run_options()->mutable_debug_options() =
1550       run_state_args->debug_options;
1551   callable_options.mutable_run_options()
1552       ->mutable_experimental()
1553       ->set_collective_graph_key(run_state_args->collective_graph_key);
1554   std::unique_ptr<ExecutorsAndKeys> ek;
1555   std::unique_ptr<FunctionInfo> func_info;
1556   TF_RETURN_IF_ERROR(
1557       CreateExecutors(callable_options, &ek, &func_info, run_state_args));
1558 
1559   // Reacquire the lock, try to insert into the map.
1560   mutex_lock l(executor_lock_);
1561 
1562   // Another thread may have created the entry before us, in which case we will
1563   // reuse the already created one.
1564   auto insert_result = executors_.emplace(
1565       sorted_key, std::shared_ptr<ExecutorsAndKeys>(std::move(ek)));
1566   if (insert_result.second) {
1567     functions_.push_back(std::move(func_info));
1568   }
1569 
1570   // Insert the value under the original key, so the fast path lookup will work
1571   // if the user uses the same order of inputs, outputs, and targets again.
1572   executors_.emplace(key, insert_result.first->second);
1573   *executors_and_keys = insert_result.first->second.get();
1574 
1575   return Status::OK();
1576 }
1577 
CreateGraphs(const BuildGraphOptions & subgraph_options,std::unordered_map<string,std::unique_ptr<Graph>> * outputs,std::unique_ptr<FunctionLibraryDefinition> * flib_def,RunStateArgs * run_state_args,DataTypeVector * input_types,DataTypeVector * output_types,int64 * collective_graph_key)1578 Status DirectSession::CreateGraphs(
1579     const BuildGraphOptions& subgraph_options,
1580     std::unordered_map<string, std::unique_ptr<Graph>>* outputs,
1581     std::unique_ptr<FunctionLibraryDefinition>* flib_def,
1582     RunStateArgs* run_state_args, DataTypeVector* input_types,
1583     DataTypeVector* output_types, int64* collective_graph_key) {
1584   mutex_lock l(graph_state_lock_);
1585   if (finalized_) {
1586     return errors::FailedPrecondition("Session has been finalized.");
1587   }
1588 
1589   std::unique_ptr<ClientGraph> client_graph;
1590 
1591   std::unique_ptr<GraphExecutionState> temp_exec_state_holder;
1592   GraphExecutionState* execution_state = nullptr;
1593   if (options_.config.graph_options().place_pruned_graph()) {
1594     // Because we are placing pruned graphs, we need to create a
1595     // new GraphExecutionState for every new unseen graph,
1596     // and then place it.
1597     GraphExecutionStateOptions prune_options;
1598     prune_options.device_set = &device_set_;
1599     prune_options.session_options = &options_;
1600     prune_options.stateful_placements = stateful_placements_;
1601     prune_options.session_handle = session_handle_;
1602     TF_RETURN_IF_ERROR(GraphExecutionState::MakeForPrunedGraph(
1603         *execution_state_, prune_options, subgraph_options,
1604         &temp_exec_state_holder, &client_graph));
1605     execution_state = temp_exec_state_holder.get();
1606   } else {
1607     execution_state = execution_state_.get();
1608     TF_RETURN_IF_ERROR(
1609         execution_state->BuildGraph(subgraph_options, &client_graph));
1610   }
1611   *collective_graph_key = client_graph->collective_graph_key;
1612 
1613   if (subgraph_options.callable_options.feed_size() !=
1614       client_graph->feed_types.size()) {
1615     return errors::Internal(
1616         "Graph pruning failed: requested number of feed endpoints = ",
1617         subgraph_options.callable_options.feed_size(),
1618         " versus number of pruned feed endpoints = ",
1619         client_graph->feed_types.size());
1620   }
1621   if (subgraph_options.callable_options.fetch_size() !=
1622       client_graph->fetch_types.size()) {
1623     return errors::Internal(
1624         "Graph pruning failed: requested number of fetch endpoints = ",
1625         subgraph_options.callable_options.fetch_size(),
1626         " versus number of pruned fetch endpoints = ",
1627         client_graph->fetch_types.size());
1628   }
1629 
1630   auto current_stateful_placements = execution_state->GetStatefulPlacements();
1631   // Update our current state based on the execution_state's
1632   // placements.  If there are any mismatches for a node,
1633   // we should fail, as this should never happen.
1634   for (const auto& placement_pair : current_stateful_placements) {
1635     const string& node_name = placement_pair.first;
1636     const string& placement = placement_pair.second;
1637     auto iter = stateful_placements_.find(node_name);
1638     if (iter == stateful_placements_.end()) {
1639       stateful_placements_.insert(std::make_pair(node_name, placement));
1640     } else if (iter->second != placement) {
1641       return errors::Internal(
1642           "Stateful placement mismatch. "
1643           "Current assignment of ",
1644           node_name, " to ", iter->second, " does not match ", placement);
1645     }
1646   }
1647 
1648   stateful_placements_ = execution_state->GetStatefulPlacements();
1649 
1650   // Remember the graph in run state if this is a partial run.
1651   if (run_state_args->is_partial_run) {
1652     run_state_args->graph.reset(new Graph(flib_def_.get()));
1653     CopyGraph(*execution_state->full_graph(), run_state_args->graph.get());
1654   }
1655 
1656   // Partition the graph across devices.
1657   PartitionOptions popts;
1658   popts.node_to_loc = [](const Node* node) {
1659     return node->assigned_device_name();
1660   };
1661   popts.new_name = [this](const string& prefix) {
1662     return strings::StrCat(prefix, "/_", edge_name_counter_.fetch_add(1));
1663   };
1664   popts.get_incarnation = [](const string& name) {
1665     // The direct session does not have changing incarnation numbers.
1666     // Just return '1'.
1667     return 1;
1668   };
1669   popts.flib_def = &client_graph->graph.flib_def();
1670   popts.control_flow_added = false;
1671 
1672   std::unordered_map<string, GraphDef> partitions;
1673   TF_RETURN_IF_ERROR(Partition(popts, &client_graph->graph, &partitions));
1674 
1675   std::vector<string> device_names;
1676   for (auto device : devices_) {
1677     // Extract the LocalName from the device.
1678     device_names.push_back(DeviceNameUtils::LocalName(device->name()));
1679   }
1680 
1681   // Check for valid partitions.
1682   for (const auto& partition : partitions) {
1683     const string local_partition_name =
1684         DeviceNameUtils::LocalName(partition.first);
1685     if (std::count(device_names.begin(), device_names.end(),
1686                    local_partition_name) == 0) {
1687       return errors::InvalidArgument(
1688           "Creating a partition for ", local_partition_name,
1689           " which doesn't exist in the list of available devices. Available "
1690           "devices: ",
1691           absl::StrJoin(device_names, ","));
1692     }
1693   }
1694 
1695   for (auto& partition : partitions) {
1696     std::unique_ptr<Graph> device_graph(
1697         new Graph(client_graph->flib_def.get()));
1698     device_graph->SetConstructionContext(ConstructionContext::kDirectSession);
1699     GraphConstructorOptions device_opts;
1700     // There are internal operations (e.g., send/recv) that we now allow.
1701     device_opts.allow_internal_ops = true;
1702     device_opts.expect_device_spec = true;
1703     TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(
1704         device_opts, std::move(partition.second), device_graph.get()));
1705     outputs->emplace(partition.first, std::move(device_graph));
1706   }
1707 
1708   GraphOptimizationPassOptions optimization_options;
1709   optimization_options.session_options = &options_;
1710   optimization_options.flib_def = client_graph->flib_def.get();
1711   optimization_options.partition_graphs = outputs;
1712   TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
1713       OptimizationPassRegistry::POST_PARTITIONING, optimization_options));
1714 
1715   Status s;
1716   for (auto& partition : *outputs) {
1717     const string& partition_name = partition.first;
1718     std::unique_ptr<Graph>* graph = &partition.second;
1719 
1720     VLOG(2) << "Created " << DebugString(graph->get()) << " for "
1721             << partition_name;
1722 
1723     // Give the device an opportunity to rewrite its subgraph.
1724     Device* d;
1725     s = device_mgr_->LookupDevice(partition_name, &d);
1726     if (!s.ok()) break;
1727     s = d->MaybeRewriteGraph(graph);
1728     if (!s.ok()) {
1729       break;
1730     }
1731   }
1732   *flib_def = std::move(client_graph->flib_def);
1733   std::swap(*input_types, client_graph->feed_types);
1734   std::swap(*output_types, client_graph->fetch_types);
1735   return s;
1736 }
1737 
ListDevices(std::vector<DeviceAttributes> * response)1738 ::tensorflow::Status DirectSession::ListDevices(
1739     std::vector<DeviceAttributes>* response) {
1740   response->clear();
1741   response->reserve(devices_.size());
1742   for (Device* d : devices_) {
1743     const DeviceAttributes& attrs = d->attributes();
1744     response->emplace_back(attrs);
1745   }
1746   return ::tensorflow::Status::OK();
1747 }
1748 
Reset(const std::vector<string> & containers)1749 ::tensorflow::Status DirectSession::Reset(
1750     const std::vector<string>& containers) {
1751   device_mgr_->ClearContainers(containers);
1752   return ::tensorflow::Status::OK();
1753 }
1754 
Close()1755 ::tensorflow::Status DirectSession::Close() {
1756   cancellation_manager_->StartCancel();
1757   {
1758     mutex_lock l(closed_lock_);
1759     if (closed_) return ::tensorflow::Status::OK();
1760     closed_ = true;
1761   }
1762   if (factory_ != nullptr) factory_->Deregister(this);
1763   return ::tensorflow::Status::OK();
1764 }
1765 
RunState(int64 step_id,const std::vector<Device * > * devices)1766 DirectSession::RunState::RunState(int64 step_id,
1767                                   const std::vector<Device*>* devices)
1768     : step_container(step_id, [devices, step_id](const string& name) {
1769         for (auto d : *devices) {
1770           if (!d->resource_manager()->Cleanup(name).ok()) {
1771             // Do nothing...
1772           }
1773           ScopedAllocatorMgr* sam = d->GetScopedAllocatorMgr();
1774           if (sam) sam->Cleanup(step_id);
1775         }
1776       }) {}
1777 
PartialRunState(const std::vector<string> & pending_input_names,const std::vector<string> & pending_output_names,int64 step_id,const std::vector<Device * > * devices)1778 DirectSession::PartialRunState::PartialRunState(
1779     const std::vector<string>& pending_input_names,
1780     const std::vector<string>& pending_output_names, int64 step_id,
1781     const std::vector<Device*>* devices)
1782     : RunState(step_id, devices) {
1783   // Initially all the feeds and fetches are pending.
1784   for (auto& name : pending_input_names) {
1785     pending_inputs[name] = false;
1786   }
1787   for (auto& name : pending_output_names) {
1788     pending_outputs[name] = false;
1789   }
1790 }
1791 
~PartialRunState()1792 DirectSession::PartialRunState::~PartialRunState() {
1793   if (rendez != nullptr) {
1794     rendez->StartAbort(errors::Cancelled("PRun cancellation"));
1795     executors_done.WaitForNotification();
1796   }
1797 }
1798 
PendingDone() const1799 bool DirectSession::PartialRunState::PendingDone() const {
1800   for (const auto& it : pending_inputs) {
1801     if (!it.second) return false;
1802   }
1803   for (const auto& it : pending_outputs) {
1804     if (!it.second) return false;
1805   }
1806   return true;
1807 }
1808 
WaitForNotification(Notification * n,RunState * run_state,CancellationManager * cm,int64 timeout_in_ms)1809 void DirectSession::WaitForNotification(Notification* n, RunState* run_state,
1810                                         CancellationManager* cm,
1811                                         int64 timeout_in_ms) {
1812   const Status status = WaitForNotification(n, timeout_in_ms);
1813   if (!status.ok()) {
1814     {
1815       mutex_lock l(run_state->mu);
1816       run_state->status.Update(status);
1817     }
1818     cm->StartCancel();
1819     // We must wait for the executors to complete, because they have borrowed
1820     // references to `cm` and other per-step state. After this notification, it
1821     // is safe to clean up the step.
1822     n->WaitForNotification();
1823   }
1824 }
1825 
WaitForNotification(Notification * notification,int64 timeout_in_ms)1826 ::tensorflow::Status DirectSession::WaitForNotification(
1827     Notification* notification, int64 timeout_in_ms) {
1828   if (timeout_in_ms > 0) {
1829     const int64 timeout_in_us = timeout_in_ms * 1000;
1830     const bool notified =
1831         WaitForNotificationWithTimeout(notification, timeout_in_us);
1832     if (!notified) {
1833       return Status(error::DEADLINE_EXCEEDED,
1834                     "Timed out waiting for notification");
1835     }
1836   } else {
1837     notification->WaitForNotification();
1838   }
1839   return Status::OK();
1840 }
1841 
MakeCallable(const CallableOptions & callable_options,CallableHandle * out_handle)1842 Status DirectSession::MakeCallable(const CallableOptions& callable_options,
1843                                    CallableHandle* out_handle) {
1844   TF_RETURN_IF_ERROR(CheckNotClosed());
1845   TF_RETURN_IF_ERROR(CheckGraphCreated("MakeCallable()"));
1846 
1847   std::unique_ptr<ExecutorsAndKeys> ek;
1848   std::unique_ptr<FunctionInfo> func_info;
1849   RunStateArgs run_state_args(callable_options.run_options().debug_options());
1850   TF_RETURN_IF_ERROR(
1851       CreateExecutors(callable_options, &ek, &func_info, &run_state_args));
1852   {
1853     mutex_lock l(callables_lock_);
1854     *out_handle = next_callable_handle_++;
1855     callables_[*out_handle] = {std::move(ek), std::move(func_info)};
1856   }
1857   return Status::OK();
1858 }
1859 
1860 class DirectSession::RunCallableCallFrame : public CallFrameInterface {
1861  public:
RunCallableCallFrame(DirectSession * session,ExecutorsAndKeys * executors_and_keys,const std::vector<Tensor> * feed_tensors,std::vector<Tensor> * fetch_tensors)1862   RunCallableCallFrame(DirectSession* session,
1863                        ExecutorsAndKeys* executors_and_keys,
1864                        const std::vector<Tensor>* feed_tensors,
1865                        std::vector<Tensor>* fetch_tensors)
1866       : session_(session),
1867         executors_and_keys_(executors_and_keys),
1868         feed_tensors_(feed_tensors),
1869         fetch_tensors_(fetch_tensors) {}
1870 
num_args() const1871   size_t num_args() const override {
1872     return executors_and_keys_->input_types.size();
1873   }
num_retvals() const1874   size_t num_retvals() const override {
1875     return executors_and_keys_->output_types.size();
1876   }
1877 
GetArg(int index,const Tensor ** val)1878   Status GetArg(int index, const Tensor** val) override {
1879     if (TF_PREDICT_FALSE(index > feed_tensors_->size())) {
1880       return errors::Internal("Args index out of bounds: ", index);
1881     } else {
1882       *val = &(*feed_tensors_)[index];
1883     }
1884     return Status::OK();
1885   }
1886 
SetRetval(int index,const Tensor & val)1887   Status SetRetval(int index, const Tensor& val) override {
1888     if (index > fetch_tensors_->size()) {
1889       return errors::Internal("RetVal index out of bounds: ", index);
1890     }
1891     (*fetch_tensors_)[index] = val;
1892     return Status::OK();
1893   }
1894 
1895  private:
1896   DirectSession* const session_;                   // Not owned.
1897   ExecutorsAndKeys* const executors_and_keys_;     // Not owned.
1898   const std::vector<Tensor>* const feed_tensors_;  // Not owned.
1899   std::vector<Tensor>* const fetch_tensors_;       // Not owned.
1900 };
1901 
RunCallable(CallableHandle handle,const std::vector<Tensor> & feed_tensors,std::vector<Tensor> * fetch_tensors,RunMetadata * run_metadata)1902 ::tensorflow::Status DirectSession::RunCallable(
1903     CallableHandle handle, const std::vector<Tensor>& feed_tensors,
1904     std::vector<Tensor>* fetch_tensors, RunMetadata* run_metadata) {
1905   return RunCallable(handle, feed_tensors, fetch_tensors, run_metadata,
1906                      thread::ThreadPoolOptions());
1907 }
1908 
RunCallable(CallableHandle handle,const std::vector<Tensor> & feed_tensors,std::vector<Tensor> * fetch_tensors,RunMetadata * run_metadata,const thread::ThreadPoolOptions & threadpool_options)1909 ::tensorflow::Status DirectSession::RunCallable(
1910     CallableHandle handle, const std::vector<Tensor>& feed_tensors,
1911     std::vector<Tensor>* fetch_tensors, RunMetadata* run_metadata,
1912     const thread::ThreadPoolOptions& threadpool_options) {
1913   TF_RETURN_IF_ERROR(CheckNotClosed());
1914   TF_RETURN_IF_ERROR(CheckGraphCreated("RunCallable()"));
1915   direct_session_runs->GetCell()->IncrementBy(1);
1916 
1917   // Check if we already have an executor for these arguments.
1918   std::shared_ptr<ExecutorsAndKeys> executors_and_keys;
1919   const int64 step_id = step_id_counter_.fetch_add(1);
1920 
1921   {
1922     tf_shared_lock l(callables_lock_);
1923     if (handle >= next_callable_handle_) {
1924       return errors::InvalidArgument("No such callable handle: ", handle);
1925     }
1926     executors_and_keys = callables_[handle].executors_and_keys;
1927   }
1928 
1929   if (!executors_and_keys) {
1930     return errors::InvalidArgument(
1931         "Attempted to run callable after handle was released: ", handle);
1932   }
1933 
1934   // NOTE(mrry): Debug options are not currently supported in the
1935   // callable interface.
1936   DebugOptions debug_options;
1937   RunStateArgs run_state_args(debug_options);
1938 
1939   // Configure a call frame for the step, which we use to feed and
1940   // fetch values to and from the executors.
1941   if (feed_tensors.size() != executors_and_keys->input_types.size()) {
1942     return errors::InvalidArgument(
1943         "Expected ", executors_and_keys->input_types.size(),
1944         " feed tensors, but got ", feed_tensors.size());
1945   }
1946   if (fetch_tensors != nullptr) {
1947     fetch_tensors->resize(executors_and_keys->output_types.size());
1948   } else if (!executors_and_keys->output_types.empty()) {
1949     return errors::InvalidArgument(
1950         "`fetch_tensors` must be provided when the callable has one or more "
1951         "outputs.");
1952   }
1953 
1954   size_t input_size = 0;
1955   bool any_resource_feeds = false;
1956   for (auto& tensor : feed_tensors) {
1957     input_size += tensor.AllocatedBytes();
1958     any_resource_feeds = any_resource_feeds || tensor.dtype() == DT_RESOURCE;
1959   }
1960   metrics::RecordGraphInputTensors(input_size);
1961 
1962   std::unique_ptr<std::vector<Tensor>> converted_feed_tensors;
1963   const std::vector<Tensor>* actual_feed_tensors;
1964 
1965   if (TF_PREDICT_FALSE(any_resource_feeds)) {
1966     converted_feed_tensors = absl::make_unique<std::vector<Tensor>>();
1967     converted_feed_tensors->reserve(feed_tensors.size());
1968     for (const Tensor& t : feed_tensors) {
1969       if (t.dtype() == DT_RESOURCE) {
1970         converted_feed_tensors->emplace_back();
1971         Tensor* tensor_from_handle = &converted_feed_tensors->back();
1972         TF_RETURN_IF_ERROR(ResourceHandleToInputTensor(t, tensor_from_handle));
1973       } else {
1974         converted_feed_tensors->emplace_back(t);
1975       }
1976     }
1977     actual_feed_tensors = converted_feed_tensors.get();
1978   } else {
1979     actual_feed_tensors = &feed_tensors;
1980   }
1981 
1982   // A specialized CallFrame implementation that takes advantage of the
1983   // optimized RunCallable interface.
1984   RunCallableCallFrame call_frame(this, executors_and_keys.get(),
1985                                   actual_feed_tensors, fetch_tensors);
1986 
1987   if (LogMemory::IsEnabled()) {
1988     LogMemory::RecordStep(step_id, run_state_args.handle);
1989   }
1990 
1991   TF_RETURN_IF_ERROR(RunInternal(
1992       step_id, executors_and_keys->callable_options.run_options(), &call_frame,
1993       executors_and_keys.get(), run_metadata, threadpool_options));
1994 
1995   if (fetch_tensors != nullptr) {
1996     size_t output_size = 0;
1997     for (auto& tensor : *fetch_tensors) {
1998       output_size += tensor.AllocatedBytes();
1999     }
2000     metrics::RecordGraphOutputTensors(output_size);
2001   }
2002 
2003   return Status::OK();
2004 }
2005 
ReleaseCallable(CallableHandle handle)2006 ::tensorflow::Status DirectSession::ReleaseCallable(CallableHandle handle) {
2007   mutex_lock l(callables_lock_);
2008   if (handle >= next_callable_handle_) {
2009     return errors::InvalidArgument("No such callable handle: ", handle);
2010   }
2011   callables_.erase(handle);
2012   return Status::OK();
2013 }
2014 
Finalize()2015 Status DirectSession::Finalize() {
2016   mutex_lock l(graph_state_lock_);
2017   if (finalized_) {
2018     return errors::FailedPrecondition("Session already finalized.");
2019   }
2020   if (!graph_created_) {
2021     return errors::FailedPrecondition("Session not yet created.");
2022   }
2023   execution_state_.reset();
2024   flib_def_.reset();
2025   finalized_ = true;
2026   return Status::OK();
2027 }
2028 
~Callable()2029 DirectSession::Callable::~Callable() {
2030   // We must delete the fields in this order, because the destructor
2031   // of `executors_and_keys` will call into an object owned by
2032   // `function_info` (in particular, when deleting a kernel, it relies
2033   // on the `FunctionLibraryRuntime` to know if the kernel is stateful
2034   // or not).
2035   executors_and_keys.reset();
2036   function_info.reset();
2037 }
2038 
2039 }  // namespace tensorflow
2040