• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/common_runtime/executor.h"
17 
18 #include <atomic>
19 #include <memory>
20 #include <vector>
21 
22 #include "absl/memory/memory.h"
23 #include "tensorflow/core/common_runtime/costmodel_manager.h"
24 #include "tensorflow/core/common_runtime/entry.h"
25 #include "tensorflow/core/common_runtime/executor_factory.h"
26 #include "tensorflow/core/common_runtime/graph_view.h"
27 #include "tensorflow/core/common_runtime/immutable_executor_state.h"
28 #include "tensorflow/core/common_runtime/pending_counts.h"
29 #include "tensorflow/core/common_runtime/propagator_state.h"
30 #include "tensorflow/core/common_runtime/renamed_device.h"
31 #include "tensorflow/core/common_runtime/simple_propagator_state.h"
32 #include "tensorflow/core/common_runtime/step_stats_collector.h"
33 #include "tensorflow/core/framework/allocator.h"
34 #include "tensorflow/core/framework/cancellation.h"
35 #include "tensorflow/core/framework/collective.h"
36 #include "tensorflow/core/framework/control_flow.h"
37 #include "tensorflow/core/framework/device_attributes.pb.h"
38 #include "tensorflow/core/framework/log_memory.h"
39 #include "tensorflow/core/framework/metrics.h"
40 #include "tensorflow/core/framework/node_def_util.h"
41 #include "tensorflow/core/framework/op_kernel.h"
42 #include "tensorflow/core/framework/op_segment.h"
43 #include "tensorflow/core/framework/tensor.h"
44 #include "tensorflow/core/framework/tensor_reference.h"
45 #include "tensorflow/core/framework/types.h"
46 #include "tensorflow/core/framework/types.pb.h"
47 #include "tensorflow/core/graph/edgeset.h"
48 #include "tensorflow/core/graph/graph.h"
49 #include "tensorflow/core/graph/graph_node_util.h"
50 #include "tensorflow/core/lib/core/errors.h"
51 #include "tensorflow/core/lib/core/notification.h"
52 #include "tensorflow/core/lib/core/status.h"
53 #include "tensorflow/core/lib/core/threadpool.h"
54 #include "tensorflow/core/lib/gtl/flatmap.h"
55 #include "tensorflow/core/lib/gtl/inlined_vector.h"
56 #include "tensorflow/core/lib/gtl/manual_constructor.h"
57 #include "tensorflow/core/lib/hash/hash.h"
58 #include "tensorflow/core/platform/context.h"
59 #include "tensorflow/core/platform/env.h"
60 #include "tensorflow/core/platform/errors.h"
61 #include "tensorflow/core/platform/logging.h"
62 #include "tensorflow/core/platform/macros.h"
63 #include "tensorflow/core/platform/mutex.h"
64 #include "tensorflow/core/platform/profile_utils/cpu_utils.h"
65 #include "tensorflow/core/platform/thread_annotations.h"
66 #include "tensorflow/core/platform/tracing.h"
67 #include "tensorflow/core/platform/types.h"
68 #include "tensorflow/core/profiler/lib/annotated_traceme.h"
69 #include "tensorflow/core/profiler/lib/connected_traceme.h"
70 #include "tensorflow/core/profiler/lib/scoped_annotation.h"
71 #include "tensorflow/core/profiler/lib/traceme_encode.h"
72 #include "tensorflow/core/protobuf/error_codes.pb.h"
73 #include "tensorflow/core/util/tensor_slice_reader_cache.h"
74 
75 namespace tensorflow {
76 
77 namespace {
78 
79 // 1-D, 0 element tensor.
80 static const Tensor* const kEmptyTensor = new Tensor;
81 
82 // Helper routines for collecting step stats.
83 namespace nodestats {
NowInNsec()84 inline int64 NowInNsec() { return EnvTime::NowNanos(); }
85 
SetScheduled(NodeExecStatsInterface * stats,int64_t micros)86 void SetScheduled(NodeExecStatsInterface* stats, int64_t micros) {
87   if (!stats) return;
88   stats->SetScheduled(micros * EnvTime::kMicrosToNanos);
89 }
90 
SetAllStart(NodeExecStatsInterface * stats)91 void SetAllStart(NodeExecStatsInterface* stats) {
92   if (!stats) return;
93   stats->RecordExecutorStarted();
94 }
95 
SetOpStart(NodeExecStatsInterface * stats)96 void SetOpStart(NodeExecStatsInterface* stats) {
97   if (!stats) return;
98   stats->RecordComputeStarted();
99 }
100 
SetOpEnd(NodeExecStatsInterface * stats)101 void SetOpEnd(NodeExecStatsInterface* stats) {
102   if (!stats) return;
103   stats->RecordComputeEnded();
104 }
105 
SetAllEnd(NodeExecStatsInterface * stats)106 void SetAllEnd(NodeExecStatsInterface* stats) {
107   if (!stats) return;
108   stats->RecordExecutorEnded();
109 }
110 
SetOutput(NodeExecStatsInterface * stats,int slot,const Tensor * v)111 void SetOutput(NodeExecStatsInterface* stats, int slot, const Tensor* v) {
112   if (!stats) return;
113   stats->SetOutput(slot, v);
114 }
115 
SetMemory(NodeExecStatsInterface * stats,OpKernelContext * ctx)116 void SetMemory(NodeExecStatsInterface* stats, OpKernelContext* ctx) {
117   if (!stats) return;
118   stats->SetMemory(ctx);
119 }
120 
121 }  // namespace nodestats
122 
123 // Time the execution of kernels (in CPU cycles).  Used to dynamically identify
124 // inexpensive kernels which can be dispatched inline.
125 struct KernelTimer {
126   uint64 start_cycles = profile_utils::CpuUtils::GetCurrentClockCycle();
127 
ElapsedCyclestensorflow::__anon194149970111::KernelTimer128   uint64 ElapsedCycles() {
129     return profile_utils::CpuUtils::GetCurrentClockCycle() - start_cycles;
130   }
131 };
132 
133 // TODO(b/152925936): Re-evaluate these constants with current usage patterns.
134 typedef gtl::InlinedVector<TensorValue, 4> TensorValueVec;
135 typedef gtl::InlinedVector<AllocatorAttributes, 4> AllocatorAttributeVec;
136 
137 class ExecutorImpl : public Executor {
138  public:
ExecutorImpl(const LocalExecutorParams & p)139   explicit ExecutorImpl(const LocalExecutorParams& p) : immutable_state_(p) {}
140 
Initialize(const Graph & graph)141   Status Initialize(const Graph& graph) {
142     TF_RETURN_IF_ERROR(immutable_state_.Initialize(graph));
143     kernel_stats_.Initialize(immutable_state_.graph_view());
144     return Status::OK();
145   }
146 
147   void RunAsync(const Args& args, DoneCallback done) override;
148 
149  private:
150   template <class PropagatorStateType>
151   friend class ExecutorState;
152 
153   // Stores execution time information about the kernels in an executor's graph.
154   class KernelStats {
155    public:
156     KernelStats() = default;
157 
Initialize(const GraphView & gview)158     void Initialize(const GraphView& gview) {
159       is_expensive_.resize(gview.num_nodes());
160       cost_estimates_ =
161           absl::make_unique<std::atomic_uint_fast64_t[]>(gview.num_nodes());
162       for (int32_t i = 0; i < gview.num_nodes(); ++i) {
163         if (gview.node(i)) {
164           is_expensive_[i] =
165               gview.node(i)->kernel && gview.node(i)->kernel->IsExpensive();
166           cost_estimates_[i] = kInitialCostEstimateCycles;
167         }
168       }
169     }
170 
171     // Returns true iff the given node is considered "expensive". The
172     // executor uses this flag to optimize graph execution, for example
173     // by "inlining" inexpensive kernels.
IsExpensive(const NodeItem & node) const174     bool IsExpensive(const NodeItem& node) const {
175       return is_expensive_[node.node_id] &&
176              (cost_estimates_[node.node_id].load(std::memory_order_relaxed) >
177               kOpIsExpensiveThresholdCycles);
178     }
179 
180     // Returns the value of kernel->IsExpensive().
HasExpensiveMarker(const NodeItem & node) const181     bool HasExpensiveMarker(const NodeItem& node) const {
182       return is_expensive_[node.node_id];
183     }
184 
185     // Updates the dynamic cost estimate, which is used to determine whether the
186     // given node is expensive. The new cost estimate is a weighted average of
187     // the old cost estimate and the latest cost. We only update cost estimates
188     // for kernels for which IsExpensive() return true.
UpdateCostEstimate(const NodeItem & node,uint64 elapsed_cycles)189     void UpdateCostEstimate(const NodeItem& node, uint64 elapsed_cycles) {
190       // N.B. Updates to `cost_estimate` are atomic but unlocked.  Simultaneous
191       // updates may result in one or more updates being ignored.  This does not
192       // affect correctness but may slow down the update frequency.
193       std::atomic_uint_fast64_t& cost_estimate = cost_estimates_[node.node_id];
194       auto prev_estimate = cost_estimate.load(std::memory_order_relaxed);
195 
196       uint64 new_estimate =
197           ((kCostDecay - 1) * prev_estimate + elapsed_cycles) / kCostDecay;
198 
199       cost_estimate.store(new_estimate, std::memory_order_relaxed);
200     }
201 
202    private:
203     // Initial time (in CPU cycles) we expect an operation to take.  Used to
204     // determine whether an operation should be place in a threadpool.
205     // Operations start out "expensive".
206     static constexpr uint64 kInitialCostEstimateCycles = 100 * 1000 * 1000;
207     static constexpr uint64 kOpIsExpensiveThresholdCycles = 8000;
208     static constexpr uint64 kCostDecay = 10;
209 
210     std::vector<bool> is_expensive_;
211     // std::unique_ptr<std::atomic<bool>[]> is_expensive_;
212     std::unique_ptr<std::atomic_uint_fast64_t[]> cost_estimates_;
213   };
214 
215   ImmutableExecutorState immutable_state_;
216   KernelStats kernel_stats_;
217 
218   TF_DISALLOW_COPY_AND_ASSIGN(ExecutorImpl);
219 };
220 
221 // The state associated with one invocation of ExecutorImpl::Run.
222 //
223 // ExecutorState dispatches nodes when they become ready, and delegates to an
224 // instance of `PropagatorStateType` to keep track of how many predecessors of a
225 // are still pending.
226 //
227 // The template argument `class PropagatorStateType` must define the following
228 // public members:
229 // * A type `TaggedNode`, representing a node to be processed, with public
230 //   members:
231 //   * `const NodeItem& get_node_item() const`
232 //   * `bool get_is_dead() const`
233 // * A type `TaggedNodeReadyQueue`, representing a queue of nodes to be
234 //   processed, with public members (having the same meanings as in an
235 //   `std::vector<TaggedNode>`):
236 //   * `void push_back(const TaggedNode& node)`
237 //   * `TaggedNode front() const`
238 //   * `void pop_front()`
239 //   * `bool empty() const`
240 // * A type `TaggedNodeSeq`, representing a list of nodes to be schedules, with
241 //   public members (having the same meanings as in an
242 //   `std::vector<TaggedNode>`):
243 //   * `size_t size() const`
244 //   * `bool empty() const`
245 //   * `void clear()`
246 //   * `const_iterator begin() const`
247 //   * `const_iterator end() const`
248 // * A public constructor, `PropagatorStateType(const ImmutableExecutorState&
249 //   immutable_state, int64 step_id)`.
250 // * The following public methods:
251 //   * `void ActivateRoots(gtl::ArraySlice<const NodeItem*> roots,
252 //     TaggedNodeSeq* ready)`, which creates `TaggedNode` instances for the
253 //     nodes in `roots` and adds them to `*ready`
254 //   * `void PropagateOutputs(const TaggedNode& tagged_node, EntryVector*
255 //     outputs, TaggedNodeSeq* ready)`, which propagates `outputs` from the
256 //     given `tagged_node` to the destinations of its output edges, and adds
257 //     any newly runnable nodes to `*ready`
258 //   * `Entry* GetInputTensors(const TaggedNode& tagged_node) const`, which
259 //     returns a pointer to the input tensors for the given `tagged_node`
260 //   * `FrameAndIter GetFrameAndIter(const TaggedNode& tagged_node) const`,
261 //     which creates a `FrameAndIter` for the given `tagged_node`
262 //   * `void DumpState()`, which dumps the dynamic state of the executing graph
263 //   * `void MaybeMarkStarted(const TaggedNode& tagged_node)`, which records
264 //     that a node has started
265 //   * `void MaybeMarkCompleted(const TaggedNode& tagged_node)`, which records
266 //     that a node has completed
267 //
268 // See `PropagatorState` in "./propagator_state.h" for an example of a type that
269 // can be used to instantiate `PropagatorStateType`.
270 template <class PropagatorStateType>
271 class ExecutorState {
272  public:
273   ExecutorState(const Executor::Args& args,
274                 const ImmutableExecutorState& immutable_state_,
275                 ExecutorImpl::KernelStats* kernel_stats_);
276   ~ExecutorState();
277 
278   void RunAsync(Executor::DoneCallback done);
279 
280  private:
281   // Use `TaggedNode` types defined by `PropagatorStateType`.
282   typedef typename PropagatorStateType::TaggedNode TaggedNode;
283   typedef
284       typename PropagatorStateType::TaggedNodeReadyQueue TaggedNodeReadyQueue;
285   typedef typename PropagatorStateType::TaggedNodeSeq TaggedNodeSeq;
286 
287   struct AsyncState;
288 
289   // Process a ready node in current thread.
290   void Process(TaggedNode node, int64_t scheduled_nsec);
291 
292   Status ProcessSync(const NodeItem& item, OpKernelContext::Params* params,
293                      EntryVector* outputs, NodeExecStatsInterface* stats);
294   void ProcessAsync(const NodeItem& item, const OpKernelContext::Params& params,
295                     const TaggedNode& tagged_node, Entry* first_input,
296                     NodeExecStatsInterface* stats);
297   void ProcessNoop(NodeExecStatsInterface* stats);
298   void ProcessConstTensor(const NodeItem& item, EntryVector* outputs,
299                           NodeExecStatsInterface* stats);
300 
301   // Before invoking item->kernel, fills in its "inputs".
302   Status PrepareInputs(const NodeItem& item, Entry* first_input,
303                        TensorValueVec* inputs,
304                        AllocatorAttributeVec* input_alloc_attrs,
305                        bool* is_input_dead);
306 
307   // After item->kernel computation is done, processes its outputs.
308   Status ProcessOutputs(const NodeItem& item, OpKernelContext* ctx,
309                         Entry* outputs, NodeExecStatsInterface* stats);
310 
311   // Called after each node finishes. Takes ownership of "stats". Returns true
312   // if execution has completed.
313   //
314   // This method will clear `*ready` before returning.
315   bool NodeDone(const Status& s, TaggedNodeSeq* ready,
316                 NodeExecStatsInterface* stats,
317                 TaggedNodeReadyQueue* inline_ready);
318 
319   // Schedule all the expensive nodes in '*ready', and put all the inexpensive
320   // nodes in 'ready' into 'inline_ready'.
321   //
322   // This method will clear `*ready` before returning.
323   //
324   // REQUIRES: `!ready->empty()`.
325   void ScheduleReady(TaggedNodeSeq* ready, TaggedNodeReadyQueue* inline_ready);
326 
327   // A wrapper for runner_ to keep track of the pending queue length. Op
328   // execution should dispatch work using this function instead of using runner_
329   // directly.
330   template <typename Closure>
331   void RunTask(Closure&& c);
332 
333   // Clean up when this executor is done.
334   void Finish();
335   void ScheduleFinish();
336 
337   // Contains the device context assigned by the device at the beginning of a
338   // step.
339   DeviceContext* device_context_ = nullptr;
340 
341   const bool vlog_;  // true if VLOG_IS_ON(1). Used to check vlog cheaply.
342 
343   // true if LogMemory::IsEnabled(). Used to check memory enabled cheaply.
344   const bool log_memory_;
345 
346   int64 step_id_;
347   int64 start_time_usecs_ = 0;
348 
349   // Not owned.
350   RendezvousInterface* rendezvous_;
351   CollectiveExecutor* collective_executor_ = nullptr;
352   SessionState* session_state_;
353   string session_handle_;
354   const SessionMetadata* session_metadata_ = nullptr;
355   TensorStore* tensor_store_;
356   // Step-local container.
357   ScopedStepContainer* step_container_;
358   StepStatsCollectorInterface* const stats_collector_;
359   const tracing::EventCollector* const event_collector_;
360   Context context_;
361 
362   // QUESTION: Make it a checkpoint::TensorSliceReaderCacheWrapper
363   // instead of a pointer?  (avoids having to delete).
364   checkpoint::TensorSliceReaderCacheWrapper* slice_reader_cache_;
365   CallFrameInterface* call_frame_;
366   const ImmutableExecutorState& immutable_state_;
367   ExecutorImpl::KernelStats* const kernel_stats_;
368   CancellationManager* cancellation_manager_;
369   CoordinationServiceAgent* coordination_service_agent_;
370   // If not null, use this device to schedule intra-op operation
371   std::unique_ptr<DeviceBase> user_device_;
372   Executor::Args::Runner runner_;
373   bool sync_on_finish_;
374   const bool run_all_kernels_inline_;
375 
376   PropagatorStateType propagator_;
377 
378   // Invoked when the execution finishes.
379   Executor::DoneCallback done_cb_;
380 
381   std::atomic_int_fast32_t num_outstanding_ops_;
382 
383   // Available via OpKernelContext to every OpKernel invocation.
384   mutex num_deferred_ops_mu_;
385   int64 num_deferred_ops_ TF_GUARDED_BY(num_deferred_ops_mu_) = 0;
386   bool finish_when_deferred_ops_done_ TF_GUARDED_BY(num_deferred_ops_mu_) =
387       false;
388 
389   mutex mu_;
390   Status status_ TF_GUARDED_BY(mu_);
391 };
392 
393 template <class PropagatorStateType>
ExecutorState(const Executor::Args & args,const ImmutableExecutorState & immutable_state,ExecutorImpl::KernelStats * kernel_stats)394 ExecutorState<PropagatorStateType>::ExecutorState(
395     const Executor::Args& args, const ImmutableExecutorState& immutable_state,
396     ExecutorImpl::KernelStats* kernel_stats)
397     : vlog_(VLOG_IS_ON(1)),
398       log_memory_(LogMemory::IsEnabled()),
399       step_id_(args.step_id),
400       start_time_usecs_(args.start_time_usecs),
401       rendezvous_(args.rendezvous),
402       collective_executor_(args.collective_executor),
403       session_state_(args.session_state),
404       session_handle_(args.session_handle),
405       session_metadata_(immutable_state.params().session_metadata),
406       tensor_store_(args.tensor_store),
407       step_container_(args.step_container),
408       stats_collector_(args.stats_collector),
409       event_collector_(
410           tracing::GetEventCollector(tracing::EventCategory::kCompute)),
411       context_(ContextKind::kThread),
412       slice_reader_cache_(new checkpoint::TensorSliceReaderCacheWrapper),
413       call_frame_(args.call_frame),
414       immutable_state_(immutable_state),
415       kernel_stats_(kernel_stats),
416       cancellation_manager_(args.cancellation_manager),
417       coordination_service_agent_(args.coordination_service_agent),
418       runner_(args.runner),
419       sync_on_finish_(args.sync_on_finish),
420       run_all_kernels_inline_(args.run_all_kernels_inline),
421       propagator_(immutable_state, step_id_, vlog_),
422       num_outstanding_ops_(0) {
423   if (args.user_intra_op_threadpool != nullptr) {
424     Device* device = immutable_state_.params().device;
425     user_device_ = RenamedDevice::NewRenamedDevice(
426         device->name(), device, false, false, args.user_intra_op_threadpool);
427   }
428 }
429 
430 template <class PropagatorStateType>
~ExecutorState()431 ExecutorState<PropagatorStateType>::~ExecutorState() {
432   if (device_context_) {
433     device_context_->Unref();
434   }
435   delete slice_reader_cache_;
436 }
437 
438 template <class PropagatorStateType>
439 template <typename Closure>
RunTask(Closure && c)440 void ExecutorState<PropagatorStateType>::RunTask(Closure&& c) {
441   // Align the atomic variables at 64 bytes to avoid false-sharing, assuming the
442   // cacheline size is 64 bytes or smaller.
443   alignas(64) static std::atomic<int64_t> num_enqueue_ops{0};
444   alignas(64) static std::atomic<int64_t> num_dequeue_ops{0};
445 
446   auto n_enqueues = num_enqueue_ops.fetch_add(1, std::memory_order_relaxed);
447   // Sample the queue length on every 16 enqueue operations. This amortizes the
448   // cost of metric updates across 16 operations.
449   if (n_enqueues % 16 == 0) {
450     auto n_dequeues = num_dequeue_ops.load(std::memory_order_relaxed);
451     metrics::UpdateGraphPendingQueueLength(n_enqueues - n_dequeues);
452   }
453 
454   // mutable is needed because std::forward<Closure> in the lambda body may move
455   // the Closure `c`.
456   runner_([c = std::forward<Closure>(c)]() mutable {
457     num_dequeue_ops.fetch_add(1, std::memory_order_relaxed);
458     std::forward<Closure>(c)();
459   });
460 }
461 
462 template <class PropagatorStateType>
RunAsync(Executor::DoneCallback done)463 void ExecutorState<PropagatorStateType>::RunAsync(Executor::DoneCallback done) {
464   TaggedNodeSeq ready;
465 
466   // Ask the device to fill in the device context map.
467   Device* device = immutable_state_.params().device;
468   const Status get_context_status =
469       device->TryGetDeviceContext(&device_context_);
470   if (!get_context_status.ok()) {
471     delete this;
472     done(get_context_status);
473     return;
474   }
475 
476   // Initialize the ready queue.
477   ready.reserve(immutable_state_.root_nodes().size());
478   propagator_.ActivateRoots(immutable_state_.root_nodes(), &ready);
479   num_outstanding_ops_ = ready.size();
480   if (ready.empty()) {
481     delete this;
482     done(Status::OK());
483   } else {
484     done_cb_ = std::move(done);
485     // Schedule to run all the ready ops in thread pool.
486     ScheduleReady(&ready, nullptr);
487   }
488 }
489 
490 // State kept alive for executing an asynchronous node in another
491 // thread.  NOTE: We need to make a copy of p.input and p.input_alloc_attrs for
492 // asynchronous kernels because OpKernelContext methods like input_type(i) needs
493 // the param points to valid input type vector. It's not an issue for
494 // sync kernels because these vectors are kept on the stack.
495 template <class PropagatorStateType>
496 struct ExecutorState<PropagatorStateType>::AsyncState {
AsyncStatetensorflow::__anon194149970111::ExecutorState::AsyncState497   AsyncState(const OpKernelContext::Params& p, const TaggedNode& _tagged_node,
498              const NodeItem* _item, Entry* _first_input,
499              NodeExecStatsInterface* _stats)
500       : saved_inputs(*p.inputs),
501         saved_input_alloc_attrs(*p.input_alloc_attrs),
502         params(p),
503         tagged_node(_tagged_node),
504         item(_item),
505         first_input(_first_input),
506         // ParamsButClearingEigenGPUDevice does equivalent of
507         //   params.eigen_gpu_device = nullptr;
508         ctx(ParamsButClearingEigenGPUDevice(&params), item->num_outputs),
509         stats(_stats) {
510     params.inputs = &saved_inputs;
511     params.input_alloc_attrs = &saved_input_alloc_attrs;
512   }
513 
514   TensorValueVec saved_inputs;
515   AllocatorAttributeVec saved_input_alloc_attrs;
516   OpKernelContext::Params params;
517   TaggedNode tagged_node;
518   const NodeItem* item;
519   Entry* first_input;
520   OpKernelContext ctx;
521   NodeExecStatsInterface* stats;
522 
523  private:
ParamsButClearingEigenGPUDevicetensorflow::__anon194149970111::ExecutorState::AsyncState524   OpKernelContext::Params* ParamsButClearingEigenGPUDevice(
525       OpKernelContext::Params* p) {
526     // Ensure OpKernelContext constructor will make a new eigen GPU device if
527     // necessary.
528     p->eigen_gpu_device = nullptr;  // Force allocation
529     return p;
530   }
531 };
532 
533 // Returns true if `item` might be traced by the given trace and event
534 // collectors. Returns false only if `item` definitely will not be traced.
MightTrace(const tracing::EventCollector * event_collector,bool is_expensive)535 bool MightTrace(const tracing::EventCollector* event_collector,
536                 bool is_expensive) {
537   // Tracing will only be enabled if either `event_collector` is non null,
538   // or `trace_collector` is non-null and enabled for this particular kernel.
539   // Although `profiler::TraceMe`, `profiler::ScopedAnnotation`, and
540   // `tracing::ScopedRegion` check subsets of these properties internally in
541   // their constructors, the cost of passing the necessary arguments to them can
542   // be significant, so we avoid constructing them in the common case (when we
543   // know they will not be used).
544   if (event_collector != nullptr) {
545     return true;
546   }
547 
548   if (profiler::ScopedAnnotation::IsEnabled()) return true;
549 
550   return profiler::TraceMe::Active(profiler::GetTFTraceMeLevel(is_expensive));
551 }
552 
553 template <class PropagatorStateType>
ProcessSync(const NodeItem & item,OpKernelContext::Params * params,EntryVector * outputs,NodeExecStatsInterface * stats)554 Status ExecutorState<PropagatorStateType>::ProcessSync(
555     const NodeItem& item, OpKernelContext::Params* params, EntryVector* outputs,
556     NodeExecStatsInterface* stats) {
557   Status s;
558   OpKernelContext ctx(params, item.num_outputs);
559   nodestats::SetOpStart(stats);
560 
561   OpKernel* op_kernel = item.kernel;
562   Device* device = immutable_state_.params().device;
563   const bool is_expensive = kernel_stats_->IsExpensive(item);
564 
565   if (TF_PREDICT_FALSE(MightTrace(event_collector_, is_expensive))) {
566     tracing::ScopedRegion region(tracing::EventCategory::kCompute,
567                                  op_kernel->name_view());
568     profiler::AnnotatedTraceMe activity(
569         [op_kernel, &ctx] {
570           return op_kernel->TraceString(
571               ctx, /*verbose=*/profiler::TfOpDetailsEnabled());
572         },
573         profiler::GetTFTraceMeLevel(is_expensive));
574     device->Compute(op_kernel, &ctx);
575   } else if (kernel_stats_->HasExpensiveMarker(item)) {
576     KernelTimer timer;
577     device->Compute(op_kernel, &ctx);
578     // For expensive kernels, always update the cost estimate. For inexpensive
579     // kernels, update the cost estimate with ~1/16 probability. This assumes
580     // that the last 4 bits of the CPU cycle count is uniformly distributed.
581     constexpr int kKernelExecutionTrackingInvocationSkipCount = 16;
582     if (is_expensive ||
583         timer.start_cycles % kKernelExecutionTrackingInvocationSkipCount == 0) {
584       kernel_stats_->UpdateCostEstimate(item, timer.ElapsedCycles());
585     }
586   } else {
587     device->Compute(op_kernel, &ctx);
588   }
589   nodestats::SetOpEnd(stats);
590   if (outputs->size() < item.num_outputs) outputs->resize(item.num_outputs);
591   s = ProcessOutputs(item, &ctx, outputs->data(), stats);
592   nodestats::SetMemory(stats, &ctx);
593   return s;
594 }
595 
596 template <class PropagatorStateType>
ProcessAsync(const NodeItem & item,const OpKernelContext::Params & params,const TaggedNode & tagged_node,Entry * first_input,NodeExecStatsInterface * stats)597 void ExecutorState<PropagatorStateType>::ProcessAsync(
598     const NodeItem& item, const OpKernelContext::Params& params,
599     const TaggedNode& tagged_node, Entry* first_input,
600     NodeExecStatsInterface* stats) {
601   AsyncOpKernel* async_kernel = item.kernel->AsAsync();
602   DCHECK(async_kernel != nullptr);
603   AsyncState* state =
604       new AsyncState(params, tagged_node, &item, first_input, stats);
605 
606   auto done = [this, state]() {
607     Device* device = immutable_state_.params().device;
608     NodeExecStatsInterface* stats = state->stats;  // Shorthand
609     Entry* first_input = state->first_input;       // Shorthand
610 
611     nodestats::SetOpEnd(stats);
612     EntryVector outputs(state->item->num_outputs);
613     Status s = ProcessOutputs(*state->item, &state->ctx, outputs.data(), stats);
614     nodestats::SetMemory(stats, &state->ctx);
615     if (vlog_) {
616       VLOG(2) << "Async kernel done: " << state->item->node_id << " step "
617               << step_id_ << " " << SummarizeNodeDef(state->item->kernel->def())
618               << (state->tagged_node.get_is_dead() ? " is dead" : "")
619               << " device: " << device->name();
620     }
621 
622     // Clears inputs.
623     const int num_inputs = state->item->num_inputs;
624     for (int i = 0; i < num_inputs; ++i) {
625       (first_input + i)->ClearVal();
626     }
627     propagator_.MaybeMarkCompleted(state->tagged_node);
628     TaggedNodeSeq ready;
629     if (s.ok()) {
630       propagator_.PropagateOutputs(state->tagged_node, &outputs, &ready);
631     }
632     outputs.clear();
633     const bool completed = NodeDone(s, &ready, stats, nullptr);
634     delete state;
635     if (completed) ScheduleFinish();
636   };
637   nodestats::SetOpStart(stats);
638   {
639     profiler::AnnotatedTraceMe activity(
640         [async_kernel, state] {
641           return async_kernel->TraceString(
642               state->ctx, /*verbose=*/profiler::TfOpDetailsEnabled());
643         },
644         profiler::GetTFTraceMeLevel(kernel_stats_->IsExpensive(item)));
645     immutable_state_.params().device->ComputeAsync(async_kernel, &state->ctx,
646                                                    std::move(done));
647   }
648 }
649 
650 template <class PropagatorStateType>
ProcessNoop(NodeExecStatsInterface * stats)651 void ExecutorState<PropagatorStateType>::ProcessNoop(
652     NodeExecStatsInterface* stats) {
653   nodestats::SetOpStart(stats);
654   nodestats::SetOpEnd(stats);
655 }
656 
657 template <class PropagatorStateType>
ProcessConstTensor(const NodeItem & item,EntryVector * outputs,NodeExecStatsInterface * stats)658 void ExecutorState<PropagatorStateType>::ProcessConstTensor(
659     const NodeItem& item, EntryVector* outputs, NodeExecStatsInterface* stats) {
660   nodestats::SetOpStart(stats);
661   nodestats::SetOpEnd(stats);
662   Entry& output = (*outputs)[0];
663   output.state = Entry::State::HAS_CONST_TENSOR;
664   output.const_tensor = item.const_tensor;
665   output.alloc_attr = item.output_attrs()[0];
666 }
667 
668 template <class PropagatorStateType>
Process(TaggedNode tagged_node,int64_t scheduled_nsec)669 void ExecutorState<PropagatorStateType>::Process(TaggedNode tagged_node,
670                                                  int64_t scheduled_nsec) {
671   profiler::TraceMeConsumer activity(
672       // From TraceMeProducer in DirectSession::RunInternal,
673       // GraphMgr::ExecuteAsync, or FunctionLibraryRuntime::Run.
674       [&] {
675         // NOTE: This tracing uses the iteration number from the first tagged
676         // node that executes during this call to `Process()`. In principle,
677         // subsequent nodes could have different values of `iter_num` that
678         // will not be traced.
679         return profiler::TraceMeEncode(
680             "ExecutorState::Process",
681             {{"id", step_id_}, {"iter_num", tagged_node.get_iter_num()}});
682       },
683       profiler::ContextType::kTfExecutor, step_id_,
684       profiler::TraceMeLevel::kInfo);
685   WithContext wc(context_);
686   TaggedNodeSeq ready;
687   TaggedNodeReadyQueue inline_ready;
688 
689   // Parameters passed to OpKernel::Compute.
690   TensorValueVec inputs;
691   AllocatorAttributeVec input_alloc_attrs;
692 
693   OpKernelContext::Params params;
694   params.step_id = step_id_;
695   // Override device's threadpool if user provides an intra_op_threadpool
696   Device* device = immutable_state_.params().device;
697   if (user_device_) {
698     params.device = user_device_.get();
699   } else {
700     params.device = device;
701   }
702   params.start_time_usecs = start_time_usecs_;
703   params.log_memory = log_memory_;
704   params.rendezvous = rendezvous_;
705   params.collective_executor = collective_executor_;
706   params.session_state = session_state_;
707   params.session_handle = session_handle_;
708   params.session_metadata = session_metadata_;
709   params.tensor_store = tensor_store_;
710   params.cancellation_manager = cancellation_manager_;
711   params.coordination_service_agent = coordination_service_agent_;
712   params.call_frame = call_frame_;
713   params.function_library = immutable_state_.params().function_library;
714   params.resource_manager = device->resource_manager();
715   params.step_container = step_container_;
716   params.slice_reader_cache = slice_reader_cache_;
717   params.inputs = &inputs;
718   params.input_alloc_attrs = &input_alloc_attrs;
719   params.runner = &runner_;
720   params.run_all_kernels_inline = run_all_kernels_inline_;
721   params.stats_collector = stats_collector_;
722   params.inc_num_deferred_ops_function = [this]() {
723     mutex_lock lock(num_deferred_ops_mu_);
724     num_deferred_ops_++;
725   };
726   params.dec_num_deferred_ops_function = [this]() {
727     bool finish_when_deferred_ops_done = false;
728     {
729       mutex_lock lock(num_deferred_ops_mu_);
730       num_deferred_ops_--;
731       if (num_deferred_ops_ == 0) {
732         finish_when_deferred_ops_done = finish_when_deferred_ops_done_;
733       }
734     }
735     // Invoke Finish if the graph processing has completed. Finish is always
736     // called exactly once per ExecutorState, either here if there are any
737     // deferred ops, or in ScheduleFinish if there aren't any deferred ops.
738     if (finish_when_deferred_ops_done) Finish();
739   };
740 
741   // Set the device_context for this device, if it exists.
742   params.op_device_context = device_context_;
743 
744   Status s;
745   NodeExecStatsInterface* stats = nullptr;
746 
747   EntryVector outputs(1);
748 
749   bool completed = false;
750   inline_ready.push_back(tagged_node);
751   while (!inline_ready.empty()) {
752     tagged_node = inline_ready.front();
753     inline_ready.pop_front();
754     const NodeItem& item = tagged_node.get_node_item();
755     const int id = item.node_id;
756 
757     propagator_.MaybeMarkStarted(tagged_node);
758 
759     params.track_allocations = false;
760     stats = nullptr;
761     if (stats_collector_ && !tagged_node.get_is_dead()) {
762       stats = stats_collector_->CreateNodeExecStats(&item.kernel->def());
763       // Track allocations if and only if we are collecting statistics, and
764       // `stats` object is expecting allocations to be tracked.
765       params.track_allocations = stats ? stats->TrackAllocations() : false;
766       nodestats::SetScheduled(stats, scheduled_nsec);
767       nodestats::SetAllStart(stats);
768     }
769 
770     if (vlog_) {
771       VLOG(1) << "Process node: " << id << " step " << params.step_id << " "
772               << SummarizeNodeDef(item.kernel->def())
773               << (tagged_node.get_is_dead() ? " is dead" : "")
774               << " device: " << device->name();
775     }
776 
777     Entry* first_input = propagator_.GetInputTensors(tagged_node);
778 
779     // Only execute this node if it is not dead or it is a send/recv
780     // transfer node. For transfer nodes, we need to propagate the "dead"
781     // bit even when the node is dead.
782     bool launched_asynchronously = false;
783     if (tagged_node.get_is_dead() && !item.is_transfer_node) {
784       if (outputs.size() < item.num_outputs) outputs.resize(item.num_outputs);
785     } else if (TF_PREDICT_FALSE(item.is_noop)) {
786       ProcessNoop(stats);
787     } else if (item.const_tensor != nullptr && !params.track_allocations) {
788       ProcessConstTensor(item, &outputs, stats);
789     } else {
790       // Prepares inputs.
791       bool is_input_dead = false;
792       s = PrepareInputs(item, first_input, &inputs, &input_alloc_attrs,
793                         &is_input_dead);
794       if (!s.ok()) {
795         // Clear inputs.
796         const int num_inputs = item.num_inputs;
797         for (int i = 0; i < num_inputs; ++i) {
798           (first_input + i)->ClearVal();
799         }
800         propagator_.MaybeMarkCompleted(tagged_node);
801         // Continue to process the nodes in 'inline_ready'.
802         completed = NodeDone(s, &ready, stats, &inline_ready);
803         continue;
804       }
805 
806       // Set up compute params.
807       params.op_kernel = item.kernel;
808       params.frame_iter = propagator_.GetFrameAndIter(tagged_node);
809       params.is_input_dead = is_input_dead;
810       params.output_attr_array = item.output_attrs();
811       params.forward_from_array = item.forward_from();
812       params.outputs_required_array = item.outputs_required.get();
813 
814       if (item.kernel_is_async) {
815         ProcessAsync(item, params, tagged_node, first_input, stats);
816         launched_asynchronously = true;
817       } else {
818         s = ProcessSync(item, &params, &outputs, stats);
819       }
820     }
821 
822     if (!launched_asynchronously) {
823       if (vlog_) {
824         VLOG(2) << "Synchronous kernel done: " << id << " step "
825                 << params.step_id << " " << SummarizeNodeDef(item.kernel->def())
826                 << (tagged_node.get_is_dead() ? " is dead: " : "")
827                 << " device: " << device->name();
828       }
829 
830       // Clears inputs.
831       const int num_inputs = item.num_inputs;
832       for (int i = 0; i < num_inputs; ++i) {
833         (first_input + i)->ClearVal();
834       }
835       propagator_.MaybeMarkCompleted(tagged_node);
836       // Propagates outputs.
837       if (s.ok()) {
838         propagator_.PropagateOutputs(tagged_node, &outputs, &ready);
839       }
840 
841       // Clear outputs without deallocating the `outputs` vector.
842       const int num_outputs = item.num_outputs;
843       for (int i = 0; i < num_outputs; ++i) {
844         outputs[i].ClearVal();
845       }
846 
847       if (stats) {
848         scheduled_nsec = nodestats::NowInNsec();
849       }
850       // Postprocess.
851       completed = NodeDone(s, &ready, stats, &inline_ready);
852     }
853   }  // while !inline_ready.empty()
854 
855   // This thread of computation is done if completed = true.
856   if (completed) ScheduleFinish();
857 }
858 
859 template <class PropagatorStateType>
PrepareInputs(const NodeItem & item,Entry * first_input,TensorValueVec * inputs,AllocatorAttributeVec * input_alloc_attrs,bool * is_input_dead)860 Status ExecutorState<PropagatorStateType>::PrepareInputs(
861     const NodeItem& item, Entry* first_input, TensorValueVec* inputs,
862     AllocatorAttributeVec* input_alloc_attrs, bool* is_input_dead) {
863   inputs->resize(item.num_inputs);
864   input_alloc_attrs->resize(item.num_inputs);
865 
866   *is_input_dead = false;
867 
868   for (int i = 0; i < item.num_inputs; ++i) {
869     const bool expect_ref = TF_PREDICT_FALSE(item.is_any_input_ref_typed) &&
870                             IsRefType(item.input_type(i));
871     Entry* entry = first_input + i;
872     (*input_alloc_attrs)[i] = entry->alloc_attr;
873 
874     // i-th input.
875     TensorValue* inp = &(*inputs)[i];
876 
877     switch (entry->state) {
878       case Entry::State::NO_VALUE: {
879         // Only merge and transfer nodes can have no-value inputs.
880         inp->mutex_if_ref = nullptr;
881         if (item.is_merge) {
882           inp->tensor = nullptr;
883         } else {
884           DCHECK(item.is_transfer_node)
885               << item.kernel->name() << " - input " << i;
886           entry->state = Entry::State::HAS_CONST_TENSOR;
887           entry->const_tensor = kEmptyTensor;
888           // NOTE(mrry): This `const_cast` is necessary because `TensorValue`
889           // stores a non-const `Tensor*`, and relies on the `OpKernelContext`
890           // accessors making dynamic checks that prevent using an immutable
891           // tensor as a mutable tensor.
892           inp->tensor = const_cast<Tensor*>(kEmptyTensor);
893           *is_input_dead = true;
894         }
895         break;
896       }
897 
898       case Entry::State::HAS_VALUE: {
899         if (TF_PREDICT_FALSE(expect_ref)) {
900           return AttachDef(
901               errors::InvalidArgument(i, "-th input expects a ref type"),
902               item.kernel->def());
903         }
904         inp->mutex_if_ref = nullptr;
905         inp->tensor = entry->val.get();
906         break;
907       }
908 
909       case Entry::State::HAS_CONST_TENSOR: {
910         if (TF_PREDICT_FALSE(expect_ref)) {
911           return AttachDef(
912               errors::InvalidArgument(i, "-th input expects a ref type"),
913               item.kernel->def());
914         }
915         // NOTE(mrry): This `const_cast` is necessary because `TensorValue`
916         // stores a non-const `Tensor*`, and relies on the `OpKernelContext`
917         // accessors making dynamic checks that prevent using an immutable
918         // tensor as a mutable tensor.
919         inp->mutex_if_ref = nullptr;
920         inp->tensor = const_cast<Tensor*>(entry->const_tensor);
921         break;
922       }
923 
924       case Entry::State::HAS_REF_TENSOR: {
925         {
926           tf_shared_lock ml(*entry->ref_tensor.mu);
927           if (TF_PREDICT_FALSE(!entry->ref_tensor.tensor->IsInitialized() &&
928                                !item.is_initialization_op)) {
929             return AttachDef(errors::FailedPrecondition(
930                                  "Attempting to use uninitialized value ",
931                                  item.kernel->requested_input(i)),
932                              item.kernel->def());
933           }
934         }
935 
936         if (expect_ref) {
937           inp->mutex_if_ref = entry->ref_tensor.mu;
938           inp->tensor = entry->ref_tensor.tensor;
939         } else {
940           // Automatically deref the tensor ref when the op expects a
941           // tensor but is given a ref to a tensor.  Need to deref it
942           // under the mutex.
943           {
944             mutex* ref_mu = entry->ref_tensor.mu;
945             Tensor* ref_tensor = entry->ref_tensor.tensor;
946             tf_shared_lock l(*ref_mu);
947             entry->val.Init(*ref_tensor);
948           }
949           entry->state = Entry::State::HAS_VALUE;
950 
951           inp->mutex_if_ref = nullptr;
952           inp->tensor = entry->val.get();
953           // The dtype of entry->ref_tensor.tensor could have been changed by
954           // another operation that ran after the operation that "produced" it
955           // executed, so re-validate that the type of the dereferenced tensor
956           // matches the expected input type.
957           if (TF_PREDICT_FALSE(item.input_type(i) != inp->tensor->dtype())) {
958             return AttachDef(
959                 errors::InvalidArgument(
960                     i, "-th input expects type ",
961                     DataTypeString(item.input_type(i)),
962                     " but automatically dereferenced input tensor has type ",
963                     DataTypeString(inp->tensor->dtype())),
964                 item.kernel->def());
965           }
966         }
967         break;
968       }
969     }
970   }
971   return Status::OK();
972 }
973 
974 template <class PropagatorStateType>
ProcessOutputs(const NodeItem & item,OpKernelContext * ctx,Entry * outputs,NodeExecStatsInterface * stats)975 Status ExecutorState<PropagatorStateType>::ProcessOutputs(
976     const NodeItem& item, OpKernelContext* ctx, Entry* outputs,
977     NodeExecStatsInterface* stats) {
978   Status s = ctx->status();
979   if (!s.ok()) {
980     s = AttachDef(s, item.kernel->def());
981     // TODO(misard) Replace with a finer-grain enabling flag once we
982     // add better optional debugging support.
983     if (vlog_ && VLOG_IS_ON(1)) {
984       LOG(WARNING) << this << " Compute status: " << s;
985     }
986     if (s.code() == error::RESOURCE_EXHAUSTED) {
987       if (stats_collector_) {
988         string err = stats_collector_->ReportAllocsOnResourceExhausted(
989             s.error_message());
990         s = Status(s.code(), strings::StrCat(s.error_message(), err));
991       } else {
992         s = Status(
993             s.code(),
994             strings::StrCat(
995                 s.error_message(),
996                 "\nHint: If you want to see a list of allocated tensors when "
997                 "OOM happens, add report_tensor_allocations_upon_oom "
998                 "to RunOptions for current allocation info. This isn't "
999                 "available when running in Eager mode.\n"));
1000       }
1001     } else if (s.code() == error::UNAVAILABLE &&
1002                !item.is_distributed_communication) {
1003       s = errors::ReplaceErrorFromNonCommunicationOps(s, item.kernel->name());
1004     }
1005     return s;
1006   }
1007 
1008   for (int i = 0; i < item.num_outputs; ++i) {
1009     const TensorValue val = ctx->release_output(i);
1010     Entry* out = &outputs[i];
1011     DCHECK(out->state == Entry::State::NO_VALUE);
1012 
1013     if (val.tensor == nullptr) {
1014       // Unless it's a Switch or a Recv, or the executor has marked the output
1015       // as not required, the node must produce a tensor value at i-th output.
1016       if (!(item.is_recv_or_switch ||
1017             (item.outputs_required && !item.outputs_required[i]))) {
1018         s.Update(errors::Internal("Missing ", i, "-th output from ",
1019                                   FormatNodeDefForError(item.kernel->def())));
1020       }
1021     } else {
1022       // Set the allocator attributes of the output entry.
1023       out->alloc_attr = ctx->output_alloc_attr(i);
1024 
1025       // Sanity check of output tensor types. We need to inspect this safely as
1026       // we are in the tensor buffer.
1027       DataType dtype = val.dtype_safe();
1028       if (dtype == item.output_type(i)) {
1029         if (stats && val.tensor->IsInitialized()) {
1030           nodestats::SetOutput(stats, i, val.tensor);
1031         }
1032         if (val.is_ref()) {
1033           out->state = Entry::State::HAS_REF_TENSOR;
1034           out->ref_tensor.tensor = val.tensor;
1035           out->ref_tensor.mu = val.mutex_if_ref;
1036           if (log_memory_) {
1037             Tensor to_log;
1038             {
1039               // Dereference the tensor under the lock.
1040               tf_shared_lock l(*out->ref_tensor.mu);
1041               to_log = *out->ref_tensor.tensor;
1042             }
1043             LogMemory::RecordTensorOutput(ctx->op_kernel().name(),
1044                                           ctx->step_id(), i, to_log);
1045           }
1046         } else {
1047           // NOTE that std::move is used here, so val.tensor goes to
1048           // uninitialized state (val.tensor->IsInitialized return false).
1049           out->state = Entry::State::HAS_VALUE;
1050           out->val.Init(std::move(*val.tensor));
1051           if (log_memory_) {
1052             LogMemory::RecordTensorOutput(ctx->op_kernel().name(),
1053                                           ctx->step_id(), i, *out->val);
1054           }
1055         }
1056       } else {
1057         s.Update(
1058             errors::Internal("Output ", i, " of type ", DataTypeString(dtype),
1059                              " does not match declared output type ",
1060                              DataTypeString(item.output_type(i)), " for node ",
1061                              FormatNodeDefForError(item.kernel->def())));
1062       }
1063     }
1064     if (!val.is_ref()) {
1065       // If OpKernelContext returns outputs via pass-by-value, we
1066       // don't need this trouble.
1067       delete val.tensor;
1068     }
1069   }
1070   return s;
1071 }
1072 
1073 template <class PropagatorStateType>
NodeDone(const Status & s,TaggedNodeSeq * ready,NodeExecStatsInterface * stats,TaggedNodeReadyQueue * inline_ready)1074 bool ExecutorState<PropagatorStateType>::NodeDone(
1075     const Status& s, TaggedNodeSeq* ready, NodeExecStatsInterface* stats,
1076     TaggedNodeReadyQueue* inline_ready) {
1077   if (stats) {
1078     nodestats::SetAllEnd(stats);
1079     DCHECK_NE(stats_collector_, nullptr);
1080     stats->Done(immutable_state_.params().device->name());
1081   }
1082 
1083   if (TF_PREDICT_TRUE(s.ok())) {
1084     const size_t ready_size = ready->size();
1085     if (ready_size == 0) {
1086       return num_outstanding_ops_.fetch_sub(1) == 1;
1087     } else {
1088       // NOTE: Avoid touching the atomic counter if only one node becomes ready.
1089       if (ready_size > 1) {
1090         num_outstanding_ops_.fetch_add(ready_size - 1,
1091                                        std::memory_order_relaxed);
1092       }
1093 
1094       // Schedule the ready nodes in 'ready'.
1095       ScheduleReady(ready, inline_ready);
1096 
1097       return false;
1098     }
1099   } else {
1100     bool abort_run = false;
1101 
1102     // Some error happened. This thread of computation is done.
1103     {
1104       mutex_lock l(mu_);
1105       if (status_.ok()) {
1106         // If this is the first node to fail in this run, we are responsible for
1107         // aborting all other execution in the step.
1108         abort_run = true;
1109 
1110         // If execution has been cancelled, mark cancelled or aborted errors as
1111         // being derived. Note that the original node that fails might also
1112         // trigger cancellation, and here we make sure the original error is
1113         // exposed to users and not buried as a derived error.
1114         if (cancellation_manager_ && cancellation_manager_->IsCancelled() &&
1115             (errors::IsCancelled(s) || errors::IsAborted(s))) {
1116           status_ = StatusGroup::MakeDerived(s);
1117         } else {
1118           status_ = s;
1119         }
1120       }
1121     }
1122 
1123     if (abort_run) {
1124       TRACEPRINTF("StartAbort: %s", s.ToString().c_str());
1125       if (cancellation_manager_) {
1126         // Only log when the abort happens during the actual run time.
1127         // Use VLOG instead of LOG(warning) because error status is expected
1128         // when the executor is run under the grappler optimization phase or
1129         // when iterating through a tf.data input pipeline.
1130         VLOG(1) << "[" << immutable_state_.params().device->name()
1131                 << "] Executor start aborting: " << s;
1132       }
1133 
1134       if (rendezvous_) {
1135         rendezvous_->StartAbort(s);
1136       }
1137       if (cancellation_manager_) {
1138         cancellation_manager_->StartCancel();
1139       } else if (collective_executor_) {
1140         // If there's cancellation_manager_, collective ops aborts
1141         // collective_executor_ upon cancellation; otherwise we need to abort
1142         // here.
1143         collective_executor_->StartAbort(s);
1144       }
1145     }
1146 
1147     return num_outstanding_ops_.fetch_sub(1) == 1;
1148   }
1149 }
1150 
1151 template <class PropagatorStateType>
ScheduleReady(TaggedNodeSeq * ready,TaggedNodeReadyQueue * inline_ready)1152 void ExecutorState<PropagatorStateType>::ScheduleReady(
1153     TaggedNodeSeq* ready, TaggedNodeReadyQueue* inline_ready) {
1154   DCHECK(!ready->empty());
1155 
1156   int64_t scheduled_nsec = 0;
1157   if (stats_collector_) {
1158     scheduled_nsec = nodestats::NowInNsec();
1159   }
1160 
1161   if (run_all_kernels_inline_) {
1162     if (inline_ready == nullptr) {
1163       // Schedule all ready kernels from a single closure. This ensure that,
1164       // regardless of the `runner_` implementation, all kernels will run
1165       // sequentially on the same thread, and thread wakeup overhead and
1166       // executor mutex contention will be minimized.
1167       RunTask([this, ready = std::move(*ready), scheduled_nsec]() {
1168         for (auto& tagged_node : ready) {
1169           Process(tagged_node, scheduled_nsec);
1170         }
1171       });
1172     } else {
1173       for (auto& tagged_node : *ready) {
1174         inline_ready->push_back(tagged_node);
1175       }
1176     }
1177   } else {
1178     const TaggedNode* curr_expensive_node = nullptr;
1179     if (inline_ready == nullptr) {
1180       // Schedule to run all the ready ops in thread pool.
1181       for (auto& tagged_node : *ready) {
1182         RunTask([=]() { Process(tagged_node, scheduled_nsec); });
1183       }
1184     } else {
1185       for (auto& tagged_node : *ready) {
1186         const NodeItem& item = *tagged_node.node_item;
1187         if (tagged_node.get_is_dead() || !kernel_stats_->IsExpensive(item)) {
1188           // Inline this inexpensive node.
1189           inline_ready->push_back(tagged_node);
1190         } else {
1191           if (curr_expensive_node) {
1192             // Dispatch to another thread since there is plenty of work to
1193             // do for this thread.
1194             RunTask(std::bind(&ExecutorState::Process, this,
1195                               *curr_expensive_node, scheduled_nsec));
1196           }
1197           curr_expensive_node = &tagged_node;
1198         }
1199       }
1200     }
1201     if (curr_expensive_node) {
1202       if (inline_ready->empty()) {
1203         inline_ready->push_back(*curr_expensive_node);
1204       } else {
1205         // There are inline nodes to run already. We dispatch this expensive
1206         // node to other thread.
1207         RunTask(std::bind(&ExecutorState::Process, this, *curr_expensive_node,
1208                           scheduled_nsec));
1209       }
1210     }
1211   }
1212   ready->clear();
1213 }
1214 
1215 template <class PropagatorStateType>
ScheduleFinish()1216 void ExecutorState<PropagatorStateType>::ScheduleFinish() {
1217   // Checks condition to decide if needs to invoke Finish(). If there are
1218   // in-flight deffered ops, wait for `num_deferred_ops_` reaches 0 to invoke
1219   // Finish(). Otherwise, invoke Finish() directly.
1220   // Note that it is critical that the ScheduleFinish / Finish codepath does not
1221   // block, otherwise we might deadlock.  See b/124523000 for details.
1222   {
1223     mutex_lock lock(num_deferred_ops_mu_);
1224     if (num_deferred_ops_ > 0) {
1225       finish_when_deferred_ops_done_ = true;
1226       return;
1227     }
1228   }
1229   // Finish is always called exactly once per ExecutorState, either here if
1230   // there aren't any deferred ops, or in the dec_num_deferred_ops_function if
1231   // there are deferred ops.
1232   Finish();
1233 }
1234 
1235 template <class PropagatorStateType>
Finish()1236 void ExecutorState<PropagatorStateType>::Finish() {
1237   mu_.lock();
1238   auto status = status_;
1239   auto done_cb = std::move(done_cb_);
1240   auto runner = std::move(runner_);
1241   mu_.unlock();
1242   int64_t step_id = step_id_;
1243   CHECK(done_cb != nullptr);
1244   Device* device = immutable_state_.params().device;
1245 
1246   if (vlog_ && !status.ok() && VLOG_IS_ON(1)) {
1247     // Logs verbose information about the current state of active and pending
1248     // nodes in the propagator.
1249     propagator_.DumpState();
1250   }
1251 
1252   // There are several potential race conditions below. To name a few:
1253   // 1. Even if the device's status is OK at the precise moment when
1254   // num_deferred_ops_ reaches 0, it could go bad before device->RefreshStatus()
1255   // is called below, caused by work enqueued onto the same device by other
1256   // concurrent ExecutorState objects.
1257   // 2. Some implementations of Device::RefreshStatus, such as
1258   // XlaDevice::RefreshStatus, may be inherently racy because it releases the
1259   // device mutex after a stream pointer is acquired and before the stream is
1260   // queried for status.
1261   // 3. It's the same for some implementations of Device::Sync, such as
1262   // XlaDevice::Sync.
1263   //
1264   // However, these race conditions are acceptable because a stream (and
1265   // therefore an XlaDevice) can only go from OK to not-OK, never the opposite,
1266   // which means we will at worst report errors when there isn't any, never the
1267   // opposite.
1268 
1269   // An early exit for devices don't allow sync on completion. Ops that run on
1270   // these devices should have used num_deferred_ops correctly to ensure the
1271   // device has finished all relevant work at this point.
1272   if (!device->AllowsSyncOnCompletion()) {
1273     status.Update(device->RefreshStatus());
1274     if (!status.ok()) {
1275       // In device async execution mode, it's possible for device execution to
1276       // lag behind ExecutorState scheduling so much that this is the first
1277       // place a device execution error surfaces.
1278       // If so, all ExecutorState::NodeDone calls have already happened with OK
1279       // status. This is the last defense where StartCancel must be called to
1280       // abort all computation still running on any device.
1281       // TODO(b/124523000): Always call Finish in a separate thread, so even if
1282       // StartCancel blocks the current thread's execution, we won't encounter
1283       // deadlocks caused by inter-op thread exhaustion.
1284       if (rendezvous_) {
1285         rendezvous_->StartAbort(status);
1286       }
1287       if (cancellation_manager_) {
1288         cancellation_manager_->StartCancel();
1289       } else if (collective_executor_) {
1290         // If there's cancellation_manager_, collective ops aborts
1291         // collective_executor_ upon cancellation; otherwise we need to abort
1292         // here.
1293         collective_executor_->StartAbort(status);
1294       }
1295     }
1296     delete this;
1297     runner([step_id, status, done_cb = std::move(done_cb)]() {
1298       profiler::TraceMeConsumer activity(
1299           // From TraceMeProducer in KernelAndDeviceFunc::RunAsync,
1300           // DirectSession::RunInternal or GraphMgr::ExecuteAsync.
1301           [&] {
1302             return profiler::TraceMeEncode("ExecutorDoneCallback",
1303                                            {{"id", step_id}});
1304           },
1305           profiler::ContextType::kTfExecutor, step_id,
1306           profiler::TraceMeLevel::kInfo);
1307       done_cb(status);
1308     });
1309     return;
1310   }
1311 
1312   if (sync_on_finish_ && status.ok()) {
1313     // Block until the device has finished all queued operations. For
1314     // devices like GPUs that continue to execute Ops after their Compute
1315     // methods have completed, this ensures that control is not returned to
1316     // the user until the step (and its side-effects) has actually completed.
1317     device->Sync([this, step_id, runner = std::move(runner),
1318                   done_cb = std::move(done_cb)](const Status& status) mutable {
1319       delete this;
1320       runner([step_id, status, done_cb = std::move(done_cb)]() {
1321         profiler::TraceMeConsumer activity(
1322             // From TraceMeProducer in KernelAndDeviceFunc::RunAsync,
1323             // DirectSession::RunInternal or GraphMgr::ExecuteAsync.
1324             [&] {
1325               return profiler::TraceMeEncode("ExecutorDoneCallback",
1326                                              {{"id", step_id}});
1327             },
1328             profiler::ContextType::kTfExecutor, step_id,
1329             profiler::TraceMeLevel::kInfo);
1330         done_cb(status);
1331       });
1332     });
1333   } else {
1334     delete this;
1335     runner([step_id, status, done_cb = std::move(done_cb)]() {
1336       profiler::TraceMeConsumer activity(
1337           // From TraceMeProducer in KernelAndDeviceFunc::RunAsync,
1338           // DirectSession::RunInternal or GraphMgr::ExecuteAsync.
1339           [&] {
1340             return profiler::TraceMeEncode("ExecutorDoneCallback",
1341                                            {{"id", step_id}});
1342           },
1343           profiler::ContextType::kTfExecutor, step_id,
1344           profiler::TraceMeLevel::kInfo);
1345       done_cb(status);
1346     });
1347   }
1348 }
1349 
RunAsync(const Args & args,DoneCallback done)1350 void ExecutorImpl::RunAsync(const Args& args, DoneCallback done) {
1351   if (immutable_state_.requires_control_flow_support()) {
1352     (new ExecutorState<PropagatorState>(args, immutable_state_, &kernel_stats_))
1353         ->RunAsync(std::move(done));
1354   } else {
1355     (new ExecutorState<SimplePropagatorState>(args, immutable_state_,
1356                                               &kernel_stats_))
1357         ->RunAsync(std::move(done));
1358   }
1359 }
1360 
1361 }  // namespace
1362 
NewLocalExecutor(const LocalExecutorParams & params,const Graph & graph,Executor ** executor)1363 Status NewLocalExecutor(const LocalExecutorParams& params, const Graph& graph,
1364                         Executor** executor) {
1365   ExecutorImpl* impl = new ExecutorImpl(params);
1366   const Status s = impl->Initialize(graph);
1367   if (s.ok()) {
1368     *executor = impl;
1369   } else {
1370     delete impl;
1371   }
1372   return s;
1373 }
1374 
CreateNonCachedKernel(Device * device,FunctionLibraryRuntime * flib,const std::shared_ptr<const NodeProperties> & props,int graph_def_version,OpKernel ** kernel)1375 Status CreateNonCachedKernel(Device* device, FunctionLibraryRuntime* flib,
1376                              const std::shared_ptr<const NodeProperties>& props,
1377                              int graph_def_version, OpKernel** kernel) {
1378   const auto device_type = DeviceType(device->attributes().device_type());
1379   auto allocator = device->GetAllocator(AllocatorAttributes());
1380   return CreateOpKernel(device_type, device, allocator, flib,
1381                         device->resource_manager(), props, graph_def_version,
1382                         kernel);
1383 }
1384 
DeleteNonCachedKernel(OpKernel * kernel)1385 void DeleteNonCachedKernel(OpKernel* kernel) { delete kernel; }
1386 
1387 namespace {
1388 
1389 class DefaultExecutorRegistrar {
1390  public:
DefaultExecutorRegistrar()1391   DefaultExecutorRegistrar() {
1392     Factory* factory = new Factory;
1393     ExecutorFactory::Register("", factory);
1394     ExecutorFactory::Register("DEFAULT", factory);
1395   }
1396 
1397  private:
1398   class Factory : public ExecutorFactory {
NewExecutor(const LocalExecutorParams & params,const Graph & graph,std::unique_ptr<Executor> * out_executor)1399     Status NewExecutor(const LocalExecutorParams& params, const Graph& graph,
1400                        std::unique_ptr<Executor>* out_executor) override {
1401       Executor* ret = nullptr;
1402       TF_RETURN_IF_ERROR(NewLocalExecutor(params, std::move(graph), &ret));
1403       out_executor->reset(ret);
1404       return Status::OK();
1405     }
1406   };
1407 };
1408 static DefaultExecutorRegistrar registrar;
1409 
1410 }  // namespace
1411 
1412 }  // namespace tensorflow
1413