1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/common_runtime/direct_session.h"
17
18 #include <atomic>
19 #include <string>
20 #include <vector>
21
22 #include "tensorflow/core/common_runtime/collective_executor_mgr.h"
23 #include "tensorflow/core/common_runtime/collective_param_resolver_local.h"
24 #include "tensorflow/core/common_runtime/constant_folding.h"
25 #include "tensorflow/core/common_runtime/debugger_state_interface.h"
26 #include "tensorflow/core/common_runtime/device_factory.h"
27 #include "tensorflow/core/common_runtime/device_resolver_local.h"
28 #include "tensorflow/core/common_runtime/executor.h"
29 #include "tensorflow/core/common_runtime/executor_factory.h"
30 #include "tensorflow/core/common_runtime/function.h"
31 #include "tensorflow/core/common_runtime/graph_optimizer.h"
32 #include "tensorflow/core/common_runtime/memory_types.h"
33 #include "tensorflow/core/common_runtime/metrics.h"
34 #include "tensorflow/core/common_runtime/optimization_registry.h"
35 #include "tensorflow/core/common_runtime/process_util.h"
36 #include "tensorflow/core/common_runtime/scoped_allocator_mgr.h"
37 #include "tensorflow/core/common_runtime/step_stats_collector.h"
38 #include "tensorflow/core/framework/function.h"
39 #include "tensorflow/core/framework/graph.pb_text.h"
40 #include "tensorflow/core/framework/graph.pb.h"
41 #include "tensorflow/core/framework/graph_def_util.h"
42 #include "tensorflow/core/framework/log_memory.h"
43 #include "tensorflow/core/framework/node_def.pb.h"
44 #include "tensorflow/core/framework/run_handler.h"
45 #include "tensorflow/core/framework/tensor.h"
46 #include "tensorflow/core/framework/versions.pb.h"
47 #include "tensorflow/core/graph/algorithm.h"
48 #include "tensorflow/core/graph/graph.h"
49 #include "tensorflow/core/graph/graph_constructor.h"
50 #include "tensorflow/core/graph/graph_partition.h"
51 #include "tensorflow/core/graph/subgraph.h"
52 #include "tensorflow/core/graph/tensor_id.h"
53 #include "tensorflow/core/lib/core/errors.h"
54 #include "tensorflow/core/lib/core/notification.h"
55 #include "tensorflow/core/lib/core/refcount.h"
56 #include "tensorflow/core/lib/core/status.h"
57 #include "tensorflow/core/lib/core/threadpool.h"
58 #include "tensorflow/core/lib/gtl/array_slice.h"
59 #include "tensorflow/core/lib/gtl/stl_util.h"
60 #include "tensorflow/core/lib/monitoring/counter.h"
61 #include "tensorflow/core/lib/random/random.h"
62 #include "tensorflow/core/lib/strings/numbers.h"
63 #include "tensorflow/core/lib/strings/str_util.h"
64 #include "tensorflow/core/lib/strings/strcat.h"
65 #include "tensorflow/core/platform/byte_order.h"
66 #include "tensorflow/core/platform/device_tracer.h"
67 #include "tensorflow/core/platform/logging.h"
68 #include "tensorflow/core/platform/mutex.h"
69 #include "tensorflow/core/platform/tracing.h"
70 #include "tensorflow/core/platform/types.h"
71 #include "tensorflow/core/util/device_name_utils.h"
72 #include "tensorflow/core/util/env_var.h"
73
74 namespace tensorflow {
75
76 namespace {
77
78 auto* direct_session_runs = monitoring::Counter<0>::New(
79 "/tensorflow/core/direct_session_runs",
80 "The number of times DirectSession::Run() has been called.");
81
NewThreadPoolFromThreadPoolOptions(const SessionOptions & options,const ThreadPoolOptionProto & thread_pool_options,int pool_number,thread::ThreadPool ** pool,bool * owned)82 Status NewThreadPoolFromThreadPoolOptions(
83 const SessionOptions& options,
84 const ThreadPoolOptionProto& thread_pool_options, int pool_number,
85 thread::ThreadPool** pool, bool* owned) {
86 int32 num_threads = thread_pool_options.num_threads();
87 if (num_threads == 0) {
88 num_threads = NumInterOpThreadsFromSessionOptions(options);
89 }
90 const string& name = thread_pool_options.global_name();
91 if (name.empty()) {
92 // Session-local threadpool.
93 VLOG(1) << "Direct session inter op parallelism threads for pool "
94 << pool_number << ": " << num_threads;
95 *pool = new thread::ThreadPool(
96 options.env, strings::StrCat("Compute", pool_number), num_threads);
97 *owned = true;
98 return Status::OK();
99 }
100
101 // Global, named threadpool.
102 typedef std::pair<int32, thread::ThreadPool*> MapValue;
103 static std::map<string, MapValue>* global_pool_map =
104 new std::map<string, MapValue>;
105 static mutex* mu = new mutex();
106 mutex_lock l(*mu);
107 MapValue* mvalue = &(*global_pool_map)[name];
108 if (mvalue->second == nullptr) {
109 mvalue->first = thread_pool_options.num_threads();
110 mvalue->second = new thread::ThreadPool(
111 options.env, strings::StrCat("Compute", pool_number), num_threads);
112 } else {
113 if (mvalue->first != thread_pool_options.num_threads()) {
114 return errors::InvalidArgument(
115 "Pool ", name,
116 " configured previously with num_threads=", mvalue->first,
117 "; cannot re-configure with num_threads=",
118 thread_pool_options.num_threads());
119 }
120 }
121 *owned = false;
122 *pool = mvalue->second;
123 return Status::OK();
124 }
125
GlobalThreadPool(const SessionOptions & options)126 thread::ThreadPool* GlobalThreadPool(const SessionOptions& options) {
127 static thread::ThreadPool* const thread_pool =
128 NewThreadPoolFromSessionOptions(options);
129 return thread_pool;
130 }
131
132 // TODO(vrv): Figure out how to unify the many different functions
133 // that generate RendezvousKey, since many of them have to be
134 // consistent with each other.
GetRendezvousKey(const string & tensor_name,const DeviceAttributes & device_info,const FrameAndIter & frame_iter)135 string GetRendezvousKey(const string& tensor_name,
136 const DeviceAttributes& device_info,
137 const FrameAndIter& frame_iter) {
138 return strings::StrCat(device_info.name(), ";",
139 strings::FpToString(device_info.incarnation()), ";",
140 device_info.name(), ";", tensor_name, ";",
141 frame_iter.frame_id, ":", frame_iter.iter_id);
142 }
143
144 } // namespace
145
146 class DirectSessionFactory : public SessionFactory {
147 public:
DirectSessionFactory()148 DirectSessionFactory() {}
149
AcceptsOptions(const SessionOptions & options)150 bool AcceptsOptions(const SessionOptions& options) override {
151 return options.target.empty();
152 }
153
NewSession(const SessionOptions & options,Session ** out_session)154 Status NewSession(const SessionOptions& options,
155 Session** out_session) override {
156 // Must do this before the CPU allocator is created.
157 if (options.config.graph_options().build_cost_model() > 0) {
158 EnableCPUAllocatorFullStats(true);
159 }
160 std::vector<std::unique_ptr<Device>> devices;
161 TF_RETURN_IF_ERROR(DeviceFactory::AddDevices(
162 options, "/job:localhost/replica:0/task:0", &devices));
163
164 DirectSession* session =
165 new DirectSession(options, new DeviceMgr(std::move(devices)), this);
166 {
167 mutex_lock l(sessions_lock_);
168 sessions_.push_back(session);
169 }
170 *out_session = session;
171 return Status::OK();
172 }
173
Reset(const SessionOptions & options,const std::vector<string> & containers)174 Status Reset(const SessionOptions& options,
175 const std::vector<string>& containers) override {
176 std::vector<DirectSession*> sessions_to_reset;
177 {
178 mutex_lock l(sessions_lock_);
179 // We create a copy to ensure that we don't have a deadlock when
180 // session->Close calls the DirectSessionFactory.Deregister, which
181 // acquires sessions_lock_.
182 std::swap(sessions_to_reset, sessions_);
183 }
184 Status s;
185 for (auto session : sessions_to_reset) {
186 s.Update(session->Reset(containers));
187 }
188 // TODO(suharshs): Change the Reset behavior of all SessionFactories so that
189 // it doesn't close the sessions?
190 for (auto session : sessions_to_reset) {
191 s.Update(session->Close());
192 }
193 return s;
194 }
195
Deregister(const DirectSession * session)196 void Deregister(const DirectSession* session) {
197 mutex_lock l(sessions_lock_);
198 sessions_.erase(std::remove(sessions_.begin(), sessions_.end(), session),
199 sessions_.end());
200 }
201
202 private:
203 mutex sessions_lock_;
204 std::vector<DirectSession*> sessions_ GUARDED_BY(sessions_lock_);
205 };
206
207 class DirectSessionRegistrar {
208 public:
DirectSessionRegistrar()209 DirectSessionRegistrar() {
210 SessionFactory::Register("DIRECT_SESSION", new DirectSessionFactory());
211 }
212 };
213 static DirectSessionRegistrar registrar;
214
215 std::atomic_int_fast64_t DirectSession::step_id_counter_(1);
216
217 // NOTE: On Android with a single device, there is never
218 // a risk of an OpKernel blocking indefinitely:
219 //
220 // 1) No operations do I/O that depends on other simultaneous kernels,
221 //
222 // 2) Recv nodes always complete immediately: The inputs are sent into
223 // the local rendezvous before we start the executor, so the
224 // corresponding recvs will not block.
225 //
226 // Based on these assumptions, we can use the same thread pool for
227 // both "non-blocking" and "blocking" OpKernels on Android.
228 //
229 // This may change down the road when we add support for multiple
230 // devices that run concurrently, in which case we will need to
231 // revisit this decision.
SchedClosure(thread::ThreadPool * pool,std::function<void ()> c)232 void DirectSession::SchedClosure(thread::ThreadPool* pool,
233 std::function<void()> c) {
234 // TODO(sanjay): Get rid of __ANDROID__ path
235 #ifdef __ANDROID__
236 // On Android, there is no implementation of ThreadPool that takes
237 // std::function, only Closure, which we cannot easily convert.
238 //
239 // Instead, we just run the function in-line, which is currently
240 // safe given the reasoning above.
241 c();
242 #else
243 if (pool != nullptr) {
244 pool->Schedule(std::move(c));
245 } else {
246 c();
247 }
248 #endif // __ANDROID__
249 }
250
GetOrCreateRunHandlerPool(const SessionOptions & options)251 static RunHandlerPool* GetOrCreateRunHandlerPool(
252 const SessionOptions& options) {
253 static RunHandlerPool* pool =
254 new RunHandlerPool(NumInterOpThreadsFromSessionOptions(options));
255 return pool;
256 }
257
ShouldUseRunHandlerPool(const RunOptions & run_options) const258 bool DirectSession::ShouldUseRunHandlerPool(
259 const RunOptions& run_options) const {
260 if (options_.config.use_per_session_threads()) return false;
261 if (options_.config.session_inter_op_thread_pool_size() > 0 &&
262 run_options.inter_op_thread_pool() > 0)
263 return false;
264 // Only use RunHandlerPool when:
265 // a. Single global thread pool is used for inter-op parallelism.
266 // b. When multiple inter_op_thread_pool(s) are created, use it only while
267 // running sessions on the default inter_op_thread_pool=0. Typically,
268 // servo-team uses inter_op_thread_pool > 0 for model loading.
269 // TODO(crk): Revisit whether we'd want to create one (static) RunHandlerPool
270 // per entry in session_inter_op_thread_pool() in the future.
271 return true;
272 }
273
DirectSession(const SessionOptions & options,const DeviceMgr * device_mgr,DirectSessionFactory * const factory)274 DirectSession::DirectSession(const SessionOptions& options,
275 const DeviceMgr* device_mgr,
276 DirectSessionFactory* const factory)
277 : options_(options),
278 device_mgr_(device_mgr),
279 factory_(factory),
280 cancellation_manager_(new CancellationManager()),
281 operation_timeout_in_ms_(options_.config.operation_timeout_in_ms()) {
282 const int thread_pool_size =
283 options_.config.session_inter_op_thread_pool_size();
284 if (thread_pool_size > 0) {
285 for (int i = 0; i < thread_pool_size; ++i) {
286 thread::ThreadPool* pool = nullptr;
287 bool owned = false;
288 init_error_.Update(NewThreadPoolFromThreadPoolOptions(
289 options_, options_.config.session_inter_op_thread_pool(i), i, &pool,
290 &owned));
291 thread_pools_.emplace_back(pool, owned);
292 }
293 } else if (options_.config.use_per_session_threads()) {
294 thread_pools_.emplace_back(NewThreadPoolFromSessionOptions(options_),
295 true /* owned */);
296 } else {
297 thread_pools_.emplace_back(GlobalThreadPool(options), false /* owned */);
298 }
299 // The default value of sync_on_finish will be flipped soon and this
300 // environment variable will be removed as well.
301 const Status status =
302 ReadBoolFromEnvVar("TF_SYNC_ON_FINISH", true, &sync_on_finish_);
303 if (!status.ok()) {
304 LOG(ERROR) << status.error_message();
305 }
306 session_handle_ =
307 strings::StrCat("direct", strings::FpToString(random::New64()));
308 int devices_added = 0;
309 if (options.config.log_device_placement()) {
310 const string mapping_str = device_mgr_->DeviceMappingString();
311 if (mapping_str.empty()) {
312 printf("Device mapping: no known devices.\n");
313 } else {
314 printf("Device mapping:\n%s", mapping_str.c_str());
315 }
316 LOG(INFO) << "Device mapping:\n" << mapping_str;
317 }
318 for (auto d : device_mgr_->ListDevices()) {
319 devices_.push_back(d);
320 device_set_.AddDevice(d);
321 d->op_segment()->AddHold(session_handle_);
322
323 // The first device added is special: it is the 'client device' (a
324 // CPU device) from which we feed and fetch Tensors.
325 if (devices_added == 0) {
326 device_set_.set_client_device(d);
327 }
328 ++devices_added;
329 }
330 }
331
~DirectSession()332 DirectSession::~DirectSession() {
333 if (!closed_) Close().IgnoreError();
334 for (auto& it : partial_runs_) {
335 it.second.reset(nullptr);
336 }
337 for (auto& it : executors_) {
338 it.second.reset();
339 }
340 callables_.clear();
341 for (auto d : device_mgr_->ListDevices()) {
342 d->op_segment()->RemoveHold(session_handle_);
343 }
344 for (auto d : device_mgr_->ListDevices()) {
345 d->ClearResourceMgr();
346 }
347 functions_.clear();
348 delete cancellation_manager_;
349 for (const auto& p_and_owned : thread_pools_) {
350 if (p_and_owned.second) delete p_and_owned.first;
351 }
352
353 execution_state_.reset(nullptr);
354 flib_def_.reset(nullptr);
355 }
356
MaybeInitializeExecutionState(const GraphDef & graph,bool * out_already_initialized)357 Status DirectSession::MaybeInitializeExecutionState(
358 const GraphDef& graph, bool* out_already_initialized) {
359 // If already initialized, do nothing.
360 if (flib_def_ && execution_state_) {
361 *out_already_initialized = true;
362 return Status::OK();
363 }
364 // Set up the per-session execution state.
365 // NOTE(mrry): The function library created here will be used for
366 // all subsequent extensions of the graph.
367 flib_def_.reset(
368 new FunctionLibraryDefinition(OpRegistry::Global(), graph.library()));
369 GraphExecutionStateOptions options;
370 options.device_set = &device_set_;
371 options.session_options = &options_;
372 options.session_handle = session_handle_;
373 // TODO(mrry,suharshs): We explicitly copy `graph` so that
374 // `MakeForBaseGraph()` can take ownership of its
375 // contents. Previously this happened implicitly in calls to the
376 // `GraphExecutionState`. Other sessions call
377 // `MakeForBaseGraph` in such a way that we can destructively read
378 // the passed-in `GraphDef`. In principle we could do the same here,
379 // with a wider refactoring; we might revise the direct session so
380 // that it copies the graph fewer times.
381 GraphDef temp(graph);
382 TF_RETURN_IF_ERROR(
383 GraphExecutionState::MakeForBaseGraph(&temp, options, &execution_state_));
384 graph_created_ = true;
385 *out_already_initialized = false;
386 return Status::OK();
387 }
388
Create(const GraphDef & graph)389 Status DirectSession::Create(const GraphDef& graph) {
390 TF_RETURN_IF_ERROR(init_error_);
391 if (graph.node_size() > 0) {
392 mutex_lock l(graph_state_lock_);
393 if (graph_created_) {
394 return errors::AlreadyExists(
395 "A Graph has already been created for this session.");
396 }
397 return ExtendLocked(graph);
398 }
399 return Status::OK();
400 }
401
Extend(const GraphDef & graph)402 Status DirectSession::Extend(const GraphDef& graph) {
403 TF_RETURN_IF_ERROR(CheckNotClosed());
404 mutex_lock l(graph_state_lock_);
405 return ExtendLocked(graph);
406 }
407
ExtendLocked(const GraphDef & graph)408 Status DirectSession::ExtendLocked(const GraphDef& graph) {
409 bool already_initialized;
410 // If this is the first call, we can initialize the execution state
411 // with `graph` and do not need to call `Extend()`.
412 TF_RETURN_IF_ERROR(
413 MaybeInitializeExecutionState(graph, &already_initialized));
414 if (already_initialized) {
415 TF_RETURN_IF_ERROR(flib_def_->AddLibrary(graph.library()));
416 std::unique_ptr<GraphExecutionState> state;
417 TF_RETURN_IF_ERROR(execution_state_->Extend(graph, &state));
418 execution_state_.swap(state);
419 }
420 return Status::OK();
421 }
422
Run(const NamedTensorList & inputs,const std::vector<string> & output_names,const std::vector<string> & target_nodes,std::vector<Tensor> * outputs)423 Status DirectSession::Run(const NamedTensorList& inputs,
424 const std::vector<string>& output_names,
425 const std::vector<string>& target_nodes,
426 std::vector<Tensor>* outputs) {
427 RunMetadata run_metadata;
428 return Run(RunOptions(), inputs, output_names, target_nodes, outputs,
429 &run_metadata);
430 }
431
CreateDebuggerState(const CallableOptions & callable_options,int64 global_step,int64 session_run_index,int64 executor_step_index,std::unique_ptr<DebuggerStateInterface> * debugger_state)432 Status DirectSession::CreateDebuggerState(
433 const CallableOptions& callable_options, int64 global_step,
434 int64 session_run_index, int64 executor_step_index,
435 std::unique_ptr<DebuggerStateInterface>* debugger_state) {
436 TF_RETURN_IF_ERROR(DebuggerStateRegistry::CreateState(
437 callable_options.run_options().debug_options(), debugger_state));
438 std::vector<string> input_names(callable_options.feed().begin(),
439 callable_options.feed().end());
440 std::vector<string> output_names(callable_options.fetch().begin(),
441 callable_options.fetch().end());
442 std::vector<string> target_names(callable_options.target().begin(),
443 callable_options.target().end());
444
445 TF_RETURN_IF_ERROR(debugger_state->get()->PublishDebugMetadata(
446 global_step, session_run_index, executor_step_index, input_names,
447 output_names, target_names));
448 return Status::OK();
449 }
450
DecorateAndPublishGraphForDebug(const DebugOptions & debug_options,Graph * graph,Device * device)451 Status DirectSession::DecorateAndPublishGraphForDebug(
452 const DebugOptions& debug_options, Graph* graph, Device* device) {
453 std::unique_ptr<DebugGraphDecoratorInterface> decorator;
454 TF_RETURN_IF_ERROR(
455 DebugGraphDecoratorRegistry::CreateDecorator(debug_options, &decorator));
456
457 TF_RETURN_IF_ERROR(decorator->DecorateGraph(graph, device));
458 TF_RETURN_IF_ERROR(decorator->PublishGraph(*graph, device->name()));
459 return Status::OK();
460 }
461
RunInternal(int64 step_id,const RunOptions & run_options,CallFrameInterface * call_frame,ExecutorsAndKeys * executors_and_keys,RunMetadata * run_metadata)462 Status DirectSession::RunInternal(int64 step_id, const RunOptions& run_options,
463 CallFrameInterface* call_frame,
464 ExecutorsAndKeys* executors_and_keys,
465 RunMetadata* run_metadata) {
466 const uint64 start_time_usecs = Env::Default()->NowMicros();
467 string session_id_meta = strings::StrCat("SessionRun #id=", step_id, "#");
468 tracing::ScopedActivity activity(session_id_meta);
469
470 const int64 executor_step_count = executors_and_keys->step_count.fetch_add(1);
471
472 std::unique_ptr<DebuggerStateInterface> debugger_state;
473 if (!run_options.debug_options().debug_tensor_watch_opts().empty()) {
474 TF_RETURN_IF_ERROR(
475 CreateDebuggerState(executors_and_keys->callable_options,
476 run_options.debug_options().global_step(), step_id,
477 executor_step_count, &debugger_state));
478 }
479
480 // Create a run state and start execution.
481 RunState run_state(step_id, &devices_);
482 run_state.rendez = new IntraProcessRendezvous(device_mgr_.get());
483 #ifndef __ANDROID__
484 // Set up for collectives if ExecutorsAndKeys declares a key.
485 if (executors_and_keys->collective_graph_key !=
486 BuildGraphOptions::kNoCollectiveGraphKey) {
487 if (run_options.experimental().collective_graph_key() !=
488 BuildGraphOptions::kNoCollectiveGraphKey) {
489 // If a collective_graph_key was specified in run_options, ensure that it
490 // matches what came out of GraphExecutionState::BuildGraph().
491 if (run_options.experimental().collective_graph_key() !=
492 executors_and_keys->collective_graph_key) {
493 return errors::Internal(
494 "collective_graph_key in RunOptions ",
495 run_options.experimental().collective_graph_key(),
496 " should match collective_graph_key from optimized graph ",
497 executors_and_keys->collective_graph_key);
498 }
499 }
500 if (!collective_executor_mgr_) {
501 std::unique_ptr<DeviceResolverInterface> drl(
502 new DeviceResolverLocal(device_mgr_.get()));
503 std::unique_ptr<ParamResolverInterface> cprl(
504 new CollectiveParamResolverLocal(options_.config, device_mgr_.get(),
505 drl.get(),
506 "/job:localhost/replica:0/task:0"));
507 collective_executor_mgr_.reset(new CollectiveExecutorMgr(
508 options_.config, device_mgr_.get(), std::move(drl), std::move(cprl)));
509 }
510 run_state.collective_executor.reset(new CollectiveExecutor::Handle(
511 collective_executor_mgr_->FindOrCreate(step_id), true /*inherit_ref*/));
512 }
513 #endif
514
515 // Start parallel Executors.
516 const size_t num_executors = executors_and_keys->items.size();
517 ExecutorBarrier* barrier = new ExecutorBarrier(
518 num_executors, run_state.rendez, [&run_state](const Status& ret) {
519 {
520 mutex_lock l(run_state.mu_);
521 run_state.status.Update(ret);
522 }
523 run_state.executors_done.Notify();
524 });
525
526 Executor::Args args;
527 args.step_id = step_id;
528 args.call_frame = call_frame;
529 args.rendezvous = run_state.rendez;
530 args.collective_executor =
531 (run_state.collective_executor ? run_state.collective_executor->get()
532 : nullptr);
533 CancellationManager step_cancellation_manager;
534 args.cancellation_manager = &step_cancellation_manager;
535 args.session_state = &session_state_;
536 args.session_handle = session_handle_;
537 args.tensor_store = &run_state.tensor_store;
538 args.step_container = &run_state.step_container;
539 args.sync_on_finish = sync_on_finish_;
540
541 const bool do_trace = (run_options.trace_level() > RunOptions::NO_TRACE);
542
543 bool update_cost_model = false;
544 if (options_.config.graph_options().build_cost_model() > 0) {
545 const int64 build_cost_model_every =
546 options_.config.graph_options().build_cost_model();
547 const int64 build_cost_model_after =
548 options_.config.graph_options().build_cost_model_after();
549 int64 measure_step_count = executor_step_count - build_cost_model_after;
550 if (measure_step_count >= 0) {
551 update_cost_model =
552 ((measure_step_count + 1) % build_cost_model_every == 0);
553 }
554 }
555 if (do_trace || update_cost_model ||
556 run_options.report_tensor_allocations_upon_oom()) {
557 run_state.collector.reset(
558 new StepStatsCollector(run_metadata->mutable_step_stats()));
559 args.stats_collector = run_state.collector.get();
560 }
561
562 std::unique_ptr<DeviceTracer> tracer;
563 if (run_options.trace_level() >= RunOptions::HARDWARE_TRACE) {
564 tracer = CreateDeviceTracer();
565 // tracer may be NULL on platforms without accelerators.
566 if (tracer) {
567 Status s = tracer->Start();
568 if (!s.ok()) {
569 run_state.executors_done.Notify();
570 delete barrier;
571 return s;
572 }
573 }
574 }
575
576 if (run_options.inter_op_thread_pool() < -1 ||
577 run_options.inter_op_thread_pool() >=
578 static_cast<int32>(thread_pools_.size())) {
579 run_state.executors_done.Notify();
580 delete barrier;
581 return errors::InvalidArgument("Invalid inter_op_thread_pool: ",
582 run_options.inter_op_thread_pool());
583 }
584
585 // Register this step with session's cancellation manager, so that
586 // `Session::Close()` will cancel the step.
587 const CancellationToken cancellation_token =
588 cancellation_manager_->get_cancellation_token();
589 const bool already_cancelled = !cancellation_manager_->RegisterCallback(
590 cancellation_token, [&step_cancellation_manager]() {
591 step_cancellation_manager.StartCancel();
592 });
593 if (already_cancelled) {
594 // NOTE(mrry): If we don't explicitly notify
595 // `run_state.executors_done`, the RunState destructor would
596 // block on this notification.
597 run_state.executors_done.Notify();
598 delete barrier;
599 return errors::Cancelled("Run call was cancelled");
600 }
601
602 thread::ThreadPool* pool =
603 run_options.inter_op_thread_pool() >= 0
604 ? thread_pools_[run_options.inter_op_thread_pool()].first
605 : nullptr;
606
607 if (pool == nullptr) {
608 // We allow using the caller thread only when having a single executor
609 // specified.
610 if (executors_and_keys->items.size() > 1) {
611 pool = thread_pools_[0].first;
612 } else {
613 VLOG(1) << "Executing Session::Run() synchronously!";
614 }
615 }
616
617 std::unique_ptr<RunHandler> handler;
618 if (ShouldUseRunHandlerPool(run_options) &&
619 run_options.experimental().use_run_handler_pool()) {
620 VLOG(1) << "Using RunHandler to scheduler inter-op closures.";
621 handler = GetOrCreateRunHandlerPool(options_)->Get();
622 }
623 auto* handler_ptr = handler.get();
624
625 Executor::Args::Runner default_runner = nullptr;
626
627 if (pool == nullptr) {
628 default_runner = [](Executor::Args::Closure c) { c(); };
629 } else if (handler_ptr != nullptr) {
630 default_runner = [handler_ptr](Executor::Args::Closure c) {
631 handler_ptr->ScheduleInterOpClosure(std::move(c));
632 };
633 } else {
634 default_runner = [this, pool](Executor::Args::Closure c) {
635 SchedClosure(pool, std::move(c));
636 };
637 }
638
639 for (const auto& item : executors_and_keys->items) {
640 // TODO(azaks): support partial run.
641 // TODO(azaks): if the device picks its own threadpool, we need to assign
642 // less threads to the main compute pool by default.
643 thread::ThreadPool* device_thread_pool =
644 item.device->tensorflow_device_thread_pool();
645 // TODO(crk): Investigate usage of RunHandlerPool when using device specific
646 // thread pool(s).
647 if (!device_thread_pool) {
648 args.runner = default_runner;
649 } else {
650 args.runner = [this, device_thread_pool](Executor::Args::Closure c) {
651 SchedClosure(device_thread_pool, std::move(c));
652 };
653 }
654 item.executor->RunAsync(args, barrier->Get());
655 }
656
657 WaitForNotification(&run_state, &step_cancellation_manager,
658 run_options.timeout_in_ms() > 0
659 ? run_options.timeout_in_ms()
660 : operation_timeout_in_ms_);
661
662 if (!cancellation_manager_->DeregisterCallback(cancellation_token)) {
663 // The step has been cancelled: make sure we don't attempt to receive the
664 // outputs as this would make it block forever.
665 mutex_lock l(run_state.mu_);
666 run_state.status.Update(errors::Cancelled("Run call was cancelled"));
667 }
668
669 if (tracer) {
670 TF_RETURN_IF_ERROR(tracer->Stop());
671 TF_RETURN_IF_ERROR(tracer->Collect(run_state.collector.get()));
672 }
673
674 {
675 mutex_lock l(run_state.mu_);
676 TF_RETURN_IF_ERROR(run_state.status);
677 }
678
679 // Save the output tensors of this run we choose to keep.
680 if (!run_state.tensor_store.empty()) {
681 TF_RETURN_IF_ERROR(run_state.tensor_store.SaveTensors(
682 {executors_and_keys->callable_options.fetch().begin(),
683 executors_and_keys->callable_options.fetch().end()},
684 &session_state_));
685 }
686
687 if (run_state.collector) {
688 run_state.collector->Finalize();
689 }
690
691 // Build and return the cost model as instructed.
692 if (update_cost_model) {
693 // Build the cost model
694 std::unordered_map<string, const Graph*> device_to_graph;
695 for (const PerPartitionExecutorsAndLib& partition :
696 executors_and_keys->items) {
697 const Graph* graph = partition.graph;
698 const string device = partition.flib->device()->name();
699 device_to_graph[device] = graph;
700 }
701
702 mutex_lock l(executor_lock_);
703 run_state.collector->BuildCostModel(&cost_model_manager_, device_to_graph);
704
705 // annotate stats onto cost graph.
706 CostGraphDef* cost_graph = run_metadata->mutable_cost_graph();
707 for (const auto& item : executors_and_keys->items) {
708 TF_RETURN_IF_ERROR(
709 cost_model_manager_.AddToCostGraphDef(item.graph, cost_graph));
710 }
711 }
712
713 // If requested via RunOptions, output the partition graphs.
714 if (run_options.output_partition_graphs()) {
715 protobuf::RepeatedPtrField<GraphDef>* partition_graph_defs =
716 run_metadata->mutable_partition_graphs();
717 for (const PerPartitionExecutorsAndLib& exec_and_lib :
718 executors_and_keys->items) {
719 GraphDef* partition_graph_def = partition_graph_defs->Add();
720 exec_and_lib.graph->ToGraphDef(partition_graph_def);
721 }
722 }
723 metrics::UpdateGraphExecTime(Env::Default()->NowMicros() - start_time_usecs);
724
725 return Status::OK();
726 }
727
Run(const RunOptions & run_options,const NamedTensorList & inputs,const std::vector<string> & output_names,const std::vector<string> & target_nodes,std::vector<Tensor> * outputs,RunMetadata * run_metadata)728 Status DirectSession::Run(const RunOptions& run_options,
729 const NamedTensorList& inputs,
730 const std::vector<string>& output_names,
731 const std::vector<string>& target_nodes,
732 std::vector<Tensor>* outputs,
733 RunMetadata* run_metadata) {
734 TF_RETURN_IF_ERROR(CheckNotClosed());
735 TF_RETURN_IF_ERROR(CheckGraphCreated("Run()"));
736 direct_session_runs->GetCell()->IncrementBy(1);
737
738 // Extract the inputs names for this run of the session.
739 std::vector<string> input_tensor_names;
740 input_tensor_names.reserve(inputs.size());
741 for (const auto& it : inputs) {
742 input_tensor_names.push_back(it.first);
743 }
744
745 // Check if we already have an executor for these arguments.
746 ExecutorsAndKeys* executors_and_keys;
747 RunStateArgs run_state_args(run_options.debug_options());
748 run_state_args.collective_graph_key =
749 run_options.experimental().collective_graph_key();
750
751 TF_RETURN_IF_ERROR(GetOrCreateExecutors(input_tensor_names, output_names,
752 target_nodes, &executors_and_keys,
753 &run_state_args));
754 {
755 mutex_lock l(collective_graph_key_lock_);
756 collective_graph_key_ = executors_and_keys->collective_graph_key;
757 }
758
759 // Configure a call frame for the step, which we use to feed and
760 // fetch values to and from the executors.
761 FunctionCallFrame call_frame(executors_and_keys->input_types,
762 executors_and_keys->output_types);
763 gtl::InlinedVector<Tensor, 4> feed_args(inputs.size());
764 for (const auto& it : inputs) {
765 if (it.second.dtype() == DT_RESOURCE) {
766 Tensor tensor_from_handle;
767 TF_RETURN_IF_ERROR(
768 ResourceHandleToInputTensor(it.second, &tensor_from_handle));
769 feed_args[executors_and_keys->input_name_to_index[it.first]] =
770 tensor_from_handle;
771 } else {
772 feed_args[executors_and_keys->input_name_to_index[it.first]] = it.second;
773 }
774 }
775 const Status s = call_frame.SetArgs(feed_args);
776 if (errors::IsInternal(s)) {
777 return errors::InvalidArgument(s.error_message());
778 } else if (!s.ok()) {
779 return s;
780 }
781
782 const int64 step_id = step_id_counter_.fetch_add(1);
783
784 if (LogMemory::IsEnabled()) {
785 LogMemory::RecordStep(step_id, run_state_args.handle);
786 }
787
788 TF_RETURN_IF_ERROR(RunInternal(step_id, run_options, &call_frame,
789 executors_and_keys, run_metadata));
790
791 // Receive outputs.
792 if (outputs) {
793 std::vector<Tensor> sorted_outputs;
794 const Status s = call_frame.ConsumeRetvals(
795 &sorted_outputs, /* allow_dead_tensors = */ false);
796 if (errors::IsInternal(s)) {
797 return errors::InvalidArgument(s.error_message());
798 } else if (!s.ok()) {
799 return s;
800 }
801 const bool unique_outputs =
802 output_names.size() == executors_and_keys->output_name_to_index.size();
803 // first_indices[i] = j implies that j is the smallest value for which
804 // output_names[i] == output_names[j].
805 std::vector<int> first_indices;
806 if (!unique_outputs) {
807 first_indices.resize(output_names.size());
808 for (int i = 0; i < output_names.size(); ++i) {
809 for (int j = 0; j <= i; ++j) {
810 if (output_names[i] == output_names[j]) {
811 first_indices[i] = j;
812 break;
813 }
814 }
815 }
816 }
817 outputs->clear();
818 outputs->reserve(sorted_outputs.size());
819 for (int i = 0; i < output_names.size(); ++i) {
820 const string& output_name = output_names[i];
821 if (first_indices.empty() || first_indices[i] == i) {
822 outputs->emplace_back(
823 std::move(sorted_outputs[executors_and_keys
824 ->output_name_to_index[output_name]]));
825 } else {
826 outputs->push_back((*outputs)[first_indices[i]]);
827 }
828 }
829 }
830
831 return Status::OK();
832 }
833
PRunSetup(const std::vector<string> & input_names,const std::vector<string> & output_names,const std::vector<string> & target_nodes,string * handle)834 Status DirectSession::PRunSetup(const std::vector<string>& input_names,
835 const std::vector<string>& output_names,
836 const std::vector<string>& target_nodes,
837 string* handle) {
838 TF_RETURN_IF_ERROR(CheckNotClosed());
839 TF_RETURN_IF_ERROR(CheckGraphCreated("PRunSetup()"));
840
841 // RunOptions is not available in PRunSetup, so use thread pool 0.
842 thread::ThreadPool* pool = thread_pools_[0].first;
843
844 // Check if we already have an executor for these arguments.
845 ExecutorsAndKeys* executors_and_keys;
846 // TODO(cais): TFDBG support for partial runs.
847 DebugOptions debug_options;
848 RunStateArgs run_state_args(debug_options);
849 run_state_args.is_partial_run = true;
850 TF_RETURN_IF_ERROR(GetOrCreateExecutors(input_names, output_names,
851 target_nodes, &executors_and_keys,
852 &run_state_args));
853
854 // Create the run state and save it for future PRun calls.
855 Executor::Args args;
856 args.step_id = step_id_counter_.fetch_add(1);
857 RunState* run_state =
858 new RunState(input_names, output_names, args.step_id, &devices_);
859 run_state->rendez = new IntraProcessRendezvous(device_mgr_.get());
860 {
861 mutex_lock l(executor_lock_);
862 if (!partial_runs_
863 .emplace(run_state_args.handle,
864 std::unique_ptr<RunState>(run_state))
865 .second) {
866 return errors::Internal("The handle '", run_state_args.handle,
867 "' created for this partial run is not unique.");
868 }
869 }
870
871 // Start parallel Executors.
872 const size_t num_executors = executors_and_keys->items.size();
873 ExecutorBarrier* barrier = new ExecutorBarrier(
874 num_executors, run_state->rendez, [run_state](const Status& ret) {
875 if (!ret.ok()) {
876 mutex_lock l(run_state->mu_);
877 run_state->status.Update(ret);
878 }
879 run_state->executors_done.Notify();
880 });
881
882 args.rendezvous = run_state->rendez;
883 args.cancellation_manager = cancellation_manager_;
884 // Note that Collectives are not supported in partial runs
885 // because RunOptions is not passed in so we can't know whether
886 // their use is intended.
887 args.collective_executor = nullptr;
888 args.runner = [this, pool](Executor::Args::Closure c) {
889 SchedClosure(pool, std::move(c));
890 };
891 args.session_state = &session_state_;
892 args.session_handle = session_handle_;
893 args.tensor_store = &run_state->tensor_store;
894 args.step_container = &run_state->step_container;
895 if (LogMemory::IsEnabled()) {
896 LogMemory::RecordStep(args.step_id, run_state_args.handle);
897 }
898 args.sync_on_finish = sync_on_finish_;
899
900 if (options_.config.graph_options().build_cost_model()) {
901 run_state->collector.reset(new StepStatsCollector(nullptr));
902 args.stats_collector = run_state->collector.get();
903 }
904
905 for (auto& item : executors_and_keys->items) {
906 item.executor->RunAsync(args, barrier->Get());
907 }
908
909 *handle = run_state_args.handle;
910 return Status::OK();
911 }
912
PRun(const string & handle,const NamedTensorList & inputs,const std::vector<string> & output_names,std::vector<Tensor> * outputs)913 Status DirectSession::PRun(const string& handle, const NamedTensorList& inputs,
914 const std::vector<string>& output_names,
915 std::vector<Tensor>* outputs) {
916 TF_RETURN_IF_ERROR(CheckNotClosed());
917 std::vector<string> parts = str_util::Split(handle, ';');
918 const string& key = parts[0];
919 // Get the executors for this partial run.
920 ExecutorsAndKeys* executors_and_keys;
921 RunState* run_state;
922 {
923 mutex_lock l(executor_lock_); // could use reader lock
924 auto exc_it = executors_.find(key);
925 if (exc_it == executors_.end()) {
926 return errors::InvalidArgument(
927 "Must run 'setup' before performing partial runs!");
928 }
929 executors_and_keys = exc_it->second.get();
930
931 auto prun_it = partial_runs_.find(handle);
932 if (prun_it == partial_runs_.end()) {
933 return errors::InvalidArgument(
934 "Must run 'setup' before performing partial runs!");
935 }
936 run_state = prun_it->second.get();
937
938 // Make sure that this is a new set of feeds that are still pending.
939 for (const auto& input : inputs) {
940 auto it = run_state->pending_inputs.find(input.first);
941 if (it == run_state->pending_inputs.end()) {
942 return errors::InvalidArgument(
943 "The feed ", input.first,
944 " was not specified in partial_run_setup.");
945 } else if (it->second) {
946 return errors::InvalidArgument("The feed ", input.first,
947 " has already been fed.");
948 }
949 }
950 // Check that this is a new set of fetches that are still pending.
951 for (const auto& output : output_names) {
952 auto it = run_state->pending_outputs.find(output);
953 if (it == run_state->pending_outputs.end()) {
954 return errors::InvalidArgument(
955 "The fetch ", output, " was not specified in partial_run_setup.");
956 } else if (it->second) {
957 return errors::InvalidArgument("The fetch ", output,
958 " has already been fetched.");
959 }
960 }
961 }
962
963 // Check that this new set of fetches can be computed from all the
964 // feeds we have supplied.
965 TF_RETURN_IF_ERROR(
966 CheckFetch(inputs, output_names, executors_and_keys, run_state));
967
968 // Send inputs.
969 Status s = SendPRunInputs(inputs, executors_and_keys, run_state->rendez);
970
971 // Receive outputs.
972 if (s.ok()) {
973 s = RecvPRunOutputs(output_names, executors_and_keys, run_state, outputs);
974 }
975
976 // Save the output tensors of this run we choose to keep.
977 if (s.ok()) {
978 s = run_state->tensor_store.SaveTensors(output_names, &session_state_);
979 }
980
981 {
982 mutex_lock l(executor_lock_);
983 // Delete the run state if there is an error or all fetches are done.
984 bool done = true;
985 if (s.ok()) {
986 {
987 mutex_lock l(run_state->mu_);
988 if (!run_state->status.ok()) {
989 LOG(WARNING) << "An error unrelated to this prun has been detected. "
990 << run_state->status;
991 }
992 }
993 for (const auto& input : inputs) {
994 auto it = run_state->pending_inputs.find(input.first);
995 it->second = true;
996 }
997 for (const auto& name : output_names) {
998 auto it = run_state->pending_outputs.find(name);
999 it->second = true;
1000 }
1001 done = run_state->PendingDone();
1002 }
1003 if (done) {
1004 WaitForNotification(run_state, cancellation_manager_,
1005 operation_timeout_in_ms_);
1006 partial_runs_.erase(handle);
1007 }
1008 }
1009
1010 return s;
1011 }
1012
ResourceHandleToInputTensor(const Tensor & resource_tensor,Tensor * retrieved_tensor)1013 Status DirectSession::ResourceHandleToInputTensor(const Tensor& resource_tensor,
1014 Tensor* retrieved_tensor) {
1015 if (resource_tensor.dtype() != DT_RESOURCE) {
1016 return errors::InvalidArgument(strings::StrCat(
1017 "ResourceHandleToInputTensor() received non-DT_RESOURCE Tensor: ",
1018 resource_tensor.dtype()));
1019 }
1020
1021 const ResourceHandle& resource_handle =
1022 resource_tensor.scalar<ResourceHandle>()();
1023
1024 if (resource_handle.container() ==
1025 SessionState::kTensorHandleResourceTypeName) {
1026 return session_state_.GetTensor(resource_handle.name(), retrieved_tensor);
1027 } else {
1028 return errors::InvalidArgument(strings::StrCat(
1029 "Invalid resource type hash code: ", resource_handle.hash_code(),
1030 "(name: ", resource_handle.name(),
1031 " type: ", resource_handle.maybe_type_name(),
1032 "). Perhaps a resource tensor was being provided as a feed? That is "
1033 "not currently allowed. Please file an issue at "
1034 "https://github.com/tensorflow/tensorflow/issues/new, ideally with a "
1035 "short code snippet that leads to this error message."));
1036 }
1037 }
1038
SendPRunInputs(const NamedTensorList & inputs,const ExecutorsAndKeys * executors_and_keys,IntraProcessRendezvous * rendez)1039 Status DirectSession::SendPRunInputs(const NamedTensorList& inputs,
1040 const ExecutorsAndKeys* executors_and_keys,
1041 IntraProcessRendezvous* rendez) {
1042 Status s;
1043 Rendezvous::ParsedKey parsed;
1044 // Insert the input tensors into the local rendezvous by their
1045 // rendezvous key.
1046 for (const auto& input : inputs) {
1047 auto it =
1048 executors_and_keys->input_name_to_rendezvous_key.find(input.first);
1049 if (it == executors_and_keys->input_name_to_rendezvous_key.end()) {
1050 return errors::Internal("'", input.first, "' is not a pre-defined feed.");
1051 }
1052 const string& input_key = it->second;
1053
1054 s = Rendezvous::ParseKey(input_key, &parsed);
1055 if (!s.ok()) {
1056 rendez->StartAbort(s);
1057 return s;
1058 }
1059
1060 if (input.second.dtype() == DT_RESOURCE) {
1061 Tensor tensor_from_handle;
1062 s = ResourceHandleToInputTensor(input.second, &tensor_from_handle);
1063 if (s.ok()) {
1064 s = rendez->Send(parsed, Rendezvous::Args(), tensor_from_handle, false);
1065 }
1066 } else {
1067 s = rendez->Send(parsed, Rendezvous::Args(), input.second, false);
1068 }
1069
1070 if (!s.ok()) {
1071 rendez->StartAbort(s);
1072 return s;
1073 }
1074 }
1075 return Status::OK();
1076 }
1077
RecvPRunOutputs(const std::vector<string> & output_names,const ExecutorsAndKeys * executors_and_keys,RunState * run_state,std::vector<Tensor> * outputs)1078 Status DirectSession::RecvPRunOutputs(
1079 const std::vector<string>& output_names,
1080 const ExecutorsAndKeys* executors_and_keys, RunState* run_state,
1081 std::vector<Tensor>* outputs) {
1082 Status s;
1083 if (!output_names.empty()) {
1084 outputs->resize(output_names.size());
1085 }
1086
1087 Rendezvous::ParsedKey parsed;
1088 // Get the outputs from the rendezvous
1089 for (size_t output_offset = 0; output_offset < output_names.size();
1090 ++output_offset) {
1091 const string& output_name = output_names[output_offset];
1092 auto it =
1093 executors_and_keys->output_name_to_rendezvous_key.find(output_name);
1094 if (it == executors_and_keys->output_name_to_rendezvous_key.end()) {
1095 return errors::Internal("'", output_name,
1096 "' is not a pre-defined fetch.");
1097 }
1098 const string& output_key = it->second;
1099 Tensor output_tensor;
1100 bool is_dead;
1101 IntraProcessRendezvous* rendez = run_state->rendez;
1102
1103 s = Rendezvous::ParseKey(output_key, &parsed);
1104 if (s.ok()) {
1105 // Fetch data from the Rendezvous.
1106 s = rendez->Recv(parsed, Rendezvous::Args(), &output_tensor, &is_dead,
1107 operation_timeout_in_ms_);
1108 if (is_dead && s.ok()) {
1109 s = errors::InvalidArgument("The tensor returned for ", output_name,
1110 " was not valid.");
1111 }
1112 }
1113 if (!s.ok()) {
1114 rendez->StartAbort(s);
1115 outputs->clear();
1116 return s;
1117 }
1118
1119 (*outputs)[output_offset] = output_tensor;
1120 }
1121 return Status::OK();
1122 }
1123
CheckFetch(const NamedTensorList & feeds,const std::vector<string> & fetches,const ExecutorsAndKeys * executors_and_keys,const RunState * run_state)1124 Status DirectSession::CheckFetch(const NamedTensorList& feeds,
1125 const std::vector<string>& fetches,
1126 const ExecutorsAndKeys* executors_and_keys,
1127 const RunState* run_state) {
1128 const Graph* graph = executors_and_keys->graph.get();
1129 const NameNodeMap* name_to_node = &executors_and_keys->name_to_node;
1130
1131 // Build the set of pending feeds that we haven't seen.
1132 std::unordered_set<TensorId, TensorId::Hasher> pending_feeds;
1133 {
1134 mutex_lock l(executor_lock_);
1135 for (const auto& input : run_state->pending_inputs) {
1136 // Skip if the feed has already been fed.
1137 if (input.second) continue;
1138 TensorId id(ParseTensorName(input.first));
1139 auto it = name_to_node->find(id.first);
1140 if (it == name_to_node->end()) {
1141 return errors::NotFound("Feed ", input.first, ": not found");
1142 }
1143 pending_feeds.insert(id);
1144 }
1145 }
1146 for (const auto& it : feeds) {
1147 TensorId id(ParseTensorName(it.first));
1148 pending_feeds.erase(id);
1149 }
1150
1151 // Initialize the stack with the fetch nodes.
1152 std::vector<const Node*> stack;
1153 for (const string& fetch : fetches) {
1154 TensorId id(ParseTensorName(fetch));
1155 auto it = name_to_node->find(id.first);
1156 if (it == name_to_node->end()) {
1157 return errors::NotFound("Fetch ", fetch, ": not found");
1158 }
1159 stack.push_back(it->second);
1160 }
1161
1162 // Any tensor needed for fetches can't be in pending_feeds.
1163 std::vector<bool> visited(graph->num_node_ids(), false);
1164 while (!stack.empty()) {
1165 const Node* n = stack.back();
1166 stack.pop_back();
1167
1168 for (const Edge* in_edge : n->in_edges()) {
1169 const Node* in_node = in_edge->src();
1170 if (pending_feeds.count({in_node->name(), in_edge->src_output()}) > 0) {
1171 return errors::InvalidArgument("Fetch ", in_node->name(), ":",
1172 in_edge->src_output(),
1173 " can't be computed from the feeds"
1174 " that have been fed so far.");
1175 }
1176 if (!visited[in_node->id()]) {
1177 visited[in_node->id()] = true;
1178 stack.push_back(in_node);
1179 }
1180 }
1181 }
1182 return Status::OK();
1183 }
1184
CreateExecutors(const CallableOptions & callable_options,std::unique_ptr<ExecutorsAndKeys> * out_executors_and_keys,std::unique_ptr<FunctionInfo> * out_func_info,RunStateArgs * run_state_args)1185 Status DirectSession::CreateExecutors(
1186 const CallableOptions& callable_options,
1187 std::unique_ptr<ExecutorsAndKeys>* out_executors_and_keys,
1188 std::unique_ptr<FunctionInfo>* out_func_info,
1189 RunStateArgs* run_state_args) {
1190 BuildGraphOptions options;
1191 options.callable_options = callable_options;
1192 options.use_function_convention = !run_state_args->is_partial_run;
1193 options.collective_graph_key =
1194 callable_options.run_options().experimental().collective_graph_key();
1195 if (options_.config.experimental()
1196 .collective_deterministic_sequential_execution()) {
1197 options.collective_order = GraphCollectiveOrder::kEdges;
1198 } else if (options_.config.experimental().collective_nccl()) {
1199 options.collective_order = GraphCollectiveOrder::kAttrs;
1200 }
1201
1202 std::unique_ptr<FunctionInfo> func_info(new FunctionInfo);
1203 std::unique_ptr<ExecutorsAndKeys> ek(new ExecutorsAndKeys);
1204
1205 ek->callable_options = callable_options;
1206
1207 std::unordered_map<string, std::unique_ptr<Graph>> graphs;
1208 TF_RETURN_IF_ERROR(CreateGraphs(
1209 options, &graphs, &func_info->flib_def, run_state_args, &ek->input_types,
1210 &ek->output_types, &ek->collective_graph_key));
1211
1212 if (run_state_args->is_partial_run) {
1213 ek->graph = std::move(run_state_args->graph);
1214 std::unordered_set<StringPiece, StringPieceHasher> names;
1215 for (const string& input : callable_options.feed()) {
1216 TensorId id(ParseTensorName(input));
1217 names.emplace(id.first);
1218 }
1219 for (const string& output : callable_options.fetch()) {
1220 TensorId id(ParseTensorName(output));
1221 names.emplace(id.first);
1222 }
1223 for (Node* n : ek->graph->nodes()) {
1224 if (names.count(n->name()) > 0) {
1225 ek->name_to_node.insert({n->name(), n});
1226 }
1227 }
1228 }
1229 ek->items.reserve(graphs.size());
1230 const auto& optimizer_opts =
1231 options_.config.graph_options().optimizer_options();
1232
1233 int graph_def_version;
1234 {
1235 mutex_lock l(graph_state_lock_);
1236 graph_def_version =
1237 execution_state_->original_graph_def().versions().producer();
1238 }
1239 func_info->proc_flr.reset(new ProcessFunctionLibraryRuntime(
1240 device_mgr_.get(), options_.env, graph_def_version,
1241 func_info->flib_def.get(), optimizer_opts, thread_pools_[0].first));
1242
1243 GraphOptimizer optimizer(optimizer_opts);
1244 for (auto iter = graphs.begin(); iter != graphs.end(); ++iter) {
1245 const string& partition_name = iter->first;
1246 std::unique_ptr<Graph>& partition_graph = iter->second;
1247
1248 Device* device;
1249 TF_RETURN_IF_ERROR(device_mgr_->LookupDevice(partition_name, &device));
1250
1251 ek->items.resize(ek->items.size() + 1);
1252 auto* item = &(ek->items.back());
1253 auto lib = func_info->proc_flr->GetFLR(partition_name);
1254 if (lib == nullptr) {
1255 return errors::Internal("Could not find device: ", partition_name);
1256 }
1257 item->flib = lib;
1258
1259 LocalExecutorParams params;
1260 params.device = device;
1261 params.function_library = lib;
1262 auto opseg = device->op_segment();
1263 params.create_kernel = [this, lib, opseg](const NodeDef& ndef,
1264 OpKernel** kernel) {
1265 // NOTE(mrry): We must not share function kernels (implemented
1266 // using `CallOp`) between subgraphs, because `CallOp::handle_`
1267 // is tied to a particular subgraph. Even if the function itself
1268 // is stateful, the `CallOp` that invokes it is not.
1269 if (!OpSegment::ShouldOwnKernel(lib, ndef.op())) {
1270 return lib->CreateKernel(ndef, kernel);
1271 }
1272 auto create_fn = [lib, &ndef](OpKernel** kernel) {
1273 return lib->CreateKernel(ndef, kernel);
1274 };
1275 // Kernels created for subgraph nodes need to be cached. On
1276 // cache miss, create_fn() is invoked to create a kernel based
1277 // on the function library here + global op registry.
1278 return opseg->FindOrCreate(session_handle_, ndef.name(), kernel,
1279 create_fn);
1280 };
1281 params.delete_kernel = [lib](OpKernel* kernel) {
1282 if (kernel && !OpSegment::ShouldOwnKernel(lib, kernel->type_string()))
1283 delete kernel;
1284 };
1285
1286 optimizer.Optimize(lib, options_.env, device, &partition_graph,
1287 /*shape_map=*/nullptr);
1288
1289 // TensorFlow Debugger (tfdbg) inserts debug nodes in the graph.
1290 const DebugOptions& debug_options =
1291 options.callable_options.run_options().debug_options();
1292 if (!debug_options.debug_tensor_watch_opts().empty()) {
1293 TF_RETURN_IF_ERROR(DecorateAndPublishGraphForDebug(
1294 debug_options, partition_graph.get(), params.device));
1295 }
1296
1297 TF_RETURN_IF_ERROR(EnsureMemoryTypes(DeviceType(device->device_type()),
1298 device->name(),
1299 partition_graph.get()));
1300 // NewLocalExecutor takes ownership of partition_graph.
1301 item->graph = partition_graph.get();
1302 item->executor = nullptr;
1303 item->device = device;
1304 auto executor_type = options_.config.experimental().executor_type();
1305 TF_RETURN_IF_ERROR(NewExecutor(
1306 executor_type, params, std::move(partition_graph), &item->executor));
1307 }
1308
1309 // Cache the mapping from input/output names to graph elements to
1310 // avoid recomputing it every time.
1311 if (!run_state_args->is_partial_run) {
1312 // For regular `Run()`, we use the function calling convention, and so
1313 // maintain a mapping from input/output names to
1314 // argument/return-value ordinal index.
1315 for (int i = 0; i < callable_options.feed().size(); ++i) {
1316 const string& input = callable_options.feed(i);
1317 ek->input_name_to_index[input] = i;
1318 }
1319 for (int i = 0; i < callable_options.fetch().size(); ++i) {
1320 const string& output = callable_options.fetch(i);
1321 ek->output_name_to_index[output] = i;
1322 }
1323 } else {
1324 // For `PRun()`, we use the rendezvous calling convention, and so
1325 // maintain a mapping from input/output names to rendezvous keys.
1326 //
1327 // We always use the first device as the device name portion of the
1328 // key, even if we're feeding another graph.
1329 for (int i = 0; i < callable_options.feed().size(); ++i) {
1330 const string& input = callable_options.feed(i);
1331 ek->input_name_to_rendezvous_key[input] = GetRendezvousKey(
1332 input, device_set_.client_device()->attributes(), FrameAndIter(0, 0));
1333 }
1334 for (int i = 0; i < callable_options.fetch().size(); ++i) {
1335 const string& output = callable_options.fetch(i);
1336 ek->output_name_to_rendezvous_key[output] =
1337 GetRendezvousKey(output, device_set_.client_device()->attributes(),
1338 FrameAndIter(0, 0));
1339 }
1340 }
1341
1342 *out_executors_and_keys = std::move(ek);
1343 *out_func_info = std::move(func_info);
1344 return Status::OK();
1345 }
1346
GetOrCreateExecutors(gtl::ArraySlice<string> inputs,gtl::ArraySlice<string> outputs,gtl::ArraySlice<string> target_nodes,ExecutorsAndKeys ** executors_and_keys,RunStateArgs * run_state_args)1347 Status DirectSession::GetOrCreateExecutors(
1348 gtl::ArraySlice<string> inputs, gtl::ArraySlice<string> outputs,
1349 gtl::ArraySlice<string> target_nodes, ExecutorsAndKeys** executors_and_keys,
1350 RunStateArgs* run_state_args) {
1351 int64 handle_name_counter_value = -1;
1352 if (LogMemory::IsEnabled() || run_state_args->is_partial_run) {
1353 handle_name_counter_value = handle_name_counter_.fetch_add(1);
1354 }
1355
1356 string debug_tensor_watches_summary;
1357 if (!run_state_args->debug_options.debug_tensor_watch_opts().empty()) {
1358 debug_tensor_watches_summary = SummarizeDebugTensorWatches(
1359 run_state_args->debug_options.debug_tensor_watch_opts());
1360 }
1361
1362 // Fast lookup path, no sorting.
1363 const string key = strings::StrCat(
1364 str_util::Join(inputs, ","), "->", str_util::Join(outputs, ","), "/",
1365 str_util::Join(target_nodes, ","), "/", run_state_args->is_partial_run,
1366 "/", debug_tensor_watches_summary);
1367 // Set the handle, if it's needed to log memory or for partial run.
1368 if (handle_name_counter_value >= 0) {
1369 run_state_args->handle =
1370 strings::StrCat(key, ";", handle_name_counter_value);
1371 }
1372
1373 // See if we already have the executors for this run.
1374 {
1375 mutex_lock l(executor_lock_); // could use reader lock
1376 auto it = executors_.find(key);
1377 if (it != executors_.end()) {
1378 *executors_and_keys = it->second.get();
1379 return Status::OK();
1380 }
1381 }
1382
1383 // Slow lookup path, the unsorted key missed the cache.
1384 // Sort the inputs and outputs, and look up with the sorted key in case an
1385 // earlier call used a different order of inputs and outputs.
1386 //
1387 // We could consider some other signature instead of sorting that
1388 // preserves the same property to avoid the sort in the future.
1389 std::vector<string> inputs_sorted(inputs.begin(), inputs.end());
1390 std::sort(inputs_sorted.begin(), inputs_sorted.end());
1391 std::vector<string> outputs_sorted(outputs.begin(), outputs.end());
1392 std::sort(outputs_sorted.begin(), outputs_sorted.end());
1393 std::vector<string> tn_sorted(target_nodes.begin(), target_nodes.end());
1394 std::sort(tn_sorted.begin(), tn_sorted.end());
1395
1396 const string sorted_key = strings::StrCat(
1397 str_util::Join(inputs_sorted, ","), "->",
1398 str_util::Join(outputs_sorted, ","), "/", str_util::Join(tn_sorted, ","),
1399 "/", run_state_args->is_partial_run, "/", debug_tensor_watches_summary);
1400 // Set the handle, if its needed to log memory or for partial run.
1401 if (handle_name_counter_value >= 0) {
1402 run_state_args->handle =
1403 strings::StrCat(sorted_key, ";", handle_name_counter_value);
1404 }
1405
1406 // See if we already have the executors for this run.
1407 {
1408 mutex_lock l(executor_lock_);
1409 auto it = executors_.find(sorted_key);
1410 if (it != executors_.end()) {
1411 *executors_and_keys = it->second.get();
1412 // Insert this under the original key.
1413 executors_.emplace(key, it->second);
1414 return Status::OK();
1415 }
1416 }
1417
1418 // Nothing found, so create the executors and store in the cache.
1419 // The executor_lock_ is intentionally released while executors are
1420 // being created.
1421 CallableOptions callable_options;
1422 for (const string& input : inputs_sorted) {
1423 callable_options.add_feed(input);
1424 }
1425 for (const string& output : outputs_sorted) {
1426 callable_options.add_fetch(output);
1427 }
1428 for (const string& target : tn_sorted) {
1429 callable_options.add_target(target);
1430 }
1431 *callable_options.mutable_run_options()->mutable_debug_options() =
1432 run_state_args->debug_options;
1433 callable_options.mutable_run_options()
1434 ->mutable_experimental()
1435 ->set_collective_graph_key(run_state_args->collective_graph_key);
1436 std::unique_ptr<ExecutorsAndKeys> ek;
1437 std::unique_ptr<FunctionInfo> func_info;
1438 TF_RETURN_IF_ERROR(
1439 CreateExecutors(callable_options, &ek, &func_info, run_state_args));
1440
1441 // Reacquire the lock, try to insert into the map.
1442 mutex_lock l(executor_lock_);
1443 functions_.push_back(std::move(func_info));
1444
1445 // Another thread may have created the entry before us, in which case we will
1446 // reuse the already created one.
1447 auto insert_result = executors_.emplace(
1448 sorted_key, std::shared_ptr<ExecutorsAndKeys>(std::move(ek)));
1449 // Insert the value under the original key, so the fast path lookup will work
1450 // if the user uses the same order of inputs, outputs, and targets again.
1451 executors_.emplace(key, insert_result.first->second);
1452 *executors_and_keys = insert_result.first->second.get();
1453
1454 return Status::OK();
1455 }
1456
CreateGraphs(const BuildGraphOptions & subgraph_options,std::unordered_map<string,std::unique_ptr<Graph>> * outputs,std::unique_ptr<FunctionLibraryDefinition> * flib_def,RunStateArgs * run_state_args,DataTypeVector * input_types,DataTypeVector * output_types,int64 * collective_graph_key)1457 Status DirectSession::CreateGraphs(
1458 const BuildGraphOptions& subgraph_options,
1459 std::unordered_map<string, std::unique_ptr<Graph>>* outputs,
1460 std::unique_ptr<FunctionLibraryDefinition>* flib_def,
1461 RunStateArgs* run_state_args, DataTypeVector* input_types,
1462 DataTypeVector* output_types, int64* collective_graph_key) {
1463 mutex_lock l(graph_state_lock_);
1464 std::unique_ptr<ClientGraph> client_graph;
1465
1466 std::unique_ptr<GraphExecutionState> temp_exec_state_holder;
1467 GraphExecutionState* execution_state = nullptr;
1468 if (options_.config.graph_options().place_pruned_graph()) {
1469 // Because we are placing pruned graphs, we need to create a
1470 // new GraphExecutionState for every new unseen graph,
1471 // and then place it.
1472 GraphExecutionStateOptions prune_options;
1473 prune_options.device_set = &device_set_;
1474 prune_options.session_options = &options_;
1475 prune_options.stateful_placements = stateful_placements_;
1476 prune_options.session_handle = session_handle_;
1477 TF_RETURN_IF_ERROR(GraphExecutionState::MakeForPrunedGraph(
1478 execution_state_->original_graph_def().library(), prune_options,
1479 execution_state_->original_graph_def(), subgraph_options,
1480 &temp_exec_state_holder, &client_graph));
1481 execution_state = temp_exec_state_holder.get();
1482 } else {
1483 execution_state = execution_state_.get();
1484 TF_RETURN_IF_ERROR(
1485 execution_state->BuildGraph(subgraph_options, &client_graph));
1486 }
1487 *collective_graph_key = client_graph->collective_graph_key;
1488
1489 if (subgraph_options.callable_options.feed_size() !=
1490 client_graph->feed_types.size()) {
1491 return errors::Internal(
1492 "Graph pruning failed: requested number of feed endpoints = ",
1493 subgraph_options.callable_options.feed_size(),
1494 " versus number of pruned feed endpoints = ",
1495 client_graph->feed_types.size());
1496 }
1497 if (subgraph_options.callable_options.fetch_size() !=
1498 client_graph->fetch_types.size()) {
1499 return errors::Internal(
1500 "Graph pruning failed: requested number of fetch endpoints = ",
1501 subgraph_options.callable_options.fetch_size(),
1502 " versus number of pruned fetch endpoints = ",
1503 client_graph->fetch_types.size());
1504 }
1505
1506 auto current_stateful_placements = execution_state->GetStatefulPlacements();
1507 // Update our current state based on the execution_state's
1508 // placements. If there are any mismatches for a node,
1509 // we should fail, as this should never happen.
1510 for (auto placement_pair : current_stateful_placements) {
1511 const string& node_name = placement_pair.first;
1512 const string& placement = placement_pair.second;
1513 auto iter = stateful_placements_.find(node_name);
1514 if (iter == stateful_placements_.end()) {
1515 stateful_placements_.insert(std::make_pair(node_name, placement));
1516 } else if (iter->second != placement) {
1517 return errors::Internal(
1518 "Stateful placement mismatch. "
1519 "Current assignment of ",
1520 node_name, " to ", iter->second, " does not match ", placement);
1521 }
1522 }
1523
1524 stateful_placements_ = execution_state->GetStatefulPlacements();
1525
1526 // Remember the graph in run state if this is a partial run.
1527 if (run_state_args->is_partial_run) {
1528 run_state_args->graph.reset(new Graph(flib_def_.get()));
1529 CopyGraph(*execution_state->full_graph(), run_state_args->graph.get());
1530 }
1531
1532 // Partition the graph across devices.
1533 PartitionOptions popts;
1534 popts.node_to_loc = [](const Node* node) {
1535 return node->assigned_device_name();
1536 };
1537 popts.new_name = [this](const string& prefix) {
1538 return strings::StrCat(prefix, "/_", edge_name_counter_.fetch_add(1));
1539 };
1540 popts.get_incarnation = [](const string& name) {
1541 // The direct session does not have changing incarnation numbers.
1542 // Just return '1'.
1543 return 1;
1544 };
1545 popts.flib_def = &client_graph->graph.flib_def();
1546 popts.control_flow_added = false;
1547
1548 std::unordered_map<string, GraphDef> partitions;
1549 TF_RETURN_IF_ERROR(Partition(popts, &client_graph->graph, &partitions));
1550
1551 std::vector<string> device_names;
1552 for (auto device : devices_) {
1553 // Extract the LocalName from the device.
1554 device_names.push_back(DeviceNameUtils::LocalName(device->name()));
1555 }
1556
1557 // Check for valid partitions.
1558 for (const auto& partition : partitions) {
1559 const string local_partition_name =
1560 DeviceNameUtils::LocalName(partition.first);
1561 if (std::count(device_names.begin(), device_names.end(),
1562 local_partition_name) == 0) {
1563 return errors::InvalidArgument(
1564 "Creating a partition for ", local_partition_name,
1565 " which doesn't exist in the list of available devices. Available "
1566 "devices: ",
1567 str_util::Join(device_names, ","));
1568 }
1569 }
1570
1571 for (const auto& partition : partitions) {
1572 std::unique_ptr<Graph> device_graph(
1573 new Graph(client_graph->flib_def.get()));
1574 GraphConstructorOptions device_opts;
1575 // There are internal operations (e.g., send/recv) that we now allow.
1576 device_opts.allow_internal_ops = true;
1577 device_opts.expect_device_spec = true;
1578 TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(device_opts, partition.second,
1579 device_graph.get()));
1580 outputs->emplace(partition.first, std::move(device_graph));
1581 }
1582
1583 GraphOptimizationPassOptions optimization_options;
1584 optimization_options.session_options = &options_;
1585 optimization_options.flib_def = client_graph->flib_def.get();
1586 optimization_options.partition_graphs = outputs;
1587 TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
1588 OptimizationPassRegistry::POST_PARTITIONING, optimization_options));
1589
1590 Status s;
1591 for (auto& partition : *outputs) {
1592 const string& partition_name = partition.first;
1593 std::unique_ptr<Graph>* graph = &partition.second;
1594
1595 VLOG(2) << "Created " << DebugString(graph->get()) << " for "
1596 << partition_name;
1597
1598 // Give the device an opportunity to rewrite its subgraph.
1599 Device* d;
1600 s = device_mgr_->LookupDevice(partition_name, &d);
1601 if (!s.ok()) break;
1602 s = d->MaybeRewriteGraph(graph);
1603 if (!s.ok()) {
1604 break;
1605 }
1606 }
1607 *flib_def = std::move(client_graph->flib_def);
1608 std::swap(*input_types, client_graph->feed_types);
1609 std::swap(*output_types, client_graph->fetch_types);
1610 return s;
1611 }
1612
ListDevices(std::vector<DeviceAttributes> * response)1613 ::tensorflow::Status DirectSession::ListDevices(
1614 std::vector<DeviceAttributes>* response) {
1615 response->clear();
1616 response->reserve(devices_.size());
1617 for (Device* d : devices_) {
1618 const DeviceAttributes& attrs = d->attributes();
1619 response->emplace_back(attrs);
1620 }
1621 return ::tensorflow::Status::OK();
1622 }
1623
Reset(const std::vector<string> & containers)1624 ::tensorflow::Status DirectSession::Reset(
1625 const std::vector<string>& containers) {
1626 device_mgr_->ClearContainers(containers);
1627 return ::tensorflow::Status::OK();
1628 }
1629
Close()1630 ::tensorflow::Status DirectSession::Close() {
1631 cancellation_manager_->StartCancel();
1632 {
1633 mutex_lock l(closed_lock_);
1634 if (closed_) return ::tensorflow::Status::OK();
1635 closed_ = true;
1636 }
1637 if (factory_ != nullptr) factory_->Deregister(this);
1638 return ::tensorflow::Status::OK();
1639 }
1640
RunState(const std::vector<string> & pending_input_names,const std::vector<string> & pending_output_names,int64 step_id,const std::vector<Device * > * devices)1641 DirectSession::RunState::RunState(
1642 const std::vector<string>& pending_input_names,
1643 const std::vector<string>& pending_output_names, int64 step_id,
1644 const std::vector<Device*>* devices)
1645 : step_container(step_id, [devices, step_id](const string& name) {
1646 for (auto d : *devices) {
1647 if (!d->resource_manager()->Cleanup(name).ok()) {
1648 // Do nothing...
1649 }
1650 ScopedAllocatorMgr* sam = d->GetScopedAllocatorMgr();
1651 if (sam) sam->Cleanup(step_id);
1652 }
1653 }) {
1654 // Initially all the feeds and fetches are pending.
1655 for (auto& name : pending_input_names) {
1656 pending_inputs[name] = false;
1657 }
1658 for (auto& name : pending_output_names) {
1659 pending_outputs[name] = false;
1660 }
1661 }
1662
RunState(int64 step_id,const std::vector<Device * > * devices)1663 DirectSession::RunState::RunState(int64 step_id,
1664 const std::vector<Device*>* devices)
1665 : RunState({}, {}, step_id, devices) {}
1666
~RunState()1667 DirectSession::RunState::~RunState() {
1668 if (rendez != nullptr) {
1669 if (!executors_done.HasBeenNotified()) {
1670 rendez->StartAbort(errors::Cancelled("PRun cancellation"));
1671 executors_done.WaitForNotification();
1672 }
1673 rendez->Unref();
1674 }
1675 }
1676
PendingDone() const1677 bool DirectSession::RunState::PendingDone() const {
1678 for (const auto& it : pending_inputs) {
1679 if (!it.second) return false;
1680 }
1681 for (const auto& it : pending_outputs) {
1682 if (!it.second) return false;
1683 }
1684 return true;
1685 }
1686
WaitForNotification(RunState * run_state,CancellationManager * cm,int64 timeout_in_ms)1687 void DirectSession::WaitForNotification(RunState* run_state,
1688 CancellationManager* cm,
1689 int64 timeout_in_ms) {
1690 const Status status =
1691 WaitForNotification(&run_state->executors_done, timeout_in_ms);
1692 if (!status.ok()) {
1693 {
1694 mutex_lock l(run_state->mu_);
1695 run_state->status.Update(status);
1696 }
1697 cm->StartCancel();
1698 // We must wait for the executors to complete, because they have borrowed
1699 // references to `cm` and other per-step state. After this notification, it
1700 // is safe to clean up the step.
1701 run_state->executors_done.WaitForNotification();
1702 }
1703 }
1704
WaitForNotification(Notification * notification,int64 timeout_in_ms)1705 ::tensorflow::Status DirectSession::WaitForNotification(
1706 Notification* notification, int64 timeout_in_ms) {
1707 if (timeout_in_ms > 0) {
1708 const int64 timeout_in_us = timeout_in_ms * 1000;
1709 const bool notified =
1710 WaitForNotificationWithTimeout(notification, timeout_in_us);
1711 if (!notified) {
1712 return Status(error::DEADLINE_EXCEEDED,
1713 "Timed out waiting for notification");
1714 }
1715 } else {
1716 notification->WaitForNotification();
1717 }
1718 return Status::OK();
1719 }
1720
MakeCallable(const CallableOptions & callable_options,CallableHandle * out_handle)1721 Status DirectSession::MakeCallable(const CallableOptions& callable_options,
1722 CallableHandle* out_handle) {
1723 TF_RETURN_IF_ERROR(CheckNotClosed());
1724 TF_RETURN_IF_ERROR(CheckGraphCreated("MakeCallable()"));
1725
1726 std::unique_ptr<ExecutorsAndKeys> ek;
1727 std::unique_ptr<FunctionInfo> func_info;
1728 RunStateArgs run_state_args(callable_options.run_options().debug_options());
1729 TF_RETURN_IF_ERROR(
1730 CreateExecutors(callable_options, &ek, &func_info, &run_state_args));
1731 {
1732 mutex_lock l(callables_lock_);
1733 *out_handle = next_callable_handle_++;
1734 callables_[*out_handle] = {std::move(ek), std::move(func_info)};
1735 }
1736 return Status::OK();
1737 }
1738
1739 class DirectSession::RunCallableCallFrame : public CallFrameInterface {
1740 public:
RunCallableCallFrame(DirectSession * session,ExecutorsAndKeys * executors_and_keys,const std::vector<Tensor> * feed_tensors,std::vector<Tensor> * fetch_tensors)1741 RunCallableCallFrame(DirectSession* session,
1742 ExecutorsAndKeys* executors_and_keys,
1743 const std::vector<Tensor>* feed_tensors,
1744 std::vector<Tensor>* fetch_tensors)
1745 : session_(session),
1746 executors_and_keys_(executors_and_keys),
1747 feed_tensors_(feed_tensors),
1748 fetch_tensors_(fetch_tensors) {}
1749
num_args() const1750 size_t num_args() const override {
1751 return executors_and_keys_->input_types.size();
1752 }
num_retvals() const1753 size_t num_retvals() const override {
1754 return executors_and_keys_->output_types.size();
1755 }
1756
GetArg(int index,Tensor * val) const1757 Status GetArg(int index, Tensor* val) const override {
1758 if (index > feed_tensors_->size()) {
1759 return errors::Internal("Args index out of bounds: ", index);
1760 } else if (executors_and_keys_->input_types[index] == DT_RESOURCE) {
1761 TF_RETURN_IF_ERROR(
1762 session_->ResourceHandleToInputTensor((*feed_tensors_)[index], val));
1763 } else {
1764 *val = (*feed_tensors_)[index];
1765 }
1766 return Status::OK();
1767 }
1768
SetRetval(int index,const Tensor & val)1769 Status SetRetval(int index, const Tensor& val) override {
1770 if (index > fetch_tensors_->size()) {
1771 return errors::Internal("RetVal index out of bounds: ", index);
1772 }
1773 (*fetch_tensors_)[index] = val;
1774 return Status::OK();
1775 }
1776
1777 private:
1778 DirectSession* const session_; // Not owned.
1779 ExecutorsAndKeys* const executors_and_keys_; // Not owned.
1780 const std::vector<Tensor>* const feed_tensors_; // Not owned.
1781 std::vector<Tensor>* const fetch_tensors_; // Not owned.
1782 };
1783
RunCallable(CallableHandle handle,const std::vector<Tensor> & feed_tensors,std::vector<Tensor> * fetch_tensors,RunMetadata * run_metadata)1784 ::tensorflow::Status DirectSession::RunCallable(
1785 CallableHandle handle, const std::vector<Tensor>& feed_tensors,
1786 std::vector<Tensor>* fetch_tensors, RunMetadata* run_metadata) {
1787 TF_RETURN_IF_ERROR(CheckNotClosed());
1788 TF_RETURN_IF_ERROR(CheckGraphCreated("RunCallable()"));
1789 direct_session_runs->GetCell()->IncrementBy(1);
1790
1791 // Check if we already have an executor for these arguments.
1792 std::shared_ptr<ExecutorsAndKeys> executors_and_keys;
1793 const int64 step_id = step_id_counter_.fetch_add(1);
1794
1795 {
1796 tf_shared_lock l(callables_lock_);
1797 if (handle >= next_callable_handle_) {
1798 return errors::InvalidArgument("No such callable handle: ", handle);
1799 }
1800 executors_and_keys = callables_[handle].executors_and_keys;
1801 }
1802
1803 if (!executors_and_keys) {
1804 return errors::InvalidArgument(
1805 "Attempted to run callable after handle was released: ", handle);
1806 }
1807
1808 // NOTE(mrry): Debug options are not currently supported in the
1809 // callable interface.
1810 DebugOptions debug_options;
1811 RunStateArgs run_state_args(debug_options);
1812
1813 // Configure a call frame for the step, which we use to feed and
1814 // fetch values to and from the executors.
1815 if (feed_tensors.size() != executors_and_keys->input_types.size()) {
1816 return errors::InvalidArgument(
1817 "Expected ", executors_and_keys->input_types.size(),
1818 " feed tensors, but got ", feed_tensors.size());
1819 }
1820 if (fetch_tensors != nullptr) {
1821 fetch_tensors->resize(executors_and_keys->output_types.size());
1822 } else if (!executors_and_keys->output_types.empty()) {
1823 return errors::InvalidArgument(
1824 "`fetch_tensors` must be provided when the callable has one or more "
1825 "outputs.");
1826 }
1827
1828 // A specialized CallFrame implementation that takes advantage of the
1829 // optimized RunCallable interface.
1830
1831 RunCallableCallFrame call_frame(this, executors_and_keys.get(), &feed_tensors,
1832 fetch_tensors);
1833
1834 if (LogMemory::IsEnabled()) {
1835 LogMemory::RecordStep(step_id, run_state_args.handle);
1836 }
1837
1838 TF_RETURN_IF_ERROR(
1839 RunInternal(step_id, executors_and_keys->callable_options.run_options(),
1840 &call_frame, executors_and_keys.get(), run_metadata));
1841
1842 return Status::OK();
1843 }
1844
ReleaseCallable(CallableHandle handle)1845 ::tensorflow::Status DirectSession::ReleaseCallable(CallableHandle handle) {
1846 mutex_lock l(callables_lock_);
1847 if (handle >= next_callable_handle_) {
1848 return errors::InvalidArgument("No such callable handle: ", handle);
1849 }
1850 callables_.erase(handle);
1851 return Status::OK();
1852 }
1853
~Callable()1854 DirectSession::Callable::~Callable() {
1855 // We must delete the fields in this order, because the destructor
1856 // of `executors_and_keys` will call into an object owned by
1857 // `function_info` (in particular, when deleting a kernel, it relies
1858 // on the `FunctionLibraryRuntime` to know if the kernel is stateful
1859 // or not).
1860 executors_and_keys.reset();
1861 function_info.reset();
1862 }
1863
1864 } // namespace tensorflow
1865