1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/common_runtime/executor.h"
17
18 #include <atomic>
19 #include <memory>
20 #include <vector>
21
22 #include "absl/memory/memory.h"
23 #include "tensorflow/core/common_runtime/costmodel_manager.h"
24 #include "tensorflow/core/common_runtime/entry.h"
25 #include "tensorflow/core/common_runtime/executor_factory.h"
26 #include "tensorflow/core/common_runtime/graph_view.h"
27 #include "tensorflow/core/common_runtime/immutable_executor_state.h"
28 #include "tensorflow/core/common_runtime/pending_counts.h"
29 #include "tensorflow/core/common_runtime/propagator_state.h"
30 #include "tensorflow/core/common_runtime/renamed_device.h"
31 #include "tensorflow/core/common_runtime/simple_propagator_state.h"
32 #include "tensorflow/core/common_runtime/step_stats_collector.h"
33 #include "tensorflow/core/framework/allocator.h"
34 #include "tensorflow/core/framework/cancellation.h"
35 #include "tensorflow/core/framework/collective.h"
36 #include "tensorflow/core/framework/control_flow.h"
37 #include "tensorflow/core/framework/device_attributes.pb.h"
38 #include "tensorflow/core/framework/log_memory.h"
39 #include "tensorflow/core/framework/metrics.h"
40 #include "tensorflow/core/framework/node_def_util.h"
41 #include "tensorflow/core/framework/op_kernel.h"
42 #include "tensorflow/core/framework/op_segment.h"
43 #include "tensorflow/core/framework/tensor.h"
44 #include "tensorflow/core/framework/tensor_reference.h"
45 #include "tensorflow/core/framework/types.h"
46 #include "tensorflow/core/framework/types.pb.h"
47 #include "tensorflow/core/graph/edgeset.h"
48 #include "tensorflow/core/graph/graph.h"
49 #include "tensorflow/core/graph/graph_node_util.h"
50 #include "tensorflow/core/lib/core/errors.h"
51 #include "tensorflow/core/lib/core/notification.h"
52 #include "tensorflow/core/lib/core/status.h"
53 #include "tensorflow/core/lib/core/threadpool.h"
54 #include "tensorflow/core/lib/gtl/flatmap.h"
55 #include "tensorflow/core/lib/gtl/inlined_vector.h"
56 #include "tensorflow/core/lib/gtl/manual_constructor.h"
57 #include "tensorflow/core/lib/hash/hash.h"
58 #include "tensorflow/core/platform/context.h"
59 #include "tensorflow/core/platform/env.h"
60 #include "tensorflow/core/platform/errors.h"
61 #include "tensorflow/core/platform/logging.h"
62 #include "tensorflow/core/platform/macros.h"
63 #include "tensorflow/core/platform/mutex.h"
64 #include "tensorflow/core/platform/profile_utils/cpu_utils.h"
65 #include "tensorflow/core/platform/thread_annotations.h"
66 #include "tensorflow/core/platform/tracing.h"
67 #include "tensorflow/core/platform/types.h"
68 #include "tensorflow/core/profiler/lib/annotated_traceme.h"
69 #include "tensorflow/core/profiler/lib/connected_traceme.h"
70 #include "tensorflow/core/profiler/lib/scoped_annotation.h"
71 #include "tensorflow/core/profiler/lib/traceme_encode.h"
72 #include "tensorflow/core/protobuf/error_codes.pb.h"
73 #include "tensorflow/core/util/tensor_slice_reader_cache.h"
74
75 namespace tensorflow {
76
77 namespace {
78
79 // 1-D, 0 element tensor.
80 static const Tensor* const kEmptyTensor = new Tensor;
81
82 // Helper routines for collecting step stats.
83 namespace nodestats {
NowInNsec()84 inline int64 NowInNsec() { return EnvTime::NowNanos(); }
85
SetScheduled(NodeExecStatsInterface * stats,int64_t micros)86 void SetScheduled(NodeExecStatsInterface* stats, int64_t micros) {
87 if (!stats) return;
88 stats->SetScheduled(micros * EnvTime::kMicrosToNanos);
89 }
90
SetAllStart(NodeExecStatsInterface * stats)91 void SetAllStart(NodeExecStatsInterface* stats) {
92 if (!stats) return;
93 stats->RecordExecutorStarted();
94 }
95
SetOpStart(NodeExecStatsInterface * stats)96 void SetOpStart(NodeExecStatsInterface* stats) {
97 if (!stats) return;
98 stats->RecordComputeStarted();
99 }
100
SetOpEnd(NodeExecStatsInterface * stats)101 void SetOpEnd(NodeExecStatsInterface* stats) {
102 if (!stats) return;
103 stats->RecordComputeEnded();
104 }
105
SetAllEnd(NodeExecStatsInterface * stats)106 void SetAllEnd(NodeExecStatsInterface* stats) {
107 if (!stats) return;
108 stats->RecordExecutorEnded();
109 }
110
SetOutput(NodeExecStatsInterface * stats,int slot,const Tensor * v)111 void SetOutput(NodeExecStatsInterface* stats, int slot, const Tensor* v) {
112 if (!stats) return;
113 stats->SetOutput(slot, v);
114 }
115
SetMemory(NodeExecStatsInterface * stats,OpKernelContext * ctx)116 void SetMemory(NodeExecStatsInterface* stats, OpKernelContext* ctx) {
117 if (!stats) return;
118 stats->SetMemory(ctx);
119 }
120
121 } // namespace nodestats
122
123 // Time the execution of kernels (in CPU cycles). Used to dynamically identify
124 // inexpensive kernels which can be dispatched inline.
125 struct KernelTimer {
126 uint64 start_cycles = profile_utils::CpuUtils::GetCurrentClockCycle();
127
ElapsedCyclestensorflow::__anon194149970111::KernelTimer128 uint64 ElapsedCycles() {
129 return profile_utils::CpuUtils::GetCurrentClockCycle() - start_cycles;
130 }
131 };
132
133 // TODO(b/152925936): Re-evaluate these constants with current usage patterns.
134 typedef gtl::InlinedVector<TensorValue, 4> TensorValueVec;
135 typedef gtl::InlinedVector<AllocatorAttributes, 4> AllocatorAttributeVec;
136
137 class ExecutorImpl : public Executor {
138 public:
ExecutorImpl(const LocalExecutorParams & p)139 explicit ExecutorImpl(const LocalExecutorParams& p) : immutable_state_(p) {}
140
Initialize(const Graph & graph)141 Status Initialize(const Graph& graph) {
142 TF_RETURN_IF_ERROR(immutable_state_.Initialize(graph));
143 kernel_stats_.Initialize(immutable_state_.graph_view());
144 return Status::OK();
145 }
146
147 void RunAsync(const Args& args, DoneCallback done) override;
148
149 private:
150 template <class PropagatorStateType>
151 friend class ExecutorState;
152
153 // Stores execution time information about the kernels in an executor's graph.
154 class KernelStats {
155 public:
156 KernelStats() = default;
157
Initialize(const GraphView & gview)158 void Initialize(const GraphView& gview) {
159 is_expensive_.resize(gview.num_nodes());
160 cost_estimates_ =
161 absl::make_unique<std::atomic_uint_fast64_t[]>(gview.num_nodes());
162 for (int32_t i = 0; i < gview.num_nodes(); ++i) {
163 if (gview.node(i)) {
164 is_expensive_[i] =
165 gview.node(i)->kernel && gview.node(i)->kernel->IsExpensive();
166 cost_estimates_[i] = kInitialCostEstimateCycles;
167 }
168 }
169 }
170
171 // Returns true iff the given node is considered "expensive". The
172 // executor uses this flag to optimize graph execution, for example
173 // by "inlining" inexpensive kernels.
IsExpensive(const NodeItem & node) const174 bool IsExpensive(const NodeItem& node) const {
175 return is_expensive_[node.node_id] &&
176 (cost_estimates_[node.node_id].load(std::memory_order_relaxed) >
177 kOpIsExpensiveThresholdCycles);
178 }
179
180 // Returns the value of kernel->IsExpensive().
HasExpensiveMarker(const NodeItem & node) const181 bool HasExpensiveMarker(const NodeItem& node) const {
182 return is_expensive_[node.node_id];
183 }
184
185 // Updates the dynamic cost estimate, which is used to determine whether the
186 // given node is expensive. The new cost estimate is a weighted average of
187 // the old cost estimate and the latest cost. We only update cost estimates
188 // for kernels for which IsExpensive() return true.
UpdateCostEstimate(const NodeItem & node,uint64 elapsed_cycles)189 void UpdateCostEstimate(const NodeItem& node, uint64 elapsed_cycles) {
190 // N.B. Updates to `cost_estimate` are atomic but unlocked. Simultaneous
191 // updates may result in one or more updates being ignored. This does not
192 // affect correctness but may slow down the update frequency.
193 std::atomic_uint_fast64_t& cost_estimate = cost_estimates_[node.node_id];
194 auto prev_estimate = cost_estimate.load(std::memory_order_relaxed);
195
196 uint64 new_estimate =
197 ((kCostDecay - 1) * prev_estimate + elapsed_cycles) / kCostDecay;
198
199 cost_estimate.store(new_estimate, std::memory_order_relaxed);
200 }
201
202 private:
203 // Initial time (in CPU cycles) we expect an operation to take. Used to
204 // determine whether an operation should be place in a threadpool.
205 // Operations start out "expensive".
206 static constexpr uint64 kInitialCostEstimateCycles = 100 * 1000 * 1000;
207 static constexpr uint64 kOpIsExpensiveThresholdCycles = 8000;
208 static constexpr uint64 kCostDecay = 10;
209
210 std::vector<bool> is_expensive_;
211 // std::unique_ptr<std::atomic<bool>[]> is_expensive_;
212 std::unique_ptr<std::atomic_uint_fast64_t[]> cost_estimates_;
213 };
214
215 ImmutableExecutorState immutable_state_;
216 KernelStats kernel_stats_;
217
218 TF_DISALLOW_COPY_AND_ASSIGN(ExecutorImpl);
219 };
220
221 // The state associated with one invocation of ExecutorImpl::Run.
222 //
223 // ExecutorState dispatches nodes when they become ready, and delegates to an
224 // instance of `PropagatorStateType` to keep track of how many predecessors of a
225 // are still pending.
226 //
227 // The template argument `class PropagatorStateType` must define the following
228 // public members:
229 // * A type `TaggedNode`, representing a node to be processed, with public
230 // members:
231 // * `const NodeItem& get_node_item() const`
232 // * `bool get_is_dead() const`
233 // * A type `TaggedNodeReadyQueue`, representing a queue of nodes to be
234 // processed, with public members (having the same meanings as in an
235 // `std::vector<TaggedNode>`):
236 // * `void push_back(const TaggedNode& node)`
237 // * `TaggedNode front() const`
238 // * `void pop_front()`
239 // * `bool empty() const`
240 // * A type `TaggedNodeSeq`, representing a list of nodes to be schedules, with
241 // public members (having the same meanings as in an
242 // `std::vector<TaggedNode>`):
243 // * `size_t size() const`
244 // * `bool empty() const`
245 // * `void clear()`
246 // * `const_iterator begin() const`
247 // * `const_iterator end() const`
248 // * A public constructor, `PropagatorStateType(const ImmutableExecutorState&
249 // immutable_state, int64 step_id)`.
250 // * The following public methods:
251 // * `void ActivateRoots(gtl::ArraySlice<const NodeItem*> roots,
252 // TaggedNodeSeq* ready)`, which creates `TaggedNode` instances for the
253 // nodes in `roots` and adds them to `*ready`
254 // * `void PropagateOutputs(const TaggedNode& tagged_node, EntryVector*
255 // outputs, TaggedNodeSeq* ready)`, which propagates `outputs` from the
256 // given `tagged_node` to the destinations of its output edges, and adds
257 // any newly runnable nodes to `*ready`
258 // * `Entry* GetInputTensors(const TaggedNode& tagged_node) const`, which
259 // returns a pointer to the input tensors for the given `tagged_node`
260 // * `FrameAndIter GetFrameAndIter(const TaggedNode& tagged_node) const`,
261 // which creates a `FrameAndIter` for the given `tagged_node`
262 // * `void DumpState()`, which dumps the dynamic state of the executing graph
263 // * `void MaybeMarkStarted(const TaggedNode& tagged_node)`, which records
264 // that a node has started
265 // * `void MaybeMarkCompleted(const TaggedNode& tagged_node)`, which records
266 // that a node has completed
267 //
268 // See `PropagatorState` in "./propagator_state.h" for an example of a type that
269 // can be used to instantiate `PropagatorStateType`.
270 template <class PropagatorStateType>
271 class ExecutorState {
272 public:
273 ExecutorState(const Executor::Args& args,
274 const ImmutableExecutorState& immutable_state_,
275 ExecutorImpl::KernelStats* kernel_stats_);
276 ~ExecutorState();
277
278 void RunAsync(Executor::DoneCallback done);
279
280 private:
281 // Use `TaggedNode` types defined by `PropagatorStateType`.
282 typedef typename PropagatorStateType::TaggedNode TaggedNode;
283 typedef
284 typename PropagatorStateType::TaggedNodeReadyQueue TaggedNodeReadyQueue;
285 typedef typename PropagatorStateType::TaggedNodeSeq TaggedNodeSeq;
286
287 struct AsyncState;
288
289 // Process a ready node in current thread.
290 void Process(TaggedNode node, int64_t scheduled_nsec);
291
292 Status ProcessSync(const NodeItem& item, OpKernelContext::Params* params,
293 EntryVector* outputs, NodeExecStatsInterface* stats);
294 void ProcessAsync(const NodeItem& item, const OpKernelContext::Params& params,
295 const TaggedNode& tagged_node, Entry* first_input,
296 NodeExecStatsInterface* stats);
297 void ProcessNoop(NodeExecStatsInterface* stats);
298 void ProcessConstTensor(const NodeItem& item, EntryVector* outputs,
299 NodeExecStatsInterface* stats);
300
301 // Before invoking item->kernel, fills in its "inputs".
302 Status PrepareInputs(const NodeItem& item, Entry* first_input,
303 TensorValueVec* inputs,
304 AllocatorAttributeVec* input_alloc_attrs,
305 bool* is_input_dead);
306
307 // After item->kernel computation is done, processes its outputs.
308 Status ProcessOutputs(const NodeItem& item, OpKernelContext* ctx,
309 Entry* outputs, NodeExecStatsInterface* stats);
310
311 // Called after each node finishes. Takes ownership of "stats". Returns true
312 // if execution has completed.
313 //
314 // This method will clear `*ready` before returning.
315 bool NodeDone(const Status& s, TaggedNodeSeq* ready,
316 NodeExecStatsInterface* stats,
317 TaggedNodeReadyQueue* inline_ready);
318
319 // Schedule all the expensive nodes in '*ready', and put all the inexpensive
320 // nodes in 'ready' into 'inline_ready'.
321 //
322 // This method will clear `*ready` before returning.
323 //
324 // REQUIRES: `!ready->empty()`.
325 void ScheduleReady(TaggedNodeSeq* ready, TaggedNodeReadyQueue* inline_ready);
326
327 // A wrapper for runner_ to keep track of the pending queue length. Op
328 // execution should dispatch work using this function instead of using runner_
329 // directly.
330 template <typename Closure>
331 void RunTask(Closure&& c);
332
333 // Clean up when this executor is done.
334 void Finish();
335 void ScheduleFinish();
336
337 // Contains the device context assigned by the device at the beginning of a
338 // step.
339 DeviceContext* device_context_ = nullptr;
340
341 const bool vlog_; // true if VLOG_IS_ON(1). Used to check vlog cheaply.
342
343 // true if LogMemory::IsEnabled(). Used to check memory enabled cheaply.
344 const bool log_memory_;
345
346 int64 step_id_;
347 int64 start_time_usecs_ = 0;
348
349 // Not owned.
350 RendezvousInterface* rendezvous_;
351 CollectiveExecutor* collective_executor_ = nullptr;
352 SessionState* session_state_;
353 string session_handle_;
354 const SessionMetadata* session_metadata_ = nullptr;
355 TensorStore* tensor_store_;
356 // Step-local container.
357 ScopedStepContainer* step_container_;
358 StepStatsCollectorInterface* const stats_collector_;
359 const tracing::EventCollector* const event_collector_;
360 Context context_;
361
362 // QUESTION: Make it a checkpoint::TensorSliceReaderCacheWrapper
363 // instead of a pointer? (avoids having to delete).
364 checkpoint::TensorSliceReaderCacheWrapper* slice_reader_cache_;
365 CallFrameInterface* call_frame_;
366 const ImmutableExecutorState& immutable_state_;
367 ExecutorImpl::KernelStats* const kernel_stats_;
368 CancellationManager* cancellation_manager_;
369 CoordinationServiceAgent* coordination_service_agent_;
370 // If not null, use this device to schedule intra-op operation
371 std::unique_ptr<DeviceBase> user_device_;
372 Executor::Args::Runner runner_;
373 bool sync_on_finish_;
374 const bool run_all_kernels_inline_;
375
376 PropagatorStateType propagator_;
377
378 // Invoked when the execution finishes.
379 Executor::DoneCallback done_cb_;
380
381 std::atomic_int_fast32_t num_outstanding_ops_;
382
383 // Available via OpKernelContext to every OpKernel invocation.
384 mutex num_deferred_ops_mu_;
385 int64 num_deferred_ops_ TF_GUARDED_BY(num_deferred_ops_mu_) = 0;
386 bool finish_when_deferred_ops_done_ TF_GUARDED_BY(num_deferred_ops_mu_) =
387 false;
388
389 mutex mu_;
390 Status status_ TF_GUARDED_BY(mu_);
391 };
392
393 template <class PropagatorStateType>
ExecutorState(const Executor::Args & args,const ImmutableExecutorState & immutable_state,ExecutorImpl::KernelStats * kernel_stats)394 ExecutorState<PropagatorStateType>::ExecutorState(
395 const Executor::Args& args, const ImmutableExecutorState& immutable_state,
396 ExecutorImpl::KernelStats* kernel_stats)
397 : vlog_(VLOG_IS_ON(1)),
398 log_memory_(LogMemory::IsEnabled()),
399 step_id_(args.step_id),
400 start_time_usecs_(args.start_time_usecs),
401 rendezvous_(args.rendezvous),
402 collective_executor_(args.collective_executor),
403 session_state_(args.session_state),
404 session_handle_(args.session_handle),
405 session_metadata_(immutable_state.params().session_metadata),
406 tensor_store_(args.tensor_store),
407 step_container_(args.step_container),
408 stats_collector_(args.stats_collector),
409 event_collector_(
410 tracing::GetEventCollector(tracing::EventCategory::kCompute)),
411 context_(ContextKind::kThread),
412 slice_reader_cache_(new checkpoint::TensorSliceReaderCacheWrapper),
413 call_frame_(args.call_frame),
414 immutable_state_(immutable_state),
415 kernel_stats_(kernel_stats),
416 cancellation_manager_(args.cancellation_manager),
417 coordination_service_agent_(args.coordination_service_agent),
418 runner_(args.runner),
419 sync_on_finish_(args.sync_on_finish),
420 run_all_kernels_inline_(args.run_all_kernels_inline),
421 propagator_(immutable_state, step_id_, vlog_),
422 num_outstanding_ops_(0) {
423 if (args.user_intra_op_threadpool != nullptr) {
424 Device* device = immutable_state_.params().device;
425 user_device_ = RenamedDevice::NewRenamedDevice(
426 device->name(), device, false, false, args.user_intra_op_threadpool);
427 }
428 }
429
430 template <class PropagatorStateType>
~ExecutorState()431 ExecutorState<PropagatorStateType>::~ExecutorState() {
432 if (device_context_) {
433 device_context_->Unref();
434 }
435 delete slice_reader_cache_;
436 }
437
438 template <class PropagatorStateType>
439 template <typename Closure>
RunTask(Closure && c)440 void ExecutorState<PropagatorStateType>::RunTask(Closure&& c) {
441 // Align the atomic variables at 64 bytes to avoid false-sharing, assuming the
442 // cacheline size is 64 bytes or smaller.
443 alignas(64) static std::atomic<int64_t> num_enqueue_ops{0};
444 alignas(64) static std::atomic<int64_t> num_dequeue_ops{0};
445
446 auto n_enqueues = num_enqueue_ops.fetch_add(1, std::memory_order_relaxed);
447 // Sample the queue length on every 16 enqueue operations. This amortizes the
448 // cost of metric updates across 16 operations.
449 if (n_enqueues % 16 == 0) {
450 auto n_dequeues = num_dequeue_ops.load(std::memory_order_relaxed);
451 metrics::UpdateGraphPendingQueueLength(n_enqueues - n_dequeues);
452 }
453
454 // mutable is needed because std::forward<Closure> in the lambda body may move
455 // the Closure `c`.
456 runner_([c = std::forward<Closure>(c)]() mutable {
457 num_dequeue_ops.fetch_add(1, std::memory_order_relaxed);
458 std::forward<Closure>(c)();
459 });
460 }
461
462 template <class PropagatorStateType>
RunAsync(Executor::DoneCallback done)463 void ExecutorState<PropagatorStateType>::RunAsync(Executor::DoneCallback done) {
464 TaggedNodeSeq ready;
465
466 // Ask the device to fill in the device context map.
467 Device* device = immutable_state_.params().device;
468 const Status get_context_status =
469 device->TryGetDeviceContext(&device_context_);
470 if (!get_context_status.ok()) {
471 delete this;
472 done(get_context_status);
473 return;
474 }
475
476 // Initialize the ready queue.
477 ready.reserve(immutable_state_.root_nodes().size());
478 propagator_.ActivateRoots(immutable_state_.root_nodes(), &ready);
479 num_outstanding_ops_ = ready.size();
480 if (ready.empty()) {
481 delete this;
482 done(Status::OK());
483 } else {
484 done_cb_ = std::move(done);
485 // Schedule to run all the ready ops in thread pool.
486 ScheduleReady(&ready, nullptr);
487 }
488 }
489
490 // State kept alive for executing an asynchronous node in another
491 // thread. NOTE: We need to make a copy of p.input and p.input_alloc_attrs for
492 // asynchronous kernels because OpKernelContext methods like input_type(i) needs
493 // the param points to valid input type vector. It's not an issue for
494 // sync kernels because these vectors are kept on the stack.
495 template <class PropagatorStateType>
496 struct ExecutorState<PropagatorStateType>::AsyncState {
AsyncStatetensorflow::__anon194149970111::ExecutorState::AsyncState497 AsyncState(const OpKernelContext::Params& p, const TaggedNode& _tagged_node,
498 const NodeItem* _item, Entry* _first_input,
499 NodeExecStatsInterface* _stats)
500 : saved_inputs(*p.inputs),
501 saved_input_alloc_attrs(*p.input_alloc_attrs),
502 params(p),
503 tagged_node(_tagged_node),
504 item(_item),
505 first_input(_first_input),
506 // ParamsButClearingEigenGPUDevice does equivalent of
507 // params.eigen_gpu_device = nullptr;
508 ctx(ParamsButClearingEigenGPUDevice(¶ms), item->num_outputs),
509 stats(_stats) {
510 params.inputs = &saved_inputs;
511 params.input_alloc_attrs = &saved_input_alloc_attrs;
512 }
513
514 TensorValueVec saved_inputs;
515 AllocatorAttributeVec saved_input_alloc_attrs;
516 OpKernelContext::Params params;
517 TaggedNode tagged_node;
518 const NodeItem* item;
519 Entry* first_input;
520 OpKernelContext ctx;
521 NodeExecStatsInterface* stats;
522
523 private:
ParamsButClearingEigenGPUDevicetensorflow::__anon194149970111::ExecutorState::AsyncState524 OpKernelContext::Params* ParamsButClearingEigenGPUDevice(
525 OpKernelContext::Params* p) {
526 // Ensure OpKernelContext constructor will make a new eigen GPU device if
527 // necessary.
528 p->eigen_gpu_device = nullptr; // Force allocation
529 return p;
530 }
531 };
532
533 // Returns true if `item` might be traced by the given trace and event
534 // collectors. Returns false only if `item` definitely will not be traced.
MightTrace(const tracing::EventCollector * event_collector,bool is_expensive)535 bool MightTrace(const tracing::EventCollector* event_collector,
536 bool is_expensive) {
537 // Tracing will only be enabled if either `event_collector` is non null,
538 // or `trace_collector` is non-null and enabled for this particular kernel.
539 // Although `profiler::TraceMe`, `profiler::ScopedAnnotation`, and
540 // `tracing::ScopedRegion` check subsets of these properties internally in
541 // their constructors, the cost of passing the necessary arguments to them can
542 // be significant, so we avoid constructing them in the common case (when we
543 // know they will not be used).
544 if (event_collector != nullptr) {
545 return true;
546 }
547
548 if (profiler::ScopedAnnotation::IsEnabled()) return true;
549
550 return profiler::TraceMe::Active(profiler::GetTFTraceMeLevel(is_expensive));
551 }
552
553 template <class PropagatorStateType>
ProcessSync(const NodeItem & item,OpKernelContext::Params * params,EntryVector * outputs,NodeExecStatsInterface * stats)554 Status ExecutorState<PropagatorStateType>::ProcessSync(
555 const NodeItem& item, OpKernelContext::Params* params, EntryVector* outputs,
556 NodeExecStatsInterface* stats) {
557 Status s;
558 OpKernelContext ctx(params, item.num_outputs);
559 nodestats::SetOpStart(stats);
560
561 OpKernel* op_kernel = item.kernel;
562 Device* device = immutable_state_.params().device;
563 const bool is_expensive = kernel_stats_->IsExpensive(item);
564
565 if (TF_PREDICT_FALSE(MightTrace(event_collector_, is_expensive))) {
566 tracing::ScopedRegion region(tracing::EventCategory::kCompute,
567 op_kernel->name_view());
568 profiler::AnnotatedTraceMe activity(
569 [op_kernel, &ctx] {
570 return op_kernel->TraceString(
571 ctx, /*verbose=*/profiler::TfOpDetailsEnabled());
572 },
573 profiler::GetTFTraceMeLevel(is_expensive));
574 device->Compute(op_kernel, &ctx);
575 } else if (kernel_stats_->HasExpensiveMarker(item)) {
576 KernelTimer timer;
577 device->Compute(op_kernel, &ctx);
578 // For expensive kernels, always update the cost estimate. For inexpensive
579 // kernels, update the cost estimate with ~1/16 probability. This assumes
580 // that the last 4 bits of the CPU cycle count is uniformly distributed.
581 constexpr int kKernelExecutionTrackingInvocationSkipCount = 16;
582 if (is_expensive ||
583 timer.start_cycles % kKernelExecutionTrackingInvocationSkipCount == 0) {
584 kernel_stats_->UpdateCostEstimate(item, timer.ElapsedCycles());
585 }
586 } else {
587 device->Compute(op_kernel, &ctx);
588 }
589 nodestats::SetOpEnd(stats);
590 if (outputs->size() < item.num_outputs) outputs->resize(item.num_outputs);
591 s = ProcessOutputs(item, &ctx, outputs->data(), stats);
592 nodestats::SetMemory(stats, &ctx);
593 return s;
594 }
595
596 template <class PropagatorStateType>
ProcessAsync(const NodeItem & item,const OpKernelContext::Params & params,const TaggedNode & tagged_node,Entry * first_input,NodeExecStatsInterface * stats)597 void ExecutorState<PropagatorStateType>::ProcessAsync(
598 const NodeItem& item, const OpKernelContext::Params& params,
599 const TaggedNode& tagged_node, Entry* first_input,
600 NodeExecStatsInterface* stats) {
601 AsyncOpKernel* async_kernel = item.kernel->AsAsync();
602 DCHECK(async_kernel != nullptr);
603 AsyncState* state =
604 new AsyncState(params, tagged_node, &item, first_input, stats);
605
606 auto done = [this, state]() {
607 Device* device = immutable_state_.params().device;
608 NodeExecStatsInterface* stats = state->stats; // Shorthand
609 Entry* first_input = state->first_input; // Shorthand
610
611 nodestats::SetOpEnd(stats);
612 EntryVector outputs(state->item->num_outputs);
613 Status s = ProcessOutputs(*state->item, &state->ctx, outputs.data(), stats);
614 nodestats::SetMemory(stats, &state->ctx);
615 if (vlog_) {
616 VLOG(2) << "Async kernel done: " << state->item->node_id << " step "
617 << step_id_ << " " << SummarizeNodeDef(state->item->kernel->def())
618 << (state->tagged_node.get_is_dead() ? " is dead" : "")
619 << " device: " << device->name();
620 }
621
622 // Clears inputs.
623 const int num_inputs = state->item->num_inputs;
624 for (int i = 0; i < num_inputs; ++i) {
625 (first_input + i)->ClearVal();
626 }
627 propagator_.MaybeMarkCompleted(state->tagged_node);
628 TaggedNodeSeq ready;
629 if (s.ok()) {
630 propagator_.PropagateOutputs(state->tagged_node, &outputs, &ready);
631 }
632 outputs.clear();
633 const bool completed = NodeDone(s, &ready, stats, nullptr);
634 delete state;
635 if (completed) ScheduleFinish();
636 };
637 nodestats::SetOpStart(stats);
638 {
639 profiler::AnnotatedTraceMe activity(
640 [async_kernel, state] {
641 return async_kernel->TraceString(
642 state->ctx, /*verbose=*/profiler::TfOpDetailsEnabled());
643 },
644 profiler::GetTFTraceMeLevel(kernel_stats_->IsExpensive(item)));
645 immutable_state_.params().device->ComputeAsync(async_kernel, &state->ctx,
646 std::move(done));
647 }
648 }
649
650 template <class PropagatorStateType>
ProcessNoop(NodeExecStatsInterface * stats)651 void ExecutorState<PropagatorStateType>::ProcessNoop(
652 NodeExecStatsInterface* stats) {
653 nodestats::SetOpStart(stats);
654 nodestats::SetOpEnd(stats);
655 }
656
657 template <class PropagatorStateType>
ProcessConstTensor(const NodeItem & item,EntryVector * outputs,NodeExecStatsInterface * stats)658 void ExecutorState<PropagatorStateType>::ProcessConstTensor(
659 const NodeItem& item, EntryVector* outputs, NodeExecStatsInterface* stats) {
660 nodestats::SetOpStart(stats);
661 nodestats::SetOpEnd(stats);
662 Entry& output = (*outputs)[0];
663 output.state = Entry::State::HAS_CONST_TENSOR;
664 output.const_tensor = item.const_tensor;
665 output.alloc_attr = item.output_attrs()[0];
666 }
667
668 template <class PropagatorStateType>
Process(TaggedNode tagged_node,int64_t scheduled_nsec)669 void ExecutorState<PropagatorStateType>::Process(TaggedNode tagged_node,
670 int64_t scheduled_nsec) {
671 profiler::TraceMeConsumer activity(
672 // From TraceMeProducer in DirectSession::RunInternal,
673 // GraphMgr::ExecuteAsync, or FunctionLibraryRuntime::Run.
674 [&] {
675 // NOTE: This tracing uses the iteration number from the first tagged
676 // node that executes during this call to `Process()`. In principle,
677 // subsequent nodes could have different values of `iter_num` that
678 // will not be traced.
679 return profiler::TraceMeEncode(
680 "ExecutorState::Process",
681 {{"id", step_id_}, {"iter_num", tagged_node.get_iter_num()}});
682 },
683 profiler::ContextType::kTfExecutor, step_id_,
684 profiler::TraceMeLevel::kInfo);
685 WithContext wc(context_);
686 TaggedNodeSeq ready;
687 TaggedNodeReadyQueue inline_ready;
688
689 // Parameters passed to OpKernel::Compute.
690 TensorValueVec inputs;
691 AllocatorAttributeVec input_alloc_attrs;
692
693 OpKernelContext::Params params;
694 params.step_id = step_id_;
695 // Override device's threadpool if user provides an intra_op_threadpool
696 Device* device = immutable_state_.params().device;
697 if (user_device_) {
698 params.device = user_device_.get();
699 } else {
700 params.device = device;
701 }
702 params.start_time_usecs = start_time_usecs_;
703 params.log_memory = log_memory_;
704 params.rendezvous = rendezvous_;
705 params.collective_executor = collective_executor_;
706 params.session_state = session_state_;
707 params.session_handle = session_handle_;
708 params.session_metadata = session_metadata_;
709 params.tensor_store = tensor_store_;
710 params.cancellation_manager = cancellation_manager_;
711 params.coordination_service_agent = coordination_service_agent_;
712 params.call_frame = call_frame_;
713 params.function_library = immutable_state_.params().function_library;
714 params.resource_manager = device->resource_manager();
715 params.step_container = step_container_;
716 params.slice_reader_cache = slice_reader_cache_;
717 params.inputs = &inputs;
718 params.input_alloc_attrs = &input_alloc_attrs;
719 params.runner = &runner_;
720 params.run_all_kernels_inline = run_all_kernels_inline_;
721 params.stats_collector = stats_collector_;
722 params.inc_num_deferred_ops_function = [this]() {
723 mutex_lock lock(num_deferred_ops_mu_);
724 num_deferred_ops_++;
725 };
726 params.dec_num_deferred_ops_function = [this]() {
727 bool finish_when_deferred_ops_done = false;
728 {
729 mutex_lock lock(num_deferred_ops_mu_);
730 num_deferred_ops_--;
731 if (num_deferred_ops_ == 0) {
732 finish_when_deferred_ops_done = finish_when_deferred_ops_done_;
733 }
734 }
735 // Invoke Finish if the graph processing has completed. Finish is always
736 // called exactly once per ExecutorState, either here if there are any
737 // deferred ops, or in ScheduleFinish if there aren't any deferred ops.
738 if (finish_when_deferred_ops_done) Finish();
739 };
740
741 // Set the device_context for this device, if it exists.
742 params.op_device_context = device_context_;
743
744 Status s;
745 NodeExecStatsInterface* stats = nullptr;
746
747 EntryVector outputs(1);
748
749 bool completed = false;
750 inline_ready.push_back(tagged_node);
751 while (!inline_ready.empty()) {
752 tagged_node = inline_ready.front();
753 inline_ready.pop_front();
754 const NodeItem& item = tagged_node.get_node_item();
755 const int id = item.node_id;
756
757 propagator_.MaybeMarkStarted(tagged_node);
758
759 params.track_allocations = false;
760 stats = nullptr;
761 if (stats_collector_ && !tagged_node.get_is_dead()) {
762 stats = stats_collector_->CreateNodeExecStats(&item.kernel->def());
763 // Track allocations if and only if we are collecting statistics, and
764 // `stats` object is expecting allocations to be tracked.
765 params.track_allocations = stats ? stats->TrackAllocations() : false;
766 nodestats::SetScheduled(stats, scheduled_nsec);
767 nodestats::SetAllStart(stats);
768 }
769
770 if (vlog_) {
771 VLOG(1) << "Process node: " << id << " step " << params.step_id << " "
772 << SummarizeNodeDef(item.kernel->def())
773 << (tagged_node.get_is_dead() ? " is dead" : "")
774 << " device: " << device->name();
775 }
776
777 Entry* first_input = propagator_.GetInputTensors(tagged_node);
778
779 // Only execute this node if it is not dead or it is a send/recv
780 // transfer node. For transfer nodes, we need to propagate the "dead"
781 // bit even when the node is dead.
782 bool launched_asynchronously = false;
783 if (tagged_node.get_is_dead() && !item.is_transfer_node) {
784 if (outputs.size() < item.num_outputs) outputs.resize(item.num_outputs);
785 } else if (TF_PREDICT_FALSE(item.is_noop)) {
786 ProcessNoop(stats);
787 } else if (item.const_tensor != nullptr && !params.track_allocations) {
788 ProcessConstTensor(item, &outputs, stats);
789 } else {
790 // Prepares inputs.
791 bool is_input_dead = false;
792 s = PrepareInputs(item, first_input, &inputs, &input_alloc_attrs,
793 &is_input_dead);
794 if (!s.ok()) {
795 // Clear inputs.
796 const int num_inputs = item.num_inputs;
797 for (int i = 0; i < num_inputs; ++i) {
798 (first_input + i)->ClearVal();
799 }
800 propagator_.MaybeMarkCompleted(tagged_node);
801 // Continue to process the nodes in 'inline_ready'.
802 completed = NodeDone(s, &ready, stats, &inline_ready);
803 continue;
804 }
805
806 // Set up compute params.
807 params.op_kernel = item.kernel;
808 params.frame_iter = propagator_.GetFrameAndIter(tagged_node);
809 params.is_input_dead = is_input_dead;
810 params.output_attr_array = item.output_attrs();
811 params.forward_from_array = item.forward_from();
812 params.outputs_required_array = item.outputs_required.get();
813
814 if (item.kernel_is_async) {
815 ProcessAsync(item, params, tagged_node, first_input, stats);
816 launched_asynchronously = true;
817 } else {
818 s = ProcessSync(item, ¶ms, &outputs, stats);
819 }
820 }
821
822 if (!launched_asynchronously) {
823 if (vlog_) {
824 VLOG(2) << "Synchronous kernel done: " << id << " step "
825 << params.step_id << " " << SummarizeNodeDef(item.kernel->def())
826 << (tagged_node.get_is_dead() ? " is dead: " : "")
827 << " device: " << device->name();
828 }
829
830 // Clears inputs.
831 const int num_inputs = item.num_inputs;
832 for (int i = 0; i < num_inputs; ++i) {
833 (first_input + i)->ClearVal();
834 }
835 propagator_.MaybeMarkCompleted(tagged_node);
836 // Propagates outputs.
837 if (s.ok()) {
838 propagator_.PropagateOutputs(tagged_node, &outputs, &ready);
839 }
840
841 // Clear outputs without deallocating the `outputs` vector.
842 const int num_outputs = item.num_outputs;
843 for (int i = 0; i < num_outputs; ++i) {
844 outputs[i].ClearVal();
845 }
846
847 if (stats) {
848 scheduled_nsec = nodestats::NowInNsec();
849 }
850 // Postprocess.
851 completed = NodeDone(s, &ready, stats, &inline_ready);
852 }
853 } // while !inline_ready.empty()
854
855 // This thread of computation is done if completed = true.
856 if (completed) ScheduleFinish();
857 }
858
859 template <class PropagatorStateType>
PrepareInputs(const NodeItem & item,Entry * first_input,TensorValueVec * inputs,AllocatorAttributeVec * input_alloc_attrs,bool * is_input_dead)860 Status ExecutorState<PropagatorStateType>::PrepareInputs(
861 const NodeItem& item, Entry* first_input, TensorValueVec* inputs,
862 AllocatorAttributeVec* input_alloc_attrs, bool* is_input_dead) {
863 inputs->resize(item.num_inputs);
864 input_alloc_attrs->resize(item.num_inputs);
865
866 *is_input_dead = false;
867
868 for (int i = 0; i < item.num_inputs; ++i) {
869 const bool expect_ref = TF_PREDICT_FALSE(item.is_any_input_ref_typed) &&
870 IsRefType(item.input_type(i));
871 Entry* entry = first_input + i;
872 (*input_alloc_attrs)[i] = entry->alloc_attr;
873
874 // i-th input.
875 TensorValue* inp = &(*inputs)[i];
876
877 switch (entry->state) {
878 case Entry::State::NO_VALUE: {
879 // Only merge and transfer nodes can have no-value inputs.
880 inp->mutex_if_ref = nullptr;
881 if (item.is_merge) {
882 inp->tensor = nullptr;
883 } else {
884 DCHECK(item.is_transfer_node)
885 << item.kernel->name() << " - input " << i;
886 entry->state = Entry::State::HAS_CONST_TENSOR;
887 entry->const_tensor = kEmptyTensor;
888 // NOTE(mrry): This `const_cast` is necessary because `TensorValue`
889 // stores a non-const `Tensor*`, and relies on the `OpKernelContext`
890 // accessors making dynamic checks that prevent using an immutable
891 // tensor as a mutable tensor.
892 inp->tensor = const_cast<Tensor*>(kEmptyTensor);
893 *is_input_dead = true;
894 }
895 break;
896 }
897
898 case Entry::State::HAS_VALUE: {
899 if (TF_PREDICT_FALSE(expect_ref)) {
900 return AttachDef(
901 errors::InvalidArgument(i, "-th input expects a ref type"),
902 item.kernel->def());
903 }
904 inp->mutex_if_ref = nullptr;
905 inp->tensor = entry->val.get();
906 break;
907 }
908
909 case Entry::State::HAS_CONST_TENSOR: {
910 if (TF_PREDICT_FALSE(expect_ref)) {
911 return AttachDef(
912 errors::InvalidArgument(i, "-th input expects a ref type"),
913 item.kernel->def());
914 }
915 // NOTE(mrry): This `const_cast` is necessary because `TensorValue`
916 // stores a non-const `Tensor*`, and relies on the `OpKernelContext`
917 // accessors making dynamic checks that prevent using an immutable
918 // tensor as a mutable tensor.
919 inp->mutex_if_ref = nullptr;
920 inp->tensor = const_cast<Tensor*>(entry->const_tensor);
921 break;
922 }
923
924 case Entry::State::HAS_REF_TENSOR: {
925 {
926 tf_shared_lock ml(*entry->ref_tensor.mu);
927 if (TF_PREDICT_FALSE(!entry->ref_tensor.tensor->IsInitialized() &&
928 !item.is_initialization_op)) {
929 return AttachDef(errors::FailedPrecondition(
930 "Attempting to use uninitialized value ",
931 item.kernel->requested_input(i)),
932 item.kernel->def());
933 }
934 }
935
936 if (expect_ref) {
937 inp->mutex_if_ref = entry->ref_tensor.mu;
938 inp->tensor = entry->ref_tensor.tensor;
939 } else {
940 // Automatically deref the tensor ref when the op expects a
941 // tensor but is given a ref to a tensor. Need to deref it
942 // under the mutex.
943 {
944 mutex* ref_mu = entry->ref_tensor.mu;
945 Tensor* ref_tensor = entry->ref_tensor.tensor;
946 tf_shared_lock l(*ref_mu);
947 entry->val.Init(*ref_tensor);
948 }
949 entry->state = Entry::State::HAS_VALUE;
950
951 inp->mutex_if_ref = nullptr;
952 inp->tensor = entry->val.get();
953 // The dtype of entry->ref_tensor.tensor could have been changed by
954 // another operation that ran after the operation that "produced" it
955 // executed, so re-validate that the type of the dereferenced tensor
956 // matches the expected input type.
957 if (TF_PREDICT_FALSE(item.input_type(i) != inp->tensor->dtype())) {
958 return AttachDef(
959 errors::InvalidArgument(
960 i, "-th input expects type ",
961 DataTypeString(item.input_type(i)),
962 " but automatically dereferenced input tensor has type ",
963 DataTypeString(inp->tensor->dtype())),
964 item.kernel->def());
965 }
966 }
967 break;
968 }
969 }
970 }
971 return Status::OK();
972 }
973
974 template <class PropagatorStateType>
ProcessOutputs(const NodeItem & item,OpKernelContext * ctx,Entry * outputs,NodeExecStatsInterface * stats)975 Status ExecutorState<PropagatorStateType>::ProcessOutputs(
976 const NodeItem& item, OpKernelContext* ctx, Entry* outputs,
977 NodeExecStatsInterface* stats) {
978 Status s = ctx->status();
979 if (!s.ok()) {
980 s = AttachDef(s, item.kernel->def());
981 // TODO(misard) Replace with a finer-grain enabling flag once we
982 // add better optional debugging support.
983 if (vlog_ && VLOG_IS_ON(1)) {
984 LOG(WARNING) << this << " Compute status: " << s;
985 }
986 if (s.code() == error::RESOURCE_EXHAUSTED) {
987 if (stats_collector_) {
988 string err = stats_collector_->ReportAllocsOnResourceExhausted(
989 s.error_message());
990 s = Status(s.code(), strings::StrCat(s.error_message(), err));
991 } else {
992 s = Status(
993 s.code(),
994 strings::StrCat(
995 s.error_message(),
996 "\nHint: If you want to see a list of allocated tensors when "
997 "OOM happens, add report_tensor_allocations_upon_oom "
998 "to RunOptions for current allocation info. This isn't "
999 "available when running in Eager mode.\n"));
1000 }
1001 } else if (s.code() == error::UNAVAILABLE &&
1002 !item.is_distributed_communication) {
1003 s = errors::ReplaceErrorFromNonCommunicationOps(s, item.kernel->name());
1004 }
1005 return s;
1006 }
1007
1008 for (int i = 0; i < item.num_outputs; ++i) {
1009 const TensorValue val = ctx->release_output(i);
1010 Entry* out = &outputs[i];
1011 DCHECK(out->state == Entry::State::NO_VALUE);
1012
1013 if (val.tensor == nullptr) {
1014 // Unless it's a Switch or a Recv, or the executor has marked the output
1015 // as not required, the node must produce a tensor value at i-th output.
1016 if (!(item.is_recv_or_switch ||
1017 (item.outputs_required && !item.outputs_required[i]))) {
1018 s.Update(errors::Internal("Missing ", i, "-th output from ",
1019 FormatNodeDefForError(item.kernel->def())));
1020 }
1021 } else {
1022 // Set the allocator attributes of the output entry.
1023 out->alloc_attr = ctx->output_alloc_attr(i);
1024
1025 // Sanity check of output tensor types. We need to inspect this safely as
1026 // we are in the tensor buffer.
1027 DataType dtype = val.dtype_safe();
1028 if (dtype == item.output_type(i)) {
1029 if (stats && val.tensor->IsInitialized()) {
1030 nodestats::SetOutput(stats, i, val.tensor);
1031 }
1032 if (val.is_ref()) {
1033 out->state = Entry::State::HAS_REF_TENSOR;
1034 out->ref_tensor.tensor = val.tensor;
1035 out->ref_tensor.mu = val.mutex_if_ref;
1036 if (log_memory_) {
1037 Tensor to_log;
1038 {
1039 // Dereference the tensor under the lock.
1040 tf_shared_lock l(*out->ref_tensor.mu);
1041 to_log = *out->ref_tensor.tensor;
1042 }
1043 LogMemory::RecordTensorOutput(ctx->op_kernel().name(),
1044 ctx->step_id(), i, to_log);
1045 }
1046 } else {
1047 // NOTE that std::move is used here, so val.tensor goes to
1048 // uninitialized state (val.tensor->IsInitialized return false).
1049 out->state = Entry::State::HAS_VALUE;
1050 out->val.Init(std::move(*val.tensor));
1051 if (log_memory_) {
1052 LogMemory::RecordTensorOutput(ctx->op_kernel().name(),
1053 ctx->step_id(), i, *out->val);
1054 }
1055 }
1056 } else {
1057 s.Update(
1058 errors::Internal("Output ", i, " of type ", DataTypeString(dtype),
1059 " does not match declared output type ",
1060 DataTypeString(item.output_type(i)), " for node ",
1061 FormatNodeDefForError(item.kernel->def())));
1062 }
1063 }
1064 if (!val.is_ref()) {
1065 // If OpKernelContext returns outputs via pass-by-value, we
1066 // don't need this trouble.
1067 delete val.tensor;
1068 }
1069 }
1070 return s;
1071 }
1072
1073 template <class PropagatorStateType>
NodeDone(const Status & s,TaggedNodeSeq * ready,NodeExecStatsInterface * stats,TaggedNodeReadyQueue * inline_ready)1074 bool ExecutorState<PropagatorStateType>::NodeDone(
1075 const Status& s, TaggedNodeSeq* ready, NodeExecStatsInterface* stats,
1076 TaggedNodeReadyQueue* inline_ready) {
1077 if (stats) {
1078 nodestats::SetAllEnd(stats);
1079 DCHECK_NE(stats_collector_, nullptr);
1080 stats->Done(immutable_state_.params().device->name());
1081 }
1082
1083 if (TF_PREDICT_TRUE(s.ok())) {
1084 const size_t ready_size = ready->size();
1085 if (ready_size == 0) {
1086 return num_outstanding_ops_.fetch_sub(1) == 1;
1087 } else {
1088 // NOTE: Avoid touching the atomic counter if only one node becomes ready.
1089 if (ready_size > 1) {
1090 num_outstanding_ops_.fetch_add(ready_size - 1,
1091 std::memory_order_relaxed);
1092 }
1093
1094 // Schedule the ready nodes in 'ready'.
1095 ScheduleReady(ready, inline_ready);
1096
1097 return false;
1098 }
1099 } else {
1100 bool abort_run = false;
1101
1102 // Some error happened. This thread of computation is done.
1103 {
1104 mutex_lock l(mu_);
1105 if (status_.ok()) {
1106 // If this is the first node to fail in this run, we are responsible for
1107 // aborting all other execution in the step.
1108 abort_run = true;
1109
1110 // If execution has been cancelled, mark cancelled or aborted errors as
1111 // being derived. Note that the original node that fails might also
1112 // trigger cancellation, and here we make sure the original error is
1113 // exposed to users and not buried as a derived error.
1114 if (cancellation_manager_ && cancellation_manager_->IsCancelled() &&
1115 (errors::IsCancelled(s) || errors::IsAborted(s))) {
1116 status_ = StatusGroup::MakeDerived(s);
1117 } else {
1118 status_ = s;
1119 }
1120 }
1121 }
1122
1123 if (abort_run) {
1124 TRACEPRINTF("StartAbort: %s", s.ToString().c_str());
1125 if (cancellation_manager_) {
1126 // Only log when the abort happens during the actual run time.
1127 // Use VLOG instead of LOG(warning) because error status is expected
1128 // when the executor is run under the grappler optimization phase or
1129 // when iterating through a tf.data input pipeline.
1130 VLOG(1) << "[" << immutable_state_.params().device->name()
1131 << "] Executor start aborting: " << s;
1132 }
1133
1134 if (rendezvous_) {
1135 rendezvous_->StartAbort(s);
1136 }
1137 if (cancellation_manager_) {
1138 cancellation_manager_->StartCancel();
1139 } else if (collective_executor_) {
1140 // If there's cancellation_manager_, collective ops aborts
1141 // collective_executor_ upon cancellation; otherwise we need to abort
1142 // here.
1143 collective_executor_->StartAbort(s);
1144 }
1145 }
1146
1147 return num_outstanding_ops_.fetch_sub(1) == 1;
1148 }
1149 }
1150
1151 template <class PropagatorStateType>
ScheduleReady(TaggedNodeSeq * ready,TaggedNodeReadyQueue * inline_ready)1152 void ExecutorState<PropagatorStateType>::ScheduleReady(
1153 TaggedNodeSeq* ready, TaggedNodeReadyQueue* inline_ready) {
1154 DCHECK(!ready->empty());
1155
1156 int64_t scheduled_nsec = 0;
1157 if (stats_collector_) {
1158 scheduled_nsec = nodestats::NowInNsec();
1159 }
1160
1161 if (run_all_kernels_inline_) {
1162 if (inline_ready == nullptr) {
1163 // Schedule all ready kernels from a single closure. This ensure that,
1164 // regardless of the `runner_` implementation, all kernels will run
1165 // sequentially on the same thread, and thread wakeup overhead and
1166 // executor mutex contention will be minimized.
1167 RunTask([this, ready = std::move(*ready), scheduled_nsec]() {
1168 for (auto& tagged_node : ready) {
1169 Process(tagged_node, scheduled_nsec);
1170 }
1171 });
1172 } else {
1173 for (auto& tagged_node : *ready) {
1174 inline_ready->push_back(tagged_node);
1175 }
1176 }
1177 } else {
1178 const TaggedNode* curr_expensive_node = nullptr;
1179 if (inline_ready == nullptr) {
1180 // Schedule to run all the ready ops in thread pool.
1181 for (auto& tagged_node : *ready) {
1182 RunTask([=]() { Process(tagged_node, scheduled_nsec); });
1183 }
1184 } else {
1185 for (auto& tagged_node : *ready) {
1186 const NodeItem& item = *tagged_node.node_item;
1187 if (tagged_node.get_is_dead() || !kernel_stats_->IsExpensive(item)) {
1188 // Inline this inexpensive node.
1189 inline_ready->push_back(tagged_node);
1190 } else {
1191 if (curr_expensive_node) {
1192 // Dispatch to another thread since there is plenty of work to
1193 // do for this thread.
1194 RunTask(std::bind(&ExecutorState::Process, this,
1195 *curr_expensive_node, scheduled_nsec));
1196 }
1197 curr_expensive_node = &tagged_node;
1198 }
1199 }
1200 }
1201 if (curr_expensive_node) {
1202 if (inline_ready->empty()) {
1203 inline_ready->push_back(*curr_expensive_node);
1204 } else {
1205 // There are inline nodes to run already. We dispatch this expensive
1206 // node to other thread.
1207 RunTask(std::bind(&ExecutorState::Process, this, *curr_expensive_node,
1208 scheduled_nsec));
1209 }
1210 }
1211 }
1212 ready->clear();
1213 }
1214
1215 template <class PropagatorStateType>
ScheduleFinish()1216 void ExecutorState<PropagatorStateType>::ScheduleFinish() {
1217 // Checks condition to decide if needs to invoke Finish(). If there are
1218 // in-flight deffered ops, wait for `num_deferred_ops_` reaches 0 to invoke
1219 // Finish(). Otherwise, invoke Finish() directly.
1220 // Note that it is critical that the ScheduleFinish / Finish codepath does not
1221 // block, otherwise we might deadlock. See b/124523000 for details.
1222 {
1223 mutex_lock lock(num_deferred_ops_mu_);
1224 if (num_deferred_ops_ > 0) {
1225 finish_when_deferred_ops_done_ = true;
1226 return;
1227 }
1228 }
1229 // Finish is always called exactly once per ExecutorState, either here if
1230 // there aren't any deferred ops, or in the dec_num_deferred_ops_function if
1231 // there are deferred ops.
1232 Finish();
1233 }
1234
1235 template <class PropagatorStateType>
Finish()1236 void ExecutorState<PropagatorStateType>::Finish() {
1237 mu_.lock();
1238 auto status = status_;
1239 auto done_cb = std::move(done_cb_);
1240 auto runner = std::move(runner_);
1241 mu_.unlock();
1242 int64_t step_id = step_id_;
1243 CHECK(done_cb != nullptr);
1244 Device* device = immutable_state_.params().device;
1245
1246 if (vlog_ && !status.ok() && VLOG_IS_ON(1)) {
1247 // Logs verbose information about the current state of active and pending
1248 // nodes in the propagator.
1249 propagator_.DumpState();
1250 }
1251
1252 // There are several potential race conditions below. To name a few:
1253 // 1. Even if the device's status is OK at the precise moment when
1254 // num_deferred_ops_ reaches 0, it could go bad before device->RefreshStatus()
1255 // is called below, caused by work enqueued onto the same device by other
1256 // concurrent ExecutorState objects.
1257 // 2. Some implementations of Device::RefreshStatus, such as
1258 // XlaDevice::RefreshStatus, may be inherently racy because it releases the
1259 // device mutex after a stream pointer is acquired and before the stream is
1260 // queried for status.
1261 // 3. It's the same for some implementations of Device::Sync, such as
1262 // XlaDevice::Sync.
1263 //
1264 // However, these race conditions are acceptable because a stream (and
1265 // therefore an XlaDevice) can only go from OK to not-OK, never the opposite,
1266 // which means we will at worst report errors when there isn't any, never the
1267 // opposite.
1268
1269 // An early exit for devices don't allow sync on completion. Ops that run on
1270 // these devices should have used num_deferred_ops correctly to ensure the
1271 // device has finished all relevant work at this point.
1272 if (!device->AllowsSyncOnCompletion()) {
1273 status.Update(device->RefreshStatus());
1274 if (!status.ok()) {
1275 // In device async execution mode, it's possible for device execution to
1276 // lag behind ExecutorState scheduling so much that this is the first
1277 // place a device execution error surfaces.
1278 // If so, all ExecutorState::NodeDone calls have already happened with OK
1279 // status. This is the last defense where StartCancel must be called to
1280 // abort all computation still running on any device.
1281 // TODO(b/124523000): Always call Finish in a separate thread, so even if
1282 // StartCancel blocks the current thread's execution, we won't encounter
1283 // deadlocks caused by inter-op thread exhaustion.
1284 if (rendezvous_) {
1285 rendezvous_->StartAbort(status);
1286 }
1287 if (cancellation_manager_) {
1288 cancellation_manager_->StartCancel();
1289 } else if (collective_executor_) {
1290 // If there's cancellation_manager_, collective ops aborts
1291 // collective_executor_ upon cancellation; otherwise we need to abort
1292 // here.
1293 collective_executor_->StartAbort(status);
1294 }
1295 }
1296 delete this;
1297 runner([step_id, status, done_cb = std::move(done_cb)]() {
1298 profiler::TraceMeConsumer activity(
1299 // From TraceMeProducer in KernelAndDeviceFunc::RunAsync,
1300 // DirectSession::RunInternal or GraphMgr::ExecuteAsync.
1301 [&] {
1302 return profiler::TraceMeEncode("ExecutorDoneCallback",
1303 {{"id", step_id}});
1304 },
1305 profiler::ContextType::kTfExecutor, step_id,
1306 profiler::TraceMeLevel::kInfo);
1307 done_cb(status);
1308 });
1309 return;
1310 }
1311
1312 if (sync_on_finish_ && status.ok()) {
1313 // Block until the device has finished all queued operations. For
1314 // devices like GPUs that continue to execute Ops after their Compute
1315 // methods have completed, this ensures that control is not returned to
1316 // the user until the step (and its side-effects) has actually completed.
1317 device->Sync([this, step_id, runner = std::move(runner),
1318 done_cb = std::move(done_cb)](const Status& status) mutable {
1319 delete this;
1320 runner([step_id, status, done_cb = std::move(done_cb)]() {
1321 profiler::TraceMeConsumer activity(
1322 // From TraceMeProducer in KernelAndDeviceFunc::RunAsync,
1323 // DirectSession::RunInternal or GraphMgr::ExecuteAsync.
1324 [&] {
1325 return profiler::TraceMeEncode("ExecutorDoneCallback",
1326 {{"id", step_id}});
1327 },
1328 profiler::ContextType::kTfExecutor, step_id,
1329 profiler::TraceMeLevel::kInfo);
1330 done_cb(status);
1331 });
1332 });
1333 } else {
1334 delete this;
1335 runner([step_id, status, done_cb = std::move(done_cb)]() {
1336 profiler::TraceMeConsumer activity(
1337 // From TraceMeProducer in KernelAndDeviceFunc::RunAsync,
1338 // DirectSession::RunInternal or GraphMgr::ExecuteAsync.
1339 [&] {
1340 return profiler::TraceMeEncode("ExecutorDoneCallback",
1341 {{"id", step_id}});
1342 },
1343 profiler::ContextType::kTfExecutor, step_id,
1344 profiler::TraceMeLevel::kInfo);
1345 done_cb(status);
1346 });
1347 }
1348 }
1349
RunAsync(const Args & args,DoneCallback done)1350 void ExecutorImpl::RunAsync(const Args& args, DoneCallback done) {
1351 if (immutable_state_.requires_control_flow_support()) {
1352 (new ExecutorState<PropagatorState>(args, immutable_state_, &kernel_stats_))
1353 ->RunAsync(std::move(done));
1354 } else {
1355 (new ExecutorState<SimplePropagatorState>(args, immutable_state_,
1356 &kernel_stats_))
1357 ->RunAsync(std::move(done));
1358 }
1359 }
1360
1361 } // namespace
1362
NewLocalExecutor(const LocalExecutorParams & params,const Graph & graph,Executor ** executor)1363 Status NewLocalExecutor(const LocalExecutorParams& params, const Graph& graph,
1364 Executor** executor) {
1365 ExecutorImpl* impl = new ExecutorImpl(params);
1366 const Status s = impl->Initialize(graph);
1367 if (s.ok()) {
1368 *executor = impl;
1369 } else {
1370 delete impl;
1371 }
1372 return s;
1373 }
1374
CreateNonCachedKernel(Device * device,FunctionLibraryRuntime * flib,const std::shared_ptr<const NodeProperties> & props,int graph_def_version,OpKernel ** kernel)1375 Status CreateNonCachedKernel(Device* device, FunctionLibraryRuntime* flib,
1376 const std::shared_ptr<const NodeProperties>& props,
1377 int graph_def_version, OpKernel** kernel) {
1378 const auto device_type = DeviceType(device->attributes().device_type());
1379 auto allocator = device->GetAllocator(AllocatorAttributes());
1380 return CreateOpKernel(device_type, device, allocator, flib,
1381 device->resource_manager(), props, graph_def_version,
1382 kernel);
1383 }
1384
DeleteNonCachedKernel(OpKernel * kernel)1385 void DeleteNonCachedKernel(OpKernel* kernel) { delete kernel; }
1386
1387 namespace {
1388
1389 class DefaultExecutorRegistrar {
1390 public:
DefaultExecutorRegistrar()1391 DefaultExecutorRegistrar() {
1392 Factory* factory = new Factory;
1393 ExecutorFactory::Register("", factory);
1394 ExecutorFactory::Register("DEFAULT", factory);
1395 }
1396
1397 private:
1398 class Factory : public ExecutorFactory {
NewExecutor(const LocalExecutorParams & params,const Graph & graph,std::unique_ptr<Executor> * out_executor)1399 Status NewExecutor(const LocalExecutorParams& params, const Graph& graph,
1400 std::unique_ptr<Executor>* out_executor) override {
1401 Executor* ret = nullptr;
1402 TF_RETURN_IF_ERROR(NewLocalExecutor(params, std::move(graph), &ret));
1403 out_executor->reset(ret);
1404 return Status::OK();
1405 }
1406 };
1407 };
1408 static DefaultExecutorRegistrar registrar;
1409
1410 } // namespace
1411
1412 } // namespace tensorflow
1413