1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/common_runtime/direct_session.h"
17
18 #include <algorithm>
19 #include <atomic>
20 #include <string>
21 #include <vector>
22
23 #include "absl/container/flat_hash_set.h"
24 #include "tensorflow/core/common_runtime/collective_executor_mgr.h"
25 #include "tensorflow/core/common_runtime/collective_param_resolver_local.h"
26 #include "tensorflow/core/common_runtime/constant_folding.h"
27 #include "tensorflow/core/common_runtime/debugger_state_interface.h"
28 #include "tensorflow/core/common_runtime/device_factory.h"
29 #include "tensorflow/core/common_runtime/device_resolver_local.h"
30 #include "tensorflow/core/common_runtime/executor.h"
31 #include "tensorflow/core/common_runtime/executor_factory.h"
32 #include "tensorflow/core/common_runtime/function.h"
33 #include "tensorflow/core/common_runtime/graph_constructor.h"
34 #include "tensorflow/core/common_runtime/graph_optimizer.h"
35 #include "tensorflow/core/common_runtime/memory_types.h"
36 #include "tensorflow/core/common_runtime/metrics.h"
37 #include "tensorflow/core/common_runtime/optimization_registry.h"
38 #include "tensorflow/core/common_runtime/process_util.h"
39 #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
40 #include "tensorflow/core/common_runtime/scoped_allocator_mgr.h"
41 #include "tensorflow/core/common_runtime/step_stats_collector.h"
42 #include "tensorflow/core/framework/function.h"
43 #include "tensorflow/core/framework/graph.pb.h"
44 #include "tensorflow/core/framework/graph_def_util.h"
45 #include "tensorflow/core/framework/log_memory.h"
46 #include "tensorflow/core/framework/logging.h"
47 #include "tensorflow/core/framework/node_def.pb.h"
48 #include "tensorflow/core/framework/run_handler.h"
49 #include "tensorflow/core/framework/tensor.h"
50 #include "tensorflow/core/framework/versions.pb.h"
51 #include "tensorflow/core/graph/algorithm.h"
52 #include "tensorflow/core/graph/graph.h"
53 #include "tensorflow/core/graph/graph_partition.h"
54 #include "tensorflow/core/graph/subgraph.h"
55 #include "tensorflow/core/graph/tensor_id.h"
56 #include "tensorflow/core/lib/core/errors.h"
57 #include "tensorflow/core/lib/core/notification.h"
58 #include "tensorflow/core/lib/core/refcount.h"
59 #include "tensorflow/core/lib/core/status.h"
60 #include "tensorflow/core/lib/core/threadpool.h"
61 #include "tensorflow/core/lib/core/threadpool_options.h"
62 #include "tensorflow/core/lib/gtl/array_slice.h"
63 #include "tensorflow/core/lib/monitoring/counter.h"
64 #include "tensorflow/core/lib/random/random.h"
65 #include "tensorflow/core/lib/strings/numbers.h"
66 #include "tensorflow/core/lib/strings/str_util.h"
67 #include "tensorflow/core/lib/strings/strcat.h"
68 #include "tensorflow/core/nccl/collective_communicator.h"
69 #include "tensorflow/core/platform/byte_order.h"
70 #include "tensorflow/core/platform/cpu_info.h"
71 #include "tensorflow/core/platform/logging.h"
72 #include "tensorflow/core/platform/mutex.h"
73 #include "tensorflow/core/platform/tracing.h"
74 #include "tensorflow/core/platform/types.h"
75 #include "tensorflow/core/profiler/lib/connected_traceme.h"
76 #include "tensorflow/core/profiler/lib/profiler_session.h"
77 #include "tensorflow/core/profiler/lib/traceme_encode.h"
78 #include "tensorflow/core/protobuf/config.pb.h"
79 #include "tensorflow/core/util/device_name_utils.h"
80 #include "tensorflow/core/util/env_var.h"
81
82 namespace tensorflow {
83
84 namespace {
85
86 auto* direct_session_runs = monitoring::Counter<0>::New(
87 "/tensorflow/core/direct_session_runs",
88 "The number of times DirectSession::Run() has been called.");
89
NewThreadPoolFromThreadPoolOptions(const SessionOptions & options,const ThreadPoolOptionProto & thread_pool_options,int pool_number,thread::ThreadPool ** pool,bool * owned)90 Status NewThreadPoolFromThreadPoolOptions(
91 const SessionOptions& options,
92 const ThreadPoolOptionProto& thread_pool_options, int pool_number,
93 thread::ThreadPool** pool, bool* owned) {
94 int32 num_threads = thread_pool_options.num_threads();
95 if (num_threads == 0) {
96 num_threads = NumInterOpThreadsFromSessionOptions(options);
97 }
98 const string& name = thread_pool_options.global_name();
99 if (name.empty()) {
100 // Session-local threadpool.
101 VLOG(1) << "Direct session inter op parallelism threads for pool "
102 << pool_number << ": " << num_threads;
103 *pool = new thread::ThreadPool(
104 options.env, ThreadOptions(), strings::StrCat("Compute", pool_number),
105 num_threads, !options.config.experimental().disable_thread_spinning(),
106 /*allocator=*/nullptr);
107 *owned = true;
108 return Status::OK();
109 }
110
111 // Global, named threadpool.
112 typedef std::pair<int32, thread::ThreadPool*> MapValue;
113 static std::map<string, MapValue>* global_pool_map =
114 new std::map<string, MapValue>;
115 static mutex* mu = new mutex();
116 mutex_lock l(*mu);
117 MapValue* mvalue = &(*global_pool_map)[name];
118 if (mvalue->second == nullptr) {
119 mvalue->first = thread_pool_options.num_threads();
120 mvalue->second = new thread::ThreadPool(
121 options.env, ThreadOptions(), strings::StrCat("Compute", pool_number),
122 num_threads, !options.config.experimental().disable_thread_spinning(),
123 /*allocator=*/nullptr);
124 } else {
125 if (mvalue->first != thread_pool_options.num_threads()) {
126 return errors::InvalidArgument(
127 "Pool ", name,
128 " configured previously with num_threads=", mvalue->first,
129 "; cannot re-configure with num_threads=",
130 thread_pool_options.num_threads());
131 }
132 }
133 *owned = false;
134 *pool = mvalue->second;
135 return Status::OK();
136 }
137
GlobalThreadPool(const SessionOptions & options)138 thread::ThreadPool* GlobalThreadPool(const SessionOptions& options) {
139 static thread::ThreadPool* const thread_pool =
140 NewThreadPoolFromSessionOptions(options);
141 return thread_pool;
142 }
143
144 // TODO(vrv): Figure out how to unify the many different functions
145 // that generate RendezvousKey, since many of them have to be
146 // consistent with each other.
GetRendezvousKey(const string & tensor_name,const DeviceAttributes & device_info,const FrameAndIter & frame_iter)147 string GetRendezvousKey(const string& tensor_name,
148 const DeviceAttributes& device_info,
149 const FrameAndIter& frame_iter) {
150 return strings::StrCat(device_info.name(), ";",
151 strings::FpToString(device_info.incarnation()), ";",
152 device_info.name(), ";", tensor_name, ";",
153 frame_iter.frame_id, ":", frame_iter.iter_id);
154 }
155
156 } // namespace
157
158 class DirectSessionFactory : public SessionFactory {
159 public:
DirectSessionFactory()160 DirectSessionFactory() {}
161
AcceptsOptions(const SessionOptions & options)162 bool AcceptsOptions(const SessionOptions& options) override {
163 return options.target.empty();
164 }
165
NewSession(const SessionOptions & options,Session ** out_session)166 Status NewSession(const SessionOptions& options,
167 Session** out_session) override {
168 const auto& experimental_config = options.config.experimental();
169 if (experimental_config.has_session_metadata()) {
170 if (experimental_config.session_metadata().version() < 0) {
171 return errors::InvalidArgument(
172 "Session version shouldn't be negative: ",
173 experimental_config.session_metadata().DebugString());
174 }
175 const string key = GetMetadataKey(experimental_config.session_metadata());
176 mutex_lock l(sessions_lock_);
177 if (!session_metadata_keys_.insert(key).second) {
178 return errors::InvalidArgument(
179 "A session with the same name and version has already been "
180 "created: ",
181 experimental_config.session_metadata().DebugString());
182 }
183 }
184
185 // Must do this before the CPU allocator is created.
186 if (options.config.graph_options().build_cost_model() > 0) {
187 EnableCPUAllocatorFullStats();
188 }
189 std::vector<std::unique_ptr<Device>> devices;
190 TF_RETURN_IF_ERROR(DeviceFactory::AddDevices(
191 options, "/job:localhost/replica:0/task:0", &devices));
192
193 DirectSession* session = new DirectSession(
194 options, new StaticDeviceMgr(std::move(devices)), this);
195 {
196 mutex_lock l(sessions_lock_);
197 sessions_.push_back(session);
198 }
199 *out_session = session;
200 return Status::OK();
201 }
202
Reset(const SessionOptions & options,const std::vector<string> & containers)203 Status Reset(const SessionOptions& options,
204 const std::vector<string>& containers) override {
205 std::vector<DirectSession*> sessions_to_reset;
206 {
207 mutex_lock l(sessions_lock_);
208 // We create a copy to ensure that we don't have a deadlock when
209 // session->Close calls the DirectSessionFactory.Deregister, which
210 // acquires sessions_lock_.
211 std::swap(sessions_to_reset, sessions_);
212 }
213 Status s;
214 for (auto session : sessions_to_reset) {
215 s.Update(session->Reset(containers));
216 }
217 // TODO(suharshs): Change the Reset behavior of all SessionFactories so that
218 // it doesn't close the sessions?
219 for (auto session : sessions_to_reset) {
220 s.Update(session->Close());
221 }
222 return s;
223 }
224
Deregister(const DirectSession * session)225 void Deregister(const DirectSession* session) {
226 mutex_lock l(sessions_lock_);
227 sessions_.erase(std::remove(sessions_.begin(), sessions_.end(), session),
228 sessions_.end());
229 if (session->options().config.experimental().has_session_metadata()) {
230 session_metadata_keys_.erase(GetMetadataKey(
231 session->options().config.experimental().session_metadata()));
232 }
233 }
234
235 private:
GetMetadataKey(const SessionMetadata & metadata)236 static string GetMetadataKey(const SessionMetadata& metadata) {
237 return absl::StrCat(metadata.name(), "/", metadata.version());
238 }
239
240 mutex sessions_lock_;
241 std::vector<DirectSession*> sessions_ TF_GUARDED_BY(sessions_lock_);
242 absl::flat_hash_set<string> session_metadata_keys_
243 TF_GUARDED_BY(sessions_lock_);
244 };
245
246 class DirectSessionRegistrar {
247 public:
DirectSessionRegistrar()248 DirectSessionRegistrar() {
249 SessionFactory::Register("DIRECT_SESSION", new DirectSessionFactory());
250 }
251 };
252 static DirectSessionRegistrar registrar;
253
254 std::atomic_int_fast64_t DirectSession::step_id_counter_(1);
255
GetOrCreateRunHandlerPool(const SessionOptions & options)256 static RunHandlerPool* GetOrCreateRunHandlerPool(
257 const SessionOptions& options) {
258 int num_inter_threads = 0;
259 int num_intra_threads = 0;
260 static const int env_num_inter_threads = NumInterOpThreadsFromEnvironment();
261 static const int env_num_intra_threads = NumIntraOpThreadsFromEnvironment();
262 if (env_num_inter_threads > 0) {
263 num_inter_threads = env_num_inter_threads;
264 }
265 if (env_num_intra_threads > 0) {
266 num_intra_threads = env_num_intra_threads;
267 }
268
269 if (num_inter_threads == 0) {
270 if (options.config.session_inter_op_thread_pool_size() > 0) {
271 // Note due to ShouldUseRunHandler we are guaranteed that
272 // run_options.inter_op_thread_pool() == 0
273 num_inter_threads =
274 options.config.session_inter_op_thread_pool(0).num_threads();
275 }
276 if (num_inter_threads == 0) {
277 num_inter_threads = NumInterOpThreadsFromSessionOptions(options);
278 }
279 }
280
281 if (num_intra_threads == 0) {
282 num_intra_threads = options.config.intra_op_parallelism_threads();
283 if (num_intra_threads == 0) {
284 num_intra_threads = port::MaxParallelism();
285 }
286 }
287
288 static RunHandlerPool* pool =
289 new RunHandlerPool(num_inter_threads, num_intra_threads);
290 return pool;
291 }
292
ShouldUseRunHandlerPool(const RunOptions & run_options) const293 bool DirectSession::ShouldUseRunHandlerPool(
294 const RunOptions& run_options) const {
295 if (options_.config.use_per_session_threads()) return false;
296 if (options_.config.session_inter_op_thread_pool_size() > 0 &&
297 run_options.inter_op_thread_pool() > 0)
298 return false;
299 // Only use RunHandlerPool when:
300 // a. Single global thread pool is used for inter-op parallelism.
301 // b. When multiple inter_op_thread_pool(s) are created, use it only while
302 // running sessions on the default inter_op_thread_pool=0. Typically,
303 // servo-team uses inter_op_thread_pool > 0 for model loading.
304 // TODO(crk): Revisit whether we'd want to create one (static) RunHandlerPool
305 // per entry in session_inter_op_thread_pool() in the future.
306 return true;
307 }
308
DirectSession(const SessionOptions & options,const DeviceMgr * device_mgr,DirectSessionFactory * const factory)309 DirectSession::DirectSession(const SessionOptions& options,
310 const DeviceMgr* device_mgr,
311 DirectSessionFactory* const factory)
312 : options_(options),
313 device_mgr_(device_mgr),
314 factory_(factory),
315 cancellation_manager_(new CancellationManager()),
316 operation_timeout_in_ms_(options_.config.operation_timeout_in_ms()) {
317 const int thread_pool_size =
318 options_.config.session_inter_op_thread_pool_size();
319 if (thread_pool_size > 0) {
320 for (int i = 0; i < thread_pool_size; ++i) {
321 thread::ThreadPool* pool = nullptr;
322 bool owned = false;
323 init_error_.Update(NewThreadPoolFromThreadPoolOptions(
324 options_, options_.config.session_inter_op_thread_pool(i), i, &pool,
325 &owned));
326 thread_pools_.emplace_back(pool, owned);
327 }
328 } else if (options_.config.use_per_session_threads()) {
329 thread_pools_.emplace_back(NewThreadPoolFromSessionOptions(options_),
330 true /* owned */);
331 } else {
332 thread_pools_.emplace_back(GlobalThreadPool(options), false /* owned */);
333 // Run locally if environment value of TF_NUM_INTEROP_THREADS is negative
334 // and config.inter_op_parallelism_threads is unspecified or negative.
335 static const int env_num_threads = NumInterOpThreadsFromEnvironment();
336 if (options_.config.inter_op_parallelism_threads() < 0 ||
337 (options_.config.inter_op_parallelism_threads() == 0 &&
338 env_num_threads < 0)) {
339 run_in_caller_thread_ = true;
340 }
341 }
342 // The default value of sync_on_finish will be flipped soon and this
343 // environment variable will be removed as well.
344 const Status status =
345 ReadBoolFromEnvVar("TF_SYNC_ON_FINISH", true, &sync_on_finish_);
346 if (!status.ok()) {
347 LOG(ERROR) << status.error_message();
348 }
349 session_handle_ =
350 strings::StrCat("direct", strings::FpToString(random::New64()));
351 int devices_added = 0;
352 if (options.config.log_device_placement()) {
353 const string mapping_str = device_mgr_->DeviceMappingString();
354 string msg;
355 if (mapping_str.empty()) {
356 msg = "Device mapping: no known devices.";
357 } else {
358 msg = strings::StrCat("Device mapping:\n", mapping_str);
359 }
360 if (!logging::LogToListeners(msg)) {
361 LOG(INFO) << msg;
362 }
363 }
364 for (auto d : device_mgr_->ListDevices()) {
365 devices_.push_back(d);
366 device_set_.AddDevice(d);
367 d->op_segment()->AddHold(session_handle_);
368
369 // The first device added is special: it is the 'client device' (a
370 // CPU device) from which we feed and fetch Tensors.
371 if (devices_added == 0) {
372 device_set_.set_client_device(d);
373 }
374 ++devices_added;
375 }
376 }
377
~DirectSession()378 DirectSession::~DirectSession() {
379 if (!closed_) Close().IgnoreError();
380 for (auto& it : partial_runs_) {
381 it.second.reset(nullptr);
382 }
383 for (auto& it : executors_) {
384 it.second.reset();
385 }
386 callables_.clear();
387 for (auto d : device_mgr_->ListDevices()) {
388 d->op_segment()->RemoveHold(session_handle_);
389 }
390 functions_.clear();
391 delete cancellation_manager_;
392 for (const auto& p_and_owned : thread_pools_) {
393 if (p_and_owned.second) delete p_and_owned.first;
394 }
395
396 execution_state_.reset(nullptr);
397 flib_def_.reset(nullptr);
398 }
399
Create(const GraphDef & graph)400 Status DirectSession::Create(const GraphDef& graph) {
401 return Create(GraphDef(graph));
402 }
403
Create(GraphDef && graph)404 Status DirectSession::Create(GraphDef&& graph) {
405 TF_RETURN_IF_ERROR(init_error_);
406 if (graph.node_size() > 0) {
407 mutex_lock l(graph_state_lock_);
408 if (graph_created_) {
409 return errors::AlreadyExists(
410 "A Graph has already been created for this session.");
411 }
412 return ExtendLocked(std::move(graph));
413 }
414 return Status::OK();
415 }
416
Extend(const GraphDef & graph)417 Status DirectSession::Extend(const GraphDef& graph) {
418 return Extend(GraphDef(graph));
419 }
420
Extend(GraphDef && graph)421 Status DirectSession::Extend(GraphDef&& graph) {
422 TF_RETURN_IF_ERROR(CheckNotClosed());
423 mutex_lock l(graph_state_lock_);
424 return ExtendLocked(std::move(graph));
425 }
426
ExtendLocked(GraphDef graph)427 Status DirectSession::ExtendLocked(GraphDef graph) {
428 if (finalized_) {
429 return errors::FailedPrecondition("Session has been finalized.");
430 }
431 if (!(flib_def_ && execution_state_)) {
432 // If this is the first call, we can initialize the execution state
433 // with `graph` and do not need to call `Extend()`.
434 // NOTE(mrry): The function library created here will be used for
435 // all subsequent extensions of the graph.
436 flib_def_.reset(
437 new FunctionLibraryDefinition(OpRegistry::Global(), graph.library()));
438 GraphExecutionStateOptions options;
439 options.device_set = &device_set_;
440 options.session_options = &options_;
441 options.session_handle = session_handle_;
442 TF_RETURN_IF_ERROR(GraphExecutionState::MakeForBaseGraph(
443 std::move(graph), options, &execution_state_));
444 graph_created_ = true;
445 } else {
446 TF_RETURN_IF_ERROR(flib_def_->AddLibrary(graph.library()));
447 std::unique_ptr<GraphExecutionState> state;
448 // TODO(mrry): Rewrite GraphExecutionState::Extend() to take `graph` by
449 // value and move `graph` in here.
450 TF_RETURN_IF_ERROR(execution_state_->Extend(graph, &state));
451 execution_state_.swap(state);
452 }
453 return Status::OK();
454 }
455
Run(const NamedTensorList & inputs,const std::vector<string> & output_names,const std::vector<string> & target_nodes,std::vector<Tensor> * outputs)456 Status DirectSession::Run(const NamedTensorList& inputs,
457 const std::vector<string>& output_names,
458 const std::vector<string>& target_nodes,
459 std::vector<Tensor>* outputs) {
460 RunMetadata run_metadata;
461 return Run(RunOptions(), inputs, output_names, target_nodes, outputs,
462 &run_metadata);
463 }
464
CreateDebuggerState(const CallableOptions & callable_options,int64 global_step,int64 session_run_index,int64 executor_step_index,std::unique_ptr<DebuggerStateInterface> * debugger_state)465 Status DirectSession::CreateDebuggerState(
466 const CallableOptions& callable_options, int64 global_step,
467 int64 session_run_index, int64 executor_step_index,
468 std::unique_ptr<DebuggerStateInterface>* debugger_state) {
469 TF_RETURN_IF_ERROR(DebuggerStateRegistry::CreateState(
470 callable_options.run_options().debug_options(), debugger_state));
471 std::vector<string> input_names(callable_options.feed().begin(),
472 callable_options.feed().end());
473 std::vector<string> output_names(callable_options.fetch().begin(),
474 callable_options.fetch().end());
475 std::vector<string> target_names(callable_options.target().begin(),
476 callable_options.target().end());
477
478 TF_RETURN_IF_ERROR(debugger_state->get()->PublishDebugMetadata(
479 global_step, session_run_index, executor_step_index, input_names,
480 output_names, target_names));
481 return Status::OK();
482 }
483
DecorateAndPublishGraphForDebug(const DebugOptions & debug_options,Graph * graph,Device * device)484 Status DirectSession::DecorateAndPublishGraphForDebug(
485 const DebugOptions& debug_options, Graph* graph, Device* device) {
486 std::unique_ptr<DebugGraphDecoratorInterface> decorator;
487 TF_RETURN_IF_ERROR(
488 DebugGraphDecoratorRegistry::CreateDecorator(debug_options, &decorator));
489
490 TF_RETURN_IF_ERROR(decorator->DecorateGraph(graph, device));
491 TF_RETURN_IF_ERROR(decorator->PublishGraph(*graph, device->name()));
492 return Status::OK();
493 }
494
RunInternal(int64 step_id,const RunOptions & run_options,CallFrameInterface * call_frame,ExecutorsAndKeys * executors_and_keys,RunMetadata * run_metadata,const thread::ThreadPoolOptions & threadpool_options)495 Status DirectSession::RunInternal(
496 int64 step_id, const RunOptions& run_options,
497 CallFrameInterface* call_frame, ExecutorsAndKeys* executors_and_keys,
498 RunMetadata* run_metadata,
499 const thread::ThreadPoolOptions& threadpool_options) {
500 const uint64 start_time_usecs = options_.env->NowMicros();
501 const int64 executor_step_count = executors_and_keys->step_count.fetch_add(1);
502 RunState run_state(step_id, &devices_);
503 const size_t num_executors = executors_and_keys->items.size();
504
505 profiler::TraceMeProducer activity(
506 // To TraceMeConsumers in ExecutorState::Process/Finish.
507 [&] {
508 if (options_.config.experimental().has_session_metadata()) {
509 const auto& model_metadata =
510 options_.config.experimental().session_metadata();
511 string model_id = strings::StrCat(model_metadata.name(), ":",
512 model_metadata.version());
513 return profiler::TraceMeEncode("SessionRun",
514 {{"id", step_id},
515 {"_r", 1} /*root_event*/,
516 {"model_id", model_id}});
517 } else {
518 return profiler::TraceMeEncode(
519 "SessionRun", {{"id", step_id}, {"_r", 1} /*root_event*/});
520 }
521 },
522 profiler::ContextType::kTfExecutor, step_id,
523 profiler::TraceMeLevel::kInfo);
524
525 std::unique_ptr<DebuggerStateInterface> debugger_state;
526 if (!run_options.debug_options().debug_tensor_watch_opts().empty()) {
527 TF_RETURN_IF_ERROR(
528 CreateDebuggerState(executors_and_keys->callable_options,
529 run_options.debug_options().global_step(), step_id,
530 executor_step_count, &debugger_state));
531 }
532
533 #ifndef __ANDROID__
534 // Set up for collectives if ExecutorsAndKeys declares a key.
535 if (executors_and_keys->collective_graph_key !=
536 BuildGraphOptions::kNoCollectiveGraphKey) {
537 if (run_options.experimental().collective_graph_key() !=
538 BuildGraphOptions::kNoCollectiveGraphKey) {
539 // If a collective_graph_key was specified in run_options, ensure that it
540 // matches what came out of GraphExecutionState::BuildGraph().
541 if (run_options.experimental().collective_graph_key() !=
542 executors_and_keys->collective_graph_key) {
543 return errors::Internal(
544 "collective_graph_key in RunOptions ",
545 run_options.experimental().collective_graph_key(),
546 " should match collective_graph_key from optimized graph ",
547 executors_and_keys->collective_graph_key);
548 }
549 }
550 if (!collective_executor_mgr_) {
551 std::unique_ptr<DeviceResolverInterface> drl(
552 new DeviceResolverLocal(device_mgr_.get()));
553 std::unique_ptr<ParamResolverInterface> cprl(
554 new CollectiveParamResolverLocal(options_.config, device_mgr_.get(),
555 drl.get(),
556 "/job:localhost/replica:0/task:0"));
557 collective_executor_mgr_.reset(new CollectiveExecutorMgr(
558 options_.config, device_mgr_.get(), std::move(drl), std::move(cprl),
559 MaybeCreateNcclCommunicator()));
560 }
561 run_state.collective_executor.reset(new CollectiveExecutor::Handle(
562 collective_executor_mgr_->FindOrCreate(step_id), true /*inherit_ref*/));
563 }
564 #endif
565
566 thread::ThreadPool* pool;
567 // Use std::unique_ptr to ensure garbage collection
568 std::unique_ptr<thread::ThreadPool> threadpool_wrapper;
569
570 const bool inline_execution_requested =
571 run_in_caller_thread_ || run_options.inter_op_thread_pool() == -1;
572
573 if (inline_execution_requested) {
574 // We allow using the caller thread only when having a single executor
575 // specified.
576 if (executors_and_keys->items.size() > 1) {
577 pool = thread_pools_[0].first;
578 } else {
579 VLOG(1) << "Executing Session::Run() synchronously!";
580 pool = nullptr;
581 }
582 } else if (threadpool_options.inter_op_threadpool != nullptr) {
583 threadpool_wrapper = absl::make_unique<thread::ThreadPool>(
584 threadpool_options.inter_op_threadpool);
585 pool = threadpool_wrapper.get();
586 } else {
587 if (run_options.inter_op_thread_pool() < -1 ||
588 run_options.inter_op_thread_pool() >=
589 static_cast<int32>(thread_pools_.size())) {
590 return errors::InvalidArgument("Invalid inter_op_thread_pool: ",
591 run_options.inter_op_thread_pool());
592 }
593
594 pool = thread_pools_[run_options.inter_op_thread_pool()].first;
595 }
596
597 const int64 call_timeout = run_options.timeout_in_ms() > 0
598 ? run_options.timeout_in_ms()
599 : operation_timeout_in_ms_;
600
601 std::unique_ptr<RunHandler> handler;
602 if (ShouldUseRunHandlerPool(run_options) &&
603 run_options.experimental().use_run_handler_pool()) {
604 VLOG(1) << "Using RunHandler to scheduler inter-op closures.";
605 handler = GetOrCreateRunHandlerPool(options_)->Get(
606 step_id, call_timeout,
607 run_options.experimental().run_handler_pool_options());
608 if (!handler) {
609 return errors::DeadlineExceeded(
610 "Could not obtain RunHandler for request after waiting for ",
611 call_timeout, "ms.");
612 }
613 }
614 auto* handler_ptr = handler.get();
615
616 Executor::Args::Runner default_runner = nullptr;
617
618 if (pool == nullptr) {
619 default_runner = [](const Executor::Args::Closure& c) { c(); };
620 } else if (handler_ptr != nullptr) {
621 default_runner = [handler_ptr](Executor::Args::Closure c) {
622 handler_ptr->ScheduleInterOpClosure(std::move(c));
623 };
624 } else {
625 default_runner = [pool](Executor::Args::Closure c) {
626 pool->Schedule(std::move(c));
627 };
628 }
629
630 // Start parallel Executors.
631
632 // We can execute this step synchronously on the calling thread whenever
633 // there is a single device and the timeout mechanism is not used.
634 //
635 // When timeouts are used, we must execute the graph(s) asynchronously, in
636 // order to invoke the cancellation manager on the calling thread if the
637 // timeout expires.
638 const bool can_execute_synchronously =
639 executors_and_keys->items.size() == 1 && call_timeout == 0;
640
641 Executor::Args args;
642 args.step_id = step_id;
643 args.call_frame = call_frame;
644 args.collective_executor =
645 (run_state.collective_executor ? run_state.collective_executor->get()
646 : nullptr);
647 args.session_state = &session_state_;
648 args.session_handle = session_handle_;
649 args.tensor_store = &run_state.tensor_store;
650 args.step_container = &run_state.step_container;
651 args.sync_on_finish = sync_on_finish_;
652 args.user_intra_op_threadpool = threadpool_options.intra_op_threadpool;
653 args.run_all_kernels_inline = pool == nullptr;
654
655 const bool do_trace = (run_options.trace_level() > RunOptions::NO_TRACE);
656
657 bool update_cost_model = false;
658 if (options_.config.graph_options().build_cost_model() > 0) {
659 const int64 build_cost_model_every =
660 options_.config.graph_options().build_cost_model();
661 const int64 build_cost_model_after =
662 options_.config.graph_options().build_cost_model_after();
663 int64 measure_step_count = executor_step_count - build_cost_model_after;
664 if (measure_step_count >= 0) {
665 update_cost_model =
666 ((measure_step_count + 1) % build_cost_model_every == 0);
667 }
668 }
669 if (do_trace || update_cost_model ||
670 run_options.report_tensor_allocations_upon_oom()) {
671 run_state.collector.reset(
672 new StepStatsCollector(run_metadata->mutable_step_stats()));
673 args.stats_collector = run_state.collector.get();
674 }
675
676 std::unique_ptr<ProfilerSession> profiler_session;
677 if (run_options.trace_level() >= RunOptions::HARDWARE_TRACE) {
678 ProfileOptions options = ProfilerSession::DefaultOptions();
679 options.set_host_tracer_level(0);
680 profiler_session = ProfilerSession::Create(options);
681 }
682
683 // Register this step with session's cancellation manager, so that
684 // `Session::Close()` will cancel the step.
685 CancellationManager step_cancellation_manager(cancellation_manager_);
686 if (step_cancellation_manager.IsCancelled()) {
687 return errors::Cancelled("Run call was cancelled");
688 }
689 args.cancellation_manager = &step_cancellation_manager;
690
691 Status run_status;
692
693 auto set_threadpool_args_for_item =
694 [&default_runner, &handler](const PerPartitionExecutorsAndLib& item,
695 Executor::Args* args) {
696 // TODO(azaks): support partial run.
697 // TODO(azaks): if the device picks its own threadpool, we need to
698 // assign
699 // less threads to the main compute pool by default.
700 thread::ThreadPool* device_thread_pool =
701 item.device->tensorflow_device_thread_pool();
702 // TODO(crk): Investigate usage of RunHandlerPool when using device
703 // specific thread pool(s).
704 if (!device_thread_pool) {
705 args->runner = default_runner;
706 } else {
707 args->runner = [device_thread_pool](Executor::Args::Closure c) {
708 device_thread_pool->Schedule(std::move(c));
709 };
710 }
711 if (handler != nullptr) {
712 args->user_intra_op_threadpool =
713 handler->AsIntraThreadPoolInterface();
714 }
715 };
716
717 if (can_execute_synchronously) {
718 PrivateIntraProcessRendezvous rendezvous(device_mgr_.get());
719 args.rendezvous = &rendezvous;
720
721 const auto& item = executors_and_keys->items[0];
722 set_threadpool_args_for_item(item, &args);
723 run_status = item.executor->Run(args);
724 } else {
725 core::RefCountPtr<RefCountedIntraProcessRendezvous> rendezvous(
726 new RefCountedIntraProcessRendezvous(device_mgr_.get()));
727 args.rendezvous = rendezvous.get();
728
729 // `barrier` will delete itself after the final executor finishes.
730 Notification executors_done;
731 ExecutorBarrier* barrier =
732 new ExecutorBarrier(num_executors, rendezvous.get(),
733 [&run_state, &executors_done](const Status& ret) {
734 {
735 mutex_lock l(run_state.mu);
736 run_state.status.Update(ret);
737 }
738 executors_done.Notify();
739 });
740
741 for (const auto& item : executors_and_keys->items) {
742 set_threadpool_args_for_item(item, &args);
743 item.executor->RunAsync(args, barrier->Get());
744 }
745
746 WaitForNotification(&executors_done, &run_state, &step_cancellation_manager,
747 call_timeout);
748 {
749 tf_shared_lock l(run_state.mu);
750 run_status = run_state.status;
751 }
752 }
753
754 if (step_cancellation_manager.IsCancelled()) {
755 run_status.Update(errors::Cancelled("Run call was cancelled"));
756 }
757
758 if (profiler_session) {
759 TF_RETURN_IF_ERROR(profiler_session->CollectData(run_metadata));
760 }
761
762 TF_RETURN_IF_ERROR(run_status);
763
764 // Save the output tensors of this run we choose to keep.
765 if (!run_state.tensor_store.empty()) {
766 TF_RETURN_IF_ERROR(run_state.tensor_store.SaveTensors(
767 {executors_and_keys->callable_options.fetch().begin(),
768 executors_and_keys->callable_options.fetch().end()},
769 &session_state_));
770 }
771
772 if (run_state.collector) {
773 run_state.collector->Finalize();
774 }
775
776 // Build and return the cost model as instructed.
777 if (update_cost_model) {
778 // Build the cost model
779 std::unordered_map<string, const Graph*> device_to_graph;
780 for (const PerPartitionExecutorsAndLib& partition :
781 executors_and_keys->items) {
782 const Graph* graph = partition.graph.get();
783 const string& device = partition.flib->device()->name();
784 device_to_graph[device] = graph;
785 }
786
787 mutex_lock l(executor_lock_);
788 run_state.collector->BuildCostModel(&cost_model_manager_, device_to_graph);
789
790 // annotate stats onto cost graph.
791 CostGraphDef* cost_graph = run_metadata->mutable_cost_graph();
792 for (const auto& item : executors_and_keys->items) {
793 TF_RETURN_IF_ERROR(
794 cost_model_manager_.AddToCostGraphDef(item.graph.get(), cost_graph));
795 }
796 }
797
798 // If requested via RunOptions, output the partition graphs.
799 if (run_options.output_partition_graphs()) {
800 if (options_.config.experimental().disable_output_partition_graphs()) {
801 return errors::InvalidArgument(
802 "RunOptions.output_partition_graphs() is not supported when "
803 "disable_output_partition_graphs is true.");
804 } else {
805 protobuf::RepeatedPtrField<GraphDef>* partition_graph_defs =
806 run_metadata->mutable_partition_graphs();
807 for (const PerPartitionExecutorsAndLib& exec_and_lib :
808 executors_and_keys->items) {
809 GraphDef* partition_graph_def = partition_graph_defs->Add();
810 exec_and_lib.graph->ToGraphDef(partition_graph_def);
811 }
812 }
813 }
814 metrics::UpdateGraphExecTime(options_.env->NowMicros() - start_time_usecs);
815
816 return Status::OK();
817 }
818
Run(const RunOptions & run_options,const NamedTensorList & inputs,const std::vector<string> & output_names,const std::vector<string> & target_nodes,std::vector<Tensor> * outputs,RunMetadata * run_metadata)819 Status DirectSession::Run(const RunOptions& run_options,
820 const NamedTensorList& inputs,
821 const std::vector<string>& output_names,
822 const std::vector<string>& target_nodes,
823 std::vector<Tensor>* outputs,
824 RunMetadata* run_metadata) {
825 return Run(run_options, inputs, output_names, target_nodes, outputs,
826 run_metadata, thread::ThreadPoolOptions());
827 }
828
Run(const RunOptions & run_options,const NamedTensorList & inputs,const std::vector<string> & output_names,const std::vector<string> & target_nodes,std::vector<Tensor> * outputs,RunMetadata * run_metadata,const thread::ThreadPoolOptions & threadpool_options)829 Status DirectSession::Run(const RunOptions& run_options,
830 const NamedTensorList& inputs,
831 const std::vector<string>& output_names,
832 const std::vector<string>& target_nodes,
833 std::vector<Tensor>* outputs,
834 RunMetadata* run_metadata,
835 const thread::ThreadPoolOptions& threadpool_options) {
836 TF_RETURN_IF_ERROR(CheckNotClosed());
837 TF_RETURN_IF_ERROR(CheckGraphCreated("Run()"));
838 direct_session_runs->GetCell()->IncrementBy(1);
839
840 // Extract the inputs names for this run of the session.
841 std::vector<string> input_tensor_names;
842 input_tensor_names.reserve(inputs.size());
843 size_t input_size = 0;
844 for (const auto& it : inputs) {
845 input_tensor_names.push_back(it.first);
846 input_size += it.second.AllocatedBytes();
847 }
848 metrics::RecordGraphInputTensors(input_size);
849
850 // Check if we already have an executor for these arguments.
851 ExecutorsAndKeys* executors_and_keys;
852 RunStateArgs run_state_args(run_options.debug_options());
853 run_state_args.collective_graph_key =
854 run_options.experimental().collective_graph_key();
855
856 TF_RETURN_IF_ERROR(GetOrCreateExecutors(input_tensor_names, output_names,
857 target_nodes, &executors_and_keys,
858 &run_state_args));
859 {
860 mutex_lock l(collective_graph_key_lock_);
861 collective_graph_key_ = executors_and_keys->collective_graph_key;
862 }
863
864 // Configure a call frame for the step, which we use to feed and
865 // fetch values to and from the executors.
866 FunctionCallFrame call_frame(executors_and_keys->input_types,
867 executors_and_keys->output_types);
868 gtl::InlinedVector<Tensor, 4> feed_args(inputs.size());
869 for (const auto& it : inputs) {
870 if (it.second.dtype() == DT_RESOURCE) {
871 Tensor tensor_from_handle;
872 TF_RETURN_IF_ERROR(
873 ResourceHandleToInputTensor(it.second, &tensor_from_handle));
874 feed_args[executors_and_keys->input_name_to_index[it.first]] =
875 tensor_from_handle;
876 } else {
877 feed_args[executors_and_keys->input_name_to_index[it.first]] = it.second;
878 }
879 }
880 const Status s = call_frame.SetArgs(feed_args);
881 if (errors::IsInternal(s)) {
882 return errors::InvalidArgument(s.error_message());
883 } else if (!s.ok()) {
884 return s;
885 }
886
887 const int64 step_id = step_id_counter_.fetch_add(1);
888
889 if (LogMemory::IsEnabled()) {
890 LogMemory::RecordStep(step_id, run_state_args.handle);
891 }
892
893 TF_RETURN_IF_ERROR(RunInternal(step_id, run_options, &call_frame,
894 executors_and_keys, run_metadata,
895 threadpool_options));
896
897 // Receive outputs.
898 if (outputs) {
899 std::vector<Tensor> sorted_outputs;
900 const Status s = call_frame.ConsumeRetvals(
901 &sorted_outputs, /* allow_dead_tensors = */ false);
902 if (errors::IsInternal(s)) {
903 return errors::InvalidArgument(s.error_message());
904 } else if (!s.ok()) {
905 return s;
906 }
907 const bool unique_outputs =
908 output_names.size() == executors_and_keys->output_name_to_index.size();
909 // first_indices[i] = j implies that j is the smallest value for which
910 // output_names[i] == output_names[j].
911 std::vector<int> first_indices;
912 if (!unique_outputs) {
913 first_indices.reserve(output_names.size());
914 for (const auto& name : output_names) {
915 first_indices.push_back(
916 std::find(output_names.begin(), output_names.end(), name) -
917 output_names.begin());
918 }
919 }
920 outputs->clear();
921 size_t output_size = 0;
922 outputs->reserve(sorted_outputs.size());
923 for (int i = 0; i < output_names.size(); ++i) {
924 const string& output_name = output_names[i];
925 if (first_indices.empty() || first_indices[i] == i) {
926 outputs->emplace_back(
927 std::move(sorted_outputs[executors_and_keys
928 ->output_name_to_index[output_name]]));
929 } else {
930 outputs->push_back((*outputs)[first_indices[i]]);
931 }
932 output_size += outputs->back().AllocatedBytes();
933 }
934 metrics::RecordGraphOutputTensors(output_size);
935 }
936
937 return Status::OK();
938 }
939
PRunSetup(const std::vector<string> & input_names,const std::vector<string> & output_names,const std::vector<string> & target_nodes,string * handle)940 Status DirectSession::PRunSetup(const std::vector<string>& input_names,
941 const std::vector<string>& output_names,
942 const std::vector<string>& target_nodes,
943 string* handle) {
944 TF_RETURN_IF_ERROR(CheckNotClosed());
945 TF_RETURN_IF_ERROR(CheckGraphCreated("PRunSetup()"));
946
947 // RunOptions is not available in PRunSetup, so use thread pool 0.
948 thread::ThreadPool* pool = thread_pools_[0].first;
949
950 // Check if we already have an executor for these arguments.
951 ExecutorsAndKeys* executors_and_keys;
952 // TODO(cais): TFDBG support for partial runs.
953 DebugOptions debug_options;
954 RunStateArgs run_state_args(debug_options);
955 run_state_args.is_partial_run = true;
956 TF_RETURN_IF_ERROR(GetOrCreateExecutors(input_names, output_names,
957 target_nodes, &executors_and_keys,
958 &run_state_args));
959
960 // Create the run state and save it for future PRun calls.
961 Executor::Args args;
962 args.step_id = step_id_counter_.fetch_add(1);
963 PartialRunState* run_state =
964 new PartialRunState(input_names, output_names, args.step_id, &devices_);
965 run_state->rendez.reset(new IntraProcessRendezvous(device_mgr_.get()));
966 {
967 mutex_lock l(executor_lock_);
968 if (!partial_runs_
969 .emplace(run_state_args.handle,
970 std::unique_ptr<PartialRunState>(run_state))
971 .second) {
972 return errors::Internal("The handle '", run_state_args.handle,
973 "' created for this partial run is not unique.");
974 }
975 }
976
977 // Start parallel Executors.
978 const size_t num_executors = executors_and_keys->items.size();
979 ExecutorBarrier* barrier = new ExecutorBarrier(
980 num_executors, run_state->rendez.get(), [run_state](const Status& ret) {
981 if (!ret.ok()) {
982 mutex_lock l(run_state->mu);
983 run_state->status.Update(ret);
984 }
985 run_state->executors_done.Notify();
986 });
987
988 args.rendezvous = run_state->rendez.get();
989 args.cancellation_manager = cancellation_manager_;
990 // Note that Collectives are not supported in partial runs
991 // because RunOptions is not passed in so we can't know whether
992 // their use is intended.
993 args.collective_executor = nullptr;
994 args.runner = [this, pool](Executor::Args::Closure c) {
995 pool->Schedule(std::move(c));
996 };
997 args.session_state = &session_state_;
998 args.session_handle = session_handle_;
999 args.tensor_store = &run_state->tensor_store;
1000 args.step_container = &run_state->step_container;
1001 if (LogMemory::IsEnabled()) {
1002 LogMemory::RecordStep(args.step_id, run_state_args.handle);
1003 }
1004 args.sync_on_finish = sync_on_finish_;
1005
1006 if (options_.config.graph_options().build_cost_model()) {
1007 run_state->collector.reset(new StepStatsCollector(nullptr));
1008 args.stats_collector = run_state->collector.get();
1009 }
1010
1011 for (auto& item : executors_and_keys->items) {
1012 item.executor->RunAsync(args, barrier->Get());
1013 }
1014
1015 *handle = run_state_args.handle;
1016 return Status::OK();
1017 }
1018
PRun(const string & handle,const NamedTensorList & inputs,const std::vector<string> & output_names,std::vector<Tensor> * outputs)1019 Status DirectSession::PRun(const string& handle, const NamedTensorList& inputs,
1020 const std::vector<string>& output_names,
1021 std::vector<Tensor>* outputs) {
1022 TF_RETURN_IF_ERROR(CheckNotClosed());
1023 std::vector<string> parts = str_util::Split(handle, ';');
1024 const string& key = parts[0];
1025 // Get the executors for this partial run.
1026 ExecutorsAndKeys* executors_and_keys;
1027 PartialRunState* run_state;
1028 {
1029 mutex_lock l(executor_lock_); // could use reader lock
1030 auto exc_it = executors_.find(key);
1031 if (exc_it == executors_.end()) {
1032 return errors::InvalidArgument(
1033 "Must run 'setup' before performing partial runs!");
1034 }
1035 executors_and_keys = exc_it->second.get();
1036
1037 auto prun_it = partial_runs_.find(handle);
1038 if (prun_it == partial_runs_.end()) {
1039 return errors::InvalidArgument(
1040 "Must run 'setup' before performing partial runs!");
1041 }
1042 run_state = prun_it->second.get();
1043
1044 // Make sure that this is a new set of feeds that are still pending.
1045 for (const auto& input : inputs) {
1046 auto it = run_state->pending_inputs.find(input.first);
1047 if (it == run_state->pending_inputs.end()) {
1048 return errors::InvalidArgument(
1049 "The feed ", input.first,
1050 " was not specified in partial_run_setup.");
1051 } else if (it->second) {
1052 return errors::InvalidArgument("The feed ", input.first,
1053 " has already been fed.");
1054 }
1055 }
1056 // Check that this is a new set of fetches that are still pending.
1057 for (const auto& output : output_names) {
1058 auto it = run_state->pending_outputs.find(output);
1059 if (it == run_state->pending_outputs.end()) {
1060 return errors::InvalidArgument(
1061 "The fetch ", output, " was not specified in partial_run_setup.");
1062 } else if (it->second) {
1063 return errors::InvalidArgument("The fetch ", output,
1064 " has already been fetched.");
1065 }
1066 }
1067 }
1068
1069 // Check that this new set of fetches can be computed from all the
1070 // feeds we have supplied.
1071 TF_RETURN_IF_ERROR(
1072 CheckFetch(inputs, output_names, executors_and_keys, run_state));
1073
1074 // Send inputs.
1075 Status s =
1076 SendPRunInputs(inputs, executors_and_keys, run_state->rendez.get());
1077
1078 // Receive outputs.
1079 if (s.ok()) {
1080 s = RecvPRunOutputs(output_names, executors_and_keys, run_state, outputs);
1081 }
1082
1083 // Save the output tensors of this run we choose to keep.
1084 if (s.ok()) {
1085 s = run_state->tensor_store.SaveTensors(output_names, &session_state_);
1086 }
1087
1088 {
1089 mutex_lock l(executor_lock_);
1090 // Delete the run state if there is an error or all fetches are done.
1091 bool done = true;
1092 if (s.ok()) {
1093 {
1094 mutex_lock l(run_state->mu);
1095 if (!run_state->status.ok()) {
1096 LOG(WARNING) << "An error unrelated to this prun has been detected. "
1097 << run_state->status;
1098 }
1099 }
1100 for (const auto& input : inputs) {
1101 auto it = run_state->pending_inputs.find(input.first);
1102 it->second = true;
1103 }
1104 for (const auto& name : output_names) {
1105 auto it = run_state->pending_outputs.find(name);
1106 it->second = true;
1107 }
1108 done = run_state->PendingDone();
1109 }
1110 if (done) {
1111 WaitForNotification(&run_state->executors_done, run_state,
1112 cancellation_manager_, operation_timeout_in_ms_);
1113 partial_runs_.erase(handle);
1114 }
1115 }
1116
1117 return s;
1118 }
1119
ResourceHandleToInputTensor(const Tensor & resource_tensor,Tensor * retrieved_tensor)1120 Status DirectSession::ResourceHandleToInputTensor(const Tensor& resource_tensor,
1121 Tensor* retrieved_tensor) {
1122 if (resource_tensor.dtype() != DT_RESOURCE) {
1123 return errors::InvalidArgument(strings::StrCat(
1124 "ResourceHandleToInputTensor() received non-DT_RESOURCE Tensor: ",
1125 resource_tensor.dtype()));
1126 }
1127
1128 const ResourceHandle& resource_handle =
1129 resource_tensor.scalar<ResourceHandle>()();
1130
1131 if (resource_handle.container() ==
1132 SessionState::kTensorHandleResourceTypeName) {
1133 return session_state_.GetTensor(resource_handle.name(), retrieved_tensor);
1134 } else {
1135 return errors::InvalidArgument(strings::StrCat(
1136 "Invalid resource type hash code: ", resource_handle.hash_code(),
1137 "(name: ", resource_handle.name(),
1138 " type: ", resource_handle.maybe_type_name(),
1139 "). Perhaps a resource tensor was being provided as a feed? That is "
1140 "not currently allowed. Please file an issue at "
1141 "https://github.com/tensorflow/tensorflow/issues/new, ideally with a "
1142 "short code snippet that leads to this error message."));
1143 }
1144 }
1145
SendPRunInputs(const NamedTensorList & inputs,const ExecutorsAndKeys * executors_and_keys,IntraProcessRendezvous * rendez)1146 Status DirectSession::SendPRunInputs(const NamedTensorList& inputs,
1147 const ExecutorsAndKeys* executors_and_keys,
1148 IntraProcessRendezvous* rendez) {
1149 Status s;
1150 Rendezvous::ParsedKey parsed;
1151 // Insert the input tensors into the local rendezvous by their
1152 // rendezvous key.
1153 for (const auto& input : inputs) {
1154 auto it =
1155 executors_and_keys->input_name_to_rendezvous_key.find(input.first);
1156 if (it == executors_and_keys->input_name_to_rendezvous_key.end()) {
1157 return errors::Internal("'", input.first, "' is not a pre-defined feed.");
1158 }
1159 const string& input_key = it->second;
1160
1161 s = Rendezvous::ParseKey(input_key, &parsed);
1162 if (!s.ok()) {
1163 rendez->StartAbort(s);
1164 return s;
1165 }
1166
1167 if (input.second.dtype() == DT_RESOURCE) {
1168 Tensor tensor_from_handle;
1169 s = ResourceHandleToInputTensor(input.second, &tensor_from_handle);
1170 if (s.ok()) {
1171 s = rendez->Send(parsed, Rendezvous::Args(), tensor_from_handle, false);
1172 }
1173 } else {
1174 s = rendez->Send(parsed, Rendezvous::Args(), input.second, false);
1175 }
1176
1177 if (!s.ok()) {
1178 rendez->StartAbort(s);
1179 return s;
1180 }
1181 }
1182 return Status::OK();
1183 }
1184
RecvPRunOutputs(const std::vector<string> & output_names,const ExecutorsAndKeys * executors_and_keys,PartialRunState * run_state,std::vector<Tensor> * outputs)1185 Status DirectSession::RecvPRunOutputs(
1186 const std::vector<string>& output_names,
1187 const ExecutorsAndKeys* executors_and_keys, PartialRunState* run_state,
1188 std::vector<Tensor>* outputs) {
1189 Status s;
1190 if (!output_names.empty()) {
1191 outputs->resize(output_names.size());
1192 }
1193
1194 Rendezvous::ParsedKey parsed;
1195 // Get the outputs from the rendezvous
1196 for (size_t output_offset = 0; output_offset < output_names.size();
1197 ++output_offset) {
1198 const string& output_name = output_names[output_offset];
1199 auto it =
1200 executors_and_keys->output_name_to_rendezvous_key.find(output_name);
1201 if (it == executors_and_keys->output_name_to_rendezvous_key.end()) {
1202 return errors::Internal("'", output_name,
1203 "' is not a pre-defined fetch.");
1204 }
1205 const string& output_key = it->second;
1206 Tensor output_tensor;
1207 bool is_dead;
1208
1209 s = Rendezvous::ParseKey(output_key, &parsed);
1210 if (s.ok()) {
1211 // Fetch data from the Rendezvous.
1212 s = run_state->rendez->Recv(parsed, Rendezvous::Args(), &output_tensor,
1213 &is_dead, operation_timeout_in_ms_);
1214 if (is_dead && s.ok()) {
1215 s = errors::InvalidArgument("The tensor returned for ", output_name,
1216 " was not valid.");
1217 }
1218 }
1219 if (!s.ok()) {
1220 run_state->rendez->StartAbort(s);
1221 outputs->clear();
1222 return s;
1223 }
1224
1225 (*outputs)[output_offset] = output_tensor;
1226 }
1227 return Status::OK();
1228 }
1229
CheckFetch(const NamedTensorList & feeds,const std::vector<string> & fetches,const ExecutorsAndKeys * executors_and_keys,const PartialRunState * run_state)1230 Status DirectSession::CheckFetch(const NamedTensorList& feeds,
1231 const std::vector<string>& fetches,
1232 const ExecutorsAndKeys* executors_and_keys,
1233 const PartialRunState* run_state) {
1234 const Graph* graph = executors_and_keys->graph.get();
1235 const NameNodeMap* name_to_node = &executors_and_keys->name_to_node;
1236
1237 // Build the set of pending feeds that we haven't seen.
1238 std::unordered_set<TensorId, TensorId::Hasher> pending_feeds;
1239 {
1240 mutex_lock l(executor_lock_);
1241 for (const auto& input : run_state->pending_inputs) {
1242 // Skip if the feed has already been fed.
1243 if (input.second) continue;
1244 TensorId id(ParseTensorName(input.first));
1245 auto it = name_to_node->find(id.first);
1246 if (it == name_to_node->end()) {
1247 return errors::NotFound("Feed ", input.first, ": not found");
1248 }
1249 pending_feeds.insert(id);
1250 }
1251 }
1252 for (const auto& it : feeds) {
1253 TensorId id(ParseTensorName(it.first));
1254 pending_feeds.erase(id);
1255 }
1256
1257 // Initialize the stack with the fetch nodes.
1258 std::vector<const Node*> stack;
1259 for (const string& fetch : fetches) {
1260 TensorId id(ParseTensorName(fetch));
1261 auto it = name_to_node->find(id.first);
1262 if (it == name_to_node->end()) {
1263 return errors::NotFound("Fetch ", fetch, ": not found");
1264 }
1265 stack.push_back(it->second);
1266 }
1267
1268 // Any tensor needed for fetches can't be in pending_feeds.
1269 std::vector<bool> visited(graph->num_node_ids(), false);
1270 while (!stack.empty()) {
1271 const Node* n = stack.back();
1272 stack.pop_back();
1273
1274 for (const Edge* in_edge : n->in_edges()) {
1275 const Node* in_node = in_edge->src();
1276 if (pending_feeds.count({in_node->name(), in_edge->src_output()}) > 0) {
1277 return errors::InvalidArgument("Fetch ", in_node->name(), ":",
1278 in_edge->src_output(),
1279 " can't be computed from the feeds"
1280 " that have been fed so far.");
1281 }
1282 if (!visited[in_node->id()]) {
1283 visited[in_node->id()] = true;
1284 stack.push_back(in_node);
1285 }
1286 }
1287 }
1288 return Status::OK();
1289 }
1290
CreateExecutors(const CallableOptions & callable_options,std::unique_ptr<ExecutorsAndKeys> * out_executors_and_keys,std::unique_ptr<FunctionInfo> * out_func_info,RunStateArgs * run_state_args)1291 Status DirectSession::CreateExecutors(
1292 const CallableOptions& callable_options,
1293 std::unique_ptr<ExecutorsAndKeys>* out_executors_and_keys,
1294 std::unique_ptr<FunctionInfo>* out_func_info,
1295 RunStateArgs* run_state_args) {
1296 BuildGraphOptions options;
1297 options.callable_options = callable_options;
1298 options.use_function_convention = !run_state_args->is_partial_run;
1299 options.collective_graph_key =
1300 callable_options.run_options().experimental().collective_graph_key();
1301 if (options_.config.experimental()
1302 .collective_deterministic_sequential_execution()) {
1303 options.collective_order = GraphCollectiveOrder::kEdges;
1304 } else if (options_.config.experimental().collective_nccl()) {
1305 options.collective_order = GraphCollectiveOrder::kAttrs;
1306 }
1307
1308 std::unique_ptr<FunctionInfo> func_info(new FunctionInfo);
1309 std::unique_ptr<ExecutorsAndKeys> ek(new ExecutorsAndKeys);
1310
1311 ek->callable_options = callable_options;
1312
1313 std::unordered_map<string, std::unique_ptr<Graph>> graphs;
1314 TF_RETURN_IF_ERROR(CreateGraphs(
1315 options, &graphs, &func_info->flib_def, run_state_args, &ek->input_types,
1316 &ek->output_types, &ek->collective_graph_key));
1317
1318 if (run_state_args->is_partial_run) {
1319 ek->graph = std::move(run_state_args->graph);
1320 std::unordered_set<StringPiece, StringPieceHasher> names;
1321 for (const string& input : callable_options.feed()) {
1322 TensorId id(ParseTensorName(input));
1323 names.emplace(id.first);
1324 }
1325 for (const string& output : callable_options.fetch()) {
1326 TensorId id(ParseTensorName(output));
1327 names.emplace(id.first);
1328 }
1329 for (Node* n : ek->graph->nodes()) {
1330 if (names.count(n->name()) > 0) {
1331 ek->name_to_node.insert({n->name(), n});
1332 }
1333 }
1334 }
1335 ek->items.reserve(graphs.size());
1336 const auto& optimizer_opts =
1337 options_.config.graph_options().optimizer_options();
1338
1339 int graph_def_version = graphs.begin()->second->versions().producer();
1340
1341 const auto* session_metadata =
1342 options_.config.experimental().has_session_metadata()
1343 ? &options_.config.experimental().session_metadata()
1344 : nullptr;
1345 func_info->proc_flr.reset(new ProcessFunctionLibraryRuntime(
1346 device_mgr_.get(), options_.env, &options_.config, graph_def_version,
1347 func_info->flib_def.get(), optimizer_opts, thread_pools_[0].first,
1348 /*parent=*/nullptr, session_metadata,
1349 Rendezvous::Factory{
1350 [](const int64, const DeviceMgr* device_mgr, Rendezvous** r) {
1351 *r = new IntraProcessRendezvous(device_mgr);
1352 return Status::OK();
1353 }}));
1354
1355 GraphOptimizer optimizer(optimizer_opts);
1356 for (auto iter = graphs.begin(); iter != graphs.end(); ++iter) {
1357 const string& partition_name = iter->first;
1358 std::unique_ptr<Graph>& partition_graph = iter->second;
1359
1360 Device* device;
1361 TF_RETURN_IF_ERROR(device_mgr_->LookupDevice(partition_name, &device));
1362
1363 ek->items.resize(ek->items.size() + 1);
1364 auto* item = &(ek->items.back());
1365 auto lib = func_info->proc_flr->GetFLR(partition_name);
1366 if (lib == nullptr) {
1367 return errors::Internal("Could not find device: ", partition_name);
1368 }
1369 item->flib = lib;
1370
1371 LocalExecutorParams params;
1372 params.device = device;
1373 params.session_metadata = session_metadata;
1374 params.function_library = lib;
1375 auto opseg = device->op_segment();
1376 params.create_kernel =
1377 [this, lib, opseg](const std::shared_ptr<const NodeProperties>& props,
1378 OpKernel** kernel) {
1379 // NOTE(mrry): We must not share function kernels (implemented
1380 // using `CallOp`) between subgraphs, because `CallOp::handle_`
1381 // is tied to a particular subgraph. Even if the function itself
1382 // is stateful, the `CallOp` that invokes it is not.
1383 if (!OpSegment::ShouldOwnKernel(lib, props->node_def.op())) {
1384 return lib->CreateKernel(props, kernel);
1385 }
1386 auto create_fn = [lib, &props](OpKernel** kernel) {
1387 return lib->CreateKernel(props, kernel);
1388 };
1389 // Kernels created for subgraph nodes need to be cached. On
1390 // cache miss, create_fn() is invoked to create a kernel based
1391 // on the function library here + global op registry.
1392 return opseg->FindOrCreate(session_handle_, props->node_def.name(),
1393 kernel, create_fn);
1394 };
1395 params.delete_kernel = [lib](OpKernel* kernel) {
1396 if (kernel && !OpSegment::ShouldOwnKernel(lib, kernel->type_string()))
1397 delete kernel;
1398 };
1399
1400 optimizer.Optimize(lib, options_.env, device, &partition_graph,
1401 /*shape_map=*/nullptr);
1402
1403 // TensorFlow Debugger (tfdbg) inserts debug nodes in the graph.
1404 const DebugOptions& debug_options =
1405 options.callable_options.run_options().debug_options();
1406 if (!debug_options.debug_tensor_watch_opts().empty()) {
1407 TF_RETURN_IF_ERROR(DecorateAndPublishGraphForDebug(
1408 debug_options, partition_graph.get(), params.device));
1409 }
1410
1411 TF_RETURN_IF_ERROR(EnsureMemoryTypes(DeviceType(device->device_type()),
1412 device->name(),
1413 partition_graph.get()));
1414
1415 item->executor = nullptr;
1416 item->device = device;
1417 auto executor_type = options_.config.experimental().executor_type();
1418 TF_RETURN_IF_ERROR(
1419 NewExecutor(executor_type, params, *partition_graph, &item->executor));
1420 if (!options_.config.experimental().disable_output_partition_graphs() ||
1421 options_.config.graph_options().build_cost_model() > 0) {
1422 item->graph = std::move(partition_graph);
1423 }
1424 }
1425
1426 // Cache the mapping from input/output names to graph elements to
1427 // avoid recomputing it every time.
1428 if (!run_state_args->is_partial_run) {
1429 // For regular `Run()`, we use the function calling convention, and so
1430 // maintain a mapping from input/output names to
1431 // argument/return-value ordinal index.
1432 for (int i = 0; i < callable_options.feed().size(); ++i) {
1433 const string& input = callable_options.feed(i);
1434 ek->input_name_to_index[input] = i;
1435 }
1436 for (int i = 0; i < callable_options.fetch().size(); ++i) {
1437 const string& output = callable_options.fetch(i);
1438 ek->output_name_to_index[output] = i;
1439 }
1440 } else {
1441 // For `PRun()`, we use the rendezvous calling convention, and so
1442 // maintain a mapping from input/output names to rendezvous keys.
1443 //
1444 // We always use the first device as the device name portion of the
1445 // key, even if we're feeding another graph.
1446 for (int i = 0; i < callable_options.feed().size(); ++i) {
1447 const string& input = callable_options.feed(i);
1448 ek->input_name_to_rendezvous_key[input] = GetRendezvousKey(
1449 input, device_set_.client_device()->attributes(), FrameAndIter(0, 0));
1450 }
1451 for (int i = 0; i < callable_options.fetch().size(); ++i) {
1452 const string& output = callable_options.fetch(i);
1453 ek->output_name_to_rendezvous_key[output] =
1454 GetRendezvousKey(output, device_set_.client_device()->attributes(),
1455 FrameAndIter(0, 0));
1456 }
1457 }
1458
1459 *out_executors_and_keys = std::move(ek);
1460 *out_func_info = std::move(func_info);
1461 return Status::OK();
1462 }
1463
GetOrCreateExecutors(gtl::ArraySlice<string> inputs,gtl::ArraySlice<string> outputs,gtl::ArraySlice<string> target_nodes,ExecutorsAndKeys ** executors_and_keys,RunStateArgs * run_state_args)1464 Status DirectSession::GetOrCreateExecutors(
1465 gtl::ArraySlice<string> inputs, gtl::ArraySlice<string> outputs,
1466 gtl::ArraySlice<string> target_nodes, ExecutorsAndKeys** executors_and_keys,
1467 RunStateArgs* run_state_args) {
1468 int64 handle_name_counter_value = -1;
1469 if (LogMemory::IsEnabled() || run_state_args->is_partial_run) {
1470 handle_name_counter_value = handle_name_counter_.fetch_add(1);
1471 }
1472
1473 string debug_tensor_watches_summary;
1474 if (!run_state_args->debug_options.debug_tensor_watch_opts().empty()) {
1475 debug_tensor_watches_summary = SummarizeDebugTensorWatches(
1476 run_state_args->debug_options.debug_tensor_watch_opts());
1477 }
1478
1479 // Fast lookup path, no sorting.
1480 const string key = strings::StrCat(
1481 absl::StrJoin(inputs, ","), "->", absl::StrJoin(outputs, ","), "/",
1482 absl::StrJoin(target_nodes, ","), "/", run_state_args->is_partial_run,
1483 "/", debug_tensor_watches_summary);
1484 // Set the handle, if it's needed to log memory or for partial run.
1485 if (handle_name_counter_value >= 0) {
1486 run_state_args->handle =
1487 strings::StrCat(key, ";", handle_name_counter_value);
1488 }
1489
1490 // See if we already have the executors for this run.
1491 {
1492 mutex_lock l(executor_lock_); // could use reader lock
1493 auto it = executors_.find(key);
1494 if (it != executors_.end()) {
1495 *executors_and_keys = it->second.get();
1496 return Status::OK();
1497 }
1498 }
1499
1500 // Slow lookup path, the unsorted key missed the cache.
1501 // Sort the inputs and outputs, and look up with the sorted key in case an
1502 // earlier call used a different order of inputs and outputs.
1503 //
1504 // We could consider some other signature instead of sorting that
1505 // preserves the same property to avoid the sort in the future.
1506 std::vector<string> inputs_sorted(inputs.begin(), inputs.end());
1507 std::sort(inputs_sorted.begin(), inputs_sorted.end());
1508 std::vector<string> outputs_sorted(outputs.begin(), outputs.end());
1509 std::sort(outputs_sorted.begin(), outputs_sorted.end());
1510 std::vector<string> tn_sorted(target_nodes.begin(), target_nodes.end());
1511 std::sort(tn_sorted.begin(), tn_sorted.end());
1512
1513 const string sorted_key = strings::StrCat(
1514 absl::StrJoin(inputs_sorted, ","), "->",
1515 absl::StrJoin(outputs_sorted, ","), "/", absl::StrJoin(tn_sorted, ","),
1516 "/", run_state_args->is_partial_run, "/", debug_tensor_watches_summary);
1517 // Set the handle, if its needed to log memory or for partial run.
1518 if (handle_name_counter_value >= 0) {
1519 run_state_args->handle =
1520 strings::StrCat(sorted_key, ";", handle_name_counter_value);
1521 }
1522
1523 // See if we already have the executors for this run.
1524 {
1525 mutex_lock l(executor_lock_);
1526 auto it = executors_.find(sorted_key);
1527 if (it != executors_.end()) {
1528 *executors_and_keys = it->second.get();
1529 return Status::OK();
1530 }
1531 }
1532
1533 // Nothing found, so create the executors and store in the cache.
1534 // The executor_lock_ is intentionally released while executors are
1535 // being created.
1536 CallableOptions callable_options;
1537 callable_options.mutable_feed()->Reserve(inputs_sorted.size());
1538 for (const string& input : inputs_sorted) {
1539 callable_options.add_feed(input);
1540 }
1541 callable_options.mutable_fetch()->Reserve(outputs_sorted.size());
1542 for (const string& output : outputs_sorted) {
1543 callable_options.add_fetch(output);
1544 }
1545 callable_options.mutable_target()->Reserve(tn_sorted.size());
1546 for (const string& target : tn_sorted) {
1547 callable_options.add_target(target);
1548 }
1549 *callable_options.mutable_run_options()->mutable_debug_options() =
1550 run_state_args->debug_options;
1551 callable_options.mutable_run_options()
1552 ->mutable_experimental()
1553 ->set_collective_graph_key(run_state_args->collective_graph_key);
1554 std::unique_ptr<ExecutorsAndKeys> ek;
1555 std::unique_ptr<FunctionInfo> func_info;
1556 TF_RETURN_IF_ERROR(
1557 CreateExecutors(callable_options, &ek, &func_info, run_state_args));
1558
1559 // Reacquire the lock, try to insert into the map.
1560 mutex_lock l(executor_lock_);
1561
1562 // Another thread may have created the entry before us, in which case we will
1563 // reuse the already created one.
1564 auto insert_result = executors_.emplace(
1565 sorted_key, std::shared_ptr<ExecutorsAndKeys>(std::move(ek)));
1566 if (insert_result.second) {
1567 functions_.push_back(std::move(func_info));
1568 }
1569
1570 // Insert the value under the original key, so the fast path lookup will work
1571 // if the user uses the same order of inputs, outputs, and targets again.
1572 executors_.emplace(key, insert_result.first->second);
1573 *executors_and_keys = insert_result.first->second.get();
1574
1575 return Status::OK();
1576 }
1577
CreateGraphs(const BuildGraphOptions & subgraph_options,std::unordered_map<string,std::unique_ptr<Graph>> * outputs,std::unique_ptr<FunctionLibraryDefinition> * flib_def,RunStateArgs * run_state_args,DataTypeVector * input_types,DataTypeVector * output_types,int64 * collective_graph_key)1578 Status DirectSession::CreateGraphs(
1579 const BuildGraphOptions& subgraph_options,
1580 std::unordered_map<string, std::unique_ptr<Graph>>* outputs,
1581 std::unique_ptr<FunctionLibraryDefinition>* flib_def,
1582 RunStateArgs* run_state_args, DataTypeVector* input_types,
1583 DataTypeVector* output_types, int64* collective_graph_key) {
1584 mutex_lock l(graph_state_lock_);
1585 if (finalized_) {
1586 return errors::FailedPrecondition("Session has been finalized.");
1587 }
1588
1589 std::unique_ptr<ClientGraph> client_graph;
1590
1591 std::unique_ptr<GraphExecutionState> temp_exec_state_holder;
1592 GraphExecutionState* execution_state = nullptr;
1593 if (options_.config.graph_options().place_pruned_graph()) {
1594 // Because we are placing pruned graphs, we need to create a
1595 // new GraphExecutionState for every new unseen graph,
1596 // and then place it.
1597 GraphExecutionStateOptions prune_options;
1598 prune_options.device_set = &device_set_;
1599 prune_options.session_options = &options_;
1600 prune_options.stateful_placements = stateful_placements_;
1601 prune_options.session_handle = session_handle_;
1602 TF_RETURN_IF_ERROR(GraphExecutionState::MakeForPrunedGraph(
1603 *execution_state_, prune_options, subgraph_options,
1604 &temp_exec_state_holder, &client_graph));
1605 execution_state = temp_exec_state_holder.get();
1606 } else {
1607 execution_state = execution_state_.get();
1608 TF_RETURN_IF_ERROR(
1609 execution_state->BuildGraph(subgraph_options, &client_graph));
1610 }
1611 *collective_graph_key = client_graph->collective_graph_key;
1612
1613 if (subgraph_options.callable_options.feed_size() !=
1614 client_graph->feed_types.size()) {
1615 return errors::Internal(
1616 "Graph pruning failed: requested number of feed endpoints = ",
1617 subgraph_options.callable_options.feed_size(),
1618 " versus number of pruned feed endpoints = ",
1619 client_graph->feed_types.size());
1620 }
1621 if (subgraph_options.callable_options.fetch_size() !=
1622 client_graph->fetch_types.size()) {
1623 return errors::Internal(
1624 "Graph pruning failed: requested number of fetch endpoints = ",
1625 subgraph_options.callable_options.fetch_size(),
1626 " versus number of pruned fetch endpoints = ",
1627 client_graph->fetch_types.size());
1628 }
1629
1630 auto current_stateful_placements = execution_state->GetStatefulPlacements();
1631 // Update our current state based on the execution_state's
1632 // placements. If there are any mismatches for a node,
1633 // we should fail, as this should never happen.
1634 for (const auto& placement_pair : current_stateful_placements) {
1635 const string& node_name = placement_pair.first;
1636 const string& placement = placement_pair.second;
1637 auto iter = stateful_placements_.find(node_name);
1638 if (iter == stateful_placements_.end()) {
1639 stateful_placements_.insert(std::make_pair(node_name, placement));
1640 } else if (iter->second != placement) {
1641 return errors::Internal(
1642 "Stateful placement mismatch. "
1643 "Current assignment of ",
1644 node_name, " to ", iter->second, " does not match ", placement);
1645 }
1646 }
1647
1648 stateful_placements_ = execution_state->GetStatefulPlacements();
1649
1650 // Remember the graph in run state if this is a partial run.
1651 if (run_state_args->is_partial_run) {
1652 run_state_args->graph.reset(new Graph(flib_def_.get()));
1653 CopyGraph(*execution_state->full_graph(), run_state_args->graph.get());
1654 }
1655
1656 // Partition the graph across devices.
1657 PartitionOptions popts;
1658 popts.node_to_loc = [](const Node* node) {
1659 return node->assigned_device_name();
1660 };
1661 popts.new_name = [this](const string& prefix) {
1662 return strings::StrCat(prefix, "/_", edge_name_counter_.fetch_add(1));
1663 };
1664 popts.get_incarnation = [](const string& name) {
1665 // The direct session does not have changing incarnation numbers.
1666 // Just return '1'.
1667 return 1;
1668 };
1669 popts.flib_def = &client_graph->graph.flib_def();
1670 popts.control_flow_added = false;
1671
1672 std::unordered_map<string, GraphDef> partitions;
1673 TF_RETURN_IF_ERROR(Partition(popts, &client_graph->graph, &partitions));
1674
1675 std::vector<string> device_names;
1676 for (auto device : devices_) {
1677 // Extract the LocalName from the device.
1678 device_names.push_back(DeviceNameUtils::LocalName(device->name()));
1679 }
1680
1681 // Check for valid partitions.
1682 for (const auto& partition : partitions) {
1683 const string local_partition_name =
1684 DeviceNameUtils::LocalName(partition.first);
1685 if (std::count(device_names.begin(), device_names.end(),
1686 local_partition_name) == 0) {
1687 return errors::InvalidArgument(
1688 "Creating a partition for ", local_partition_name,
1689 " which doesn't exist in the list of available devices. Available "
1690 "devices: ",
1691 absl::StrJoin(device_names, ","));
1692 }
1693 }
1694
1695 for (auto& partition : partitions) {
1696 std::unique_ptr<Graph> device_graph(
1697 new Graph(client_graph->flib_def.get()));
1698 device_graph->SetConstructionContext(ConstructionContext::kDirectSession);
1699 GraphConstructorOptions device_opts;
1700 // There are internal operations (e.g., send/recv) that we now allow.
1701 device_opts.allow_internal_ops = true;
1702 device_opts.expect_device_spec = true;
1703 TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(
1704 device_opts, std::move(partition.second), device_graph.get()));
1705 outputs->emplace(partition.first, std::move(device_graph));
1706 }
1707
1708 GraphOptimizationPassOptions optimization_options;
1709 optimization_options.session_options = &options_;
1710 optimization_options.flib_def = client_graph->flib_def.get();
1711 optimization_options.partition_graphs = outputs;
1712 TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
1713 OptimizationPassRegistry::POST_PARTITIONING, optimization_options));
1714
1715 Status s;
1716 for (auto& partition : *outputs) {
1717 const string& partition_name = partition.first;
1718 std::unique_ptr<Graph>* graph = &partition.second;
1719
1720 VLOG(2) << "Created " << DebugString(graph->get()) << " for "
1721 << partition_name;
1722
1723 // Give the device an opportunity to rewrite its subgraph.
1724 Device* d;
1725 s = device_mgr_->LookupDevice(partition_name, &d);
1726 if (!s.ok()) break;
1727 s = d->MaybeRewriteGraph(graph);
1728 if (!s.ok()) {
1729 break;
1730 }
1731 }
1732 *flib_def = std::move(client_graph->flib_def);
1733 std::swap(*input_types, client_graph->feed_types);
1734 std::swap(*output_types, client_graph->fetch_types);
1735 return s;
1736 }
1737
ListDevices(std::vector<DeviceAttributes> * response)1738 ::tensorflow::Status DirectSession::ListDevices(
1739 std::vector<DeviceAttributes>* response) {
1740 response->clear();
1741 response->reserve(devices_.size());
1742 for (Device* d : devices_) {
1743 const DeviceAttributes& attrs = d->attributes();
1744 response->emplace_back(attrs);
1745 }
1746 return ::tensorflow::Status::OK();
1747 }
1748
Reset(const std::vector<string> & containers)1749 ::tensorflow::Status DirectSession::Reset(
1750 const std::vector<string>& containers) {
1751 device_mgr_->ClearContainers(containers);
1752 return ::tensorflow::Status::OK();
1753 }
1754
Close()1755 ::tensorflow::Status DirectSession::Close() {
1756 cancellation_manager_->StartCancel();
1757 {
1758 mutex_lock l(closed_lock_);
1759 if (closed_) return ::tensorflow::Status::OK();
1760 closed_ = true;
1761 }
1762 if (factory_ != nullptr) factory_->Deregister(this);
1763 return ::tensorflow::Status::OK();
1764 }
1765
RunState(int64 step_id,const std::vector<Device * > * devices)1766 DirectSession::RunState::RunState(int64 step_id,
1767 const std::vector<Device*>* devices)
1768 : step_container(step_id, [devices, step_id](const string& name) {
1769 for (auto d : *devices) {
1770 if (!d->resource_manager()->Cleanup(name).ok()) {
1771 // Do nothing...
1772 }
1773 ScopedAllocatorMgr* sam = d->GetScopedAllocatorMgr();
1774 if (sam) sam->Cleanup(step_id);
1775 }
1776 }) {}
1777
PartialRunState(const std::vector<string> & pending_input_names,const std::vector<string> & pending_output_names,int64 step_id,const std::vector<Device * > * devices)1778 DirectSession::PartialRunState::PartialRunState(
1779 const std::vector<string>& pending_input_names,
1780 const std::vector<string>& pending_output_names, int64 step_id,
1781 const std::vector<Device*>* devices)
1782 : RunState(step_id, devices) {
1783 // Initially all the feeds and fetches are pending.
1784 for (auto& name : pending_input_names) {
1785 pending_inputs[name] = false;
1786 }
1787 for (auto& name : pending_output_names) {
1788 pending_outputs[name] = false;
1789 }
1790 }
1791
~PartialRunState()1792 DirectSession::PartialRunState::~PartialRunState() {
1793 if (rendez != nullptr) {
1794 rendez->StartAbort(errors::Cancelled("PRun cancellation"));
1795 executors_done.WaitForNotification();
1796 }
1797 }
1798
PendingDone() const1799 bool DirectSession::PartialRunState::PendingDone() const {
1800 for (const auto& it : pending_inputs) {
1801 if (!it.second) return false;
1802 }
1803 for (const auto& it : pending_outputs) {
1804 if (!it.second) return false;
1805 }
1806 return true;
1807 }
1808
WaitForNotification(Notification * n,RunState * run_state,CancellationManager * cm,int64 timeout_in_ms)1809 void DirectSession::WaitForNotification(Notification* n, RunState* run_state,
1810 CancellationManager* cm,
1811 int64 timeout_in_ms) {
1812 const Status status = WaitForNotification(n, timeout_in_ms);
1813 if (!status.ok()) {
1814 {
1815 mutex_lock l(run_state->mu);
1816 run_state->status.Update(status);
1817 }
1818 cm->StartCancel();
1819 // We must wait for the executors to complete, because they have borrowed
1820 // references to `cm` and other per-step state. After this notification, it
1821 // is safe to clean up the step.
1822 n->WaitForNotification();
1823 }
1824 }
1825
WaitForNotification(Notification * notification,int64 timeout_in_ms)1826 ::tensorflow::Status DirectSession::WaitForNotification(
1827 Notification* notification, int64 timeout_in_ms) {
1828 if (timeout_in_ms > 0) {
1829 const int64 timeout_in_us = timeout_in_ms * 1000;
1830 const bool notified =
1831 WaitForNotificationWithTimeout(notification, timeout_in_us);
1832 if (!notified) {
1833 return Status(error::DEADLINE_EXCEEDED,
1834 "Timed out waiting for notification");
1835 }
1836 } else {
1837 notification->WaitForNotification();
1838 }
1839 return Status::OK();
1840 }
1841
MakeCallable(const CallableOptions & callable_options,CallableHandle * out_handle)1842 Status DirectSession::MakeCallable(const CallableOptions& callable_options,
1843 CallableHandle* out_handle) {
1844 TF_RETURN_IF_ERROR(CheckNotClosed());
1845 TF_RETURN_IF_ERROR(CheckGraphCreated("MakeCallable()"));
1846
1847 std::unique_ptr<ExecutorsAndKeys> ek;
1848 std::unique_ptr<FunctionInfo> func_info;
1849 RunStateArgs run_state_args(callable_options.run_options().debug_options());
1850 TF_RETURN_IF_ERROR(
1851 CreateExecutors(callable_options, &ek, &func_info, &run_state_args));
1852 {
1853 mutex_lock l(callables_lock_);
1854 *out_handle = next_callable_handle_++;
1855 callables_[*out_handle] = {std::move(ek), std::move(func_info)};
1856 }
1857 return Status::OK();
1858 }
1859
1860 class DirectSession::RunCallableCallFrame : public CallFrameInterface {
1861 public:
RunCallableCallFrame(DirectSession * session,ExecutorsAndKeys * executors_and_keys,const std::vector<Tensor> * feed_tensors,std::vector<Tensor> * fetch_tensors)1862 RunCallableCallFrame(DirectSession* session,
1863 ExecutorsAndKeys* executors_and_keys,
1864 const std::vector<Tensor>* feed_tensors,
1865 std::vector<Tensor>* fetch_tensors)
1866 : session_(session),
1867 executors_and_keys_(executors_and_keys),
1868 feed_tensors_(feed_tensors),
1869 fetch_tensors_(fetch_tensors) {}
1870
num_args() const1871 size_t num_args() const override {
1872 return executors_and_keys_->input_types.size();
1873 }
num_retvals() const1874 size_t num_retvals() const override {
1875 return executors_and_keys_->output_types.size();
1876 }
1877
GetArg(int index,const Tensor ** val)1878 Status GetArg(int index, const Tensor** val) override {
1879 if (TF_PREDICT_FALSE(index > feed_tensors_->size())) {
1880 return errors::Internal("Args index out of bounds: ", index);
1881 } else {
1882 *val = &(*feed_tensors_)[index];
1883 }
1884 return Status::OK();
1885 }
1886
SetRetval(int index,const Tensor & val)1887 Status SetRetval(int index, const Tensor& val) override {
1888 if (index > fetch_tensors_->size()) {
1889 return errors::Internal("RetVal index out of bounds: ", index);
1890 }
1891 (*fetch_tensors_)[index] = val;
1892 return Status::OK();
1893 }
1894
1895 private:
1896 DirectSession* const session_; // Not owned.
1897 ExecutorsAndKeys* const executors_and_keys_; // Not owned.
1898 const std::vector<Tensor>* const feed_tensors_; // Not owned.
1899 std::vector<Tensor>* const fetch_tensors_; // Not owned.
1900 };
1901
RunCallable(CallableHandle handle,const std::vector<Tensor> & feed_tensors,std::vector<Tensor> * fetch_tensors,RunMetadata * run_metadata)1902 ::tensorflow::Status DirectSession::RunCallable(
1903 CallableHandle handle, const std::vector<Tensor>& feed_tensors,
1904 std::vector<Tensor>* fetch_tensors, RunMetadata* run_metadata) {
1905 return RunCallable(handle, feed_tensors, fetch_tensors, run_metadata,
1906 thread::ThreadPoolOptions());
1907 }
1908
RunCallable(CallableHandle handle,const std::vector<Tensor> & feed_tensors,std::vector<Tensor> * fetch_tensors,RunMetadata * run_metadata,const thread::ThreadPoolOptions & threadpool_options)1909 ::tensorflow::Status DirectSession::RunCallable(
1910 CallableHandle handle, const std::vector<Tensor>& feed_tensors,
1911 std::vector<Tensor>* fetch_tensors, RunMetadata* run_metadata,
1912 const thread::ThreadPoolOptions& threadpool_options) {
1913 TF_RETURN_IF_ERROR(CheckNotClosed());
1914 TF_RETURN_IF_ERROR(CheckGraphCreated("RunCallable()"));
1915 direct_session_runs->GetCell()->IncrementBy(1);
1916
1917 // Check if we already have an executor for these arguments.
1918 std::shared_ptr<ExecutorsAndKeys> executors_and_keys;
1919 const int64 step_id = step_id_counter_.fetch_add(1);
1920
1921 {
1922 tf_shared_lock l(callables_lock_);
1923 if (handle >= next_callable_handle_) {
1924 return errors::InvalidArgument("No such callable handle: ", handle);
1925 }
1926 executors_and_keys = callables_[handle].executors_and_keys;
1927 }
1928
1929 if (!executors_and_keys) {
1930 return errors::InvalidArgument(
1931 "Attempted to run callable after handle was released: ", handle);
1932 }
1933
1934 // NOTE(mrry): Debug options are not currently supported in the
1935 // callable interface.
1936 DebugOptions debug_options;
1937 RunStateArgs run_state_args(debug_options);
1938
1939 // Configure a call frame for the step, which we use to feed and
1940 // fetch values to and from the executors.
1941 if (feed_tensors.size() != executors_and_keys->input_types.size()) {
1942 return errors::InvalidArgument(
1943 "Expected ", executors_and_keys->input_types.size(),
1944 " feed tensors, but got ", feed_tensors.size());
1945 }
1946 if (fetch_tensors != nullptr) {
1947 fetch_tensors->resize(executors_and_keys->output_types.size());
1948 } else if (!executors_and_keys->output_types.empty()) {
1949 return errors::InvalidArgument(
1950 "`fetch_tensors` must be provided when the callable has one or more "
1951 "outputs.");
1952 }
1953
1954 size_t input_size = 0;
1955 bool any_resource_feeds = false;
1956 for (auto& tensor : feed_tensors) {
1957 input_size += tensor.AllocatedBytes();
1958 any_resource_feeds = any_resource_feeds || tensor.dtype() == DT_RESOURCE;
1959 }
1960 metrics::RecordGraphInputTensors(input_size);
1961
1962 std::unique_ptr<std::vector<Tensor>> converted_feed_tensors;
1963 const std::vector<Tensor>* actual_feed_tensors;
1964
1965 if (TF_PREDICT_FALSE(any_resource_feeds)) {
1966 converted_feed_tensors = absl::make_unique<std::vector<Tensor>>();
1967 converted_feed_tensors->reserve(feed_tensors.size());
1968 for (const Tensor& t : feed_tensors) {
1969 if (t.dtype() == DT_RESOURCE) {
1970 converted_feed_tensors->emplace_back();
1971 Tensor* tensor_from_handle = &converted_feed_tensors->back();
1972 TF_RETURN_IF_ERROR(ResourceHandleToInputTensor(t, tensor_from_handle));
1973 } else {
1974 converted_feed_tensors->emplace_back(t);
1975 }
1976 }
1977 actual_feed_tensors = converted_feed_tensors.get();
1978 } else {
1979 actual_feed_tensors = &feed_tensors;
1980 }
1981
1982 // A specialized CallFrame implementation that takes advantage of the
1983 // optimized RunCallable interface.
1984 RunCallableCallFrame call_frame(this, executors_and_keys.get(),
1985 actual_feed_tensors, fetch_tensors);
1986
1987 if (LogMemory::IsEnabled()) {
1988 LogMemory::RecordStep(step_id, run_state_args.handle);
1989 }
1990
1991 TF_RETURN_IF_ERROR(RunInternal(
1992 step_id, executors_and_keys->callable_options.run_options(), &call_frame,
1993 executors_and_keys.get(), run_metadata, threadpool_options));
1994
1995 if (fetch_tensors != nullptr) {
1996 size_t output_size = 0;
1997 for (auto& tensor : *fetch_tensors) {
1998 output_size += tensor.AllocatedBytes();
1999 }
2000 metrics::RecordGraphOutputTensors(output_size);
2001 }
2002
2003 return Status::OK();
2004 }
2005
ReleaseCallable(CallableHandle handle)2006 ::tensorflow::Status DirectSession::ReleaseCallable(CallableHandle handle) {
2007 mutex_lock l(callables_lock_);
2008 if (handle >= next_callable_handle_) {
2009 return errors::InvalidArgument("No such callable handle: ", handle);
2010 }
2011 callables_.erase(handle);
2012 return Status::OK();
2013 }
2014
Finalize()2015 Status DirectSession::Finalize() {
2016 mutex_lock l(graph_state_lock_);
2017 if (finalized_) {
2018 return errors::FailedPrecondition("Session already finalized.");
2019 }
2020 if (!graph_created_) {
2021 return errors::FailedPrecondition("Session not yet created.");
2022 }
2023 execution_state_.reset();
2024 flib_def_.reset();
2025 finalized_ = true;
2026 return Status::OK();
2027 }
2028
~Callable()2029 DirectSession::Callable::~Callable() {
2030 // We must delete the fields in this order, because the destructor
2031 // of `executors_and_keys` will call into an object owned by
2032 // `function_info` (in particular, when deleting a kernel, it relies
2033 // on the `FunctionLibraryRuntime` to know if the kernel is stateful
2034 // or not).
2035 executors_and_keys.reset();
2036 function_info.reset();
2037 }
2038
2039 } // namespace tensorflow
2040