• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/common_runtime/direct_session.h"
17 
18 #include <algorithm>
19 #include <atomic>
20 #include <string>
21 #include <vector>
22 
23 #include "absl/container/flat_hash_set.h"
24 #include "absl/time/time.h"
25 #include "absl/types/optional.h"
26 #include "tensorflow/core/common_runtime/collective_executor_mgr.h"
27 #include "tensorflow/core/common_runtime/collective_param_resolver_local.h"
28 #include "tensorflow/core/common_runtime/constant_folding.h"
29 #include "tensorflow/core/common_runtime/debugger_state_interface.h"
30 #include "tensorflow/core/common_runtime/device_factory.h"
31 #include "tensorflow/core/common_runtime/device_resolver_local.h"
32 #include "tensorflow/core/common_runtime/executor.h"
33 #include "tensorflow/core/common_runtime/executor_factory.h"
34 #include "tensorflow/core/common_runtime/function.h"
35 #include "tensorflow/core/common_runtime/graph_constructor.h"
36 #include "tensorflow/core/common_runtime/graph_optimizer.h"
37 #include "tensorflow/core/common_runtime/local_session_selection.h"
38 #include "tensorflow/core/common_runtime/memory_types.h"
39 #include "tensorflow/core/common_runtime/optimization_registry.h"
40 #include "tensorflow/core/common_runtime/process_util.h"
41 #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
42 #include "tensorflow/core/common_runtime/scoped_allocator_mgr.h"
43 #include "tensorflow/core/common_runtime/step_stats_collector.h"
44 #include "tensorflow/core/framework/function.h"
45 #include "tensorflow/core/framework/graph.pb.h"
46 #include "tensorflow/core/framework/graph_def_util.h"
47 #include "tensorflow/core/framework/log_memory.h"
48 #include "tensorflow/core/framework/logging.h"
49 #include "tensorflow/core/framework/metrics.h"
50 #include "tensorflow/core/framework/node_def.pb.h"
51 #include "tensorflow/core/framework/run_handler.h"
52 #include "tensorflow/core/framework/tensor.h"
53 #include "tensorflow/core/framework/versions.pb.h"
54 #include "tensorflow/core/graph/algorithm.h"
55 #include "tensorflow/core/graph/graph.h"
56 #include "tensorflow/core/graph/graph_partition.h"
57 #include "tensorflow/core/graph/subgraph.h"
58 #include "tensorflow/core/graph/tensor_id.h"
59 #include "tensorflow/core/lib/core/errors.h"
60 #include "tensorflow/core/lib/core/notification.h"
61 #include "tensorflow/core/lib/core/refcount.h"
62 #include "tensorflow/core/lib/core/status.h"
63 #include "tensorflow/core/lib/core/threadpool.h"
64 #include "tensorflow/core/lib/core/threadpool_options.h"
65 #include "tensorflow/core/lib/gtl/array_slice.h"
66 #include "tensorflow/core/lib/monitoring/counter.h"
67 #include "tensorflow/core/lib/random/random.h"
68 #include "tensorflow/core/lib/strings/numbers.h"
69 #include "tensorflow/core/lib/strings/str_util.h"
70 #include "tensorflow/core/lib/strings/strcat.h"
71 #include "tensorflow/core/nccl/collective_communicator.h"
72 #include "tensorflow/core/platform/byte_order.h"
73 #include "tensorflow/core/platform/cpu_info.h"
74 #include "tensorflow/core/platform/logging.h"
75 #include "tensorflow/core/platform/mutex.h"
76 #include "tensorflow/core/platform/tracing.h"
77 #include "tensorflow/core/platform/types.h"
78 #include "tensorflow/core/profiler/lib/connected_traceme.h"
79 #include "tensorflow/core/profiler/lib/device_profiler_session.h"
80 #include "tensorflow/core/profiler/lib/traceme_encode.h"
81 #include "tensorflow/core/protobuf/config.pb.h"
82 #include "tensorflow/core/util/device_name_utils.h"
83 #include "tensorflow/core/util/env_var.h"
84 
85 namespace tensorflow {
86 
87 namespace {
88 
89 auto* direct_session_runs = monitoring::Counter<0>::New(
90     "/tensorflow/core/direct_session_runs",
91     "The number of times DirectSession::Run() has been called.");
92 
NewThreadPoolFromThreadPoolOptions(const SessionOptions & options,const ThreadPoolOptionProto & thread_pool_options,int pool_number,thread::ThreadPool ** pool,bool * owned)93 Status NewThreadPoolFromThreadPoolOptions(
94     const SessionOptions& options,
95     const ThreadPoolOptionProto& thread_pool_options, int pool_number,
96     thread::ThreadPool** pool, bool* owned) {
97   int32_t num_threads = thread_pool_options.num_threads();
98   if (num_threads == 0) {
99     num_threads = NumInterOpThreadsFromSessionOptions(options);
100   }
101   const string& name = thread_pool_options.global_name();
102   if (name.empty()) {
103     // Session-local threadpool.
104     VLOG(1) << "Direct session inter op parallelism threads for pool "
105             << pool_number << ": " << num_threads;
106     *pool = new thread::ThreadPool(
107         options.env, ThreadOptions(), strings::StrCat("Compute", pool_number),
108         num_threads, !options.config.experimental().disable_thread_spinning(),
109         /*allocator=*/nullptr);
110     *owned = true;
111     return OkStatus();
112   }
113 
114   // Global, named threadpool.
115   typedef std::pair<int32, thread::ThreadPool*> MapValue;
116   static std::map<string, MapValue>* global_pool_map =
117       new std::map<string, MapValue>;
118   static mutex* mu = new mutex();
119   mutex_lock l(*mu);
120   MapValue* mvalue = &(*global_pool_map)[name];
121   if (mvalue->second == nullptr) {
122     mvalue->first = thread_pool_options.num_threads();
123     mvalue->second = new thread::ThreadPool(
124         options.env, ThreadOptions(), strings::StrCat("Compute", pool_number),
125         num_threads, !options.config.experimental().disable_thread_spinning(),
126         /*allocator=*/nullptr);
127   } else {
128     if (mvalue->first != thread_pool_options.num_threads()) {
129       return errors::InvalidArgument(
130           "Pool ", name,
131           " configured previously with num_threads=", mvalue->first,
132           "; cannot re-configure with num_threads=",
133           thread_pool_options.num_threads());
134     }
135   }
136   *owned = false;
137   *pool = mvalue->second;
138   return OkStatus();
139 }
140 
GlobalThreadPool(const SessionOptions & options)141 thread::ThreadPool* GlobalThreadPool(const SessionOptions& options) {
142   static thread::ThreadPool* const thread_pool =
143       NewThreadPoolFromSessionOptions(options);
144   return thread_pool;
145 }
146 
147 // TODO(vrv): Figure out how to unify the many different functions
148 // that generate RendezvousKey, since many of them have to be
149 // consistent with each other.
GetRendezvousKey(const string & tensor_name,const DeviceAttributes & device_info,const FrameAndIter & frame_iter)150 string GetRendezvousKey(const string& tensor_name,
151                         const DeviceAttributes& device_info,
152                         const FrameAndIter& frame_iter) {
153   return strings::StrCat(device_info.name(), ";",
154                          strings::FpToString(device_info.incarnation()), ";",
155                          device_info.name(), ";", tensor_name, ";",
156                          frame_iter.frame_id, ":", frame_iter.iter_id);
157 }
158 
159 }  // namespace
160 
161 class DirectSessionFactory : public SessionFactory {
162  public:
DirectSessionFactory()163   DirectSessionFactory() {}
164 
AcceptsOptions(const SessionOptions & options)165   bool AcceptsOptions(const SessionOptions& options) override {
166     return options.target.empty() &&
167            !options.config.experimental().use_tfrt() &&
168            GetDefaultLocalSessionImpl() == LocalSessionImpl::kDirectSession;
169   }
170 
NewSession(const SessionOptions & options,Session ** out_session)171   Status NewSession(const SessionOptions& options,
172                     Session** out_session) override {
173     const auto& experimental_config = options.config.experimental();
174     if (experimental_config.has_session_metadata()) {
175       if (experimental_config.session_metadata().version() < 0) {
176         return errors::InvalidArgument(
177             "Session version shouldn't be negative: ",
178             experimental_config.session_metadata().DebugString());
179       }
180       const string key = GetMetadataKey(experimental_config.session_metadata());
181       mutex_lock l(sessions_lock_);
182       if (!session_metadata_keys_.insert(key).second) {
183         return errors::InvalidArgument(
184             "A session with the same name and version has already been "
185             "created: ",
186             experimental_config.session_metadata().DebugString());
187       }
188     }
189 
190     // Must do this before the CPU allocator is created.
191     if (options.config.graph_options().build_cost_model() > 0) {
192       EnableCPUAllocatorFullStats();
193     }
194     std::vector<std::unique_ptr<Device>> devices;
195     TF_RETURN_IF_ERROR(DeviceFactory::AddDevices(
196         options, "/job:localhost/replica:0/task:0", &devices));
197 
198     DirectSession* session = new DirectSession(
199         options, new StaticDeviceMgr(std::move(devices)), this);
200     {
201       mutex_lock l(sessions_lock_);
202       sessions_.push_back(session);
203     }
204     *out_session = session;
205     return OkStatus();
206   }
207 
Reset(const SessionOptions & options,const std::vector<string> & containers)208   Status Reset(const SessionOptions& options,
209                const std::vector<string>& containers) override {
210     std::vector<DirectSession*> sessions_to_reset;
211     {
212       mutex_lock l(sessions_lock_);
213       // We create a copy to ensure that we don't have a deadlock when
214       // session->Close calls the DirectSessionFactory.Deregister, which
215       // acquires sessions_lock_.
216       std::swap(sessions_to_reset, sessions_);
217     }
218     Status s;
219     for (auto session : sessions_to_reset) {
220       s.Update(session->Reset(containers));
221     }
222     // TODO(suharshs): Change the Reset behavior of all SessionFactories so that
223     // it doesn't close the sessions?
224     for (auto session : sessions_to_reset) {
225       s.Update(session->Close());
226     }
227     return s;
228   }
229 
Deregister(const DirectSession * session)230   void Deregister(const DirectSession* session) {
231     mutex_lock l(sessions_lock_);
232     sessions_.erase(std::remove(sessions_.begin(), sessions_.end(), session),
233                     sessions_.end());
234     if (session->options().config.experimental().has_session_metadata()) {
235       session_metadata_keys_.erase(GetMetadataKey(
236           session->options().config.experimental().session_metadata()));
237     }
238   }
239 
240  private:
GetMetadataKey(const SessionMetadata & metadata)241   static string GetMetadataKey(const SessionMetadata& metadata) {
242     return absl::StrCat(metadata.name(), "/", metadata.version());
243   }
244 
245   mutex sessions_lock_;
246   std::vector<DirectSession*> sessions_ TF_GUARDED_BY(sessions_lock_);
247   absl::flat_hash_set<string> session_metadata_keys_
248       TF_GUARDED_BY(sessions_lock_);
249 };
250 
251 class DirectSessionRegistrar {
252  public:
DirectSessionRegistrar()253   DirectSessionRegistrar() {
254     SessionFactory::Register("DIRECT_SESSION", new DirectSessionFactory());
255   }
256 };
257 static DirectSessionRegistrar registrar;
258 
259 std::atomic_int_fast64_t DirectSession::step_id_counter_(1);
260 
GetOrCreateRunHandlerPool(const SessionOptions & options)261 static RunHandlerPool* GetOrCreateRunHandlerPool(
262     const SessionOptions& options) {
263   int num_inter_threads = 0;
264   int num_intra_threads = 0;
265   static const int env_num_inter_threads = NumInterOpThreadsFromEnvironment();
266   static const int env_num_intra_threads = NumIntraOpThreadsFromEnvironment();
267   if (env_num_inter_threads > 0) {
268     num_inter_threads = env_num_inter_threads;
269   }
270   if (env_num_intra_threads > 0) {
271     num_intra_threads = env_num_intra_threads;
272   }
273 
274   if (num_inter_threads == 0) {
275     if (options.config.session_inter_op_thread_pool_size() > 0) {
276       // Note due to ShouldUseRunHandler we are guaranteed that
277       // run_options.inter_op_thread_pool() == 0
278       num_inter_threads =
279           options.config.session_inter_op_thread_pool(0).num_threads();
280     }
281     if (num_inter_threads == 0) {
282       num_inter_threads = NumInterOpThreadsFromSessionOptions(options);
283     }
284   }
285 
286   if (num_intra_threads == 0) {
287     num_intra_threads = options.config.intra_op_parallelism_threads();
288     if (num_intra_threads == 0) {
289       num_intra_threads = port::MaxParallelism();
290     }
291   }
292 
293   static RunHandlerPool* pool = [&]() {
294     LOG(INFO) << "Creating run-handler pool with "
295                  "[num_inter_threads, num_intra_threads] as ["
296               << num_inter_threads << "," << num_intra_threads << "]";
297     return new RunHandlerPool(num_inter_threads, num_intra_threads);
298   }();
299   return pool;
300 }
301 
ShouldUseRunHandlerPool(const RunOptions & run_options) const302 bool DirectSession::ShouldUseRunHandlerPool(
303     const RunOptions& run_options) const {
304   if (options_.config.use_per_session_threads()) return false;
305   if (options_.config.session_inter_op_thread_pool_size() > 0 &&
306       run_options.inter_op_thread_pool() > 0)
307     return false;
308   // Only use RunHandlerPool when:
309   // a. Single global thread pool is used for inter-op parallelism.
310   // b. When multiple inter_op_thread_pool(s) are created, use it only while
311   // running sessions on the default inter_op_thread_pool=0. Typically,
312   // servo-team uses inter_op_thread_pool > 0 for model loading.
313   // TODO(crk): Revisit whether we'd want to create one (static) RunHandlerPool
314   // per entry in session_inter_op_thread_pool() in the future.
315   return true;
316 }
317 
DirectSession(const SessionOptions & options,const DeviceMgr * device_mgr,DirectSessionFactory * const factory)318 DirectSession::DirectSession(const SessionOptions& options,
319                              const DeviceMgr* device_mgr,
320                              DirectSessionFactory* const factory)
321     : options_(options),
322       device_mgr_(device_mgr),
323       factory_(factory),
324       cancellation_manager_(new CancellationManager()),
325       operation_timeout_in_ms_(options_.config.operation_timeout_in_ms()) {
326   const int thread_pool_size =
327       options_.config.session_inter_op_thread_pool_size();
328   if (thread_pool_size > 0) {
329     for (int i = 0; i < thread_pool_size; ++i) {
330       thread::ThreadPool* pool = nullptr;
331       bool owned = false;
332       init_error_.Update(NewThreadPoolFromThreadPoolOptions(
333           options_, options_.config.session_inter_op_thread_pool(i), i, &pool,
334           &owned));
335       thread_pools_.emplace_back(pool, owned);
336     }
337   } else if (options_.config.use_per_session_threads()) {
338     thread_pools_.emplace_back(NewThreadPoolFromSessionOptions(options_),
339                                true /* owned */);
340   } else {
341     thread_pools_.emplace_back(GlobalThreadPool(options), false /* owned */);
342     // Run locally if environment value of TF_NUM_INTEROP_THREADS is negative
343     // and config.inter_op_parallelism_threads is unspecified or negative.
344     static const int env_num_threads = NumInterOpThreadsFromEnvironment();
345     if (options_.config.inter_op_parallelism_threads() < 0 ||
346         (options_.config.inter_op_parallelism_threads() == 0 &&
347          env_num_threads < 0)) {
348       run_in_caller_thread_ = true;
349     }
350   }
351   // The default value of sync_on_finish will be flipped soon and this
352   // environment variable will be removed as well.
353   const Status status =
354       ReadBoolFromEnvVar("TF_SYNC_ON_FINISH", true, &sync_on_finish_);
355   if (!status.ok()) {
356     LOG(ERROR) << status.error_message();
357   }
358   session_handle_ =
359       strings::StrCat("direct", strings::FpToString(random::New64()));
360   int devices_added = 0;
361   if (options.config.log_device_placement()) {
362     const string mapping_str = device_mgr_->DeviceMappingString();
363     string msg;
364     if (mapping_str.empty()) {
365       msg = "Device mapping: no known devices.";
366     } else {
367       msg = strings::StrCat("Device mapping:\n", mapping_str);
368     }
369     if (!logging::LogToListeners(msg)) {
370       LOG(INFO) << msg;
371     }
372   }
373   for (auto d : device_mgr_->ListDevices()) {
374     devices_.push_back(d);
375     device_set_.AddDevice(d);
376     d->op_segment()->AddHold(session_handle_);
377 
378     // The first device added is special: it is the 'client device' (a
379     // CPU device) from which we feed and fetch Tensors.
380     if (devices_added == 0) {
381       device_set_.set_client_device(d);
382     }
383     ++devices_added;
384   }
385 }
386 
~DirectSession()387 DirectSession::~DirectSession() {
388   if (!closed_) Close().IgnoreError();
389   for (auto& it : partial_runs_) {
390     it.second.reset(nullptr);
391   }
392   for (auto& it : executors_) {
393     it.second.reset();
394   }
395   callables_.clear();
396   for (auto d : device_mgr_->ListDevices()) {
397     d->op_segment()->RemoveHold(session_handle_);
398   }
399   functions_.clear();
400   delete cancellation_manager_;
401   for (const auto& p_and_owned : thread_pools_) {
402     if (p_and_owned.second) delete p_and_owned.first;
403   }
404 
405   execution_state_.reset(nullptr);
406   flib_def_.reset(nullptr);
407 }
408 
Create(const GraphDef & graph)409 Status DirectSession::Create(const GraphDef& graph) {
410   return Create(GraphDef(graph));
411 }
412 
Create(GraphDef && graph)413 Status DirectSession::Create(GraphDef&& graph) {
414   TF_RETURN_IF_ERROR(init_error_);
415   if (graph.node_size() > 0) {
416     mutex_lock l(graph_state_lock_);
417     if (graph_created_) {
418       return errors::AlreadyExists(
419           "A Graph has already been created for this session.");
420     }
421     return ExtendLocked(std::move(graph));
422   }
423   return OkStatus();
424 }
425 
Extend(const GraphDef & graph)426 Status DirectSession::Extend(const GraphDef& graph) {
427   return Extend(GraphDef(graph));
428 }
429 
Extend(GraphDef && graph)430 Status DirectSession::Extend(GraphDef&& graph) {
431   TF_RETURN_IF_ERROR(CheckNotClosed());
432   mutex_lock l(graph_state_lock_);
433   return ExtendLocked(std::move(graph));
434 }
435 
ExtendLocked(GraphDef && graph)436 Status DirectSession::ExtendLocked(GraphDef&& graph) {
437   if (finalized_) {
438     return errors::FailedPrecondition("Session has been finalized.");
439   }
440   if (!(flib_def_ && execution_state_)) {
441     // If this is the first call, we can initialize the execution state
442     // with `graph` and do not need to call `Extend()`.
443     GraphExecutionStateOptions options;
444     options.device_set = &device_set_;
445     options.session_options = &options_;
446     options.session_handle = session_handle_;
447     TF_RETURN_IF_ERROR(GraphExecutionState::MakeForBaseGraph(
448         std::move(graph), options, &execution_state_));
449     // NOTE(mrry): The function library created here will be used for
450     // all subsequent extensions of the graph. Also, note how using the copy
451     // constructor of FunctionLibraryDefinition avoids duplicating the memory
452     // that is occupied by its shared_ptr members.
453     flib_def_.reset(
454         new FunctionLibraryDefinition(execution_state_->flib_def()));
455     graph_created_ = true;
456   } else {
457     std::unique_ptr<GraphExecutionState> state;
458     // TODO(mrry): Rewrite GraphExecutionState::Extend() to take `graph` by
459     // value and move `graph` in here.
460     TF_RETURN_IF_ERROR(execution_state_->Extend(graph, &state));
461     execution_state_.swap(state);
462     TF_RETURN_IF_ERROR(flib_def_->AddLibrary(graph.library()));
463   }
464   return OkStatus();
465 }
466 
Run(const NamedTensorList & inputs,const std::vector<string> & output_names,const std::vector<string> & target_nodes,std::vector<Tensor> * outputs)467 Status DirectSession::Run(const NamedTensorList& inputs,
468                           const std::vector<string>& output_names,
469                           const std::vector<string>& target_nodes,
470                           std::vector<Tensor>* outputs) {
471   RunMetadata run_metadata;
472   return Run(RunOptions(), inputs, output_names, target_nodes, outputs,
473              &run_metadata);
474 }
475 
CreateDebuggerState(const CallableOptions & callable_options,int64_t global_step,int64_t session_run_index,int64_t executor_step_index,std::unique_ptr<DebuggerStateInterface> * debugger_state)476 Status DirectSession::CreateDebuggerState(
477     const CallableOptions& callable_options, int64_t global_step,
478     int64_t session_run_index, int64_t executor_step_index,
479     std::unique_ptr<DebuggerStateInterface>* debugger_state) {
480   TF_RETURN_IF_ERROR(DebuggerStateRegistry::CreateState(
481       callable_options.run_options().debug_options(), debugger_state));
482   std::vector<string> input_names(callable_options.feed().begin(),
483                                   callable_options.feed().end());
484   std::vector<string> output_names(callable_options.fetch().begin(),
485                                    callable_options.fetch().end());
486   std::vector<string> target_names(callable_options.target().begin(),
487                                    callable_options.target().end());
488 
489   TF_RETURN_IF_ERROR(debugger_state->get()->PublishDebugMetadata(
490       global_step, session_run_index, executor_step_index, input_names,
491       output_names, target_names));
492   return OkStatus();
493 }
494 
DecorateAndPublishGraphForDebug(const DebugOptions & debug_options,Graph * graph,Device * device)495 Status DirectSession::DecorateAndPublishGraphForDebug(
496     const DebugOptions& debug_options, Graph* graph, Device* device) {
497   std::unique_ptr<DebugGraphDecoratorInterface> decorator;
498   TF_RETURN_IF_ERROR(
499       DebugGraphDecoratorRegistry::CreateDecorator(debug_options, &decorator));
500 
501   TF_RETURN_IF_ERROR(decorator->DecorateGraph(graph, device));
502   TF_RETURN_IF_ERROR(decorator->PublishGraph(*graph, device->name()));
503   return OkStatus();
504 }
505 
RunInternal(int64_t step_id,const RunOptions & run_options,CallFrameInterface * call_frame,ExecutorsAndKeys * executors_and_keys,RunMetadata * run_metadata,const thread::ThreadPoolOptions & threadpool_options)506 Status DirectSession::RunInternal(
507     int64_t step_id, const RunOptions& run_options,
508     CallFrameInterface* call_frame, ExecutorsAndKeys* executors_and_keys,
509     RunMetadata* run_metadata,
510     const thread::ThreadPoolOptions& threadpool_options) {
511   const uint64 start_time_usecs = options_.env->NowMicros();
512   const int64_t executor_step_count =
513       executors_and_keys->step_count.fetch_add(1);
514   RunState run_state(step_id, &devices_);
515   const size_t num_executors = executors_and_keys->items.size();
516 
517   profiler::TraceMeProducer activity(
518       // To TraceMeConsumers in ExecutorState::Process/Finish.
519       [&] {
520         if (options_.config.experimental().has_session_metadata()) {
521           const auto& model_metadata =
522               options_.config.experimental().session_metadata();
523           string model_id = strings::StrCat(model_metadata.name(), ":",
524                                             model_metadata.version());
525           return profiler::TraceMeEncode("SessionRun",
526                                          {{"id", step_id},
527                                           {"_r", 1} /*root_event*/,
528                                           {"model_id", model_id}});
529         } else {
530           return profiler::TraceMeEncode(
531               "SessionRun", {{"id", step_id}, {"_r", 1} /*root_event*/});
532         }
533       },
534       profiler::ContextType::kTfExecutor, step_id,
535       profiler::TraceMeLevel::kInfo);
536 
537   std::unique_ptr<DebuggerStateInterface> debugger_state;
538   if (!run_options.debug_options().debug_tensor_watch_opts().empty()) {
539     TF_RETURN_IF_ERROR(
540         CreateDebuggerState(executors_and_keys->callable_options,
541                             run_options.debug_options().global_step(), step_id,
542                             executor_step_count, &debugger_state));
543   }
544 
545   if (run_metadata != nullptr &&
546       options_.config.experimental().has_session_metadata()) {
547     *run_metadata->mutable_session_metadata() =
548         options_.config.experimental().session_metadata();
549   }
550 
551 #ifndef __ANDROID__
552   // Set up for collectives if ExecutorsAndKeys declares a key.
553   if (executors_and_keys->collective_graph_key !=
554       BuildGraphOptions::kNoCollectiveGraphKey) {
555     if (run_options.experimental().collective_graph_key() !=
556         BuildGraphOptions::kNoCollectiveGraphKey) {
557       // If a collective_graph_key was specified in run_options, ensure that it
558       // matches what came out of GraphExecutionState::BuildGraph().
559       if (run_options.experimental().collective_graph_key() !=
560           executors_and_keys->collective_graph_key) {
561         return errors::Internal(
562             "collective_graph_key in RunOptions ",
563             run_options.experimental().collective_graph_key(),
564             " should match collective_graph_key from optimized graph ",
565             executors_and_keys->collective_graph_key);
566       }
567     }
568     if (!collective_executor_mgr_) {
569       collective_executor_mgr_ = CreateProdLocalCollectiveExecutorMgr(
570           options_.config, device_mgr_.get(),
571           MaybeCreateNcclCommunicator(options_.config));
572     }
573     run_state.collective_executor.reset(new CollectiveExecutor::Handle(
574         collective_executor_mgr_->FindOrCreate(step_id), true /*inherit_ref*/));
575   }
576 #endif
577 
578   thread::ThreadPool* pool;
579   // Use std::unique_ptr to ensure garbage collection
580   std::unique_ptr<thread::ThreadPool> threadpool_wrapper;
581 
582   const bool inline_execution_requested =
583       run_in_caller_thread_ || run_options.inter_op_thread_pool() == -1;
584 
585   if (inline_execution_requested) {
586     // We allow using the caller thread only when having a single executor
587     // specified.
588     if (executors_and_keys->items.size() > 1) {
589       pool = thread_pools_[0].first;
590     } else {
591       VLOG(1) << "Executing Session::Run() synchronously!";
592       pool = nullptr;
593     }
594   } else if (threadpool_options.inter_op_threadpool != nullptr) {
595     threadpool_wrapper = std::make_unique<thread::ThreadPool>(
596         threadpool_options.inter_op_threadpool);
597     pool = threadpool_wrapper.get();
598   } else {
599     if (run_options.inter_op_thread_pool() < -1 ||
600         run_options.inter_op_thread_pool() >=
601             static_cast<int32>(thread_pools_.size())) {
602       return errors::InvalidArgument("Invalid inter_op_thread_pool: ",
603                                      run_options.inter_op_thread_pool());
604     }
605 
606     pool = thread_pools_[run_options.inter_op_thread_pool()].first;
607   }
608 
609   const int64_t call_timeout = run_options.timeout_in_ms() > 0
610                                    ? run_options.timeout_in_ms()
611                                    : operation_timeout_in_ms_;
612   absl::optional<absl::Time> deadline;
613   if (call_timeout > 0) {
614     deadline = absl::Now() + absl::Milliseconds(call_timeout);
615   }
616 
617   std::unique_ptr<RunHandler> handler;
618   if (ShouldUseRunHandlerPool(run_options) &&
619       run_options.experimental().use_run_handler_pool()) {
620     VLOG(1) << "Using RunHandler to scheduler inter-op closures.";
621     handler = GetOrCreateRunHandlerPool(options_)->Get(
622         step_id, call_timeout,
623         run_options.experimental().run_handler_pool_options());
624     if (!handler) {
625       return errors::DeadlineExceeded(
626           "Could not obtain RunHandler for request after waiting for ",
627           call_timeout, "ms.");
628     }
629   }
630   auto* handler_ptr = handler.get();
631 
632   Executor::Args::Runner default_runner = nullptr;
633 
634   if (pool == nullptr) {
635     default_runner = [](const Executor::Args::Closure& c) { c(); };
636   } else if (handler_ptr != nullptr) {
637     default_runner = [handler_ptr](Executor::Args::Closure c) {
638       handler_ptr->ScheduleInterOpClosure(std::move(c));
639     };
640   } else {
641     default_runner = [pool](Executor::Args::Closure c) {
642       pool->Schedule(std::move(c));
643     };
644   }
645 
646   // Start parallel Executors.
647 
648   // We can execute this step synchronously on the calling thread whenever
649   // there is a single device and the timeout mechanism is not used.
650   //
651   // When timeouts are used, we must execute the graph(s) asynchronously, in
652   // order to invoke the cancellation manager on the calling thread if the
653   // timeout expires.
654   const bool can_execute_synchronously =
655       executors_and_keys->items.size() == 1 && call_timeout == 0;
656 
657   Executor::Args args;
658   args.step_id = step_id;
659   args.call_frame = call_frame;
660   args.collective_executor =
661       (run_state.collective_executor ? run_state.collective_executor->get()
662                                      : nullptr);
663   args.session_state = &session_state_;
664   args.session_handle = session_handle_;
665   args.tensor_store = &run_state.tensor_store;
666   args.step_container = &run_state.step_container;
667   args.sync_on_finish = sync_on_finish_;
668   args.user_intra_op_threadpool = threadpool_options.intra_op_threadpool;
669   args.run_all_kernels_inline = pool == nullptr;
670   args.start_time_usecs = start_time_usecs;
671   args.deadline = deadline;
672 
673   const bool do_trace = (run_options.trace_level() > RunOptions::NO_TRACE);
674 
675   bool update_cost_model = false;
676   if (options_.config.graph_options().build_cost_model() > 0) {
677     const int64_t build_cost_model_every =
678         options_.config.graph_options().build_cost_model();
679     const int64_t build_cost_model_after =
680         options_.config.graph_options().build_cost_model_after();
681     int64_t measure_step_count = executor_step_count - build_cost_model_after;
682     if (measure_step_count >= 0) {
683       update_cost_model =
684           ((measure_step_count + 1) % build_cost_model_every == 0);
685     }
686   }
687   if (do_trace || update_cost_model ||
688       run_options.report_tensor_allocations_upon_oom()) {
689     run_state.collector.reset(
690         new StepStatsCollector(run_metadata->mutable_step_stats()));
691     args.stats_collector = run_state.collector.get();
692   }
693 
694   std::unique_ptr<DeviceProfilerSession> device_profiler_session;
695   if (run_options.trace_level() >= RunOptions::HARDWARE_TRACE) {
696     device_profiler_session = DeviceProfilerSession::Create();
697   }
698 
699   // Register this step with session's cancellation manager, so that
700   // `Session::Close()` will cancel the step.
701   CancellationManager step_cancellation_manager(cancellation_manager_);
702   if (step_cancellation_manager.IsCancelled()) {
703     return errors::Cancelled("Run call was cancelled");
704   }
705   args.cancellation_manager = &step_cancellation_manager;
706 
707   Status run_status;
708 
709   auto set_threadpool_args_for_item =
710       [&default_runner, &handler](const PerPartitionExecutorsAndLib& item,
711                                   Executor::Args* args) {
712         // TODO(azaks): support partial run.
713         // TODO(azaks): if the device picks its own threadpool, we need to
714         // assign
715         //     less threads to the main compute pool by default.
716         thread::ThreadPool* device_thread_pool =
717             item.device->tensorflow_device_thread_pool();
718         // TODO(crk): Investigate usage of RunHandlerPool when using device
719         // specific thread pool(s).
720         if (!device_thread_pool) {
721           args->runner = default_runner;
722         } else {
723           args->runner = [device_thread_pool](Executor::Args::Closure c) {
724             device_thread_pool->Schedule(std::move(c));
725           };
726         }
727         if (handler != nullptr) {
728           args->user_intra_op_threadpool =
729               handler->AsIntraThreadPoolInterface();
730         }
731       };
732 
733   if (can_execute_synchronously) {
734     PrivateIntraProcessRendezvous rendezvous(device_mgr_.get());
735     args.rendezvous = &rendezvous;
736 
737     const auto& item = executors_and_keys->items[0];
738     set_threadpool_args_for_item(item, &args);
739     run_status = item.executor->Run(args);
740   } else {
741     core::RefCountPtr<RefCountedIntraProcessRendezvous> rendezvous(
742         new RefCountedIntraProcessRendezvous(device_mgr_.get()));
743     args.rendezvous = rendezvous.get();
744 
745     // `barrier` will delete itself after the final executor finishes.
746     Notification executors_done;
747     ExecutorBarrier* barrier =
748         new ExecutorBarrier(num_executors, rendezvous.get(),
749                             [&run_state, &executors_done](const Status& ret) {
750                               {
751                                 mutex_lock l(run_state.mu);
752                                 run_state.status.Update(ret);
753                               }
754                               executors_done.Notify();
755                             });
756 
757     for (const auto& item : executors_and_keys->items) {
758       set_threadpool_args_for_item(item, &args);
759       item.executor->RunAsync(args, barrier->Get());
760     }
761 
762     WaitForNotification(&executors_done, &run_state, &step_cancellation_manager,
763                         call_timeout);
764     {
765       tf_shared_lock l(run_state.mu);
766       run_status = run_state.status;
767     }
768   }
769 
770   if (step_cancellation_manager.IsCancelled()) {
771     run_status.Update(errors::Cancelled("Run call was cancelled"));
772   }
773 
774   if (device_profiler_session) {
775     TF_RETURN_IF_ERROR(device_profiler_session->CollectData(
776         run_metadata->mutable_step_stats()));
777   }
778 
779   TF_RETURN_IF_ERROR(run_status);
780 
781   // Save the output tensors of this run we choose to keep.
782   if (!run_state.tensor_store.empty()) {
783     TF_RETURN_IF_ERROR(run_state.tensor_store.SaveTensors(
784         {executors_and_keys->callable_options.fetch().begin(),
785          executors_and_keys->callable_options.fetch().end()},
786         &session_state_));
787   }
788 
789   if (run_state.collector) {
790     run_state.collector->Finalize();
791   }
792 
793   // Build and return the cost model as instructed.
794   if (update_cost_model) {
795     // Build the cost model
796     std::unordered_map<string, const Graph*> device_to_graph;
797     for (const PerPartitionExecutorsAndLib& partition :
798          executors_and_keys->items) {
799       const Graph* graph = partition.graph.get();
800       const string& device = partition.flib->device()->name();
801       device_to_graph[device] = graph;
802     }
803 
804     mutex_lock l(executor_lock_);
805     run_state.collector->BuildCostModel(&cost_model_manager_, device_to_graph);
806 
807     // annotate stats onto cost graph.
808     CostGraphDef* cost_graph = run_metadata->mutable_cost_graph();
809     for (const auto& item : executors_and_keys->items) {
810       TF_RETURN_IF_ERROR(
811           cost_model_manager_.AddToCostGraphDef(item.graph.get(), cost_graph));
812     }
813   }
814 
815   // If requested via RunOptions, output the partition graphs.
816   if (run_options.output_partition_graphs()) {
817     if (options_.config.experimental().disable_output_partition_graphs()) {
818       return errors::InvalidArgument(
819           "RunOptions.output_partition_graphs() is not supported when "
820           "disable_output_partition_graphs is true.");
821     } else {
822       protobuf::RepeatedPtrField<GraphDef>* partition_graph_defs =
823           run_metadata->mutable_partition_graphs();
824       for (const PerPartitionExecutorsAndLib& exec_and_lib :
825            executors_and_keys->items) {
826         GraphDef* partition_graph_def = partition_graph_defs->Add();
827         exec_and_lib.graph->ToGraphDef(partition_graph_def);
828       }
829     }
830   }
831   metrics::UpdateGraphExecTime(options_.env->NowMicros() - start_time_usecs);
832 
833   return OkStatus();
834 }
835 
Run(const RunOptions & run_options,const NamedTensorList & inputs,const std::vector<string> & output_names,const std::vector<string> & target_nodes,std::vector<Tensor> * outputs,RunMetadata * run_metadata)836 Status DirectSession::Run(const RunOptions& run_options,
837                           const NamedTensorList& inputs,
838                           const std::vector<string>& output_names,
839                           const std::vector<string>& target_nodes,
840                           std::vector<Tensor>* outputs,
841                           RunMetadata* run_metadata) {
842   return Run(run_options, inputs, output_names, target_nodes, outputs,
843              run_metadata, thread::ThreadPoolOptions());
844 }
845 
Run(const RunOptions & run_options,const NamedTensorList & inputs,const std::vector<string> & output_names,const std::vector<string> & target_nodes,std::vector<Tensor> * outputs,RunMetadata * run_metadata,const thread::ThreadPoolOptions & threadpool_options)846 Status DirectSession::Run(const RunOptions& run_options,
847                           const NamedTensorList& inputs,
848                           const std::vector<string>& output_names,
849                           const std::vector<string>& target_nodes,
850                           std::vector<Tensor>* outputs,
851                           RunMetadata* run_metadata,
852                           const thread::ThreadPoolOptions& threadpool_options) {
853   TF_RETURN_IF_ERROR(CheckNotClosed());
854   TF_RETURN_IF_ERROR(CheckGraphCreated("Run()"));
855   direct_session_runs->GetCell()->IncrementBy(1);
856 
857   // Extract the inputs names for this run of the session.
858   std::vector<string> input_tensor_names;
859   input_tensor_names.reserve(inputs.size());
860   size_t input_size = 0;
861   for (const auto& it : inputs) {
862     input_tensor_names.push_back(it.first);
863     input_size += it.second.AllocatedBytes();
864   }
865   metrics::RecordGraphInputTensors(input_size);
866 
867   // Check if we already have an executor for these arguments.
868   ExecutorsAndKeys* executors_and_keys;
869   RunStateArgs run_state_args(run_options.debug_options());
870   run_state_args.collective_graph_key =
871       run_options.experimental().collective_graph_key();
872 
873   TF_RETURN_IF_ERROR(GetOrCreateExecutors(input_tensor_names, output_names,
874                                           target_nodes, &executors_and_keys,
875                                           &run_state_args));
876   {
877     mutex_lock l(collective_graph_key_lock_);
878     collective_graph_key_ = executors_and_keys->collective_graph_key;
879   }
880 
881   // Configure a call frame for the step, which we use to feed and
882   // fetch values to and from the executors.
883   FunctionCallFrame call_frame(executors_and_keys->input_types,
884                                executors_and_keys->output_types);
885   gtl::InlinedVector<Tensor, 4> feed_args(inputs.size());
886   for (const auto& it : inputs) {
887     if (it.second.dtype() == DT_RESOURCE) {
888       Tensor tensor_from_handle;
889       TF_RETURN_IF_ERROR(
890           ResourceHandleToInputTensor(it.second, &tensor_from_handle));
891       feed_args[executors_and_keys->input_name_to_index[it.first]] =
892           tensor_from_handle;
893     } else {
894       feed_args[executors_and_keys->input_name_to_index[it.first]] = it.second;
895     }
896   }
897   const Status s = call_frame.SetArgs(feed_args);
898   if (errors::IsInternal(s)) {
899     return errors::InvalidArgument(s.error_message());
900   } else if (!s.ok()) {
901     return s;
902   }
903 
904   const int64_t step_id = step_id_counter_.fetch_add(1);
905 
906   if (LogMemory::IsEnabled()) {
907     LogMemory::RecordStep(step_id, run_state_args.handle);
908   }
909 
910   TF_RETURN_IF_ERROR(RunInternal(step_id, run_options, &call_frame,
911                                  executors_and_keys, run_metadata,
912                                  threadpool_options));
913 
914   // Receive outputs.
915   if (outputs) {
916     std::vector<Tensor> sorted_outputs;
917     const Status s = call_frame.ConsumeRetvals(
918         &sorted_outputs, /* allow_dead_tensors = */ false);
919     if (errors::IsInternal(s)) {
920       return errors::InvalidArgument(s.error_message());
921     } else if (!s.ok()) {
922       return s;
923     }
924     const bool unique_outputs =
925         output_names.size() == executors_and_keys->output_name_to_index.size();
926     // first_indices[i] = j implies that j is the smallest value for which
927     // output_names[i] == output_names[j].
928     std::vector<int> first_indices;
929     if (!unique_outputs) {
930       first_indices.reserve(output_names.size());
931       for (const auto& name : output_names) {
932         first_indices.push_back(
933             std::find(output_names.begin(), output_names.end(), name) -
934             output_names.begin());
935       }
936     }
937     outputs->clear();
938     size_t output_size = 0;
939     outputs->reserve(sorted_outputs.size());
940     for (int i = 0; i < output_names.size(); ++i) {
941       const string& output_name = output_names[i];
942       if (first_indices.empty() || first_indices[i] == i) {
943         outputs->emplace_back(
944             std::move(sorted_outputs[executors_and_keys
945                                          ->output_name_to_index[output_name]]));
946       } else {
947         outputs->push_back((*outputs)[first_indices[i]]);
948       }
949       output_size += outputs->back().AllocatedBytes();
950     }
951     metrics::RecordGraphOutputTensors(output_size);
952   }
953 
954   return OkStatus();
955 }
956 
PRunSetup(const std::vector<string> & input_names,const std::vector<string> & output_names,const std::vector<string> & target_nodes,string * handle)957 Status DirectSession::PRunSetup(const std::vector<string>& input_names,
958                                 const std::vector<string>& output_names,
959                                 const std::vector<string>& target_nodes,
960                                 string* handle) {
961   TF_RETURN_IF_ERROR(CheckNotClosed());
962   TF_RETURN_IF_ERROR(CheckGraphCreated("PRunSetup()"));
963 
964   // RunOptions is not available in PRunSetup, so use thread pool 0.
965   thread::ThreadPool* pool = thread_pools_[0].first;
966 
967   // Check if we already have an executor for these arguments.
968   ExecutorsAndKeys* executors_and_keys;
969   // TODO(cais): TFDBG support for partial runs.
970   DebugOptions debug_options;
971   RunStateArgs run_state_args(debug_options);
972   run_state_args.is_partial_run = true;
973   TF_RETURN_IF_ERROR(GetOrCreateExecutors(input_names, output_names,
974                                           target_nodes, &executors_and_keys,
975                                           &run_state_args));
976 
977   // Create the run state and save it for future PRun calls.
978   Executor::Args args;
979   args.step_id = step_id_counter_.fetch_add(1);
980   PartialRunState* run_state =
981       new PartialRunState(input_names, output_names, args.step_id, &devices_);
982   run_state->rendez.reset(new IntraProcessRendezvous(device_mgr_.get()));
983   {
984     mutex_lock l(executor_lock_);
985     if (!partial_runs_
986              .emplace(run_state_args.handle,
987                       std::unique_ptr<PartialRunState>(run_state))
988              .second) {
989       return errors::Internal("The handle '", run_state_args.handle,
990                               "' created for this partial run is not unique.");
991     }
992   }
993 
994   // Start parallel Executors.
995   const size_t num_executors = executors_and_keys->items.size();
996   ExecutorBarrier* barrier = new ExecutorBarrier(
997       num_executors, run_state->rendez.get(), [run_state](const Status& ret) {
998         if (!ret.ok()) {
999           mutex_lock l(run_state->mu);
1000           run_state->status.Update(ret);
1001         }
1002         run_state->executors_done.Notify();
1003       });
1004 
1005   args.rendezvous = run_state->rendez.get();
1006   args.cancellation_manager = cancellation_manager_;
1007   // Note that Collectives are not supported in partial runs
1008   // because RunOptions is not passed in so we can't know whether
1009   // their use is intended.
1010   args.collective_executor = nullptr;
1011   args.runner = [this, pool](Executor::Args::Closure c) {
1012     pool->Schedule(std::move(c));
1013   };
1014   args.session_state = &session_state_;
1015   args.session_handle = session_handle_;
1016   args.tensor_store = &run_state->tensor_store;
1017   args.step_container = &run_state->step_container;
1018   if (LogMemory::IsEnabled()) {
1019     LogMemory::RecordStep(args.step_id, run_state_args.handle);
1020   }
1021   args.sync_on_finish = sync_on_finish_;
1022 
1023   if (options_.config.graph_options().build_cost_model()) {
1024     run_state->collector.reset(new StepStatsCollector(nullptr));
1025     args.stats_collector = run_state->collector.get();
1026   }
1027 
1028   for (auto& item : executors_and_keys->items) {
1029     item.executor->RunAsync(args, barrier->Get());
1030   }
1031 
1032   *handle = run_state_args.handle;
1033   return OkStatus();
1034 }
1035 
PRun(const string & handle,const NamedTensorList & inputs,const std::vector<string> & output_names,std::vector<Tensor> * outputs)1036 Status DirectSession::PRun(const string& handle, const NamedTensorList& inputs,
1037                            const std::vector<string>& output_names,
1038                            std::vector<Tensor>* outputs) {
1039   TF_RETURN_IF_ERROR(CheckNotClosed());
1040   std::vector<string> parts = str_util::Split(handle, ';');
1041   const string& key = parts[0];
1042   // Get the executors for this partial run.
1043   ExecutorsAndKeys* executors_and_keys;
1044   PartialRunState* run_state;
1045   {
1046     mutex_lock l(executor_lock_);  // could use reader lock
1047     auto exc_it = executors_.find(key);
1048     if (exc_it == executors_.end()) {
1049       return errors::InvalidArgument(
1050           "Must run 'setup' before performing partial runs!");
1051     }
1052     executors_and_keys = exc_it->second.get();
1053 
1054     auto prun_it = partial_runs_.find(handle);
1055     if (prun_it == partial_runs_.end()) {
1056       return errors::InvalidArgument(
1057           "Must run 'setup' before performing partial runs!");
1058     }
1059     run_state = prun_it->second.get();
1060 
1061     // Make sure that this is a new set of feeds that are still pending.
1062     for (const auto& input : inputs) {
1063       auto it = run_state->pending_inputs.find(input.first);
1064       if (it == run_state->pending_inputs.end()) {
1065         return errors::InvalidArgument(
1066             "The feed ", input.first,
1067             " was not specified in partial_run_setup.");
1068       } else if (it->second) {
1069         return errors::InvalidArgument("The feed ", input.first,
1070                                        " has already been fed.");
1071       }
1072     }
1073     // Check that this is a new set of fetches that are still pending.
1074     for (const auto& output : output_names) {
1075       auto it = run_state->pending_outputs.find(output);
1076       if (it == run_state->pending_outputs.end()) {
1077         return errors::InvalidArgument(
1078             "The fetch ", output, " was not specified in partial_run_setup.");
1079       } else if (it->second) {
1080         return errors::InvalidArgument("The fetch ", output,
1081                                        " has already been fetched.");
1082       }
1083     }
1084   }
1085 
1086   // Check that this new set of fetches can be computed from all the
1087   // feeds we have supplied.
1088   TF_RETURN_IF_ERROR(
1089       CheckFetch(inputs, output_names, executors_and_keys, run_state));
1090 
1091   // Send inputs.
1092   Status s =
1093       SendPRunInputs(inputs, executors_and_keys, run_state->rendez.get());
1094 
1095   // Receive outputs.
1096   if (s.ok()) {
1097     s = RecvPRunOutputs(output_names, executors_and_keys, run_state, outputs);
1098   }
1099 
1100   // Save the output tensors of this run we choose to keep.
1101   if (s.ok()) {
1102     s = run_state->tensor_store.SaveTensors(output_names, &session_state_);
1103   }
1104 
1105   {
1106     mutex_lock l(executor_lock_);
1107     // Delete the run state if there is an error or all fetches are done.
1108     bool done = true;
1109     if (s.ok()) {
1110       {
1111         mutex_lock l(run_state->mu);
1112         if (!run_state->status.ok()) {
1113           LOG(WARNING) << "An error unrelated to this prun has been detected. "
1114                        << run_state->status;
1115         }
1116       }
1117       for (const auto& input : inputs) {
1118         auto it = run_state->pending_inputs.find(input.first);
1119         it->second = true;
1120       }
1121       for (const auto& name : output_names) {
1122         auto it = run_state->pending_outputs.find(name);
1123         it->second = true;
1124       }
1125       done = run_state->PendingDone();
1126     }
1127     if (done) {
1128       WaitForNotification(&run_state->executors_done, run_state,
1129                           cancellation_manager_, operation_timeout_in_ms_);
1130       partial_runs_.erase(handle);
1131     }
1132   }
1133 
1134   return s;
1135 }
1136 
ResourceHandleToInputTensor(const Tensor & resource_tensor,Tensor * retrieved_tensor)1137 Status DirectSession::ResourceHandleToInputTensor(const Tensor& resource_tensor,
1138                                                   Tensor* retrieved_tensor) {
1139   if (resource_tensor.dtype() != DT_RESOURCE) {
1140     return errors::InvalidArgument(strings::StrCat(
1141         "ResourceHandleToInputTensor() received non-DT_RESOURCE Tensor: ",
1142         resource_tensor.dtype()));
1143   }
1144 
1145   const ResourceHandle& resource_handle =
1146       resource_tensor.scalar<ResourceHandle>()();
1147 
1148   if (resource_handle.container() ==
1149       SessionState::kTensorHandleResourceTypeName) {
1150     return session_state_.GetTensor(resource_handle.name(), retrieved_tensor);
1151   } else {
1152     return errors::InvalidArgument(strings::StrCat(
1153         "Invalid resource type hash code: ", resource_handle.hash_code(),
1154         "(name: ", resource_handle.name(),
1155         " type: ", resource_handle.maybe_type_name(),
1156         "). Perhaps a resource tensor was being provided as a feed? That is "
1157         "not currently allowed. Please file an issue at "
1158         "https://github.com/tensorflow/tensorflow/issues/new, ideally with a "
1159         "short code snippet that leads to this error message."));
1160   }
1161 }
1162 
SendPRunInputs(const NamedTensorList & inputs,const ExecutorsAndKeys * executors_and_keys,IntraProcessRendezvous * rendez)1163 Status DirectSession::SendPRunInputs(const NamedTensorList& inputs,
1164                                      const ExecutorsAndKeys* executors_and_keys,
1165                                      IntraProcessRendezvous* rendez) {
1166   Status s;
1167   Rendezvous::ParsedKey parsed;
1168   // Insert the input tensors into the local rendezvous by their
1169   // rendezvous key.
1170   for (const auto& input : inputs) {
1171     auto it =
1172         executors_and_keys->input_name_to_rendezvous_key.find(input.first);
1173     if (it == executors_and_keys->input_name_to_rendezvous_key.end()) {
1174       return errors::Internal("'", input.first, "' is not a pre-defined feed.");
1175     }
1176     const string& input_key = it->second;
1177 
1178     s = Rendezvous::ParseKey(input_key, &parsed);
1179     if (!s.ok()) {
1180       rendez->StartAbort(s);
1181       return s;
1182     }
1183 
1184     if (input.second.dtype() == DT_RESOURCE) {
1185       Tensor tensor_from_handle;
1186       s = ResourceHandleToInputTensor(input.second, &tensor_from_handle);
1187       if (s.ok()) {
1188         s = rendez->Send(parsed, Rendezvous::Args(), tensor_from_handle, false);
1189       }
1190     } else {
1191       s = rendez->Send(parsed, Rendezvous::Args(), input.second, false);
1192     }
1193 
1194     if (!s.ok()) {
1195       rendez->StartAbort(s);
1196       return s;
1197     }
1198   }
1199   return OkStatus();
1200 }
1201 
RecvPRunOutputs(const std::vector<string> & output_names,const ExecutorsAndKeys * executors_and_keys,PartialRunState * run_state,std::vector<Tensor> * outputs)1202 Status DirectSession::RecvPRunOutputs(
1203     const std::vector<string>& output_names,
1204     const ExecutorsAndKeys* executors_and_keys, PartialRunState* run_state,
1205     std::vector<Tensor>* outputs) {
1206   Status s;
1207   if (!output_names.empty()) {
1208     outputs->resize(output_names.size());
1209   }
1210 
1211   Rendezvous::ParsedKey parsed;
1212   // Get the outputs from the rendezvous
1213   for (size_t output_offset = 0; output_offset < output_names.size();
1214        ++output_offset) {
1215     const string& output_name = output_names[output_offset];
1216     auto it =
1217         executors_and_keys->output_name_to_rendezvous_key.find(output_name);
1218     if (it == executors_and_keys->output_name_to_rendezvous_key.end()) {
1219       return errors::Internal("'", output_name,
1220                               "' is not a pre-defined fetch.");
1221     }
1222     const string& output_key = it->second;
1223     Tensor output_tensor;
1224     bool is_dead;
1225 
1226     s = Rendezvous::ParseKey(output_key, &parsed);
1227     if (s.ok()) {
1228       // Fetch data from the Rendezvous.
1229       s = run_state->rendez->Recv(parsed, Rendezvous::Args(), &output_tensor,
1230                                   &is_dead, operation_timeout_in_ms_);
1231       if (is_dead && s.ok()) {
1232         s = errors::InvalidArgument("The tensor returned for ", output_name,
1233                                     " was not valid.");
1234       }
1235     }
1236     if (!s.ok()) {
1237       run_state->rendez->StartAbort(s);
1238       outputs->clear();
1239       return s;
1240     }
1241 
1242     (*outputs)[output_offset] = output_tensor;
1243   }
1244   return OkStatus();
1245 }
1246 
CheckFetch(const NamedTensorList & feeds,const std::vector<string> & fetches,const ExecutorsAndKeys * executors_and_keys,const PartialRunState * run_state)1247 Status DirectSession::CheckFetch(const NamedTensorList& feeds,
1248                                  const std::vector<string>& fetches,
1249                                  const ExecutorsAndKeys* executors_and_keys,
1250                                  const PartialRunState* run_state) {
1251   const Graph* graph = executors_and_keys->graph.get();
1252   const NameNodeMap* name_to_node = &executors_and_keys->name_to_node;
1253 
1254   // Build the set of pending feeds that we haven't seen.
1255   std::unordered_set<TensorId, TensorId::Hasher> pending_feeds;
1256   {
1257     mutex_lock l(executor_lock_);
1258     for (const auto& input : run_state->pending_inputs) {
1259       // Skip if the feed has already been fed.
1260       if (input.second) continue;
1261       TensorId id(ParseTensorName(input.first));
1262       auto it = name_to_node->find(id.first);
1263       if (it == name_to_node->end()) {
1264         return errors::NotFound("Feed ", input.first, ": not found");
1265       }
1266       pending_feeds.insert(id);
1267     }
1268   }
1269   for (const auto& it : feeds) {
1270     TensorId id(ParseTensorName(it.first));
1271     pending_feeds.erase(id);
1272   }
1273 
1274   // Initialize the stack with the fetch nodes.
1275   std::vector<const Node*> stack;
1276   for (const string& fetch : fetches) {
1277     TensorId id(ParseTensorName(fetch));
1278     auto it = name_to_node->find(id.first);
1279     if (it == name_to_node->end()) {
1280       return errors::NotFound("Fetch ", fetch, ": not found");
1281     }
1282     stack.push_back(it->second);
1283   }
1284 
1285   // Any tensor needed for fetches can't be in pending_feeds.
1286   std::vector<bool> visited(graph->num_node_ids(), false);
1287   while (!stack.empty()) {
1288     const Node* n = stack.back();
1289     stack.pop_back();
1290 
1291     for (const Edge* in_edge : n->in_edges()) {
1292       const Node* in_node = in_edge->src();
1293       if (pending_feeds.count({in_node->name(), in_edge->src_output()}) > 0) {
1294         return errors::InvalidArgument("Fetch ", in_node->name(), ":",
1295                                        in_edge->src_output(),
1296                                        " can't be computed from the feeds"
1297                                        " that have been fed so far.");
1298       }
1299       if (!visited[in_node->id()]) {
1300         visited[in_node->id()] = true;
1301         stack.push_back(in_node);
1302       }
1303     }
1304   }
1305   return OkStatus();
1306 }
1307 
CreateExecutors(const CallableOptions & callable_options,std::unique_ptr<ExecutorsAndKeys> * out_executors_and_keys,std::unique_ptr<FunctionInfo> * out_func_info,RunStateArgs * run_state_args)1308 Status DirectSession::CreateExecutors(
1309     const CallableOptions& callable_options,
1310     std::unique_ptr<ExecutorsAndKeys>* out_executors_and_keys,
1311     std::unique_ptr<FunctionInfo>* out_func_info,
1312     RunStateArgs* run_state_args) {
1313   BuildGraphOptions options;
1314   options.callable_options = callable_options;
1315   options.use_function_convention = !run_state_args->is_partial_run;
1316   options.collective_graph_key =
1317       callable_options.run_options().experimental().collective_graph_key();
1318   if (options_.config.experimental()
1319           .collective_deterministic_sequential_execution()) {
1320     options.collective_order = GraphCollectiveOrder::kEdges;
1321   } else if (options_.config.experimental().collective_nccl()) {
1322     options.collective_order = GraphCollectiveOrder::kAttrs;
1323   }
1324 
1325   std::unique_ptr<FunctionInfo> func_info(new FunctionInfo);
1326   std::unique_ptr<ExecutorsAndKeys> ek(new ExecutorsAndKeys);
1327 
1328   ek->callable_options = callable_options;
1329 
1330   std::unordered_map<string, std::unique_ptr<Graph>> graphs;
1331   TF_RETURN_IF_ERROR(CreateGraphs(
1332       options, &graphs, &func_info->flib_def, run_state_args, &ek->input_types,
1333       &ek->output_types, &ek->collective_graph_key));
1334 
1335   if (run_state_args->is_partial_run) {
1336     ek->graph = std::move(run_state_args->graph);
1337     std::unordered_set<StringPiece, StringPieceHasher> names;
1338     for (const string& input : callable_options.feed()) {
1339       TensorId id(ParseTensorName(input));
1340       names.emplace(id.first);
1341     }
1342     for (const string& output : callable_options.fetch()) {
1343       TensorId id(ParseTensorName(output));
1344       names.emplace(id.first);
1345     }
1346     for (Node* n : ek->graph->nodes()) {
1347       if (names.count(n->name()) > 0) {
1348         ek->name_to_node.insert({n->name(), n});
1349       }
1350     }
1351   }
1352   ek->items.reserve(graphs.size());
1353   const auto& optimizer_opts =
1354       options_.config.graph_options().optimizer_options();
1355 
1356   int graph_def_version = graphs.begin()->second->versions().producer();
1357 
1358   const auto* session_metadata =
1359       options_.config.experimental().has_session_metadata()
1360           ? &options_.config.experimental().session_metadata()
1361           : nullptr;
1362   func_info->proc_flr.reset(new ProcessFunctionLibraryRuntime(
1363       device_mgr_.get(), options_.env, &options_.config, graph_def_version,
1364       func_info->flib_def.get(), optimizer_opts, thread_pools_[0].first,
1365       /*parent=*/nullptr, session_metadata,
1366       Rendezvous::Factory{
1367           [](const int64_t, const DeviceMgr* device_mgr, Rendezvous** r) {
1368             *r = new IntraProcessRendezvous(device_mgr);
1369             return OkStatus();
1370           }}));
1371 
1372   GraphOptimizer optimizer(optimizer_opts);
1373   for (auto iter = graphs.begin(); iter != graphs.end(); ++iter) {
1374     const string& partition_name = iter->first;
1375     std::unique_ptr<Graph>& partition_graph = iter->second;
1376 
1377     Device* device;
1378     TF_RETURN_IF_ERROR(device_mgr_->LookupDevice(partition_name, &device));
1379 
1380     ek->items.resize(ek->items.size() + 1);
1381     auto* item = &(ek->items.back());
1382     auto lib = func_info->proc_flr->GetFLR(partition_name);
1383     if (lib == nullptr) {
1384       return errors::Internal("Could not find device: ", partition_name);
1385     }
1386     item->flib = lib;
1387 
1388     LocalExecutorParams params;
1389     params.device = device;
1390     params.session_metadata = session_metadata;
1391     params.function_library = lib;
1392     auto opseg = device->op_segment();
1393     params.create_kernel =
1394         [this, lib, opseg](const std::shared_ptr<const NodeProperties>& props,
1395                            OpKernel** kernel) {
1396           // NOTE(mrry): We must not share function kernels (implemented
1397           // using `CallOp`) between subgraphs, because `CallOp::handle_`
1398           // is tied to a particular subgraph. Even if the function itself
1399           // is stateful, the `CallOp` that invokes it is not.
1400           if (!OpSegment::ShouldOwnKernel(lib, props->node_def.op())) {
1401             return lib->CreateKernel(props, kernel);
1402           }
1403           auto create_fn = [lib, &props](OpKernel** kernel) {
1404             return lib->CreateKernel(props, kernel);
1405           };
1406           // Kernels created for subgraph nodes need to be cached.  On
1407           // cache miss, create_fn() is invoked to create a kernel based
1408           // on the function library here + global op registry.
1409           return opseg->FindOrCreate(session_handle_, props->node_def.name(),
1410                                      kernel, create_fn);
1411         };
1412     params.delete_kernel = [lib](OpKernel* kernel) {
1413       if (kernel && !OpSegment::ShouldOwnKernel(lib, kernel->type_string()))
1414         delete kernel;
1415     };
1416 
1417     optimizer.Optimize(lib, options_.env, device, &partition_graph,
1418                        GraphOptimizer::Options());
1419 
1420     // TensorFlow Debugger (tfdbg) inserts debug nodes in the graph.
1421     const DebugOptions& debug_options =
1422         options.callable_options.run_options().debug_options();
1423     if (!debug_options.debug_tensor_watch_opts().empty()) {
1424       TF_RETURN_IF_ERROR(DecorateAndPublishGraphForDebug(
1425           debug_options, partition_graph.get(), params.device));
1426     }
1427 
1428     TF_RETURN_IF_ERROR(EnsureMemoryTypes(DeviceType(device->device_type()),
1429                                          device->name(),
1430                                          partition_graph.get()));
1431 
1432     item->executor = nullptr;
1433     item->device = device;
1434     auto executor_type = options_.config.experimental().executor_type();
1435     TF_RETURN_IF_ERROR(
1436         NewExecutor(executor_type, params, *partition_graph, &item->executor));
1437     if (!options_.config.experimental().disable_output_partition_graphs() ||
1438         options_.config.graph_options().build_cost_model() > 0) {
1439       item->graph = std::move(partition_graph);
1440     }
1441   }
1442 
1443   // Cache the mapping from input/output names to graph elements to
1444   // avoid recomputing it every time.
1445   if (!run_state_args->is_partial_run) {
1446     // For regular `Run()`, we use the function calling convention, and so
1447     // maintain a mapping from input/output names to
1448     // argument/return-value ordinal index.
1449     for (int i = 0; i < callable_options.feed().size(); ++i) {
1450       const string& input = callable_options.feed(i);
1451       ek->input_name_to_index[input] = i;
1452     }
1453     for (int i = 0; i < callable_options.fetch().size(); ++i) {
1454       const string& output = callable_options.fetch(i);
1455       ek->output_name_to_index[output] = i;
1456     }
1457   } else {
1458     // For `PRun()`, we use the rendezvous calling convention, and so
1459     // maintain a mapping from input/output names to rendezvous keys.
1460     //
1461     // We always use the first device as the device name portion of the
1462     // key, even if we're feeding another graph.
1463     for (int i = 0; i < callable_options.feed().size(); ++i) {
1464       const string& input = callable_options.feed(i);
1465       ek->input_name_to_rendezvous_key[input] = GetRendezvousKey(
1466           input, device_set_.client_device()->attributes(), FrameAndIter(0, 0));
1467     }
1468     for (int i = 0; i < callable_options.fetch().size(); ++i) {
1469       const string& output = callable_options.fetch(i);
1470       ek->output_name_to_rendezvous_key[output] =
1471           GetRendezvousKey(output, device_set_.client_device()->attributes(),
1472                            FrameAndIter(0, 0));
1473     }
1474   }
1475 
1476   *out_executors_and_keys = std::move(ek);
1477   *out_func_info = std::move(func_info);
1478   return OkStatus();
1479 }
1480 
GetOrCreateExecutors(gtl::ArraySlice<string> inputs,gtl::ArraySlice<string> outputs,gtl::ArraySlice<string> target_nodes,ExecutorsAndKeys ** executors_and_keys,RunStateArgs * run_state_args)1481 Status DirectSession::GetOrCreateExecutors(
1482     gtl::ArraySlice<string> inputs, gtl::ArraySlice<string> outputs,
1483     gtl::ArraySlice<string> target_nodes, ExecutorsAndKeys** executors_and_keys,
1484     RunStateArgs* run_state_args) {
1485   int64_t handle_name_counter_value = -1;
1486   if (LogMemory::IsEnabled() || run_state_args->is_partial_run) {
1487     handle_name_counter_value = handle_name_counter_.fetch_add(1);
1488   }
1489 
1490   string debug_tensor_watches_summary;
1491   if (!run_state_args->debug_options.debug_tensor_watch_opts().empty()) {
1492     debug_tensor_watches_summary = SummarizeDebugTensorWatches(
1493         run_state_args->debug_options.debug_tensor_watch_opts());
1494   }
1495 
1496   // Fast lookup path, no sorting.
1497   const string key = strings::StrCat(
1498       absl::StrJoin(inputs, ","), "->", absl::StrJoin(outputs, ","), "/",
1499       absl::StrJoin(target_nodes, ","), "/", run_state_args->is_partial_run,
1500       "/", debug_tensor_watches_summary);
1501   // Set the handle, if it's needed to log memory or for partial run.
1502   if (handle_name_counter_value >= 0) {
1503     run_state_args->handle =
1504         strings::StrCat(key, ";", handle_name_counter_value);
1505   }
1506 
1507   // See if we already have the executors for this run.
1508   {
1509     mutex_lock l(executor_lock_);  // could use reader lock
1510     auto it = executors_.find(key);
1511     if (it != executors_.end()) {
1512       *executors_and_keys = it->second.get();
1513       return OkStatus();
1514     }
1515   }
1516 
1517   // Slow lookup path, the unsorted key missed the cache.
1518   // Sort the inputs and outputs, and look up with the sorted key in case an
1519   // earlier call used a different order of inputs and outputs.
1520   //
1521   // We could consider some other signature instead of sorting that
1522   // preserves the same property to avoid the sort in the future.
1523   std::vector<string> inputs_sorted(inputs.begin(), inputs.end());
1524   std::sort(inputs_sorted.begin(), inputs_sorted.end());
1525   std::vector<string> outputs_sorted(outputs.begin(), outputs.end());
1526   std::sort(outputs_sorted.begin(), outputs_sorted.end());
1527   std::vector<string> tn_sorted(target_nodes.begin(), target_nodes.end());
1528   std::sort(tn_sorted.begin(), tn_sorted.end());
1529 
1530   const string sorted_key = strings::StrCat(
1531       absl::StrJoin(inputs_sorted, ","), "->",
1532       absl::StrJoin(outputs_sorted, ","), "/", absl::StrJoin(tn_sorted, ","),
1533       "/", run_state_args->is_partial_run, "/", debug_tensor_watches_summary);
1534   // Set the handle, if its needed to log memory or for partial run.
1535   if (handle_name_counter_value >= 0) {
1536     run_state_args->handle =
1537         strings::StrCat(sorted_key, ";", handle_name_counter_value);
1538   }
1539 
1540   // See if we already have the executors for this run.
1541   {
1542     mutex_lock l(executor_lock_);
1543     auto it = executors_.find(sorted_key);
1544     if (it != executors_.end()) {
1545       *executors_and_keys = it->second.get();
1546       return OkStatus();
1547     }
1548   }
1549 
1550   // Nothing found, so create the executors and store in the cache.
1551   // The executor_lock_ is intentionally released while executors are
1552   // being created.
1553   CallableOptions callable_options;
1554   callable_options.mutable_feed()->Reserve(inputs_sorted.size());
1555   for (const string& input : inputs_sorted) {
1556     callable_options.add_feed(input);
1557   }
1558   callable_options.mutable_fetch()->Reserve(outputs_sorted.size());
1559   for (const string& output : outputs_sorted) {
1560     callable_options.add_fetch(output);
1561   }
1562   callable_options.mutable_target()->Reserve(tn_sorted.size());
1563   for (const string& target : tn_sorted) {
1564     callable_options.add_target(target);
1565   }
1566   *callable_options.mutable_run_options()->mutable_debug_options() =
1567       run_state_args->debug_options;
1568   callable_options.mutable_run_options()
1569       ->mutable_experimental()
1570       ->set_collective_graph_key(run_state_args->collective_graph_key);
1571   std::unique_ptr<ExecutorsAndKeys> ek;
1572   std::unique_ptr<FunctionInfo> func_info;
1573   TF_RETURN_IF_ERROR(
1574       CreateExecutors(callable_options, &ek, &func_info, run_state_args));
1575 
1576   // Reacquire the lock, try to insert into the map.
1577   mutex_lock l(executor_lock_);
1578 
1579   // Another thread may have created the entry before us, in which case we will
1580   // reuse the already created one.
1581   auto insert_result = executors_.emplace(
1582       sorted_key, std::shared_ptr<ExecutorsAndKeys>(std::move(ek)));
1583   if (insert_result.second) {
1584     functions_.push_back(std::move(func_info));
1585   }
1586 
1587   // Insert the value under the original key, so the fast path lookup will work
1588   // if the user uses the same order of inputs, outputs, and targets again.
1589   executors_.emplace(key, insert_result.first->second);
1590   *executors_and_keys = insert_result.first->second.get();
1591 
1592   return OkStatus();
1593 }
1594 
CreateGraphs(const BuildGraphOptions & subgraph_options,std::unordered_map<string,std::unique_ptr<Graph>> * outputs,std::unique_ptr<FunctionLibraryDefinition> * flib_def,RunStateArgs * run_state_args,DataTypeVector * input_types,DataTypeVector * output_types,int64_t * collective_graph_key)1595 Status DirectSession::CreateGraphs(
1596     const BuildGraphOptions& subgraph_options,
1597     std::unordered_map<string, std::unique_ptr<Graph>>* outputs,
1598     std::unique_ptr<FunctionLibraryDefinition>* flib_def,
1599     RunStateArgs* run_state_args, DataTypeVector* input_types,
1600     DataTypeVector* output_types, int64_t* collective_graph_key) {
1601   mutex_lock l(graph_state_lock_);
1602   if (finalized_) {
1603     return errors::FailedPrecondition("Session has been finalized.");
1604   }
1605 
1606   std::unique_ptr<ClientGraph> client_graph;
1607 
1608   std::unique_ptr<GraphExecutionState> temp_exec_state_holder;
1609   GraphExecutionState* execution_state = nullptr;
1610   if (options_.config.graph_options().place_pruned_graph()) {
1611     // Because we are placing pruned graphs, we need to create a
1612     // new GraphExecutionState for every new unseen graph,
1613     // and then place it.
1614     GraphExecutionStateOptions prune_options;
1615     prune_options.device_set = &device_set_;
1616     prune_options.session_options = &options_;
1617     prune_options.stateful_placements = stateful_placements_;
1618     prune_options.session_handle = session_handle_;
1619     TF_RETURN_IF_ERROR(GraphExecutionState::MakeForPrunedGraph(
1620         *execution_state_, prune_options, subgraph_options,
1621         &temp_exec_state_holder, &client_graph));
1622     execution_state = temp_exec_state_holder.get();
1623   } else {
1624     execution_state = execution_state_.get();
1625     TF_RETURN_IF_ERROR(
1626         execution_state->BuildGraph(subgraph_options, &client_graph));
1627   }
1628   *collective_graph_key = client_graph->collective_graph_key;
1629 
1630   if (subgraph_options.callable_options.feed_size() !=
1631       client_graph->feed_types.size()) {
1632     return errors::Internal(
1633         "Graph pruning failed: requested number of feed endpoints = ",
1634         subgraph_options.callable_options.feed_size(),
1635         " versus number of pruned feed endpoints = ",
1636         client_graph->feed_types.size());
1637   }
1638   if (subgraph_options.callable_options.fetch_size() !=
1639       client_graph->fetch_types.size()) {
1640     return errors::Internal(
1641         "Graph pruning failed: requested number of fetch endpoints = ",
1642         subgraph_options.callable_options.fetch_size(),
1643         " versus number of pruned fetch endpoints = ",
1644         client_graph->fetch_types.size());
1645   }
1646 
1647   auto current_stateful_placements = execution_state->GetStatefulPlacements();
1648   // Update our current state based on the execution_state's
1649   // placements.  If there are any mismatches for a node,
1650   // we should fail, as this should never happen.
1651   for (const auto& placement_pair : current_stateful_placements) {
1652     const string& node_name = placement_pair.first;
1653     const string& placement = placement_pair.second;
1654     auto iter = stateful_placements_.find(node_name);
1655     if (iter == stateful_placements_.end()) {
1656       stateful_placements_.insert(std::make_pair(node_name, placement));
1657     } else if (iter->second != placement) {
1658       return errors::Internal(
1659           "Stateful placement mismatch. "
1660           "Current assignment of ",
1661           node_name, " to ", iter->second, " does not match ", placement);
1662     }
1663   }
1664 
1665   stateful_placements_ = execution_state->GetStatefulPlacements();
1666 
1667   // Remember the graph in run state if this is a partial run.
1668   if (run_state_args->is_partial_run) {
1669     run_state_args->graph.reset(new Graph(flib_def_.get()));
1670     CopyGraph(*execution_state->full_graph(), run_state_args->graph.get());
1671   }
1672 
1673   // Partition the graph across devices.
1674   PartitionOptions popts;
1675   popts.node_to_loc = [](const Node* node) {
1676     return node->assigned_device_name();
1677   };
1678   popts.new_name = [this](const string& prefix) {
1679     return strings::StrCat(prefix, "/_", edge_name_counter_.fetch_add(1));
1680   };
1681   popts.get_incarnation = [](const string& name) {
1682     // The direct session does not have changing incarnation numbers.
1683     // Just return '1'.
1684     return 1;
1685   };
1686   popts.flib_def = flib_def->get();
1687   popts.control_flow_added = false;
1688 
1689   std::unordered_map<string, GraphDef> partitions;
1690   TF_RETURN_IF_ERROR(Partition(popts, &client_graph->graph, &partitions));
1691 
1692   std::vector<string> device_names;
1693   for (auto device : devices_) {
1694     // Extract the LocalName from the device.
1695     device_names.push_back(DeviceNameUtils::LocalName(device->name()));
1696   }
1697 
1698   // Check for valid partitions.
1699   for (const auto& partition : partitions) {
1700     const string local_partition_name =
1701         DeviceNameUtils::LocalName(partition.first);
1702     if (std::count(device_names.begin(), device_names.end(),
1703                    local_partition_name) == 0) {
1704       return errors::InvalidArgument(
1705           "Creating a partition for ", local_partition_name,
1706           " which doesn't exist in the list of available devices. Available "
1707           "devices: ",
1708           absl::StrJoin(device_names, ","));
1709     }
1710   }
1711 
1712   for (auto& partition : partitions) {
1713     std::unique_ptr<Graph> device_graph(
1714         new Graph(client_graph->flib_def.get()));
1715     device_graph->SetConstructionContext(ConstructionContext::kDirectSession);
1716     GraphConstructorOptions device_opts;
1717     // There are internal operations (e.g., send/recv) that we now allow.
1718     device_opts.allow_internal_ops = true;
1719     device_opts.expect_device_spec = true;
1720     TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(
1721         device_opts, std::move(partition.second), device_graph.get()));
1722     outputs->emplace(partition.first, std::move(device_graph));
1723   }
1724 
1725   GraphOptimizationPassOptions optimization_options;
1726   optimization_options.session_options = &options_;
1727   optimization_options.flib_def = client_graph->flib_def.get();
1728   optimization_options.partition_graphs = outputs;
1729   TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
1730       OptimizationPassRegistry::POST_PARTITIONING, optimization_options));
1731 
1732   Status s;
1733   for (auto& partition : *outputs) {
1734     const string& partition_name = partition.first;
1735     std::unique_ptr<Graph>* graph = &partition.second;
1736 
1737     VLOG(2) << "Created " << DebugString(graph->get()) << " for "
1738             << partition_name;
1739 
1740     // Give the device an opportunity to rewrite its subgraph.
1741     Device* d;
1742     s = device_mgr_->LookupDevice(partition_name, &d);
1743     if (!s.ok()) break;
1744     s = d->MaybeRewriteGraph(graph);
1745     if (!s.ok()) {
1746       break;
1747     }
1748   }
1749   *flib_def = std::move(client_graph->flib_def);
1750   std::swap(*input_types, client_graph->feed_types);
1751   std::swap(*output_types, client_graph->fetch_types);
1752   return s;
1753 }
1754 
ListDevices(std::vector<DeviceAttributes> * response)1755 ::tensorflow::Status DirectSession::ListDevices(
1756     std::vector<DeviceAttributes>* response) {
1757   response->clear();
1758   response->reserve(devices_.size());
1759   for (Device* d : devices_) {
1760     const DeviceAttributes& attrs = d->attributes();
1761     response->emplace_back(attrs);
1762   }
1763   return OkStatus();
1764 }
1765 
Reset(const std::vector<string> & containers)1766 ::tensorflow::Status DirectSession::Reset(
1767     const std::vector<string>& containers) {
1768   device_mgr_->ClearContainers(containers);
1769   return OkStatus();
1770 }
1771 
Close()1772 ::tensorflow::Status DirectSession::Close() {
1773   cancellation_manager_->StartCancel();
1774   {
1775     mutex_lock l(closed_lock_);
1776     if (closed_) return OkStatus();
1777     closed_ = true;
1778   }
1779   if (factory_ != nullptr) factory_->Deregister(this);
1780   return OkStatus();
1781 }
1782 
RunState(int64_t step_id,const std::vector<Device * > * devices)1783 DirectSession::RunState::RunState(int64_t step_id,
1784                                   const std::vector<Device*>* devices)
1785     : step_container(step_id, [devices, step_id](const string& name) {
1786         for (auto d : *devices) {
1787           if (!d->resource_manager()->Cleanup(name).ok()) {
1788             // Do nothing...
1789           }
1790           ScopedAllocatorMgr* sam = d->GetScopedAllocatorMgr();
1791           if (sam) sam->Cleanup(step_id);
1792         }
1793       }) {}
1794 
PartialRunState(const std::vector<string> & pending_input_names,const std::vector<string> & pending_output_names,int64_t step_id,const std::vector<Device * > * devices)1795 DirectSession::PartialRunState::PartialRunState(
1796     const std::vector<string>& pending_input_names,
1797     const std::vector<string>& pending_output_names, int64_t step_id,
1798     const std::vector<Device*>* devices)
1799     : RunState(step_id, devices) {
1800   // Initially all the feeds and fetches are pending.
1801   for (auto& name : pending_input_names) {
1802     pending_inputs[name] = false;
1803   }
1804   for (auto& name : pending_output_names) {
1805     pending_outputs[name] = false;
1806   }
1807 }
1808 
~PartialRunState()1809 DirectSession::PartialRunState::~PartialRunState() {
1810   if (rendez != nullptr) {
1811     rendez->StartAbort(errors::Cancelled("PRun cancellation"));
1812     executors_done.WaitForNotification();
1813   }
1814 }
1815 
PendingDone() const1816 bool DirectSession::PartialRunState::PendingDone() const {
1817   for (const auto& it : pending_inputs) {
1818     if (!it.second) return false;
1819   }
1820   for (const auto& it : pending_outputs) {
1821     if (!it.second) return false;
1822   }
1823   return true;
1824 }
1825 
WaitForNotification(Notification * n,RunState * run_state,CancellationManager * cm,int64_t timeout_in_ms)1826 void DirectSession::WaitForNotification(Notification* n, RunState* run_state,
1827                                         CancellationManager* cm,
1828                                         int64_t timeout_in_ms) {
1829   const Status status = WaitForNotification(n, timeout_in_ms);
1830   if (!status.ok()) {
1831     {
1832       mutex_lock l(run_state->mu);
1833       run_state->status.Update(status);
1834     }
1835     cm->StartCancel();
1836     // We must wait for the executors to complete, because they have borrowed
1837     // references to `cm` and other per-step state. After this notification, it
1838     // is safe to clean up the step.
1839     n->WaitForNotification();
1840   }
1841 }
1842 
WaitForNotification(Notification * notification,int64_t timeout_in_ms)1843 ::tensorflow::Status DirectSession::WaitForNotification(
1844     Notification* notification, int64_t timeout_in_ms) {
1845   if (timeout_in_ms > 0) {
1846     const int64_t timeout_in_us = timeout_in_ms * 1000;
1847     const bool notified =
1848         WaitForNotificationWithTimeout(notification, timeout_in_us);
1849     if (!notified) {
1850       return Status(error::DEADLINE_EXCEEDED,
1851                     "Timed out waiting for notification");
1852     }
1853   } else {
1854     notification->WaitForNotification();
1855   }
1856   return OkStatus();
1857 }
1858 
MakeCallable(const CallableOptions & callable_options,CallableHandle * out_handle)1859 Status DirectSession::MakeCallable(const CallableOptions& callable_options,
1860                                    CallableHandle* out_handle) {
1861   TF_RETURN_IF_ERROR(CheckNotClosed());
1862   TF_RETURN_IF_ERROR(CheckGraphCreated("MakeCallable()"));
1863 
1864   std::unique_ptr<ExecutorsAndKeys> ek;
1865   std::unique_ptr<FunctionInfo> func_info;
1866   RunStateArgs run_state_args(callable_options.run_options().debug_options());
1867   TF_RETURN_IF_ERROR(
1868       CreateExecutors(callable_options, &ek, &func_info, &run_state_args));
1869   {
1870     mutex_lock l(callables_lock_);
1871     *out_handle = next_callable_handle_++;
1872     callables_[*out_handle] = {std::move(ek), std::move(func_info)};
1873   }
1874   return OkStatus();
1875 }
1876 
1877 class DirectSession::RunCallableCallFrame : public CallFrameInterface {
1878  public:
RunCallableCallFrame(DirectSession * session,ExecutorsAndKeys * executors_and_keys,const std::vector<Tensor> * feed_tensors,std::vector<Tensor> * fetch_tensors)1879   RunCallableCallFrame(DirectSession* session,
1880                        ExecutorsAndKeys* executors_and_keys,
1881                        const std::vector<Tensor>* feed_tensors,
1882                        std::vector<Tensor>* fetch_tensors)
1883       : session_(session),
1884         executors_and_keys_(executors_and_keys),
1885         feed_tensors_(feed_tensors),
1886         fetch_tensors_(fetch_tensors) {}
1887 
num_args() const1888   size_t num_args() const override {
1889     return executors_and_keys_->input_types.size();
1890   }
num_retvals() const1891   size_t num_retvals() const override {
1892     return executors_and_keys_->output_types.size();
1893   }
1894 
GetArg(int index,const Tensor ** val)1895   Status GetArg(int index, const Tensor** val) override {
1896     if (TF_PREDICT_FALSE(index > feed_tensors_->size())) {
1897       return errors::Internal("Args index out of bounds: ", index);
1898     } else {
1899       *val = &(*feed_tensors_)[index];
1900     }
1901     return OkStatus();
1902   }
1903 
SetRetval(int index,const Tensor & val)1904   Status SetRetval(int index, const Tensor& val) override {
1905     if (index > fetch_tensors_->size()) {
1906       return errors::Internal("RetVal index out of bounds: ", index);
1907     }
1908     (*fetch_tensors_)[index] = val;
1909     return OkStatus();
1910   }
1911 
1912  private:
1913   DirectSession* const session_;                   // Not owned.
1914   ExecutorsAndKeys* const executors_and_keys_;     // Not owned.
1915   const std::vector<Tensor>* const feed_tensors_;  // Not owned.
1916   std::vector<Tensor>* const fetch_tensors_;       // Not owned.
1917 };
1918 
RunCallable(CallableHandle handle,const std::vector<Tensor> & feed_tensors,std::vector<Tensor> * fetch_tensors,RunMetadata * run_metadata)1919 ::tensorflow::Status DirectSession::RunCallable(
1920     CallableHandle handle, const std::vector<Tensor>& feed_tensors,
1921     std::vector<Tensor>* fetch_tensors, RunMetadata* run_metadata) {
1922   return RunCallable(handle, feed_tensors, fetch_tensors, run_metadata,
1923                      thread::ThreadPoolOptions());
1924 }
1925 
RunCallable(CallableHandle handle,const std::vector<Tensor> & feed_tensors,std::vector<Tensor> * fetch_tensors,RunMetadata * run_metadata,const thread::ThreadPoolOptions & threadpool_options)1926 ::tensorflow::Status DirectSession::RunCallable(
1927     CallableHandle handle, const std::vector<Tensor>& feed_tensors,
1928     std::vector<Tensor>* fetch_tensors, RunMetadata* run_metadata,
1929     const thread::ThreadPoolOptions& threadpool_options) {
1930   TF_RETURN_IF_ERROR(CheckNotClosed());
1931   TF_RETURN_IF_ERROR(CheckGraphCreated("RunCallable()"));
1932   direct_session_runs->GetCell()->IncrementBy(1);
1933 
1934   // Check if we already have an executor for these arguments.
1935   std::shared_ptr<ExecutorsAndKeys> executors_and_keys;
1936   const int64_t step_id = step_id_counter_.fetch_add(1);
1937 
1938   {
1939     tf_shared_lock l(callables_lock_);
1940     if (handle >= next_callable_handle_) {
1941       return errors::InvalidArgument("No such callable handle: ", handle);
1942     }
1943     executors_and_keys = callables_[handle].executors_and_keys;
1944   }
1945 
1946   if (!executors_and_keys) {
1947     return errors::InvalidArgument(
1948         "Attempted to run callable after handle was released: ", handle);
1949   }
1950 
1951   // NOTE(mrry): Debug options are not currently supported in the
1952   // callable interface.
1953   DebugOptions debug_options;
1954   RunStateArgs run_state_args(debug_options);
1955 
1956   // Configure a call frame for the step, which we use to feed and
1957   // fetch values to and from the executors.
1958   if (feed_tensors.size() != executors_and_keys->input_types.size()) {
1959     return errors::InvalidArgument(
1960         "Expected ", executors_and_keys->input_types.size(),
1961         " feed tensors, but got ", feed_tensors.size());
1962   }
1963   if (fetch_tensors != nullptr) {
1964     fetch_tensors->resize(executors_and_keys->output_types.size());
1965   } else if (!executors_and_keys->output_types.empty()) {
1966     return errors::InvalidArgument(
1967         "`fetch_tensors` must be provided when the callable has one or more "
1968         "outputs.");
1969   }
1970 
1971   size_t input_size = 0;
1972   bool any_resource_feeds = false;
1973   for (auto& tensor : feed_tensors) {
1974     input_size += tensor.AllocatedBytes();
1975     any_resource_feeds = any_resource_feeds || tensor.dtype() == DT_RESOURCE;
1976   }
1977   metrics::RecordGraphInputTensors(input_size);
1978 
1979   std::unique_ptr<std::vector<Tensor>> converted_feed_tensors;
1980   const std::vector<Tensor>* actual_feed_tensors;
1981 
1982   if (TF_PREDICT_FALSE(any_resource_feeds)) {
1983     converted_feed_tensors = std::make_unique<std::vector<Tensor>>();
1984     converted_feed_tensors->reserve(feed_tensors.size());
1985     for (const Tensor& t : feed_tensors) {
1986       if (t.dtype() == DT_RESOURCE) {
1987         converted_feed_tensors->emplace_back();
1988         Tensor* tensor_from_handle = &converted_feed_tensors->back();
1989         TF_RETURN_IF_ERROR(ResourceHandleToInputTensor(t, tensor_from_handle));
1990       } else {
1991         converted_feed_tensors->emplace_back(t);
1992       }
1993     }
1994     actual_feed_tensors = converted_feed_tensors.get();
1995   } else {
1996     actual_feed_tensors = &feed_tensors;
1997   }
1998 
1999   // A specialized CallFrame implementation that takes advantage of the
2000   // optimized RunCallable interface.
2001   RunCallableCallFrame call_frame(this, executors_and_keys.get(),
2002                                   actual_feed_tensors, fetch_tensors);
2003 
2004   if (LogMemory::IsEnabled()) {
2005     LogMemory::RecordStep(step_id, run_state_args.handle);
2006   }
2007 
2008   TF_RETURN_IF_ERROR(RunInternal(
2009       step_id, executors_and_keys->callable_options.run_options(), &call_frame,
2010       executors_and_keys.get(), run_metadata, threadpool_options));
2011 
2012   if (fetch_tensors != nullptr) {
2013     size_t output_size = 0;
2014     for (auto& tensor : *fetch_tensors) {
2015       output_size += tensor.AllocatedBytes();
2016     }
2017     metrics::RecordGraphOutputTensors(output_size);
2018   }
2019 
2020   return OkStatus();
2021 }
2022 
ReleaseCallable(CallableHandle handle)2023 ::tensorflow::Status DirectSession::ReleaseCallable(CallableHandle handle) {
2024   mutex_lock l(callables_lock_);
2025   if (handle >= next_callable_handle_) {
2026     return errors::InvalidArgument("No such callable handle: ", handle);
2027   }
2028   callables_.erase(handle);
2029   return OkStatus();
2030 }
2031 
Finalize()2032 Status DirectSession::Finalize() {
2033   mutex_lock l(graph_state_lock_);
2034   if (finalized_) {
2035     return errors::FailedPrecondition("Session already finalized.");
2036   }
2037   if (!graph_created_) {
2038     return errors::FailedPrecondition("Session not yet created.");
2039   }
2040   execution_state_.reset();
2041   flib_def_.reset();
2042   finalized_ = true;
2043   return OkStatus();
2044 }
2045 
~Callable()2046 DirectSession::Callable::~Callable() {
2047   // We must delete the fields in this order, because the destructor
2048   // of `executors_and_keys` will call into an object owned by
2049   // `function_info` (in particular, when deleting a kernel, it relies
2050   // on the `FunctionLibraryRuntime` to know if the kernel is stateful
2051   // or not).
2052   executors_and_keys.reset();
2053   function_info.reset();
2054 }
2055 
2056 }  // namespace tensorflow
2057