1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/common_runtime/direct_session.h"
17
18 #include <algorithm>
19 #include <atomic>
20 #include <string>
21 #include <vector>
22
23 #include "absl/container/flat_hash_set.h"
24 #include "tensorflow/core/common_runtime/collective_executor_mgr.h"
25 #include "tensorflow/core/common_runtime/collective_param_resolver_local.h"
26 #include "tensorflow/core/common_runtime/constant_folding.h"
27 #include "tensorflow/core/common_runtime/debugger_state_interface.h"
28 #include "tensorflow/core/common_runtime/device_factory.h"
29 #include "tensorflow/core/common_runtime/device_resolver_local.h"
30 #include "tensorflow/core/common_runtime/executor.h"
31 #include "tensorflow/core/common_runtime/executor_factory.h"
32 #include "tensorflow/core/common_runtime/function.h"
33 #include "tensorflow/core/common_runtime/graph_constructor.h"
34 #include "tensorflow/core/common_runtime/graph_optimizer.h"
35 #include "tensorflow/core/common_runtime/memory_types.h"
36 #include "tensorflow/core/common_runtime/metrics.h"
37 #include "tensorflow/core/common_runtime/optimization_registry.h"
38 #include "tensorflow/core/common_runtime/process_util.h"
39 #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
40 #include "tensorflow/core/common_runtime/scoped_allocator_mgr.h"
41 #include "tensorflow/core/common_runtime/step_stats_collector.h"
42 #include "tensorflow/core/framework/function.h"
43 #include "tensorflow/core/framework/graph.pb.h"
44 #include "tensorflow/core/framework/graph_def_util.h"
45 #include "tensorflow/core/framework/log_memory.h"
46 #include "tensorflow/core/framework/logging.h"
47 #include "tensorflow/core/framework/node_def.pb.h"
48 #include "tensorflow/core/framework/run_handler.h"
49 #include "tensorflow/core/framework/tensor.h"
50 #include "tensorflow/core/framework/versions.pb.h"
51 #include "tensorflow/core/graph/algorithm.h"
52 #include "tensorflow/core/graph/graph.h"
53 #include "tensorflow/core/graph/graph_partition.h"
54 #include "tensorflow/core/graph/subgraph.h"
55 #include "tensorflow/core/graph/tensor_id.h"
56 #include "tensorflow/core/lib/core/errors.h"
57 #include "tensorflow/core/lib/core/notification.h"
58 #include "tensorflow/core/lib/core/refcount.h"
59 #include "tensorflow/core/lib/core/status.h"
60 #include "tensorflow/core/lib/core/threadpool.h"
61 #include "tensorflow/core/lib/core/threadpool_options.h"
62 #include "tensorflow/core/lib/gtl/array_slice.h"
63 #include "tensorflow/core/lib/monitoring/counter.h"
64 #include "tensorflow/core/lib/random/random.h"
65 #include "tensorflow/core/lib/strings/numbers.h"
66 #include "tensorflow/core/lib/strings/str_util.h"
67 #include "tensorflow/core/lib/strings/strcat.h"
68 #include "tensorflow/core/nccl/collective_communicator.h"
69 #include "tensorflow/core/platform/byte_order.h"
70 #include "tensorflow/core/platform/cpu_info.h"
71 #include "tensorflow/core/platform/logging.h"
72 #include "tensorflow/core/platform/mutex.h"
73 #include "tensorflow/core/platform/tracing.h"
74 #include "tensorflow/core/platform/types.h"
75 #include "tensorflow/core/profiler/lib/connected_traceme.h"
76 #include "tensorflow/core/profiler/lib/device_profiler_session.h"
77 #include "tensorflow/core/profiler/lib/traceme_encode.h"
78 #include "tensorflow/core/protobuf/config.pb.h"
79 #include "tensorflow/core/util/device_name_utils.h"
80 #include "tensorflow/core/util/env_var.h"
81
82 namespace tensorflow {
83
84 namespace {
85
86 auto* direct_session_runs = monitoring::Counter<0>::New(
87 "/tensorflow/core/direct_session_runs",
88 "The number of times DirectSession::Run() has been called.");
89
NewThreadPoolFromThreadPoolOptions(const SessionOptions & options,const ThreadPoolOptionProto & thread_pool_options,int pool_number,thread::ThreadPool ** pool,bool * owned)90 Status NewThreadPoolFromThreadPoolOptions(
91 const SessionOptions& options,
92 const ThreadPoolOptionProto& thread_pool_options, int pool_number,
93 thread::ThreadPool** pool, bool* owned) {
94 int32_t num_threads = thread_pool_options.num_threads();
95 if (num_threads == 0) {
96 num_threads = NumInterOpThreadsFromSessionOptions(options);
97 }
98 const string& name = thread_pool_options.global_name();
99 if (name.empty()) {
100 // Session-local threadpool.
101 VLOG(1) << "Direct session inter op parallelism threads for pool "
102 << pool_number << ": " << num_threads;
103 *pool = new thread::ThreadPool(
104 options.env, ThreadOptions(), strings::StrCat("Compute", pool_number),
105 num_threads, !options.config.experimental().disable_thread_spinning(),
106 /*allocator=*/nullptr);
107 *owned = true;
108 return Status::OK();
109 }
110
111 // Global, named threadpool.
112 typedef std::pair<int32, thread::ThreadPool*> MapValue;
113 static std::map<string, MapValue>* global_pool_map =
114 new std::map<string, MapValue>;
115 static mutex* mu = new mutex();
116 mutex_lock l(*mu);
117 MapValue* mvalue = &(*global_pool_map)[name];
118 if (mvalue->second == nullptr) {
119 mvalue->first = thread_pool_options.num_threads();
120 mvalue->second = new thread::ThreadPool(
121 options.env, ThreadOptions(), strings::StrCat("Compute", pool_number),
122 num_threads, !options.config.experimental().disable_thread_spinning(),
123 /*allocator=*/nullptr);
124 } else {
125 if (mvalue->first != thread_pool_options.num_threads()) {
126 return errors::InvalidArgument(
127 "Pool ", name,
128 " configured previously with num_threads=", mvalue->first,
129 "; cannot re-configure with num_threads=",
130 thread_pool_options.num_threads());
131 }
132 }
133 *owned = false;
134 *pool = mvalue->second;
135 return Status::OK();
136 }
137
GlobalThreadPool(const SessionOptions & options)138 thread::ThreadPool* GlobalThreadPool(const SessionOptions& options) {
139 static thread::ThreadPool* const thread_pool =
140 NewThreadPoolFromSessionOptions(options);
141 return thread_pool;
142 }
143
144 // TODO(vrv): Figure out how to unify the many different functions
145 // that generate RendezvousKey, since many of them have to be
146 // consistent with each other.
GetRendezvousKey(const string & tensor_name,const DeviceAttributes & device_info,const FrameAndIter & frame_iter)147 string GetRendezvousKey(const string& tensor_name,
148 const DeviceAttributes& device_info,
149 const FrameAndIter& frame_iter) {
150 return strings::StrCat(device_info.name(), ";",
151 strings::FpToString(device_info.incarnation()), ";",
152 device_info.name(), ";", tensor_name, ";",
153 frame_iter.frame_id, ":", frame_iter.iter_id);
154 }
155
156 } // namespace
157
158 class DirectSessionFactory : public SessionFactory {
159 public:
DirectSessionFactory()160 DirectSessionFactory() {}
161
AcceptsOptions(const SessionOptions & options)162 bool AcceptsOptions(const SessionOptions& options) override {
163 return options.target.empty();
164 }
165
NewSession(const SessionOptions & options,Session ** out_session)166 Status NewSession(const SessionOptions& options,
167 Session** out_session) override {
168 const auto& experimental_config = options.config.experimental();
169 if (experimental_config.has_session_metadata()) {
170 if (experimental_config.session_metadata().version() < 0) {
171 return errors::InvalidArgument(
172 "Session version shouldn't be negative: ",
173 experimental_config.session_metadata().DebugString());
174 }
175 const string key = GetMetadataKey(experimental_config.session_metadata());
176 mutex_lock l(sessions_lock_);
177 if (!session_metadata_keys_.insert(key).second) {
178 return errors::InvalidArgument(
179 "A session with the same name and version has already been "
180 "created: ",
181 experimental_config.session_metadata().DebugString());
182 }
183 }
184
185 // Must do this before the CPU allocator is created.
186 if (options.config.graph_options().build_cost_model() > 0) {
187 EnableCPUAllocatorFullStats();
188 }
189 std::vector<std::unique_ptr<Device>> devices;
190 TF_RETURN_IF_ERROR(DeviceFactory::AddDevices(
191 options, "/job:localhost/replica:0/task:0", &devices));
192
193 DirectSession* session = new DirectSession(
194 options, new StaticDeviceMgr(std::move(devices)), this);
195 {
196 mutex_lock l(sessions_lock_);
197 sessions_.push_back(session);
198 }
199 *out_session = session;
200 return Status::OK();
201 }
202
Reset(const SessionOptions & options,const std::vector<string> & containers)203 Status Reset(const SessionOptions& options,
204 const std::vector<string>& containers) override {
205 std::vector<DirectSession*> sessions_to_reset;
206 {
207 mutex_lock l(sessions_lock_);
208 // We create a copy to ensure that we don't have a deadlock when
209 // session->Close calls the DirectSessionFactory.Deregister, which
210 // acquires sessions_lock_.
211 std::swap(sessions_to_reset, sessions_);
212 }
213 Status s;
214 for (auto session : sessions_to_reset) {
215 s.Update(session->Reset(containers));
216 }
217 // TODO(suharshs): Change the Reset behavior of all SessionFactories so that
218 // it doesn't close the sessions?
219 for (auto session : sessions_to_reset) {
220 s.Update(session->Close());
221 }
222 return s;
223 }
224
Deregister(const DirectSession * session)225 void Deregister(const DirectSession* session) {
226 mutex_lock l(sessions_lock_);
227 sessions_.erase(std::remove(sessions_.begin(), sessions_.end(), session),
228 sessions_.end());
229 if (session->options().config.experimental().has_session_metadata()) {
230 session_metadata_keys_.erase(GetMetadataKey(
231 session->options().config.experimental().session_metadata()));
232 }
233 }
234
235 private:
GetMetadataKey(const SessionMetadata & metadata)236 static string GetMetadataKey(const SessionMetadata& metadata) {
237 return absl::StrCat(metadata.name(), "/", metadata.version());
238 }
239
240 mutex sessions_lock_;
241 std::vector<DirectSession*> sessions_ TF_GUARDED_BY(sessions_lock_);
242 absl::flat_hash_set<string> session_metadata_keys_
243 TF_GUARDED_BY(sessions_lock_);
244 };
245
246 class DirectSessionRegistrar {
247 public:
DirectSessionRegistrar()248 DirectSessionRegistrar() {
249 SessionFactory::Register("DIRECT_SESSION", new DirectSessionFactory());
250 }
251 };
252 static DirectSessionRegistrar registrar;
253
254 std::atomic_int_fast64_t DirectSession::step_id_counter_(1);
255
GetOrCreateRunHandlerPool(const SessionOptions & options)256 static RunHandlerPool* GetOrCreateRunHandlerPool(
257 const SessionOptions& options) {
258 int num_inter_threads = 0;
259 int num_intra_threads = 0;
260 static const int env_num_inter_threads = NumInterOpThreadsFromEnvironment();
261 static const int env_num_intra_threads = NumIntraOpThreadsFromEnvironment();
262 if (env_num_inter_threads > 0) {
263 num_inter_threads = env_num_inter_threads;
264 }
265 if (env_num_intra_threads > 0) {
266 num_intra_threads = env_num_intra_threads;
267 }
268
269 if (num_inter_threads == 0) {
270 if (options.config.session_inter_op_thread_pool_size() > 0) {
271 // Note due to ShouldUseRunHandler we are guaranteed that
272 // run_options.inter_op_thread_pool() == 0
273 num_inter_threads =
274 options.config.session_inter_op_thread_pool(0).num_threads();
275 }
276 if (num_inter_threads == 0) {
277 num_inter_threads = NumInterOpThreadsFromSessionOptions(options);
278 }
279 }
280
281 if (num_intra_threads == 0) {
282 num_intra_threads = options.config.intra_op_parallelism_threads();
283 if (num_intra_threads == 0) {
284 num_intra_threads = port::MaxParallelism();
285 }
286 }
287
288 static RunHandlerPool* pool =
289 new RunHandlerPool(num_inter_threads, num_intra_threads);
290 return pool;
291 }
292
ShouldUseRunHandlerPool(const RunOptions & run_options) const293 bool DirectSession::ShouldUseRunHandlerPool(
294 const RunOptions& run_options) const {
295 if (options_.config.use_per_session_threads()) return false;
296 if (options_.config.session_inter_op_thread_pool_size() > 0 &&
297 run_options.inter_op_thread_pool() > 0)
298 return false;
299 // Only use RunHandlerPool when:
300 // a. Single global thread pool is used for inter-op parallelism.
301 // b. When multiple inter_op_thread_pool(s) are created, use it only while
302 // running sessions on the default inter_op_thread_pool=0. Typically,
303 // servo-team uses inter_op_thread_pool > 0 for model loading.
304 // TODO(crk): Revisit whether we'd want to create one (static) RunHandlerPool
305 // per entry in session_inter_op_thread_pool() in the future.
306 return true;
307 }
308
DirectSession(const SessionOptions & options,const DeviceMgr * device_mgr,DirectSessionFactory * const factory)309 DirectSession::DirectSession(const SessionOptions& options,
310 const DeviceMgr* device_mgr,
311 DirectSessionFactory* const factory)
312 : options_(options),
313 device_mgr_(device_mgr),
314 factory_(factory),
315 cancellation_manager_(new CancellationManager()),
316 operation_timeout_in_ms_(options_.config.operation_timeout_in_ms()) {
317 const int thread_pool_size =
318 options_.config.session_inter_op_thread_pool_size();
319 if (thread_pool_size > 0) {
320 for (int i = 0; i < thread_pool_size; ++i) {
321 thread::ThreadPool* pool = nullptr;
322 bool owned = false;
323 init_error_.Update(NewThreadPoolFromThreadPoolOptions(
324 options_, options_.config.session_inter_op_thread_pool(i), i, &pool,
325 &owned));
326 thread_pools_.emplace_back(pool, owned);
327 }
328 } else if (options_.config.use_per_session_threads()) {
329 thread_pools_.emplace_back(NewThreadPoolFromSessionOptions(options_),
330 true /* owned */);
331 } else {
332 thread_pools_.emplace_back(GlobalThreadPool(options), false /* owned */);
333 // Run locally if environment value of TF_NUM_INTEROP_THREADS is negative
334 // and config.inter_op_parallelism_threads is unspecified or negative.
335 static const int env_num_threads = NumInterOpThreadsFromEnvironment();
336 if (options_.config.inter_op_parallelism_threads() < 0 ||
337 (options_.config.inter_op_parallelism_threads() == 0 &&
338 env_num_threads < 0)) {
339 run_in_caller_thread_ = true;
340 }
341 }
342 // The default value of sync_on_finish will be flipped soon and this
343 // environment variable will be removed as well.
344 const Status status =
345 ReadBoolFromEnvVar("TF_SYNC_ON_FINISH", true, &sync_on_finish_);
346 if (!status.ok()) {
347 LOG(ERROR) << status.error_message();
348 }
349 session_handle_ =
350 strings::StrCat("direct", strings::FpToString(random::New64()));
351 int devices_added = 0;
352 if (options.config.log_device_placement()) {
353 const string mapping_str = device_mgr_->DeviceMappingString();
354 string msg;
355 if (mapping_str.empty()) {
356 msg = "Device mapping: no known devices.";
357 } else {
358 msg = strings::StrCat("Device mapping:\n", mapping_str);
359 }
360 if (!logging::LogToListeners(msg)) {
361 LOG(INFO) << msg;
362 }
363 }
364 for (auto d : device_mgr_->ListDevices()) {
365 devices_.push_back(d);
366 device_set_.AddDevice(d);
367 d->op_segment()->AddHold(session_handle_);
368
369 // The first device added is special: it is the 'client device' (a
370 // CPU device) from which we feed and fetch Tensors.
371 if (devices_added == 0) {
372 device_set_.set_client_device(d);
373 }
374 ++devices_added;
375 }
376 }
377
~DirectSession()378 DirectSession::~DirectSession() {
379 if (!closed_) Close().IgnoreError();
380 for (auto& it : partial_runs_) {
381 it.second.reset(nullptr);
382 }
383 for (auto& it : executors_) {
384 it.second.reset();
385 }
386 callables_.clear();
387 for (auto d : device_mgr_->ListDevices()) {
388 d->op_segment()->RemoveHold(session_handle_);
389 }
390 functions_.clear();
391 delete cancellation_manager_;
392 for (const auto& p_and_owned : thread_pools_) {
393 if (p_and_owned.second) delete p_and_owned.first;
394 }
395
396 execution_state_.reset(nullptr);
397 flib_def_.reset(nullptr);
398 }
399
Create(const GraphDef & graph)400 Status DirectSession::Create(const GraphDef& graph) {
401 return Create(GraphDef(graph));
402 }
403
Create(GraphDef && graph)404 Status DirectSession::Create(GraphDef&& graph) {
405 TF_RETURN_IF_ERROR(init_error_);
406 if (graph.node_size() > 0) {
407 mutex_lock l(graph_state_lock_);
408 if (graph_created_) {
409 return errors::AlreadyExists(
410 "A Graph has already been created for this session.");
411 }
412 return ExtendLocked(std::move(graph));
413 }
414 return Status::OK();
415 }
416
Extend(const GraphDef & graph)417 Status DirectSession::Extend(const GraphDef& graph) {
418 return Extend(GraphDef(graph));
419 }
420
Extend(GraphDef && graph)421 Status DirectSession::Extend(GraphDef&& graph) {
422 TF_RETURN_IF_ERROR(CheckNotClosed());
423 mutex_lock l(graph_state_lock_);
424 return ExtendLocked(std::move(graph));
425 }
426
ExtendLocked(GraphDef graph)427 Status DirectSession::ExtendLocked(GraphDef graph) {
428 if (finalized_) {
429 return errors::FailedPrecondition("Session has been finalized.");
430 }
431 if (!(flib_def_ && execution_state_)) {
432 // If this is the first call, we can initialize the execution state
433 // with `graph` and do not need to call `Extend()`.
434 // NOTE(mrry): The function library created here will be used for
435 // all subsequent extensions of the graph.
436 flib_def_.reset(
437 new FunctionLibraryDefinition(OpRegistry::Global(), graph.library()));
438 GraphExecutionStateOptions options;
439 options.device_set = &device_set_;
440 options.session_options = &options_;
441 options.session_handle = session_handle_;
442 TF_RETURN_IF_ERROR(GraphExecutionState::MakeForBaseGraph(
443 std::move(graph), options, &execution_state_));
444 graph_created_ = true;
445 } else {
446 TF_RETURN_IF_ERROR(flib_def_->AddLibrary(graph.library()));
447 std::unique_ptr<GraphExecutionState> state;
448 // TODO(mrry): Rewrite GraphExecutionState::Extend() to take `graph` by
449 // value and move `graph` in here.
450 TF_RETURN_IF_ERROR(execution_state_->Extend(graph, &state));
451 execution_state_.swap(state);
452 }
453 return Status::OK();
454 }
455
Run(const NamedTensorList & inputs,const std::vector<string> & output_names,const std::vector<string> & target_nodes,std::vector<Tensor> * outputs)456 Status DirectSession::Run(const NamedTensorList& inputs,
457 const std::vector<string>& output_names,
458 const std::vector<string>& target_nodes,
459 std::vector<Tensor>* outputs) {
460 RunMetadata run_metadata;
461 return Run(RunOptions(), inputs, output_names, target_nodes, outputs,
462 &run_metadata);
463 }
464
CreateDebuggerState(const CallableOptions & callable_options,int64_t global_step,int64_t session_run_index,int64_t executor_step_index,std::unique_ptr<DebuggerStateInterface> * debugger_state)465 Status DirectSession::CreateDebuggerState(
466 const CallableOptions& callable_options, int64_t global_step,
467 int64_t session_run_index, int64_t executor_step_index,
468 std::unique_ptr<DebuggerStateInterface>* debugger_state) {
469 TF_RETURN_IF_ERROR(DebuggerStateRegistry::CreateState(
470 callable_options.run_options().debug_options(), debugger_state));
471 std::vector<string> input_names(callable_options.feed().begin(),
472 callable_options.feed().end());
473 std::vector<string> output_names(callable_options.fetch().begin(),
474 callable_options.fetch().end());
475 std::vector<string> target_names(callable_options.target().begin(),
476 callable_options.target().end());
477
478 TF_RETURN_IF_ERROR(debugger_state->get()->PublishDebugMetadata(
479 global_step, session_run_index, executor_step_index, input_names,
480 output_names, target_names));
481 return Status::OK();
482 }
483
DecorateAndPublishGraphForDebug(const DebugOptions & debug_options,Graph * graph,Device * device)484 Status DirectSession::DecorateAndPublishGraphForDebug(
485 const DebugOptions& debug_options, Graph* graph, Device* device) {
486 std::unique_ptr<DebugGraphDecoratorInterface> decorator;
487 TF_RETURN_IF_ERROR(
488 DebugGraphDecoratorRegistry::CreateDecorator(debug_options, &decorator));
489
490 TF_RETURN_IF_ERROR(decorator->DecorateGraph(graph, device));
491 TF_RETURN_IF_ERROR(decorator->PublishGraph(*graph, device->name()));
492 return Status::OK();
493 }
494
RunInternal(int64_t step_id,const RunOptions & run_options,CallFrameInterface * call_frame,ExecutorsAndKeys * executors_and_keys,RunMetadata * run_metadata,const thread::ThreadPoolOptions & threadpool_options)495 Status DirectSession::RunInternal(
496 int64_t step_id, const RunOptions& run_options,
497 CallFrameInterface* call_frame, ExecutorsAndKeys* executors_and_keys,
498 RunMetadata* run_metadata,
499 const thread::ThreadPoolOptions& threadpool_options) {
500 const uint64 start_time_usecs = options_.env->NowMicros();
501 const int64_t executor_step_count =
502 executors_and_keys->step_count.fetch_add(1);
503 RunState run_state(step_id, &devices_);
504 const size_t num_executors = executors_and_keys->items.size();
505
506 profiler::TraceMeProducer activity(
507 // To TraceMeConsumers in ExecutorState::Process/Finish.
508 [&] {
509 if (options_.config.experimental().has_session_metadata()) {
510 const auto& model_metadata =
511 options_.config.experimental().session_metadata();
512 string model_id = strings::StrCat(model_metadata.name(), ":",
513 model_metadata.version());
514 return profiler::TraceMeEncode("SessionRun",
515 {{"id", step_id},
516 {"_r", 1} /*root_event*/,
517 {"model_id", model_id}});
518 } else {
519 return profiler::TraceMeEncode(
520 "SessionRun", {{"id", step_id}, {"_r", 1} /*root_event*/});
521 }
522 },
523 profiler::ContextType::kTfExecutor, step_id,
524 profiler::TraceMeLevel::kInfo);
525
526 std::unique_ptr<DebuggerStateInterface> debugger_state;
527 if (!run_options.debug_options().debug_tensor_watch_opts().empty()) {
528 TF_RETURN_IF_ERROR(
529 CreateDebuggerState(executors_and_keys->callable_options,
530 run_options.debug_options().global_step(), step_id,
531 executor_step_count, &debugger_state));
532 }
533
534 #ifndef __ANDROID__
535 // Set up for collectives if ExecutorsAndKeys declares a key.
536 if (executors_and_keys->collective_graph_key !=
537 BuildGraphOptions::kNoCollectiveGraphKey) {
538 if (run_options.experimental().collective_graph_key() !=
539 BuildGraphOptions::kNoCollectiveGraphKey) {
540 // If a collective_graph_key was specified in run_options, ensure that it
541 // matches what came out of GraphExecutionState::BuildGraph().
542 if (run_options.experimental().collective_graph_key() !=
543 executors_and_keys->collective_graph_key) {
544 return errors::Internal(
545 "collective_graph_key in RunOptions ",
546 run_options.experimental().collective_graph_key(),
547 " should match collective_graph_key from optimized graph ",
548 executors_and_keys->collective_graph_key);
549 }
550 }
551 if (!collective_executor_mgr_) {
552 collective_executor_mgr_ = CreateProdLocalCollectiveExecutorMgr(
553 options_.config, device_mgr_.get(),
554 MaybeCreateNcclCommunicator(options_.config));
555 }
556 run_state.collective_executor.reset(new CollectiveExecutor::Handle(
557 collective_executor_mgr_->FindOrCreate(step_id), true /*inherit_ref*/));
558 }
559 #endif
560
561 thread::ThreadPool* pool;
562 // Use std::unique_ptr to ensure garbage collection
563 std::unique_ptr<thread::ThreadPool> threadpool_wrapper;
564
565 const bool inline_execution_requested =
566 run_in_caller_thread_ || run_options.inter_op_thread_pool() == -1;
567
568 if (inline_execution_requested) {
569 // We allow using the caller thread only when having a single executor
570 // specified.
571 if (executors_and_keys->items.size() > 1) {
572 pool = thread_pools_[0].first;
573 } else {
574 VLOG(1) << "Executing Session::Run() synchronously!";
575 pool = nullptr;
576 }
577 } else if (threadpool_options.inter_op_threadpool != nullptr) {
578 threadpool_wrapper = absl::make_unique<thread::ThreadPool>(
579 threadpool_options.inter_op_threadpool);
580 pool = threadpool_wrapper.get();
581 } else {
582 if (run_options.inter_op_thread_pool() < -1 ||
583 run_options.inter_op_thread_pool() >=
584 static_cast<int32>(thread_pools_.size())) {
585 return errors::InvalidArgument("Invalid inter_op_thread_pool: ",
586 run_options.inter_op_thread_pool());
587 }
588
589 pool = thread_pools_[run_options.inter_op_thread_pool()].first;
590 }
591
592 const int64_t call_timeout = run_options.timeout_in_ms() > 0
593 ? run_options.timeout_in_ms()
594 : operation_timeout_in_ms_;
595
596 std::unique_ptr<RunHandler> handler;
597 if (ShouldUseRunHandlerPool(run_options) &&
598 run_options.experimental().use_run_handler_pool()) {
599 VLOG(1) << "Using RunHandler to scheduler inter-op closures.";
600 handler = GetOrCreateRunHandlerPool(options_)->Get(
601 step_id, call_timeout,
602 run_options.experimental().run_handler_pool_options());
603 if (!handler) {
604 return errors::DeadlineExceeded(
605 "Could not obtain RunHandler for request after waiting for ",
606 call_timeout, "ms.");
607 }
608 }
609 auto* handler_ptr = handler.get();
610
611 Executor::Args::Runner default_runner = nullptr;
612
613 if (pool == nullptr) {
614 default_runner = [](const Executor::Args::Closure& c) { c(); };
615 } else if (handler_ptr != nullptr) {
616 default_runner = [handler_ptr](Executor::Args::Closure c) {
617 handler_ptr->ScheduleInterOpClosure(std::move(c));
618 };
619 } else {
620 default_runner = [pool](Executor::Args::Closure c) {
621 pool->Schedule(std::move(c));
622 };
623 }
624
625 // Start parallel Executors.
626
627 // We can execute this step synchronously on the calling thread whenever
628 // there is a single device and the timeout mechanism is not used.
629 //
630 // When timeouts are used, we must execute the graph(s) asynchronously, in
631 // order to invoke the cancellation manager on the calling thread if the
632 // timeout expires.
633 const bool can_execute_synchronously =
634 executors_and_keys->items.size() == 1 && call_timeout == 0;
635
636 Executor::Args args;
637 args.step_id = step_id;
638 args.call_frame = call_frame;
639 args.collective_executor =
640 (run_state.collective_executor ? run_state.collective_executor->get()
641 : nullptr);
642 args.session_state = &session_state_;
643 args.session_handle = session_handle_;
644 args.tensor_store = &run_state.tensor_store;
645 args.step_container = &run_state.step_container;
646 args.sync_on_finish = sync_on_finish_;
647 args.user_intra_op_threadpool = threadpool_options.intra_op_threadpool;
648 args.run_all_kernels_inline = pool == nullptr;
649 args.start_time_usecs = start_time_usecs;
650
651 const bool do_trace = (run_options.trace_level() > RunOptions::NO_TRACE);
652
653 bool update_cost_model = false;
654 if (options_.config.graph_options().build_cost_model() > 0) {
655 const int64_t build_cost_model_every =
656 options_.config.graph_options().build_cost_model();
657 const int64_t build_cost_model_after =
658 options_.config.graph_options().build_cost_model_after();
659 int64_t measure_step_count = executor_step_count - build_cost_model_after;
660 if (measure_step_count >= 0) {
661 update_cost_model =
662 ((measure_step_count + 1) % build_cost_model_every == 0);
663 }
664 }
665 if (do_trace || update_cost_model ||
666 run_options.report_tensor_allocations_upon_oom()) {
667 run_state.collector.reset(
668 new StepStatsCollector(run_metadata->mutable_step_stats()));
669 args.stats_collector = run_state.collector.get();
670 }
671
672 std::unique_ptr<DeviceProfilerSession> device_profiler_session;
673 if (run_options.trace_level() >= RunOptions::HARDWARE_TRACE) {
674 device_profiler_session = DeviceProfilerSession::Create();
675 }
676
677 // Register this step with session's cancellation manager, so that
678 // `Session::Close()` will cancel the step.
679 CancellationManager step_cancellation_manager(cancellation_manager_);
680 if (step_cancellation_manager.IsCancelled()) {
681 return errors::Cancelled("Run call was cancelled");
682 }
683 args.cancellation_manager = &step_cancellation_manager;
684
685 Status run_status;
686
687 auto set_threadpool_args_for_item =
688 [&default_runner, &handler](const PerPartitionExecutorsAndLib& item,
689 Executor::Args* args) {
690 // TODO(azaks): support partial run.
691 // TODO(azaks): if the device picks its own threadpool, we need to
692 // assign
693 // less threads to the main compute pool by default.
694 thread::ThreadPool* device_thread_pool =
695 item.device->tensorflow_device_thread_pool();
696 // TODO(crk): Investigate usage of RunHandlerPool when using device
697 // specific thread pool(s).
698 if (!device_thread_pool) {
699 args->runner = default_runner;
700 } else {
701 args->runner = [device_thread_pool](Executor::Args::Closure c) {
702 device_thread_pool->Schedule(std::move(c));
703 };
704 }
705 if (handler != nullptr) {
706 args->user_intra_op_threadpool =
707 handler->AsIntraThreadPoolInterface();
708 }
709 };
710
711 if (can_execute_synchronously) {
712 PrivateIntraProcessRendezvous rendezvous(device_mgr_.get());
713 args.rendezvous = &rendezvous;
714
715 const auto& item = executors_and_keys->items[0];
716 set_threadpool_args_for_item(item, &args);
717 run_status = item.executor->Run(args);
718 } else {
719 core::RefCountPtr<RefCountedIntraProcessRendezvous> rendezvous(
720 new RefCountedIntraProcessRendezvous(device_mgr_.get()));
721 args.rendezvous = rendezvous.get();
722
723 // `barrier` will delete itself after the final executor finishes.
724 Notification executors_done;
725 ExecutorBarrier* barrier =
726 new ExecutorBarrier(num_executors, rendezvous.get(),
727 [&run_state, &executors_done](const Status& ret) {
728 {
729 mutex_lock l(run_state.mu);
730 run_state.status.Update(ret);
731 }
732 executors_done.Notify();
733 });
734
735 for (const auto& item : executors_and_keys->items) {
736 set_threadpool_args_for_item(item, &args);
737 item.executor->RunAsync(args, barrier->Get());
738 }
739
740 WaitForNotification(&executors_done, &run_state, &step_cancellation_manager,
741 call_timeout);
742 {
743 tf_shared_lock l(run_state.mu);
744 run_status = run_state.status;
745 }
746 }
747
748 if (step_cancellation_manager.IsCancelled()) {
749 run_status.Update(errors::Cancelled("Run call was cancelled"));
750 }
751
752 if (device_profiler_session) {
753 TF_RETURN_IF_ERROR(device_profiler_session->CollectData(
754 run_metadata->mutable_step_stats()));
755 }
756
757 TF_RETURN_IF_ERROR(run_status);
758
759 // Save the output tensors of this run we choose to keep.
760 if (!run_state.tensor_store.empty()) {
761 TF_RETURN_IF_ERROR(run_state.tensor_store.SaveTensors(
762 {executors_and_keys->callable_options.fetch().begin(),
763 executors_and_keys->callable_options.fetch().end()},
764 &session_state_));
765 }
766
767 if (run_state.collector) {
768 run_state.collector->Finalize();
769 }
770
771 // Build and return the cost model as instructed.
772 if (update_cost_model) {
773 // Build the cost model
774 std::unordered_map<string, const Graph*> device_to_graph;
775 for (const PerPartitionExecutorsAndLib& partition :
776 executors_and_keys->items) {
777 const Graph* graph = partition.graph.get();
778 const string& device = partition.flib->device()->name();
779 device_to_graph[device] = graph;
780 }
781
782 mutex_lock l(executor_lock_);
783 run_state.collector->BuildCostModel(&cost_model_manager_, device_to_graph);
784
785 // annotate stats onto cost graph.
786 CostGraphDef* cost_graph = run_metadata->mutable_cost_graph();
787 for (const auto& item : executors_and_keys->items) {
788 TF_RETURN_IF_ERROR(
789 cost_model_manager_.AddToCostGraphDef(item.graph.get(), cost_graph));
790 }
791 }
792
793 // If requested via RunOptions, output the partition graphs.
794 if (run_options.output_partition_graphs()) {
795 if (options_.config.experimental().disable_output_partition_graphs()) {
796 return errors::InvalidArgument(
797 "RunOptions.output_partition_graphs() is not supported when "
798 "disable_output_partition_graphs is true.");
799 } else {
800 protobuf::RepeatedPtrField<GraphDef>* partition_graph_defs =
801 run_metadata->mutable_partition_graphs();
802 for (const PerPartitionExecutorsAndLib& exec_and_lib :
803 executors_and_keys->items) {
804 GraphDef* partition_graph_def = partition_graph_defs->Add();
805 exec_and_lib.graph->ToGraphDef(partition_graph_def);
806 }
807 }
808 }
809 metrics::UpdateGraphExecTime(options_.env->NowMicros() - start_time_usecs);
810
811 return Status::OK();
812 }
813
Run(const RunOptions & run_options,const NamedTensorList & inputs,const std::vector<string> & output_names,const std::vector<string> & target_nodes,std::vector<Tensor> * outputs,RunMetadata * run_metadata)814 Status DirectSession::Run(const RunOptions& run_options,
815 const NamedTensorList& inputs,
816 const std::vector<string>& output_names,
817 const std::vector<string>& target_nodes,
818 std::vector<Tensor>* outputs,
819 RunMetadata* run_metadata) {
820 return Run(run_options, inputs, output_names, target_nodes, outputs,
821 run_metadata, thread::ThreadPoolOptions());
822 }
823
Run(const RunOptions & run_options,const NamedTensorList & inputs,const std::vector<string> & output_names,const std::vector<string> & target_nodes,std::vector<Tensor> * outputs,RunMetadata * run_metadata,const thread::ThreadPoolOptions & threadpool_options)824 Status DirectSession::Run(const RunOptions& run_options,
825 const NamedTensorList& inputs,
826 const std::vector<string>& output_names,
827 const std::vector<string>& target_nodes,
828 std::vector<Tensor>* outputs,
829 RunMetadata* run_metadata,
830 const thread::ThreadPoolOptions& threadpool_options) {
831 TF_RETURN_IF_ERROR(CheckNotClosed());
832 TF_RETURN_IF_ERROR(CheckGraphCreated("Run()"));
833 direct_session_runs->GetCell()->IncrementBy(1);
834
835 // Extract the inputs names for this run of the session.
836 std::vector<string> input_tensor_names;
837 input_tensor_names.reserve(inputs.size());
838 size_t input_size = 0;
839 for (const auto& it : inputs) {
840 input_tensor_names.push_back(it.first);
841 input_size += it.second.AllocatedBytes();
842 }
843 metrics::RecordGraphInputTensors(input_size);
844
845 // Check if we already have an executor for these arguments.
846 ExecutorsAndKeys* executors_and_keys;
847 RunStateArgs run_state_args(run_options.debug_options());
848 run_state_args.collective_graph_key =
849 run_options.experimental().collective_graph_key();
850
851 TF_RETURN_IF_ERROR(GetOrCreateExecutors(input_tensor_names, output_names,
852 target_nodes, &executors_and_keys,
853 &run_state_args));
854 {
855 mutex_lock l(collective_graph_key_lock_);
856 collective_graph_key_ = executors_and_keys->collective_graph_key;
857 }
858
859 // Configure a call frame for the step, which we use to feed and
860 // fetch values to and from the executors.
861 FunctionCallFrame call_frame(executors_and_keys->input_types,
862 executors_and_keys->output_types);
863 gtl::InlinedVector<Tensor, 4> feed_args(inputs.size());
864 for (const auto& it : inputs) {
865 if (it.second.dtype() == DT_RESOURCE) {
866 Tensor tensor_from_handle;
867 TF_RETURN_IF_ERROR(
868 ResourceHandleToInputTensor(it.second, &tensor_from_handle));
869 feed_args[executors_and_keys->input_name_to_index[it.first]] =
870 tensor_from_handle;
871 } else {
872 feed_args[executors_and_keys->input_name_to_index[it.first]] = it.second;
873 }
874 }
875 const Status s = call_frame.SetArgs(feed_args);
876 if (errors::IsInternal(s)) {
877 return errors::InvalidArgument(s.error_message());
878 } else if (!s.ok()) {
879 return s;
880 }
881
882 const int64_t step_id = step_id_counter_.fetch_add(1);
883
884 if (LogMemory::IsEnabled()) {
885 LogMemory::RecordStep(step_id, run_state_args.handle);
886 }
887
888 TF_RETURN_IF_ERROR(RunInternal(step_id, run_options, &call_frame,
889 executors_and_keys, run_metadata,
890 threadpool_options));
891
892 // Receive outputs.
893 if (outputs) {
894 std::vector<Tensor> sorted_outputs;
895 const Status s = call_frame.ConsumeRetvals(
896 &sorted_outputs, /* allow_dead_tensors = */ false);
897 if (errors::IsInternal(s)) {
898 return errors::InvalidArgument(s.error_message());
899 } else if (!s.ok()) {
900 return s;
901 }
902 const bool unique_outputs =
903 output_names.size() == executors_and_keys->output_name_to_index.size();
904 // first_indices[i] = j implies that j is the smallest value for which
905 // output_names[i] == output_names[j].
906 std::vector<int> first_indices;
907 if (!unique_outputs) {
908 first_indices.reserve(output_names.size());
909 for (const auto& name : output_names) {
910 first_indices.push_back(
911 std::find(output_names.begin(), output_names.end(), name) -
912 output_names.begin());
913 }
914 }
915 outputs->clear();
916 size_t output_size = 0;
917 outputs->reserve(sorted_outputs.size());
918 for (int i = 0; i < output_names.size(); ++i) {
919 const string& output_name = output_names[i];
920 if (first_indices.empty() || first_indices[i] == i) {
921 outputs->emplace_back(
922 std::move(sorted_outputs[executors_and_keys
923 ->output_name_to_index[output_name]]));
924 } else {
925 outputs->push_back((*outputs)[first_indices[i]]);
926 }
927 output_size += outputs->back().AllocatedBytes();
928 }
929 metrics::RecordGraphOutputTensors(output_size);
930 }
931
932 return Status::OK();
933 }
934
PRunSetup(const std::vector<string> & input_names,const std::vector<string> & output_names,const std::vector<string> & target_nodes,string * handle)935 Status DirectSession::PRunSetup(const std::vector<string>& input_names,
936 const std::vector<string>& output_names,
937 const std::vector<string>& target_nodes,
938 string* handle) {
939 TF_RETURN_IF_ERROR(CheckNotClosed());
940 TF_RETURN_IF_ERROR(CheckGraphCreated("PRunSetup()"));
941
942 // RunOptions is not available in PRunSetup, so use thread pool 0.
943 thread::ThreadPool* pool = thread_pools_[0].first;
944
945 // Check if we already have an executor for these arguments.
946 ExecutorsAndKeys* executors_and_keys;
947 // TODO(cais): TFDBG support for partial runs.
948 DebugOptions debug_options;
949 RunStateArgs run_state_args(debug_options);
950 run_state_args.is_partial_run = true;
951 TF_RETURN_IF_ERROR(GetOrCreateExecutors(input_names, output_names,
952 target_nodes, &executors_and_keys,
953 &run_state_args));
954
955 // Create the run state and save it for future PRun calls.
956 Executor::Args args;
957 args.step_id = step_id_counter_.fetch_add(1);
958 PartialRunState* run_state =
959 new PartialRunState(input_names, output_names, args.step_id, &devices_);
960 run_state->rendez.reset(new IntraProcessRendezvous(device_mgr_.get()));
961 {
962 mutex_lock l(executor_lock_);
963 if (!partial_runs_
964 .emplace(run_state_args.handle,
965 std::unique_ptr<PartialRunState>(run_state))
966 .second) {
967 return errors::Internal("The handle '", run_state_args.handle,
968 "' created for this partial run is not unique.");
969 }
970 }
971
972 // Start parallel Executors.
973 const size_t num_executors = executors_and_keys->items.size();
974 ExecutorBarrier* barrier = new ExecutorBarrier(
975 num_executors, run_state->rendez.get(), [run_state](const Status& ret) {
976 if (!ret.ok()) {
977 mutex_lock l(run_state->mu);
978 run_state->status.Update(ret);
979 }
980 run_state->executors_done.Notify();
981 });
982
983 args.rendezvous = run_state->rendez.get();
984 args.cancellation_manager = cancellation_manager_;
985 // Note that Collectives are not supported in partial runs
986 // because RunOptions is not passed in so we can't know whether
987 // their use is intended.
988 args.collective_executor = nullptr;
989 args.runner = [this, pool](Executor::Args::Closure c) {
990 pool->Schedule(std::move(c));
991 };
992 args.session_state = &session_state_;
993 args.session_handle = session_handle_;
994 args.tensor_store = &run_state->tensor_store;
995 args.step_container = &run_state->step_container;
996 if (LogMemory::IsEnabled()) {
997 LogMemory::RecordStep(args.step_id, run_state_args.handle);
998 }
999 args.sync_on_finish = sync_on_finish_;
1000
1001 if (options_.config.graph_options().build_cost_model()) {
1002 run_state->collector.reset(new StepStatsCollector(nullptr));
1003 args.stats_collector = run_state->collector.get();
1004 }
1005
1006 for (auto& item : executors_and_keys->items) {
1007 item.executor->RunAsync(args, barrier->Get());
1008 }
1009
1010 *handle = run_state_args.handle;
1011 return Status::OK();
1012 }
1013
PRun(const string & handle,const NamedTensorList & inputs,const std::vector<string> & output_names,std::vector<Tensor> * outputs)1014 Status DirectSession::PRun(const string& handle, const NamedTensorList& inputs,
1015 const std::vector<string>& output_names,
1016 std::vector<Tensor>* outputs) {
1017 TF_RETURN_IF_ERROR(CheckNotClosed());
1018 std::vector<string> parts = str_util::Split(handle, ';');
1019 const string& key = parts[0];
1020 // Get the executors for this partial run.
1021 ExecutorsAndKeys* executors_and_keys;
1022 PartialRunState* run_state;
1023 {
1024 mutex_lock l(executor_lock_); // could use reader lock
1025 auto exc_it = executors_.find(key);
1026 if (exc_it == executors_.end()) {
1027 return errors::InvalidArgument(
1028 "Must run 'setup' before performing partial runs!");
1029 }
1030 executors_and_keys = exc_it->second.get();
1031
1032 auto prun_it = partial_runs_.find(handle);
1033 if (prun_it == partial_runs_.end()) {
1034 return errors::InvalidArgument(
1035 "Must run 'setup' before performing partial runs!");
1036 }
1037 run_state = prun_it->second.get();
1038
1039 // Make sure that this is a new set of feeds that are still pending.
1040 for (const auto& input : inputs) {
1041 auto it = run_state->pending_inputs.find(input.first);
1042 if (it == run_state->pending_inputs.end()) {
1043 return errors::InvalidArgument(
1044 "The feed ", input.first,
1045 " was not specified in partial_run_setup.");
1046 } else if (it->second) {
1047 return errors::InvalidArgument("The feed ", input.first,
1048 " has already been fed.");
1049 }
1050 }
1051 // Check that this is a new set of fetches that are still pending.
1052 for (const auto& output : output_names) {
1053 auto it = run_state->pending_outputs.find(output);
1054 if (it == run_state->pending_outputs.end()) {
1055 return errors::InvalidArgument(
1056 "The fetch ", output, " was not specified in partial_run_setup.");
1057 } else if (it->second) {
1058 return errors::InvalidArgument("The fetch ", output,
1059 " has already been fetched.");
1060 }
1061 }
1062 }
1063
1064 // Check that this new set of fetches can be computed from all the
1065 // feeds we have supplied.
1066 TF_RETURN_IF_ERROR(
1067 CheckFetch(inputs, output_names, executors_and_keys, run_state));
1068
1069 // Send inputs.
1070 Status s =
1071 SendPRunInputs(inputs, executors_and_keys, run_state->rendez.get());
1072
1073 // Receive outputs.
1074 if (s.ok()) {
1075 s = RecvPRunOutputs(output_names, executors_and_keys, run_state, outputs);
1076 }
1077
1078 // Save the output tensors of this run we choose to keep.
1079 if (s.ok()) {
1080 s = run_state->tensor_store.SaveTensors(output_names, &session_state_);
1081 }
1082
1083 {
1084 mutex_lock l(executor_lock_);
1085 // Delete the run state if there is an error or all fetches are done.
1086 bool done = true;
1087 if (s.ok()) {
1088 {
1089 mutex_lock l(run_state->mu);
1090 if (!run_state->status.ok()) {
1091 LOG(WARNING) << "An error unrelated to this prun has been detected. "
1092 << run_state->status;
1093 }
1094 }
1095 for (const auto& input : inputs) {
1096 auto it = run_state->pending_inputs.find(input.first);
1097 it->second = true;
1098 }
1099 for (const auto& name : output_names) {
1100 auto it = run_state->pending_outputs.find(name);
1101 it->second = true;
1102 }
1103 done = run_state->PendingDone();
1104 }
1105 if (done) {
1106 WaitForNotification(&run_state->executors_done, run_state,
1107 cancellation_manager_, operation_timeout_in_ms_);
1108 partial_runs_.erase(handle);
1109 }
1110 }
1111
1112 return s;
1113 }
1114
ResourceHandleToInputTensor(const Tensor & resource_tensor,Tensor * retrieved_tensor)1115 Status DirectSession::ResourceHandleToInputTensor(const Tensor& resource_tensor,
1116 Tensor* retrieved_tensor) {
1117 if (resource_tensor.dtype() != DT_RESOURCE) {
1118 return errors::InvalidArgument(strings::StrCat(
1119 "ResourceHandleToInputTensor() received non-DT_RESOURCE Tensor: ",
1120 resource_tensor.dtype()));
1121 }
1122
1123 const ResourceHandle& resource_handle =
1124 resource_tensor.scalar<ResourceHandle>()();
1125
1126 if (resource_handle.container() ==
1127 SessionState::kTensorHandleResourceTypeName) {
1128 return session_state_.GetTensor(resource_handle.name(), retrieved_tensor);
1129 } else {
1130 return errors::InvalidArgument(strings::StrCat(
1131 "Invalid resource type hash code: ", resource_handle.hash_code(),
1132 "(name: ", resource_handle.name(),
1133 " type: ", resource_handle.maybe_type_name(),
1134 "). Perhaps a resource tensor was being provided as a feed? That is "
1135 "not currently allowed. Please file an issue at "
1136 "https://github.com/tensorflow/tensorflow/issues/new, ideally with a "
1137 "short code snippet that leads to this error message."));
1138 }
1139 }
1140
SendPRunInputs(const NamedTensorList & inputs,const ExecutorsAndKeys * executors_and_keys,IntraProcessRendezvous * rendez)1141 Status DirectSession::SendPRunInputs(const NamedTensorList& inputs,
1142 const ExecutorsAndKeys* executors_and_keys,
1143 IntraProcessRendezvous* rendez) {
1144 Status s;
1145 Rendezvous::ParsedKey parsed;
1146 // Insert the input tensors into the local rendezvous by their
1147 // rendezvous key.
1148 for (const auto& input : inputs) {
1149 auto it =
1150 executors_and_keys->input_name_to_rendezvous_key.find(input.first);
1151 if (it == executors_and_keys->input_name_to_rendezvous_key.end()) {
1152 return errors::Internal("'", input.first, "' is not a pre-defined feed.");
1153 }
1154 const string& input_key = it->second;
1155
1156 s = Rendezvous::ParseKey(input_key, &parsed);
1157 if (!s.ok()) {
1158 rendez->StartAbort(s);
1159 return s;
1160 }
1161
1162 if (input.second.dtype() == DT_RESOURCE) {
1163 Tensor tensor_from_handle;
1164 s = ResourceHandleToInputTensor(input.second, &tensor_from_handle);
1165 if (s.ok()) {
1166 s = rendez->Send(parsed, Rendezvous::Args(), tensor_from_handle, false);
1167 }
1168 } else {
1169 s = rendez->Send(parsed, Rendezvous::Args(), input.second, false);
1170 }
1171
1172 if (!s.ok()) {
1173 rendez->StartAbort(s);
1174 return s;
1175 }
1176 }
1177 return Status::OK();
1178 }
1179
RecvPRunOutputs(const std::vector<string> & output_names,const ExecutorsAndKeys * executors_and_keys,PartialRunState * run_state,std::vector<Tensor> * outputs)1180 Status DirectSession::RecvPRunOutputs(
1181 const std::vector<string>& output_names,
1182 const ExecutorsAndKeys* executors_and_keys, PartialRunState* run_state,
1183 std::vector<Tensor>* outputs) {
1184 Status s;
1185 if (!output_names.empty()) {
1186 outputs->resize(output_names.size());
1187 }
1188
1189 Rendezvous::ParsedKey parsed;
1190 // Get the outputs from the rendezvous
1191 for (size_t output_offset = 0; output_offset < output_names.size();
1192 ++output_offset) {
1193 const string& output_name = output_names[output_offset];
1194 auto it =
1195 executors_and_keys->output_name_to_rendezvous_key.find(output_name);
1196 if (it == executors_and_keys->output_name_to_rendezvous_key.end()) {
1197 return errors::Internal("'", output_name,
1198 "' is not a pre-defined fetch.");
1199 }
1200 const string& output_key = it->second;
1201 Tensor output_tensor;
1202 bool is_dead;
1203
1204 s = Rendezvous::ParseKey(output_key, &parsed);
1205 if (s.ok()) {
1206 // Fetch data from the Rendezvous.
1207 s = run_state->rendez->Recv(parsed, Rendezvous::Args(), &output_tensor,
1208 &is_dead, operation_timeout_in_ms_);
1209 if (is_dead && s.ok()) {
1210 s = errors::InvalidArgument("The tensor returned for ", output_name,
1211 " was not valid.");
1212 }
1213 }
1214 if (!s.ok()) {
1215 run_state->rendez->StartAbort(s);
1216 outputs->clear();
1217 return s;
1218 }
1219
1220 (*outputs)[output_offset] = output_tensor;
1221 }
1222 return Status::OK();
1223 }
1224
CheckFetch(const NamedTensorList & feeds,const std::vector<string> & fetches,const ExecutorsAndKeys * executors_and_keys,const PartialRunState * run_state)1225 Status DirectSession::CheckFetch(const NamedTensorList& feeds,
1226 const std::vector<string>& fetches,
1227 const ExecutorsAndKeys* executors_and_keys,
1228 const PartialRunState* run_state) {
1229 const Graph* graph = executors_and_keys->graph.get();
1230 const NameNodeMap* name_to_node = &executors_and_keys->name_to_node;
1231
1232 // Build the set of pending feeds that we haven't seen.
1233 std::unordered_set<TensorId, TensorId::Hasher> pending_feeds;
1234 {
1235 mutex_lock l(executor_lock_);
1236 for (const auto& input : run_state->pending_inputs) {
1237 // Skip if the feed has already been fed.
1238 if (input.second) continue;
1239 TensorId id(ParseTensorName(input.first));
1240 auto it = name_to_node->find(id.first);
1241 if (it == name_to_node->end()) {
1242 return errors::NotFound("Feed ", input.first, ": not found");
1243 }
1244 pending_feeds.insert(id);
1245 }
1246 }
1247 for (const auto& it : feeds) {
1248 TensorId id(ParseTensorName(it.first));
1249 pending_feeds.erase(id);
1250 }
1251
1252 // Initialize the stack with the fetch nodes.
1253 std::vector<const Node*> stack;
1254 for (const string& fetch : fetches) {
1255 TensorId id(ParseTensorName(fetch));
1256 auto it = name_to_node->find(id.first);
1257 if (it == name_to_node->end()) {
1258 return errors::NotFound("Fetch ", fetch, ": not found");
1259 }
1260 stack.push_back(it->second);
1261 }
1262
1263 // Any tensor needed for fetches can't be in pending_feeds.
1264 std::vector<bool> visited(graph->num_node_ids(), false);
1265 while (!stack.empty()) {
1266 const Node* n = stack.back();
1267 stack.pop_back();
1268
1269 for (const Edge* in_edge : n->in_edges()) {
1270 const Node* in_node = in_edge->src();
1271 if (pending_feeds.count({in_node->name(), in_edge->src_output()}) > 0) {
1272 return errors::InvalidArgument("Fetch ", in_node->name(), ":",
1273 in_edge->src_output(),
1274 " can't be computed from the feeds"
1275 " that have been fed so far.");
1276 }
1277 if (!visited[in_node->id()]) {
1278 visited[in_node->id()] = true;
1279 stack.push_back(in_node);
1280 }
1281 }
1282 }
1283 return Status::OK();
1284 }
1285
CreateExecutors(const CallableOptions & callable_options,std::unique_ptr<ExecutorsAndKeys> * out_executors_and_keys,std::unique_ptr<FunctionInfo> * out_func_info,RunStateArgs * run_state_args)1286 Status DirectSession::CreateExecutors(
1287 const CallableOptions& callable_options,
1288 std::unique_ptr<ExecutorsAndKeys>* out_executors_and_keys,
1289 std::unique_ptr<FunctionInfo>* out_func_info,
1290 RunStateArgs* run_state_args) {
1291 BuildGraphOptions options;
1292 options.callable_options = callable_options;
1293 options.use_function_convention = !run_state_args->is_partial_run;
1294 options.collective_graph_key =
1295 callable_options.run_options().experimental().collective_graph_key();
1296 if (options_.config.experimental()
1297 .collective_deterministic_sequential_execution()) {
1298 options.collective_order = GraphCollectiveOrder::kEdges;
1299 } else if (options_.config.experimental().collective_nccl()) {
1300 options.collective_order = GraphCollectiveOrder::kAttrs;
1301 }
1302
1303 std::unique_ptr<FunctionInfo> func_info(new FunctionInfo);
1304 std::unique_ptr<ExecutorsAndKeys> ek(new ExecutorsAndKeys);
1305
1306 ek->callable_options = callable_options;
1307
1308 std::unordered_map<string, std::unique_ptr<Graph>> graphs;
1309 TF_RETURN_IF_ERROR(CreateGraphs(
1310 options, &graphs, &func_info->flib_def, run_state_args, &ek->input_types,
1311 &ek->output_types, &ek->collective_graph_key));
1312
1313 if (run_state_args->is_partial_run) {
1314 ek->graph = std::move(run_state_args->graph);
1315 std::unordered_set<StringPiece, StringPieceHasher> names;
1316 for (const string& input : callable_options.feed()) {
1317 TensorId id(ParseTensorName(input));
1318 names.emplace(id.first);
1319 }
1320 for (const string& output : callable_options.fetch()) {
1321 TensorId id(ParseTensorName(output));
1322 names.emplace(id.first);
1323 }
1324 for (Node* n : ek->graph->nodes()) {
1325 if (names.count(n->name()) > 0) {
1326 ek->name_to_node.insert({n->name(), n});
1327 }
1328 }
1329 }
1330 ek->items.reserve(graphs.size());
1331 const auto& optimizer_opts =
1332 options_.config.graph_options().optimizer_options();
1333
1334 int graph_def_version = graphs.begin()->second->versions().producer();
1335
1336 const auto* session_metadata =
1337 options_.config.experimental().has_session_metadata()
1338 ? &options_.config.experimental().session_metadata()
1339 : nullptr;
1340 func_info->proc_flr.reset(new ProcessFunctionLibraryRuntime(
1341 device_mgr_.get(), options_.env, &options_.config, graph_def_version,
1342 func_info->flib_def.get(), optimizer_opts, thread_pools_[0].first,
1343 /*parent=*/nullptr, session_metadata,
1344 Rendezvous::Factory{
1345 [](const int64_t, const DeviceMgr* device_mgr, Rendezvous** r) {
1346 *r = new IntraProcessRendezvous(device_mgr);
1347 return Status::OK();
1348 }}));
1349
1350 GraphOptimizer optimizer(optimizer_opts);
1351 for (auto iter = graphs.begin(); iter != graphs.end(); ++iter) {
1352 const string& partition_name = iter->first;
1353 std::unique_ptr<Graph>& partition_graph = iter->second;
1354
1355 Device* device;
1356 TF_RETURN_IF_ERROR(device_mgr_->LookupDevice(partition_name, &device));
1357
1358 ek->items.resize(ek->items.size() + 1);
1359 auto* item = &(ek->items.back());
1360 auto lib = func_info->proc_flr->GetFLR(partition_name);
1361 if (lib == nullptr) {
1362 return errors::Internal("Could not find device: ", partition_name);
1363 }
1364 item->flib = lib;
1365
1366 LocalExecutorParams params;
1367 params.device = device;
1368 params.session_metadata = session_metadata;
1369 params.function_library = lib;
1370 auto opseg = device->op_segment();
1371 params.create_kernel =
1372 [this, lib, opseg](const std::shared_ptr<const NodeProperties>& props,
1373 OpKernel** kernel) {
1374 // NOTE(mrry): We must not share function kernels (implemented
1375 // using `CallOp`) between subgraphs, because `CallOp::handle_`
1376 // is tied to a particular subgraph. Even if the function itself
1377 // is stateful, the `CallOp` that invokes it is not.
1378 if (!OpSegment::ShouldOwnKernel(lib, props->node_def.op())) {
1379 return lib->CreateKernel(props, kernel);
1380 }
1381 auto create_fn = [lib, &props](OpKernel** kernel) {
1382 return lib->CreateKernel(props, kernel);
1383 };
1384 // Kernels created for subgraph nodes need to be cached. On
1385 // cache miss, create_fn() is invoked to create a kernel based
1386 // on the function library here + global op registry.
1387 return opseg->FindOrCreate(session_handle_, props->node_def.name(),
1388 kernel, create_fn);
1389 };
1390 params.delete_kernel = [lib](OpKernel* kernel) {
1391 if (kernel && !OpSegment::ShouldOwnKernel(lib, kernel->type_string()))
1392 delete kernel;
1393 };
1394
1395 optimizer.Optimize(lib, options_.env, device, &partition_graph,
1396 /*shape_map=*/nullptr);
1397
1398 // TensorFlow Debugger (tfdbg) inserts debug nodes in the graph.
1399 const DebugOptions& debug_options =
1400 options.callable_options.run_options().debug_options();
1401 if (!debug_options.debug_tensor_watch_opts().empty()) {
1402 TF_RETURN_IF_ERROR(DecorateAndPublishGraphForDebug(
1403 debug_options, partition_graph.get(), params.device));
1404 }
1405
1406 TF_RETURN_IF_ERROR(EnsureMemoryTypes(DeviceType(device->device_type()),
1407 device->name(),
1408 partition_graph.get()));
1409
1410 item->executor = nullptr;
1411 item->device = device;
1412 auto executor_type = options_.config.experimental().executor_type();
1413 TF_RETURN_IF_ERROR(
1414 NewExecutor(executor_type, params, *partition_graph, &item->executor));
1415 if (!options_.config.experimental().disable_output_partition_graphs() ||
1416 options_.config.graph_options().build_cost_model() > 0) {
1417 item->graph = std::move(partition_graph);
1418 }
1419 }
1420
1421 // Cache the mapping from input/output names to graph elements to
1422 // avoid recomputing it every time.
1423 if (!run_state_args->is_partial_run) {
1424 // For regular `Run()`, we use the function calling convention, and so
1425 // maintain a mapping from input/output names to
1426 // argument/return-value ordinal index.
1427 for (int i = 0; i < callable_options.feed().size(); ++i) {
1428 const string& input = callable_options.feed(i);
1429 ek->input_name_to_index[input] = i;
1430 }
1431 for (int i = 0; i < callable_options.fetch().size(); ++i) {
1432 const string& output = callable_options.fetch(i);
1433 ek->output_name_to_index[output] = i;
1434 }
1435 } else {
1436 // For `PRun()`, we use the rendezvous calling convention, and so
1437 // maintain a mapping from input/output names to rendezvous keys.
1438 //
1439 // We always use the first device as the device name portion of the
1440 // key, even if we're feeding another graph.
1441 for (int i = 0; i < callable_options.feed().size(); ++i) {
1442 const string& input = callable_options.feed(i);
1443 ek->input_name_to_rendezvous_key[input] = GetRendezvousKey(
1444 input, device_set_.client_device()->attributes(), FrameAndIter(0, 0));
1445 }
1446 for (int i = 0; i < callable_options.fetch().size(); ++i) {
1447 const string& output = callable_options.fetch(i);
1448 ek->output_name_to_rendezvous_key[output] =
1449 GetRendezvousKey(output, device_set_.client_device()->attributes(),
1450 FrameAndIter(0, 0));
1451 }
1452 }
1453
1454 *out_executors_and_keys = std::move(ek);
1455 *out_func_info = std::move(func_info);
1456 return Status::OK();
1457 }
1458
GetOrCreateExecutors(gtl::ArraySlice<string> inputs,gtl::ArraySlice<string> outputs,gtl::ArraySlice<string> target_nodes,ExecutorsAndKeys ** executors_and_keys,RunStateArgs * run_state_args)1459 Status DirectSession::GetOrCreateExecutors(
1460 gtl::ArraySlice<string> inputs, gtl::ArraySlice<string> outputs,
1461 gtl::ArraySlice<string> target_nodes, ExecutorsAndKeys** executors_and_keys,
1462 RunStateArgs* run_state_args) {
1463 int64_t handle_name_counter_value = -1;
1464 if (LogMemory::IsEnabled() || run_state_args->is_partial_run) {
1465 handle_name_counter_value = handle_name_counter_.fetch_add(1);
1466 }
1467
1468 string debug_tensor_watches_summary;
1469 if (!run_state_args->debug_options.debug_tensor_watch_opts().empty()) {
1470 debug_tensor_watches_summary = SummarizeDebugTensorWatches(
1471 run_state_args->debug_options.debug_tensor_watch_opts());
1472 }
1473
1474 // Fast lookup path, no sorting.
1475 const string key = strings::StrCat(
1476 absl::StrJoin(inputs, ","), "->", absl::StrJoin(outputs, ","), "/",
1477 absl::StrJoin(target_nodes, ","), "/", run_state_args->is_partial_run,
1478 "/", debug_tensor_watches_summary);
1479 // Set the handle, if it's needed to log memory or for partial run.
1480 if (handle_name_counter_value >= 0) {
1481 run_state_args->handle =
1482 strings::StrCat(key, ";", handle_name_counter_value);
1483 }
1484
1485 // See if we already have the executors for this run.
1486 {
1487 mutex_lock l(executor_lock_); // could use reader lock
1488 auto it = executors_.find(key);
1489 if (it != executors_.end()) {
1490 *executors_and_keys = it->second.get();
1491 return Status::OK();
1492 }
1493 }
1494
1495 // Slow lookup path, the unsorted key missed the cache.
1496 // Sort the inputs and outputs, and look up with the sorted key in case an
1497 // earlier call used a different order of inputs and outputs.
1498 //
1499 // We could consider some other signature instead of sorting that
1500 // preserves the same property to avoid the sort in the future.
1501 std::vector<string> inputs_sorted(inputs.begin(), inputs.end());
1502 std::sort(inputs_sorted.begin(), inputs_sorted.end());
1503 std::vector<string> outputs_sorted(outputs.begin(), outputs.end());
1504 std::sort(outputs_sorted.begin(), outputs_sorted.end());
1505 std::vector<string> tn_sorted(target_nodes.begin(), target_nodes.end());
1506 std::sort(tn_sorted.begin(), tn_sorted.end());
1507
1508 const string sorted_key = strings::StrCat(
1509 absl::StrJoin(inputs_sorted, ","), "->",
1510 absl::StrJoin(outputs_sorted, ","), "/", absl::StrJoin(tn_sorted, ","),
1511 "/", run_state_args->is_partial_run, "/", debug_tensor_watches_summary);
1512 // Set the handle, if its needed to log memory or for partial run.
1513 if (handle_name_counter_value >= 0) {
1514 run_state_args->handle =
1515 strings::StrCat(sorted_key, ";", handle_name_counter_value);
1516 }
1517
1518 // See if we already have the executors for this run.
1519 {
1520 mutex_lock l(executor_lock_);
1521 auto it = executors_.find(sorted_key);
1522 if (it != executors_.end()) {
1523 *executors_and_keys = it->second.get();
1524 return Status::OK();
1525 }
1526 }
1527
1528 // Nothing found, so create the executors and store in the cache.
1529 // The executor_lock_ is intentionally released while executors are
1530 // being created.
1531 CallableOptions callable_options;
1532 callable_options.mutable_feed()->Reserve(inputs_sorted.size());
1533 for (const string& input : inputs_sorted) {
1534 callable_options.add_feed(input);
1535 }
1536 callable_options.mutable_fetch()->Reserve(outputs_sorted.size());
1537 for (const string& output : outputs_sorted) {
1538 callable_options.add_fetch(output);
1539 }
1540 callable_options.mutable_target()->Reserve(tn_sorted.size());
1541 for (const string& target : tn_sorted) {
1542 callable_options.add_target(target);
1543 }
1544 *callable_options.mutable_run_options()->mutable_debug_options() =
1545 run_state_args->debug_options;
1546 callable_options.mutable_run_options()
1547 ->mutable_experimental()
1548 ->set_collective_graph_key(run_state_args->collective_graph_key);
1549 std::unique_ptr<ExecutorsAndKeys> ek;
1550 std::unique_ptr<FunctionInfo> func_info;
1551 TF_RETURN_IF_ERROR(
1552 CreateExecutors(callable_options, &ek, &func_info, run_state_args));
1553
1554 // Reacquire the lock, try to insert into the map.
1555 mutex_lock l(executor_lock_);
1556
1557 // Another thread may have created the entry before us, in which case we will
1558 // reuse the already created one.
1559 auto insert_result = executors_.emplace(
1560 sorted_key, std::shared_ptr<ExecutorsAndKeys>(std::move(ek)));
1561 if (insert_result.second) {
1562 functions_.push_back(std::move(func_info));
1563 }
1564
1565 // Insert the value under the original key, so the fast path lookup will work
1566 // if the user uses the same order of inputs, outputs, and targets again.
1567 executors_.emplace(key, insert_result.first->second);
1568 *executors_and_keys = insert_result.first->second.get();
1569
1570 return Status::OK();
1571 }
1572
CreateGraphs(const BuildGraphOptions & subgraph_options,std::unordered_map<string,std::unique_ptr<Graph>> * outputs,std::unique_ptr<FunctionLibraryDefinition> * flib_def,RunStateArgs * run_state_args,DataTypeVector * input_types,DataTypeVector * output_types,int64 * collective_graph_key)1573 Status DirectSession::CreateGraphs(
1574 const BuildGraphOptions& subgraph_options,
1575 std::unordered_map<string, std::unique_ptr<Graph>>* outputs,
1576 std::unique_ptr<FunctionLibraryDefinition>* flib_def,
1577 RunStateArgs* run_state_args, DataTypeVector* input_types,
1578 DataTypeVector* output_types, int64* collective_graph_key) {
1579 mutex_lock l(graph_state_lock_);
1580 if (finalized_) {
1581 return errors::FailedPrecondition("Session has been finalized.");
1582 }
1583
1584 std::unique_ptr<ClientGraph> client_graph;
1585
1586 std::unique_ptr<GraphExecutionState> temp_exec_state_holder;
1587 GraphExecutionState* execution_state = nullptr;
1588 if (options_.config.graph_options().place_pruned_graph()) {
1589 // Because we are placing pruned graphs, we need to create a
1590 // new GraphExecutionState for every new unseen graph,
1591 // and then place it.
1592 GraphExecutionStateOptions prune_options;
1593 prune_options.device_set = &device_set_;
1594 prune_options.session_options = &options_;
1595 prune_options.stateful_placements = stateful_placements_;
1596 prune_options.session_handle = session_handle_;
1597 TF_RETURN_IF_ERROR(GraphExecutionState::MakeForPrunedGraph(
1598 *execution_state_, prune_options, subgraph_options,
1599 &temp_exec_state_holder, &client_graph));
1600 execution_state = temp_exec_state_holder.get();
1601 } else {
1602 execution_state = execution_state_.get();
1603 TF_RETURN_IF_ERROR(
1604 execution_state->BuildGraph(subgraph_options, &client_graph));
1605 }
1606 *collective_graph_key = client_graph->collective_graph_key;
1607
1608 if (subgraph_options.callable_options.feed_size() !=
1609 client_graph->feed_types.size()) {
1610 return errors::Internal(
1611 "Graph pruning failed: requested number of feed endpoints = ",
1612 subgraph_options.callable_options.feed_size(),
1613 " versus number of pruned feed endpoints = ",
1614 client_graph->feed_types.size());
1615 }
1616 if (subgraph_options.callable_options.fetch_size() !=
1617 client_graph->fetch_types.size()) {
1618 return errors::Internal(
1619 "Graph pruning failed: requested number of fetch endpoints = ",
1620 subgraph_options.callable_options.fetch_size(),
1621 " versus number of pruned fetch endpoints = ",
1622 client_graph->fetch_types.size());
1623 }
1624
1625 auto current_stateful_placements = execution_state->GetStatefulPlacements();
1626 // Update our current state based on the execution_state's
1627 // placements. If there are any mismatches for a node,
1628 // we should fail, as this should never happen.
1629 for (const auto& placement_pair : current_stateful_placements) {
1630 const string& node_name = placement_pair.first;
1631 const string& placement = placement_pair.second;
1632 auto iter = stateful_placements_.find(node_name);
1633 if (iter == stateful_placements_.end()) {
1634 stateful_placements_.insert(std::make_pair(node_name, placement));
1635 } else if (iter->second != placement) {
1636 return errors::Internal(
1637 "Stateful placement mismatch. "
1638 "Current assignment of ",
1639 node_name, " to ", iter->second, " does not match ", placement);
1640 }
1641 }
1642
1643 stateful_placements_ = execution_state->GetStatefulPlacements();
1644
1645 // Remember the graph in run state if this is a partial run.
1646 if (run_state_args->is_partial_run) {
1647 run_state_args->graph.reset(new Graph(flib_def_.get()));
1648 CopyGraph(*execution_state->full_graph(), run_state_args->graph.get());
1649 }
1650
1651 // Partition the graph across devices.
1652 PartitionOptions popts;
1653 popts.node_to_loc = [](const Node* node) {
1654 return node->assigned_device_name();
1655 };
1656 popts.new_name = [this](const string& prefix) {
1657 return strings::StrCat(prefix, "/_", edge_name_counter_.fetch_add(1));
1658 };
1659 popts.get_incarnation = [](const string& name) {
1660 // The direct session does not have changing incarnation numbers.
1661 // Just return '1'.
1662 return 1;
1663 };
1664 popts.flib_def = flib_def->get();
1665 popts.control_flow_added = false;
1666
1667 std::unordered_map<string, GraphDef> partitions;
1668 TF_RETURN_IF_ERROR(Partition(popts, &client_graph->graph, &partitions));
1669
1670 std::vector<string> device_names;
1671 for (auto device : devices_) {
1672 // Extract the LocalName from the device.
1673 device_names.push_back(DeviceNameUtils::LocalName(device->name()));
1674 }
1675
1676 // Check for valid partitions.
1677 for (const auto& partition : partitions) {
1678 const string local_partition_name =
1679 DeviceNameUtils::LocalName(partition.first);
1680 if (std::count(device_names.begin(), device_names.end(),
1681 local_partition_name) == 0) {
1682 return errors::InvalidArgument(
1683 "Creating a partition for ", local_partition_name,
1684 " which doesn't exist in the list of available devices. Available "
1685 "devices: ",
1686 absl::StrJoin(device_names, ","));
1687 }
1688 }
1689
1690 for (auto& partition : partitions) {
1691 std::unique_ptr<Graph> device_graph(
1692 new Graph(client_graph->flib_def.get()));
1693 device_graph->SetConstructionContext(ConstructionContext::kDirectSession);
1694 GraphConstructorOptions device_opts;
1695 // There are internal operations (e.g., send/recv) that we now allow.
1696 device_opts.allow_internal_ops = true;
1697 device_opts.expect_device_spec = true;
1698 TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(
1699 device_opts, std::move(partition.second), device_graph.get()));
1700 outputs->emplace(partition.first, std::move(device_graph));
1701 }
1702
1703 GraphOptimizationPassOptions optimization_options;
1704 optimization_options.session_options = &options_;
1705 optimization_options.flib_def = client_graph->flib_def.get();
1706 optimization_options.partition_graphs = outputs;
1707 TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
1708 OptimizationPassRegistry::POST_PARTITIONING, optimization_options));
1709
1710 Status s;
1711 for (auto& partition : *outputs) {
1712 const string& partition_name = partition.first;
1713 std::unique_ptr<Graph>* graph = &partition.second;
1714
1715 VLOG(2) << "Created " << DebugString(graph->get()) << " for "
1716 << partition_name;
1717
1718 // Give the device an opportunity to rewrite its subgraph.
1719 Device* d;
1720 s = device_mgr_->LookupDevice(partition_name, &d);
1721 if (!s.ok()) break;
1722 s = d->MaybeRewriteGraph(graph);
1723 if (!s.ok()) {
1724 break;
1725 }
1726 }
1727 *flib_def = std::move(client_graph->flib_def);
1728 std::swap(*input_types, client_graph->feed_types);
1729 std::swap(*output_types, client_graph->fetch_types);
1730 return s;
1731 }
1732
ListDevices(std::vector<DeviceAttributes> * response)1733 ::tensorflow::Status DirectSession::ListDevices(
1734 std::vector<DeviceAttributes>* response) {
1735 response->clear();
1736 response->reserve(devices_.size());
1737 for (Device* d : devices_) {
1738 const DeviceAttributes& attrs = d->attributes();
1739 response->emplace_back(attrs);
1740 }
1741 return ::tensorflow::Status::OK();
1742 }
1743
Reset(const std::vector<string> & containers)1744 ::tensorflow::Status DirectSession::Reset(
1745 const std::vector<string>& containers) {
1746 device_mgr_->ClearContainers(containers);
1747 return ::tensorflow::Status::OK();
1748 }
1749
Close()1750 ::tensorflow::Status DirectSession::Close() {
1751 cancellation_manager_->StartCancel();
1752 {
1753 mutex_lock l(closed_lock_);
1754 if (closed_) return ::tensorflow::Status::OK();
1755 closed_ = true;
1756 }
1757 if (factory_ != nullptr) factory_->Deregister(this);
1758 return ::tensorflow::Status::OK();
1759 }
1760
RunState(int64_t step_id,const std::vector<Device * > * devices)1761 DirectSession::RunState::RunState(int64_t step_id,
1762 const std::vector<Device*>* devices)
1763 : step_container(step_id, [devices, step_id](const string& name) {
1764 for (auto d : *devices) {
1765 if (!d->resource_manager()->Cleanup(name).ok()) {
1766 // Do nothing...
1767 }
1768 ScopedAllocatorMgr* sam = d->GetScopedAllocatorMgr();
1769 if (sam) sam->Cleanup(step_id);
1770 }
1771 }) {}
1772
PartialRunState(const std::vector<string> & pending_input_names,const std::vector<string> & pending_output_names,int64_t step_id,const std::vector<Device * > * devices)1773 DirectSession::PartialRunState::PartialRunState(
1774 const std::vector<string>& pending_input_names,
1775 const std::vector<string>& pending_output_names, int64_t step_id,
1776 const std::vector<Device*>* devices)
1777 : RunState(step_id, devices) {
1778 // Initially all the feeds and fetches are pending.
1779 for (auto& name : pending_input_names) {
1780 pending_inputs[name] = false;
1781 }
1782 for (auto& name : pending_output_names) {
1783 pending_outputs[name] = false;
1784 }
1785 }
1786
~PartialRunState()1787 DirectSession::PartialRunState::~PartialRunState() {
1788 if (rendez != nullptr) {
1789 rendez->StartAbort(errors::Cancelled("PRun cancellation"));
1790 executors_done.WaitForNotification();
1791 }
1792 }
1793
PendingDone() const1794 bool DirectSession::PartialRunState::PendingDone() const {
1795 for (const auto& it : pending_inputs) {
1796 if (!it.second) return false;
1797 }
1798 for (const auto& it : pending_outputs) {
1799 if (!it.second) return false;
1800 }
1801 return true;
1802 }
1803
WaitForNotification(Notification * n,RunState * run_state,CancellationManager * cm,int64_t timeout_in_ms)1804 void DirectSession::WaitForNotification(Notification* n, RunState* run_state,
1805 CancellationManager* cm,
1806 int64_t timeout_in_ms) {
1807 const Status status = WaitForNotification(n, timeout_in_ms);
1808 if (!status.ok()) {
1809 {
1810 mutex_lock l(run_state->mu);
1811 run_state->status.Update(status);
1812 }
1813 cm->StartCancel();
1814 // We must wait for the executors to complete, because they have borrowed
1815 // references to `cm` and other per-step state. After this notification, it
1816 // is safe to clean up the step.
1817 n->WaitForNotification();
1818 }
1819 }
1820
WaitForNotification(Notification * notification,int64_t timeout_in_ms)1821 ::tensorflow::Status DirectSession::WaitForNotification(
1822 Notification* notification, int64_t timeout_in_ms) {
1823 if (timeout_in_ms > 0) {
1824 const int64_t timeout_in_us = timeout_in_ms * 1000;
1825 const bool notified =
1826 WaitForNotificationWithTimeout(notification, timeout_in_us);
1827 if (!notified) {
1828 return Status(error::DEADLINE_EXCEEDED,
1829 "Timed out waiting for notification");
1830 }
1831 } else {
1832 notification->WaitForNotification();
1833 }
1834 return Status::OK();
1835 }
1836
MakeCallable(const CallableOptions & callable_options,CallableHandle * out_handle)1837 Status DirectSession::MakeCallable(const CallableOptions& callable_options,
1838 CallableHandle* out_handle) {
1839 TF_RETURN_IF_ERROR(CheckNotClosed());
1840 TF_RETURN_IF_ERROR(CheckGraphCreated("MakeCallable()"));
1841
1842 std::unique_ptr<ExecutorsAndKeys> ek;
1843 std::unique_ptr<FunctionInfo> func_info;
1844 RunStateArgs run_state_args(callable_options.run_options().debug_options());
1845 TF_RETURN_IF_ERROR(
1846 CreateExecutors(callable_options, &ek, &func_info, &run_state_args));
1847 {
1848 mutex_lock l(callables_lock_);
1849 *out_handle = next_callable_handle_++;
1850 callables_[*out_handle] = {std::move(ek), std::move(func_info)};
1851 }
1852 return Status::OK();
1853 }
1854
1855 class DirectSession::RunCallableCallFrame : public CallFrameInterface {
1856 public:
RunCallableCallFrame(DirectSession * session,ExecutorsAndKeys * executors_and_keys,const std::vector<Tensor> * feed_tensors,std::vector<Tensor> * fetch_tensors)1857 RunCallableCallFrame(DirectSession* session,
1858 ExecutorsAndKeys* executors_and_keys,
1859 const std::vector<Tensor>* feed_tensors,
1860 std::vector<Tensor>* fetch_tensors)
1861 : session_(session),
1862 executors_and_keys_(executors_and_keys),
1863 feed_tensors_(feed_tensors),
1864 fetch_tensors_(fetch_tensors) {}
1865
num_args() const1866 size_t num_args() const override {
1867 return executors_and_keys_->input_types.size();
1868 }
num_retvals() const1869 size_t num_retvals() const override {
1870 return executors_and_keys_->output_types.size();
1871 }
1872
GetArg(int index,const Tensor ** val)1873 Status GetArg(int index, const Tensor** val) override {
1874 if (TF_PREDICT_FALSE(index > feed_tensors_->size())) {
1875 return errors::Internal("Args index out of bounds: ", index);
1876 } else {
1877 *val = &(*feed_tensors_)[index];
1878 }
1879 return Status::OK();
1880 }
1881
SetRetval(int index,const Tensor & val)1882 Status SetRetval(int index, const Tensor& val) override {
1883 if (index > fetch_tensors_->size()) {
1884 return errors::Internal("RetVal index out of bounds: ", index);
1885 }
1886 (*fetch_tensors_)[index] = val;
1887 return Status::OK();
1888 }
1889
1890 private:
1891 DirectSession* const session_; // Not owned.
1892 ExecutorsAndKeys* const executors_and_keys_; // Not owned.
1893 const std::vector<Tensor>* const feed_tensors_; // Not owned.
1894 std::vector<Tensor>* const fetch_tensors_; // Not owned.
1895 };
1896
RunCallable(CallableHandle handle,const std::vector<Tensor> & feed_tensors,std::vector<Tensor> * fetch_tensors,RunMetadata * run_metadata)1897 ::tensorflow::Status DirectSession::RunCallable(
1898 CallableHandle handle, const std::vector<Tensor>& feed_tensors,
1899 std::vector<Tensor>* fetch_tensors, RunMetadata* run_metadata) {
1900 return RunCallable(handle, feed_tensors, fetch_tensors, run_metadata,
1901 thread::ThreadPoolOptions());
1902 }
1903
RunCallable(CallableHandle handle,const std::vector<Tensor> & feed_tensors,std::vector<Tensor> * fetch_tensors,RunMetadata * run_metadata,const thread::ThreadPoolOptions & threadpool_options)1904 ::tensorflow::Status DirectSession::RunCallable(
1905 CallableHandle handle, const std::vector<Tensor>& feed_tensors,
1906 std::vector<Tensor>* fetch_tensors, RunMetadata* run_metadata,
1907 const thread::ThreadPoolOptions& threadpool_options) {
1908 TF_RETURN_IF_ERROR(CheckNotClosed());
1909 TF_RETURN_IF_ERROR(CheckGraphCreated("RunCallable()"));
1910 direct_session_runs->GetCell()->IncrementBy(1);
1911
1912 // Check if we already have an executor for these arguments.
1913 std::shared_ptr<ExecutorsAndKeys> executors_and_keys;
1914 const int64_t step_id = step_id_counter_.fetch_add(1);
1915
1916 {
1917 tf_shared_lock l(callables_lock_);
1918 if (handle >= next_callable_handle_) {
1919 return errors::InvalidArgument("No such callable handle: ", handle);
1920 }
1921 executors_and_keys = callables_[handle].executors_and_keys;
1922 }
1923
1924 if (!executors_and_keys) {
1925 return errors::InvalidArgument(
1926 "Attempted to run callable after handle was released: ", handle);
1927 }
1928
1929 // NOTE(mrry): Debug options are not currently supported in the
1930 // callable interface.
1931 DebugOptions debug_options;
1932 RunStateArgs run_state_args(debug_options);
1933
1934 // Configure a call frame for the step, which we use to feed and
1935 // fetch values to and from the executors.
1936 if (feed_tensors.size() != executors_and_keys->input_types.size()) {
1937 return errors::InvalidArgument(
1938 "Expected ", executors_and_keys->input_types.size(),
1939 " feed tensors, but got ", feed_tensors.size());
1940 }
1941 if (fetch_tensors != nullptr) {
1942 fetch_tensors->resize(executors_and_keys->output_types.size());
1943 } else if (!executors_and_keys->output_types.empty()) {
1944 return errors::InvalidArgument(
1945 "`fetch_tensors` must be provided when the callable has one or more "
1946 "outputs.");
1947 }
1948
1949 size_t input_size = 0;
1950 bool any_resource_feeds = false;
1951 for (auto& tensor : feed_tensors) {
1952 input_size += tensor.AllocatedBytes();
1953 any_resource_feeds = any_resource_feeds || tensor.dtype() == DT_RESOURCE;
1954 }
1955 metrics::RecordGraphInputTensors(input_size);
1956
1957 std::unique_ptr<std::vector<Tensor>> converted_feed_tensors;
1958 const std::vector<Tensor>* actual_feed_tensors;
1959
1960 if (TF_PREDICT_FALSE(any_resource_feeds)) {
1961 converted_feed_tensors = absl::make_unique<std::vector<Tensor>>();
1962 converted_feed_tensors->reserve(feed_tensors.size());
1963 for (const Tensor& t : feed_tensors) {
1964 if (t.dtype() == DT_RESOURCE) {
1965 converted_feed_tensors->emplace_back();
1966 Tensor* tensor_from_handle = &converted_feed_tensors->back();
1967 TF_RETURN_IF_ERROR(ResourceHandleToInputTensor(t, tensor_from_handle));
1968 } else {
1969 converted_feed_tensors->emplace_back(t);
1970 }
1971 }
1972 actual_feed_tensors = converted_feed_tensors.get();
1973 } else {
1974 actual_feed_tensors = &feed_tensors;
1975 }
1976
1977 // A specialized CallFrame implementation that takes advantage of the
1978 // optimized RunCallable interface.
1979 RunCallableCallFrame call_frame(this, executors_and_keys.get(),
1980 actual_feed_tensors, fetch_tensors);
1981
1982 if (LogMemory::IsEnabled()) {
1983 LogMemory::RecordStep(step_id, run_state_args.handle);
1984 }
1985
1986 TF_RETURN_IF_ERROR(RunInternal(
1987 step_id, executors_and_keys->callable_options.run_options(), &call_frame,
1988 executors_and_keys.get(), run_metadata, threadpool_options));
1989
1990 if (fetch_tensors != nullptr) {
1991 size_t output_size = 0;
1992 for (auto& tensor : *fetch_tensors) {
1993 output_size += tensor.AllocatedBytes();
1994 }
1995 metrics::RecordGraphOutputTensors(output_size);
1996 }
1997
1998 return Status::OK();
1999 }
2000
ReleaseCallable(CallableHandle handle)2001 ::tensorflow::Status DirectSession::ReleaseCallable(CallableHandle handle) {
2002 mutex_lock l(callables_lock_);
2003 if (handle >= next_callable_handle_) {
2004 return errors::InvalidArgument("No such callable handle: ", handle);
2005 }
2006 callables_.erase(handle);
2007 return Status::OK();
2008 }
2009
Finalize()2010 Status DirectSession::Finalize() {
2011 mutex_lock l(graph_state_lock_);
2012 if (finalized_) {
2013 return errors::FailedPrecondition("Session already finalized.");
2014 }
2015 if (!graph_created_) {
2016 return errors::FailedPrecondition("Session not yet created.");
2017 }
2018 execution_state_.reset();
2019 flib_def_.reset();
2020 finalized_ = true;
2021 return Status::OK();
2022 }
2023
~Callable()2024 DirectSession::Callable::~Callable() {
2025 // We must delete the fields in this order, because the destructor
2026 // of `executors_and_keys` will call into an object owned by
2027 // `function_info` (in particular, when deleting a kernel, it relies
2028 // on the `FunctionLibraryRuntime` to know if the kernel is stateful
2029 // or not).
2030 executors_and_keys.reset();
2031 function_info.reset();
2032 }
2033
2034 } // namespace tensorflow
2035