• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_FRAMEWORK_OP_KERNEL_H_
17 #define TENSORFLOW_CORE_FRAMEWORK_OP_KERNEL_H_
18 
19 #include <functional>
20 #include <unordered_set>
21 #include <utility>
22 #include <vector>
23 
24 #include "tensorflow/core/framework/allocator.h"
25 #include "tensorflow/core/framework/cancellation.h"
26 #include "tensorflow/core/framework/control_flow.h"
27 #include "tensorflow/core/framework/device_base.h"
28 #include "tensorflow/core/framework/graph.pb.h"
29 #include "tensorflow/core/framework/kernel_def.pb.h"
30 #include "tensorflow/core/framework/kernel_def_builder.h"
31 #include "tensorflow/core/framework/node_def.pb.h"
32 #include "tensorflow/core/framework/node_def_util.h"
33 #include "tensorflow/core/framework/node_properties.h"
34 #include "tensorflow/core/framework/op.h"  // TODO(b/62899350): Remove
35 #include "tensorflow/core/framework/op_requires.h"
36 #include "tensorflow/core/framework/registration/registration.h"
37 #include "tensorflow/core/framework/rendezvous.h"
38 #include "tensorflow/core/framework/session_state.h"
39 #include "tensorflow/core/framework/tensor.h"
40 #include "tensorflow/core/framework/tensor_shape.h"
41 #include "tensorflow/core/framework/tensor_shape.pb.h"  // TODO(b/62899350): Remove
42 #include "tensorflow/core/framework/tracking_allocator.h"
43 #include "tensorflow/core/framework/types.h"
44 #include "tensorflow/core/framework/types.pb.h"
45 #include "tensorflow/core/lib/core/errors.h"
46 #include "tensorflow/core/lib/core/status.h"
47 #include "tensorflow/core/lib/gtl/array_slice.h"
48 #include "tensorflow/core/lib/gtl/manual_constructor.h"
49 #include "tensorflow/core/platform/env.h"
50 #include "tensorflow/core/platform/logging.h"
51 #include "tensorflow/core/platform/macros.h"
52 #include "tensorflow/core/platform/mutex.h"
53 #include "tensorflow/core/platform/profile_utils/cpu_utils.h"
54 #include "tensorflow/core/platform/thread_annotations.h"
55 #include "tensorflow/core/platform/types.h"
56 #include "tensorflow/core/protobuf/config.pb.h"
57 #include "tensorflow/core/util/managed_stack_trace.h"
58 
59 namespace Eigen {
60 struct ThreadPoolDevice;
61 struct GpuDevice;
62 }  // end namespace Eigen
63 
64 namespace tensorflow {
65 
66 namespace checkpoint {
67 class TensorSliceReaderCacheWrapper;
68 }  // namespace checkpoint
69 
70 class AsyncOpKernel;
71 class CallFrameInterface;
72 class DeviceMgr;
73 class FunctionLibraryRuntime;
74 class OpKernelConstruction;  // declared below
75 class OpKernelContext;       // declared below,
76 class OpRegistryInterface;
77 class ResourceMgr;
78 class ScopedStepContainer;
79 class CollectiveExecutor;
80 class StepStatsCollectorInterface;
81 class CoordinationServiceAgent;
82 
83 class OpKernel {
84  public:
85   // OpKernel won't be instantiated by the scheduler, so you may perform
86   // expensive initialization in the descendant's constructor.
87   explicit OpKernel(OpKernelConstruction* context);
88 
89   // Specialized constructor that allows a kernel implementation to mark itself
90   // as a "deferred" op. If true, the executor will provide access to the
91   // `OpKernelContext::inc_num_deferred_ops_function()` and
92   // `OpKernelContext::dec_num_deferred_ops_function()` methods at run-time.
93   OpKernel(OpKernelConstruction* context, bool is_deferred);
94 
95   // Specialized constructor that enables the descendant to provide a custom
96   // `NodeDef` value. For example, this constructor can be used to provide a
97   // stripped-down `NodeDef` that does not contain the full set of attrs (such
98   // as tensor values) if the descendant stores them in a different form.
99   OpKernel(OpKernelConstruction* context, NodeDef&& custom_def,
100            bool is_deferred);
101 
102   virtual ~OpKernel();
103 
104   // An OpKernel's computation can be either synchronous or
105   // asynchronous. All OpKernel Compute() methods must be thread-safe as they
106   // may be called concurrently (e.g. by multiple executions of the same graph
107   // concurrently).
108   //
109   // Most OpKernels should compute synchronously. They should
110   // subclass OpKernel and override the Compute() method and have it
111   // return after completing the supplied work.
112   //
113   // A synchronous OpKernel *MUST NOT* block the calling thread on a
114   // synchronization mechanism (condition variable, Notification, etc.) that
115   // will be unblocked by the execution of another OpKernel. Execution may
116   // deadlock in that case, because the executor may use a bounded number of
117   // threads.
118   //
119   // If an OpKernel must block on the execution of another OpKernel (e.g. a
120   // RecvOp, or a DequeueOp), the implementation *MUST* subclass AsyncOpKernel,
121   // and override `AsyncOpKernel::ComputeAsync()`. In addition, because the
122   // unblocking kernel may never run (due to an error or cancellation), in most
123   // cases the AsyncOpKernel should implement cancellation support via
124   // `ctx->cancellation_manager()`.
125   //
126   // In both cases, implementations of Compute() and ComputeAsync()
127   // get inputs and write outputs through the given OpKernelContext
128   // and returns a status via context->SetStatus(). They must be
129   // thread-safe.
130 
131   // Synchronous compute.
132   //
133   // "context" is guaranteed to be alive until Compute() returns.
134   virtual void Compute(OpKernelContext* context) = 0;
135 
136   // Returns nullptr iff this op kernel is synchronous.
AsAsync()137   virtual AsyncOpKernel* AsAsync() { return nullptr; }
138 
139   // Returns true iff this op kernel is considered "expensive". The
140   // runtime may use this flag to optimize graph execution for example
141   // to "inline" inexpensive kernels.
IsExpensive()142   virtual bool IsExpensive() { return expensive_; }
143 
144   // Returns a pointer to the tensor stored inside constant ops.
const_tensor()145   virtual const Tensor* const_tensor() const { return nullptr; }
146 
147   // Accessors.
def()148   const NodeDef& def() const { return props_->node_def; }
name()149   const std::string& name() const { return props_->node_def.name(); }
name_view()150   absl::string_view name_view() const { return name_view_; }
type_string()151   const std::string& type_string() const { return props_->node_def.op(); }
type_string_view()152   absl::string_view type_string_view() const { return type_string_view_; }
requested_input(int i)153   const std::string& requested_input(int i) const {
154     return props_->node_def.input(i);
155   }
requested_device()156   const std::string& requested_device() const {
157     return props_->node_def.device();
158   }
159 
num_inputs()160   int num_inputs() const { return props_->input_types.size(); }
input_type(int i)161   DataType input_type(int i) const { return props_->input_types[i]; }
input_types()162   const DataTypeVector& input_types() const { return props_->input_types; }
input_memory_types()163   const MemoryTypeVector& input_memory_types() const {
164     return input_memory_types_;
165   }
166 
num_outputs()167   int num_outputs() const { return props_->output_types.size(); }
output_type(int o)168   DataType output_type(int o) const { return props_->output_types[o]; }
output_types()169   const DataTypeVector& output_types() const { return props_->output_types; }
output_memory_types()170   const MemoryTypeVector& output_memory_types() const {
171     return output_memory_types_;
172   }
173 
174   Status InputRange(StringPiece input_name, int* start, int* stop) const;
175   Status OutputRange(StringPiece output_name, int* start, int* stop) const;
176 
177   // Returns `true` if and only if this kernel uses deferred execution.
is_deferred()178   bool is_deferred() const { return is_deferred_; }
179 
180   // Returns a trace string for current computation, op name/type and input
181   // tensor shape/dtype are encoded for profiler cost analysis. Most OpKernel
182   // should use the default implementation.
183   virtual std::string TraceString(const OpKernelContext& ctx,
184                                   bool verbose) const;
185 
186  protected:
187   std::string ShapeTraceString(const OpKernelContext& ctx) const;
188 
189  private:
190   const std::shared_ptr<const NodeProperties> props_;
191   const MemoryTypeVector input_memory_types_;
192   const MemoryTypeVector output_memory_types_;
193   NameRangeMap input_name_map_;
194   NameRangeMap output_name_map_;
195   const absl::string_view name_view_;
196   const absl::string_view type_string_view_;
197   const int graph_def_version_;
198   const bool is_deferred_;
199   bool expensive_;
200 
201   TF_DISALLOW_COPY_AND_ASSIGN(OpKernel);
202 };
203 
204 class AsyncOpKernel : public OpKernel {
205  public:
206   using OpKernel::OpKernel;  // Lift OpKernel constructors.
207 
208   // Asynchronous compute.
209   //
210   // Implementations of ComputeAsync() must ensure that `done` is (eventually)
211   // called exactly once to signal the completion of the computation. The
212   // implementation of ComputeAsync() must not block on the execution of another
213   // OpKernel. `done` may be called by the current thread, or by another thread.
214   // `context` is guaranteed to stay alive until the `done` callback starts.
215   //
216   // Since it is possible that the unblocking kernel may never run (due to an
217   // error or cancellation), in most cases the AsyncOpKernel should implement
218   // cancellation support via `context->cancellation_manager()`.
219   //
220   // WARNING: As soon as the `done` callback starts, `context` and `this` may be
221   // deleted. No code depending on these objects should execute after the call
222   // to `done`.
223   typedef std::function<void()> DoneCallback;
224   virtual void ComputeAsync(OpKernelContext* context, DoneCallback done) = 0;
225 
AsAsync()226   AsyncOpKernel* AsAsync() override { return this; }
227 
228   void Compute(OpKernelContext* context) override;
229 };
230 
231 class OpKernelConstruction {
232  public:
233   OpKernelConstruction(DeviceType device_type, DeviceBase* device,
234                        Allocator* allocator, FunctionLibraryRuntime* flib,
235                        ResourceMgr* resource_mgr,
236                        const std::shared_ptr<const NodeProperties>& props,
237                        const MemoryTypeSlice& input_memory_types,
238                        const MemoryTypeSlice& output_memory_types,
239                        int graph_def_version, Status* status);
240 
env()241   Env* env() const { return device_->env(); }
242 
243   // Allocation of tensors during kernel construction:
244   //
245   // It is legal to temporarily allocate scratch tensor storage during
246   // Op kernel construction. Scratch tensors should be allocated using
247   // allocate_temp below. Some kernels need to keep tensors in between
248   // invocations. If such a Tensor is allocated during kernel
249   // construction this also must be done using allocate_temp, and the
250   // Op may only store the returned Tensor object.
251 
252   // Allocates a temporary Tensor of the specified type and shape. The
253   // Tensor must not be used after kernel construction is
254   // complete. See comment above.
255   Status allocate_temp(DataType type, const TensorShape& shape,
256                        Tensor* out_temp);
257   Status allocate_temp(DataType type, const TensorShape& shape,
258                        Tensor* out_temp, AllocatorAttributes allocator_attr);
259 
260   // User-supplied configuration of this operation.
def()261   const NodeDef& def() const { return props_->node_def; }
262 
263   // For inspecting the inputs to this operation.
num_inputs()264   int num_inputs() const { return props_->input_types.size(); }
input_type(int i)265   DataType input_type(int i) const { return props_->input_types[i]; }
input_types()266   const DataTypeSlice& input_types() const { return props_->input_types_slice; }
input_memory_types()267   const MemoryTypeSlice& input_memory_types() const {
268     return input_memory_types_;
269   }
270 
271   // For inspecting the outputs expected from this operation.
num_outputs()272   int num_outputs() const { return props_->output_types.size(); }
output_type(int i)273   DataType output_type(int i) const { return props_->output_types[i]; }
output_types()274   const DataTypeSlice& output_types() const {
275     return props_->output_types_slice;
276   }
output_memory_types()277   const MemoryTypeSlice& output_memory_types() const {
278     return output_memory_types_;
279   }
280 
281   // If expected_inputs == inputs() and expected_outputs == output_types(),
282   // returns OK, else returns INVALID_ARGUMENT with an error message.
283   // Recommended for Ops with dynamic signatures.
284   Status MatchSignature(const DataTypeSlice expected_inputs,
285                         const DataTypeSlice expected_outputs);
286 
287   // For recording configuration errors during construction.
288   void SetStatus(const Status& status);
status()289   const Status& status() const { return *status_; }
290 
291   // Look up the attr with name attr_name and set *value to its value.  If no
292   // attr with attr_name is found in def(), or the attr does not have
293   // a matching type, a non-ok status will be returned.
294   template <class T>
295   Status GetAttr(StringPiece attr_name, T* value) const;
296 
297   // Return true if the attr_name is defined in def().
298   bool HasAttr(StringPiece attr_name) const;
299 
300   // Return the device type.
device_type()301   const DeviceType& device_type() const { return device_type_; }
302 
303   // If not nullptr, the kernel can instantiate functions defined in
304   // the library. E.g.,
305   // CHECK_NOTNULL(function_library())->Instantiate("Foo", ...).
function_library()306   FunctionLibraryRuntime* function_library() const { return flib_; }
307 
308   // Shared resources accessible to this kernel.
resource_manager()309   ResourceMgr* resource_manager() const { return resource_mgr_; }
310 
311   // The GraphDef version whose behavior we should follow.
graph_def_version()312   int graph_def_version() const { return graph_def_version_; }
313 
314   // Helper routines for the OP_REQUIRES macros
315   void CtxFailure(const Status& s);
316   void CtxFailureWithWarning(const Status& s);
317   void CtxFailure(const char* file, int line, const Status& s);
318   void CtxFailureWithWarning(const char* file, int line, const Status& s);
319 
320   // Unrecommended functions: these are functions that have some
321   // current uses but are not recommended for use, and may go away at
322   // some future major version release.
323 
324   // May be used, e.g., to get GPU handles, etc.
325   //
326   // Currently only used to call MakeTensorFromProto() for
327   // implementing ConstantOp for every device.  See comments
328   // on Device::MakeTensorFromProto for longer-term replacement
329   // ideas.
device()330   DeviceBase* device() const { return device_; }
331 
332  private:
333   const DeviceType device_type_;
334   DeviceBase* const device_;
335   Allocator* allocator_;
336   FunctionLibraryRuntime* flib_;
337   ResourceMgr* const resource_mgr_;
338   std::shared_ptr<const NodeProperties> props_;
339   MemoryTypeSlice input_memory_types_;
340   MemoryTypeSlice output_memory_types_;
341   const int graph_def_version_;
342   Status* status_;
343 
344   // Allow access from OpKernel ctor.
345   friend class OpKernel;
346 
347   TF_DISALLOW_COPY_AND_ASSIGN(OpKernelConstruction);
348 };
349 
350 // TODO(mrry): Consider converting to a random_access_iterator, and upgrading
351 // tensorflow::gtl::iterator_range to make the below container classes
352 // unnecessary.
353 template <typename ListType, typename ElementType>
354 class OpArgIterator {
355  public:
356   using iterator_category = std::forward_iterator_tag;
357   using value_type = ElementType;
358   using pointer = ElementType*;
359   using const_pointer = const ElementType*;
360   using reference = ElementType&;
361   using const_reference = const ElementType&;
362   using difference_type = ptrdiff_t;
363 
OpArgIterator(const ListType * list,int i)364   OpArgIterator(const ListType* list, int i) : list_(list), i_(i) {}
365 
366   bool operator==(const OpArgIterator& rhs) {
367     DCHECK(list_ == rhs.list_);
368     return i_ == rhs.i_;
369   }
370 
371   bool operator!=(const OpArgIterator& rhs) {
372     DCHECK(list_ == rhs.list_);
373     return i_ != rhs.i_;
374   }
375 
376   OpArgIterator operator++() {  // prefix ++it
377     ++i_;
378     return *this;
379   }
380 
381   OpArgIterator operator++(int) {  // postfix it++
382     OpArgIterator old_value = *this;
383     ++i_;
384     return old_value;
385   }
386 
387   reference operator*() { return (*list_)[i_]; }
388   pointer operator->() { return &(*list_)[i_]; }
389 
390   const_reference operator*() const { return (*list_)[i_]; }
391   const_pointer operator->() const { return &(*list_)[i_]; }
392 
393  private:
394   const ListType* const list_;
395   int i_;
396 };
397 
398 // Utility class for representing a list of immutable input tensors
399 // that are passed to the op as a single named argument.
400 class OpInputList {
401  public:
402   typedef OpArgIterator<OpInputList, const Tensor> Iterator;
OpInputList()403   OpInputList() : ctx_(nullptr), start_(0), stop_(0) {}
OpInputList(OpKernelContext * ctx,int start,int stop)404   OpInputList(OpKernelContext* ctx, int start, int stop)
405       : ctx_(ctx), start_(start), stop_(stop) {}
406   OpInputList& operator=(const OpInputList& other) = default;
407   const Tensor& operator[](int i) const;
size()408   int size() const { return stop_ - start_; }
begin()409   Iterator begin() const { return Iterator(this, 0); }
end()410   Iterator end() const { return Iterator(this, size()); }
411 
412  private:
413   OpKernelContext* ctx_;  // not owned
414   int start_;
415   int stop_;
416 };
417 
418 // Utility class for representing a list of mutable ("ref") input tensors
419 // that are passed to the op as a single named argument.
420 class OpMutableInputList {
421  public:
422   typedef OpArgIterator<OpMutableInputList, Tensor*> Iterator;
OpMutableInputList(OpKernelContext * ctx,int start,int stop)423   OpMutableInputList(OpKernelContext* ctx, int start, int stop)
424       : ctx_(ctx), start_(start), stop_(stop) {}
OpMutableInputList()425   OpMutableInputList() : ctx_(nullptr), start_(0), stop_(0) {}
426   OpMutableInputList& operator=(const OpMutableInputList& other) = default;
427   Tensor at(int i, bool lock_held);
428   mutex* ref_mutex(int i);
size()429   int size() const { return stop_ - start_; }
begin()430   Iterator begin() const { return Iterator(this, 0); }
end()431   Iterator end() const { return Iterator(this, size()); }
432 
433  private:
434   OpKernelContext* ctx_;  // not owned
435   int start_;
436   int stop_;
437 };
438 
439 // Utility class for representing a list of output tensors that are
440 // grouped as a single named output.
441 class OpOutputList {
442  public:
443   typedef OpArgIterator<OpOutputList, const Tensor*> Iterator;
OpOutputList()444   OpOutputList() : ctx_(nullptr), start_(0), stop_(0) {}
OpOutputList(OpKernelContext * ctx,int start,int stop)445   OpOutputList(OpKernelContext* ctx, int start, int stop)
446       : ctx_(ctx), start_(start), stop_(stop) {}
447   OpOutputList& operator=(const OpOutputList& other) = default;
448   Tensor* operator[](int i);
449   bool required(int i) const;
450   DataType expected_output_dtype(int i) const;
451   Status allocate(int i, const TensorShape& shape, Tensor** output);
452   void set(int i, const Tensor& tensor);
453   void set(int i, Tensor&& tensor);
454   void set_ref(int i, mutex* mu, Tensor* tensor_for_ref);
size()455   int size() const { return stop_ - start_; }
begin()456   Iterator begin() const { return Iterator(this, 0); }
end()457   Iterator end() const { return Iterator(this, size()); }
458 
459  private:
460   OpKernelContext* ctx_;  // not owned
461   int start_;
462   int stop_;
463 };
464 
465 // Holds a tensor or tensor reference. For tensor references, we need
466 // a mutex to prevent concurrent access to the tensor.
467 struct TensorValue {
TensorValueTensorValue468   TensorValue() : mutex_if_ref(nullptr), tensor(nullptr) {}
TensorValueTensorValue469   explicit TensorValue(Tensor* t) : mutex_if_ref(nullptr), tensor(t) {}
TensorValueTensorValue470   TensorValue(mutex* mu, Tensor* t) : mutex_if_ref(mu), tensor(t) {}
471   Tensor* operator->() const { return tensor; }
is_refTensorValue472   bool is_ref() const { return mutex_if_ref != nullptr; }
473 
474   // Return the dtype of the Tensor. For references, return the underlying type.
dtypeTensorValue475   DataType dtype() const {
476     if (is_ref()) {
477       return MakeRefType(tensor->dtype());
478     } else {
479       return tensor->dtype();
480     }
481   }
482 
483   // Return the dtype of the Tensor. For references, return the underlying type.
484   // This variation on the dtype() acquires the lock for references.
485   //
486   // TODO(b/133843385): Disallow dtype modifications
dtype_safeTensorValue487   DataType dtype_safe() const {
488     if (is_ref()) {
489       tf_shared_lock ml(*mutex_if_ref);
490       return MakeRefType(tensor->dtype());
491     } else {
492       return tensor->dtype();
493     }
494   }
495 
496   mutex* mutex_if_ref;  // nullptr if not a ref, != nullptr if a ref
497   Tensor* tensor;
498 };
499 
500 // Used to store partitioned graphs from function-calling ops.
501 struct GraphCollector {
502   mutex mu;
503   std::vector<GraphDef> partitioned_graphs TF_GUARDED_BY(mu);
504   GraphDef raw_graph TF_GUARDED_BY(mu);
505   GraphDef optimized_graph TF_GUARDED_BY(mu);
506 
507   bool dirty TF_GUARDED_BY(mu);
508 
GraphCollectorGraphCollector509   GraphCollector() : dirty(false) {}
510 
CollectRawGraphGraphCollector511   void CollectRawGraph(const GraphDef& graph) {
512     mutex_lock ml(mu);
513     raw_graph.MergeFrom(graph);
514     dirty = true;
515   }
516 
CollectOptimizedGraphGraphCollector517   void CollectOptimizedGraph(const GraphDef& graph) {
518     mutex_lock ml(mu);
519     optimized_graph.MergeFrom(graph);
520     dirty = true;
521   }
522 
CollectPartitionedGraphGraphCollector523   void CollectPartitionedGraph(const GraphDef& graph) {
524     mutex_lock ml(mu);
525     partitioned_graphs.push_back(graph);
526     dirty = true;
527   }
528 
ClearGraphsGraphCollector529   void ClearGraphs() TF_EXCLUSIVE_LOCKS_REQUIRED(mu) {
530     raw_graph.Clear();
531     optimized_graph.Clear();
532     partitioned_graphs.clear();
533     dirty = false;
534   }
535 
HasUpdatedGraphsGraphCollector536   bool HasUpdatedGraphs() {
537     mutex_lock ml(mu);
538     return dirty;
539   }
540 };
541 
542 class OpKernelContext {
543  public:
544   // The first element of a WrappedAllocator is a "base" Allocator and
545   // the second element is that Allocator wrapped by a
546   // TrackingAllocator
547   typedef std::pair<Allocator*, TrackingAllocator*> WrappedAllocator;
548 
549   // TODO(zhifengc): Do some cleanup of Params.
550   // The Params struct is passed in to initialize an OpKernelContext,
551   // and must outlive the OpKernelContext.
552   struct Params {
~ParamsParams553     ~Params() { delete eigen_gpu_device; }
554 
555     // The step being executed.
556     int64 step_id = 0;
557 
558     // Timestamp for the start of graph execution. Used for latency metrics.
559     int64 start_time_usecs = 0;
560 
561     // The op kernel being computed.
562     OpKernel* op_kernel = nullptr;
563 
564     // The device on which the kernel is running.
565     DeviceBase* device = nullptr;
566 
567     // The Eigen GPU device wrapper, which may include a per-op
568     // wrapped allocator. The concrete type of this object depends on
569     // the type of this->device, so eigen_gpu_device can't be an
570     // inline member and must be heap allocated. However, we don't
571     // want to allocate a new eigen_gpu_device for every Op that is
572     // executed. Instead this member is allocated on first use using
573     // ensure_eigen_gpu_device, and then if the Params structure is
574     // re-used for subsequent Ops, the eigen_gpu_device is
575     // ReInitialized in the OpKernelContext constructor. Unlike the
576     // other pointers in Params, this one is owned by Params.
577     PerOpGpuDevice* eigen_gpu_device = nullptr;
578 
ensure_eigen_gpu_deviceParams579     inline void ensure_eigen_gpu_device() {
580       DCHECK(device);
581       if (nullptr == eigen_gpu_device) {
582         // Surprisingly, MakeGpuDevice will return nullptr if the
583         // device is not a GPU device. This is ok, since those devices
584         // will never use eigen_gpu_device. It seems better to have
585         // ensure_eigen_gpu_device fall through and regenerate the
586         // nullptr every time an OpKernelContext is instantiated, than
587         // to do an unnecessary allocation of a dummy eigen GPU
588         // device for CPU device Ops.
589         eigen_gpu_device = device->MakeGpuDevice();
590       }
591     }
592 
593     bool track_allocations = false;
594     bool log_memory = false;
595 
596     // Array indexed by output number for this node
597     const AllocatorAttributes* output_attr_array = nullptr;
598 
599     // Shared resources accessible by this op kernel invocation.
600     ResourceMgr* resource_manager = nullptr;
601 
602     // Per-step resources accessible by this op kernel invocation should be
603     // stored in this container..
604     ScopedStepContainer* step_container = nullptr;
605 
606     // Mechanism used by this op kernel invocation to communicate with
607     // computations running on other devices.
608     RendezvousInterface* rendezvous = nullptr;
609 
610     // Mechanism for executing a collective op that needs to coordinate
611     // with parallel instances running on other devices.
612     CollectiveExecutor* collective_executor = nullptr;
613 
614     // The session state for this op.
615     SessionState* session_state = nullptr;
616 
617     // Unique session identifier. Can be empty.
618     std::string session_handle;
619 
620     // Metadata about the session. Can be nullptr.
621     const SessionMetadata* session_metadata = nullptr;
622 
623     // The tensor store for this op.
624     TensorStore* tensor_store = nullptr;
625 
626     // Mechanism used by this op kernel invocation to register a callback
627     // for its cancellation.
628     CancellationManager* cancellation_manager = nullptr;
629 
630     // Inputs to this op kernel.
631     const gtl::InlinedVector<TensorValue, 4>* inputs = nullptr;
632     bool is_input_dead = false;
633 
634     const gtl::InlinedVector<AllocatorAttributes, 4>* input_alloc_attrs =
635         nullptr;
636 
637     // Device context.
638     DeviceContext* op_device_context = nullptr;
639 
640     // Control-flow op supports.
641     FrameAndIter frame_iter;
642 
643     // Function call supports.
644     CallFrameInterface* call_frame = nullptr;
645     FunctionLibraryRuntime* function_library = nullptr;
646     std::function<void(std::function<void()>)>* runner = nullptr;
647     StepStatsCollectorInterface* stats_collector = nullptr;
648     GraphCollector* graph_collector = nullptr;
649     bool run_all_kernels_inline = false;
650     const std::string* executor_type = nullptr;
651 
652     // TensorSliceReaderCache support.
653     checkpoint::TensorSliceReaderCacheWrapper* slice_reader_cache = nullptr;
654 
655     // Support for forwarding reservations (used by ScopedAllocator).
656     static constexpr int kNeverForward = -2;
657     static constexpr int kNoReservation = -1;
658     // Values in [0,...) represent reservations for the indexed output.
659     const int* forward_from_array = nullptr;
660 
661     // For tracking actively running deferred ops.
662     std::function<void()> inc_num_deferred_ops_function;
663     std::function<void()> dec_num_deferred_ops_function;
664 
665     absl::optional<ManagedStackTrace> stack_trace = {};
666 
667     // For implementing `OpKernelContext::output_required()`. If null, all
668     // outputs are required.
669     bool* outputs_required_array = nullptr;
670 
671     // For access to distributed coordination service.
672     CoordinationServiceAgent* coordination_service_agent = nullptr;
673   };
674 
675   // params must outlive the OpKernelContext.
676   explicit OpKernelContext(Params* params);
677   OpKernelContext(Params* params, int num_outputs);
678   ~OpKernelContext();
679 
env()680   Env* env() const { return params_->device->env(); }
681 
step_id()682   int64 step_id() const { return params_->step_id; }
683 
start_time_usecs()684   int64 start_time_usecs() const { return params_->start_time_usecs; }
685 
op_kernel()686   const OpKernel& op_kernel() const { return *params_->op_kernel; }
687 
688   // Stack trace of where the op was defined (if defined in eager mode).
stack_trace()689   const absl::optional<ManagedStackTrace>& stack_trace() const {
690     return params_->stack_trace;
691   }
692 
693   // Input/output signature.
694 
num_inputs()695   int num_inputs() const { return params_->inputs->size(); }
696   DataType input_dtype(int index) const;
697   Status input_dtype(StringPiece name, DataType* dtype) const;
698   MemoryType input_memory_type(int index) const;
699 
num_outputs()700   int num_outputs() const { return outputs_.size(); }
701   DataType expected_output_dtype(int index) const;
702   MemoryType output_memory_type(int index) const;
703 
704   // Input
705 
706   // Returns an immutable input tensor. May only be used for non-Ref
707   // inputs. For Ref inputs use mutable_input below.
708   // REQUIRES: !IsRefType(input_dtype(index))
709   // TODO(mrry): Convert this to return Status.
710   const Tensor& input(int index) const;
711 
712   // Returns the named immutable input tensor in "tensor", as defined
713   // in the OpDef. May only be used for non-Ref inputs. For Ref inputs
714   // use mutable_input below.
715   // REQUIRES: !IsRefType(input_dtype(index))
716   // REQUIRES: the named input must not be a list.
717   Status input(StringPiece name, const Tensor** tensor);
718 
719   // Returns the named list-valued immutable input in "list", as
720   // defined in the OpDef.  If the named output is not list-valued,
721   // returns a one-element list. May only be used for non-Ref
722   // inputs. For Ref inputs use mutable_input below.
723   // REQUIRES: !IsRefType(input_dtype(index))
724   Status input_list(StringPiece name, OpInputList* list);
725 
726   // For mutable inputs, use the following together to make sure there
727   // is no concurrent access to mutable_input(), e.g.:
728   // {
729   //   Tensor& t = context->mutable_input(index);
730   //   mutex_lock lock(*context->input_ref_mutex(index));
731   //   // modify the values in t
732   // }
733   // REQUIRES: IsRefType(input_dtype(index))
734   Status input_ref_mutex(StringPiece name, mutex** out_mutex);
735 
736   // Returns a mutable input tensor. Must be used to access Ref
737   // inputs.  REQUIRES: IsRefType(input_dtype(index)). The caller may
738   // modify the values stored in the Tensor buffer, and modifications
739   // will be visible to other Ops reading the same ref tensor. If
740   // !lock_held the input mutex will be acquired before returning the
741   // Tensor.
742   // TODO(mrry): Convert this to return Status.
743   Tensor mutable_input(int index, bool lock_held);
744 
745   // Returns the named mutable input tensor in "tensor", as defined in
746   // the OpDef. Must be used to access Ref inputs. The values stored
747   // in the Tensor buffer may be modified, and modifications will be
748   // visible to other Ops reading the same ref tensor. If !lock_held
749   // the input mutex will be acquired before returning the Tensor.
750   // REQUIRES: the named input must not be a list.
751   // REQUIRES: the named input must be a ref tensor.
752   Status mutable_input(StringPiece name, Tensor* tensor, bool lock_held);
753 
754   // Returns the named list-valued mutable input in "list", as defined
755   // in the OpDef.  If the named input is not list-valued, returns a
756   // one-element list. Must be used to access Ref inputs. The values
757   // stored in the Tensor buffer may be modified, and modifications
758   // will be visible to other Ops reading the same ref tensor.
759   // REQUIRES: the named input must be a ref tensor.
760   Status mutable_input_list(StringPiece name, OpMutableInputList* list);
761 
762   // Replace the corresponding Ref Input to use the storage buffer
763   // used by tensor. If !lock_held the input mutex will be acquired
764   // before returning the Tensor.
765   // REQUIRES: IsRefType(input_dtype(index)).
766   void replace_ref_input(int index, const Tensor& tensor, bool lock_held);
767 
768   // Replace the corresponding named Ref Input to use the storage
769   // buffer used by tensor. If !lock_held the input mutex will be
770   // acquired before returning the Tensor.
771   // REQUIRES: IsRefType(input_dtype(index)).
772   Status replace_ref_input(StringPiece name, const Tensor& tensor,
773                            bool lock_held);
774 
775   // Deletes the Tensor object used as the Ref Input at
776   // input_index. This is not usually necessary and should be used
777   // with caution. If !lock_held the input mutex will be acquired
778   // before returning the Tensor.
779   // REQUIRES: IsRefType(input_dtype(input_index)).
780   void delete_ref_input(int input_index, bool lock_held);
781 
782   // Return true if there is input at the given index. An operator has no
783   // input at index if its tensor is null. This is primarily used by the
784   // merge operator.
785   // TODO(mrry): Convert this to return Status.
786   bool has_input(int index) const;
787 
788   // Returns true if all inputs are the same shape, otherwise sets the
789   // status to a non-OK value and returns false.
790   // Usage: if (!context->ValidateInputsAreSameShape(this)) return;
791   bool ValidateInputsAreSameShape(OpKernel* op);
792 
793   // If non-null, kernels should populate with any partition subgraphs created.
graph_collector()794   GraphCollector* graph_collector() { return params_->graph_collector; }
795 
796   // If True, hint that all kernels in functions called by this kernel, should
797   // be treated as "inexpensive", and hence executed on the scheduling thread.
run_all_kernels_inline()798   bool run_all_kernels_inline() const {
799     return params_->run_all_kernels_inline;
800   }
801 
802   // Returns the registered name for the executor type that is executing the
803   // current kernel. If empty, the default executor is used.
804   const std::string& executor_type() const;
805 
806   // Input to output forwarding.
807 
808   // Set the output Ref Tensor at output_index to be an alias of the
809   // input Ref Tensor at input_index.
810   // REQUIRES: IsRefType(input_dtype(input_index)).
811   // REQUIRES: IsRefType(output_dtype(output_index)).
812   void forward_ref_input_to_ref_output(int input_index, int output_index);
813 
814   // Returns true when an alias to input[input_index], reshaped to output_shape,
815   // which is safe to use for in-place computation was written to *output.
816   // Returns false if input[input_index] has a refcount greater than one, or if
817   // its type does not match the expected output type of output[output_index],
818   // or the number of elements in input[input_index] does not equal the number
819   // of elements in output_shape.
820   bool forward_input_to_output_with_shape(int input_index, int output_index,
821                                           const TensorShape& output_shape,
822                                           Tensor** output) TF_MUST_USE_RESULT;
823   Status forward_input_to_output_with_shape(StringPiece input_name,
824                                             StringPiece output_name,
825                                             const TensorShape& output_shape,
826                                             Tensor** output) TF_MUST_USE_RESULT;
827 
828   // Returns a pointer to a Tensor aliasing the underlying buffer backing
829   // input[input_index] iff
830   //   * input[input_index] is not a ref,
831   //   * the data type, shape, memory type, and allocator attributes of
832   //     input[input_index] are compatible with those given in dtype, shape,
833   //     memory_type, and attr,
834   //   * refcount on the underlying buffer is one.
835   //   * Either there is no forwarding reservation for either input_index
836   //     or output_index or the specified input is reserved for the specified
837   //     output. More precisely:
838   //
839   //     These cases mean neither input nor output has a reservation:
840   //        forward_from_array = nullptr
841   //     OR (input_index is not in forward_from_array AND
842   //         (output_index == kNoReservation OR
843   //          forward_from_array[output_index] == kNoReservation))
844   //
845   //     This case means that input_index is reserved for output_index:
846   //        forward_from_array[output_index] == input_index
847   //
848   //     This case means the output is reserved to always be allocated,
849   //     never assigned a forwarded input:
850   //        forward_from_array[output_index] == kNeverForward
851   //
852   // Otherwise returns nullptr.
853   // NOTE: For Cuda kernels that read inputs using the __ldg() intrinsic,
854   // forwarding is only safe if there are no reads via __ldg() after writes
855   // to the same address.
856   std::unique_ptr<Tensor> forward_input(
857       int input_index, int output_index, DataType output_dtype,
858       const TensorShape& output_shape, MemoryType output_memory_type,
859       const AllocatorAttributes& output_attr) TF_MUST_USE_RESULT;
860 
861   // Tries to forward one of the inputs given in input_indices to
862   // output[output_index]. If none of the given inputs can be forwarded, calls
863   // allocate_output() to allocate a new output buffer. The index of the
864   // forwarded input will be assign to output argument forwarded_input (if it's
865   // not nullptr). If no inputs are forwarded, forwarded_input will be assigned
866   // -1.
867   Status forward_input_or_allocate_output(
868       gtl::ArraySlice<int> candidate_input_indices, int output_index,
869       const TensorShape& output_shape, Tensor** output,
870       int* forwarded_input = nullptr) TF_MUST_USE_RESULT;
871   Status forward_input_or_allocate_output(
872       gtl::ArraySlice<StringPiece> candidate_input_names,
873       StringPiece output_name, const TensorShape& output_shape,
874       Tensor** output) TF_MUST_USE_RESULT;
875 
876   // Tries to reuse one of the inputs given in input_indices as a temporary.
877   // If none of the given inputs can be forwarded, calls
878   // allocate_temp() to allocate a new temporary buffer.
879   Status forward_input_or_allocate_temp(
880       gtl::ArraySlice<int> candidate_input_indices, DataType type,
881       const TensorShape& shape, const AllocatorAttributes& allocator_attr,
882       Tensor* out_temp) TF_MUST_USE_RESULT;
883 
forward_input_or_allocate_temp(gtl::ArraySlice<int> candidate_input_indices,DataType type,const TensorShape & shape,Tensor * out_temp)884   Status forward_input_or_allocate_temp(
885       gtl::ArraySlice<int> candidate_input_indices, DataType type,
886       const TensorShape& shape, Tensor* out_temp) TF_MUST_USE_RESULT {
887     return forward_input_or_allocate_temp(candidate_input_indices, type, shape,
888                                           AllocatorAttributes(), out_temp);
889   }
890 
891   // Output
892 
893   // Returns the named list-valued output in "list", as defined in the OpDef.
894   // If the named output is not list-valued, returns a one-element list.
895   Status output_list(StringPiece name, OpOutputList* list);
896 
897   // If output_required(index) returns true, the OpKernel's Compute() method
898   // should call allocate_output(index, ...), set_output(index, ...),
899   // set_output_ref(index, ...), or set the status to a non-ok value.
900   // If it returns false, it may output, but is not required to do so.
output_required(int index)901   bool output_required(int index) const {
902     return !params_->outputs_required_array ||
903            params_->outputs_required_array[index];
904   }
905 
906   // If output_expects_forwarding returns true, the OpKernel's Compute() method
907   // should not allocate the output with allocate_output but instead needs to
908   // use forward_input.
output_expects_forwarding(int index)909   bool output_expects_forwarding(int index) const {
910     return params_->forward_from_array != nullptr &&
911            params_->forward_from_array[index] >= 0;
912   }
913 
914   // Allocation of tensors during kernel execution inside the Compute
915   // method:
916   //
917   // There are two methods to allocate Tensors when an Op kernel
918   // executes.
919   //
920   // 1) allocate_output. This should be used to allocate any tensor
921   // that is going to be used as an output from the Op at the end of
922   // the current execution. The caller indicates which output the
923   // Tensor will be assigned to, and the call returns the
924   // newly-allocated Tensor. The Tensor can subsequently be assigned
925   // to during kernel execution, and will be used as the designated
926   // output when the kernel execution completes.
927   //
928   // 2) allocate_temp. This should be used to allocate any scratch
929   // storage that is needed while the kernel is executing, and will
930   // not be retained by the Op.
931   //
932   // In some cases a Tensor needs to be used as an output even though
933   // it was previously allocated elsewhere. The Tensor may have been
934   // passed as an input, or stored in a Tensor during a
935   // previous kernel execution, or allocated earlier in the kernel
936   // execution at a time when it was not known which output it would
937   // be assigned to. In this case the kernel can use set_output or
938   // set_output_ref to indicate that the tensor should be used as the
939   // designated output. It is legal to use any previously-allocated
940   // Tensor as an argument to set_output or set_output_ref, including
941   // Tensors allocated via allocate_temp. There may be a performance
942   // penalty to using a Tensor that was not allocated using
943   // allocate_output. This is because allocate_output uses the
944   // AllocatorAttributes stored in output_attr_array for the
945   // designated output. In some cases, using the wrong attributes may
946   // cause an extra copy of the Tensor's buffer.
947 
948   // Allocates output for the specified output index with shape.
949   // OpKernelContext retains ownership of the returned pointer. See
950   // comment above.
951   //
952   // If memory allocation fails, returns an error status.
953   //
954   // REQUIRES: !IsRefType(expected_output_dtype(index))
955   Status allocate_output(int index, const TensorShape& shape,
956                          Tensor** tensor) TF_MUST_USE_RESULT;
957   Status allocate_output(StringPiece name, const TensorShape& shape,
958                          Tensor** tensor) TF_MUST_USE_RESULT;
959   // The following methods use the supplied attributes instead of
960   // those in output_attr_array. The caller is responsible for
961   // ensuring that the attributes are "compatible" with the
962   // output_attr_array, e.g. the tensor is allocated on the correct
963   // device. See comment above.
964   Status allocate_output(int index, const TensorShape& shape, Tensor** tensor,
965                          AllocatorAttributes attr) TF_MUST_USE_RESULT;
966   Status allocate_output(StringPiece name, const TensorShape& shape,
967                          Tensor** tensor,
968                          AllocatorAttributes attr) TF_MUST_USE_RESULT;
969 
970   // Allocates a temporary Tensor of the specified type and
971   // shape. Devices such as GPUs that enqueue Ops for lazy execution
972   // may retain references to the temporary tensors after the Op's
973   // Compute method has run. See comment above.
974   Status allocate_temp(DataType type, const TensorShape& shape,
975                        Tensor* out_temp, AllocatorAttributes allocator_attr,
976                        const AllocationAttributes& allocation_attr);
allocate_temp(DataType type,const TensorShape & shape,Tensor * out_temp,AllocatorAttributes allocator_attr)977   Status allocate_temp(DataType type, const TensorShape& shape,
978                        Tensor* out_temp, AllocatorAttributes allocator_attr) {
979     return allocate_temp(type, shape, out_temp, allocator_attr,
980                          AllocationAttributes());
981   }
allocate_temp(DataType type,const TensorShape & shape,Tensor * out_temp)982   Status allocate_temp(DataType type, const TensorShape& shape,
983                        Tensor* out_temp) {
984     return allocate_temp(type, shape, out_temp, AllocatorAttributes());
985   }
986 
987   // Copies a tensor (allocated by the caller) to the specified output
988   // index.  REQUIRES: !IsRefType(expected_output_dtype(index))
989   // REQUIRES: 'tensor' must have the same MemoryType as
990   // output_memory_types[index]. See comment above.
991   Status set_output(StringPiece name, const Tensor& tensor);
992   Status set_output(StringPiece name, Tensor&& tensor);
993   void set_output(int index, const Tensor& tensor);
994   void set_output(int index, Tensor&& tensor);
995 
996   // To output a reference.  Caller retains ownership of mu and tensor_for_ref,
997   // and they must outlive all uses within the step. See comment above.
998   // REQUIRES: IsRefType(expected_output_dtype(index))
999   Status set_output_ref(StringPiece name, mutex* mu, Tensor* tensor_for_ref);
1000 
1001   // Returns nullptr if allocate_output() or set_output() have not been called.
1002   Status mutable_output(StringPiece name, Tensor** tensor);
1003 
1004   // Return the DeviceContext that should be used for this Op.
1005   //
1006   // If using the templated function, the type must be a subclass
1007   // of DeviceContext.
1008   //
1009   // Returns nullptr if the device did not provide one.
1010   template <typename T>
1011   T* op_device_context();
op_device_context()1012   DeviceContext* op_device_context() {
1013     DeviceContext* ret = params_->op_device_context;
1014     if (ret == nullptr) {
1015       auto* dev_info = device()->tensorflow_gpu_device_info();
1016       if (dev_info) ret = dev_info->default_context;
1017     }
1018     return ret;
1019   }
1020 
input_alloc_attr(int index)1021   AllocatorAttributes input_alloc_attr(int index) const {
1022     if (params_->input_alloc_attrs == nullptr) {
1023       return AllocatorAttributes();
1024     } else {
1025       DCHECK_GE(index, 0);
1026       DCHECK_LT(index, params_->input_alloc_attrs->size());
1027       return (*params_->input_alloc_attrs)[index];
1028     }
1029   }
1030 
output_alloc_attr(int index)1031   AllocatorAttributes output_alloc_attr(int index) const {
1032     return params_->output_attr_array[index];
1033   }
1034 
ConsumeWrappedAllocators()1035   gtl::InlinedVector<WrappedAllocator, 4> ConsumeWrappedAllocators() {
1036     gtl::InlinedVector<WrappedAllocator, 4> retrieved;
1037     if (tracking_state_) {
1038       mutex_lock lock(tracking_state_->mu);
1039       retrieved.swap(tracking_state_->wrapped_allocators);
1040     }
1041     return retrieved;
1042   }
1043 
1044   // Communication.
1045   //
1046   // An op kernel communicates with outside environment through
1047   // Rendezvous Send() and Recv().
rendezvous()1048   RendezvousInterface* rendezvous() const { return params_->rendezvous; }
1049 
collective_executor()1050   CollectiveExecutor* collective_executor() const {
1051     return params_->collective_executor;
1052   }
1053 
1054   // An op kernel can access the session state it belongs to.
session_state()1055   SessionState* session_state() const { return params_->session_state; }
1056 
1057   // Unique identifier of the session it belongs to. Can be empty.
session_handle()1058   std::string session_handle() const { return params_->session_handle; }
1059 
1060   // Metadata about the session. Can be nullptr.
session_metadata()1061   const SessionMetadata* session_metadata() const {
1062     return params_->session_metadata;
1063   }
1064 
1065   // An op kernel can access the tensor store of the run it belongs to.
tensor_store()1066   TensorStore* tensor_store() const { return params_->tensor_store; }
1067 
1068   // Function call support.
1069   //
1070   // If this kernel invocation is within a function execution,
1071   // call_frame() returns the call frame for the function call.
call_frame()1072   CallFrameInterface* call_frame() const { return params_->call_frame; }
1073 
1074   // If not nullptr, the kernel invoke functions defined in the
1075   // library. E.g., CHECK_NOTNULL(function_library())->Run("Foo", ...).
function_library()1076   FunctionLibraryRuntime* function_library() const {
1077     return params_->function_library;
1078   }
1079 
runner()1080   std::function<void(std::function<void()>)>* runner() const {
1081     return params_->runner;
1082   }
stats_collector()1083   StepStatsCollectorInterface* stats_collector() const {
1084     return params_->stats_collector;
1085   }
1086 
1087   // Shared resources accessible to this kernel.
resource_manager()1088   ResourceMgr* resource_manager() const { return params_->resource_manager; }
1089 
slice_reader_cache()1090   checkpoint::TensorSliceReaderCacheWrapper* slice_reader_cache() const {
1091     return params_->slice_reader_cache;
1092   }
1093 
1094   // Execution.
1095   //
1096   // OpKernels can use these eigen devices to carry out their
1097   // numerical computation.
eigen_cpu_device()1098   const Eigen::ThreadPoolDevice& eigen_cpu_device() const {
1099     return *device()->eigen_cpu_device();
1100   }
eigen_gpu_device()1101   const Eigen::GpuDevice& eigen_gpu_device() const {
1102     return params_->eigen_gpu_device->device();
1103   }
1104   template <typename EigenDeviceType>
1105   const EigenDeviceType& eigen_device() const;
1106 
1107   // Error handling.
1108 
1109   // If expected_inputs == inputs() and expected_outputs == output_types(),
1110   // returns OK, else returns INVALID_ARGUMENT with an error message.
1111   // Recommended for Ops with dynamic signatures, where validation can only
1112   // be performed at runtime.
1113   Status MatchSignature(const DataTypeSlice expected_inputs,
1114                         const DataTypeSlice expected_outputs);
1115 
1116   // An OpKernel should call SetStatus() if Compute() encounters an
1117   // error.
1118   void SetStatus(const Status& status);
status()1119   const Status& status() const { return status_; }
1120 
1121   // Cancellation.
1122   //
1123   // EXPERIMENTAL. See the implementation in tensorflow::FIFOQueue for an
1124   // example of how to use this API.
cancellation_manager()1125   CancellationManager* cancellation_manager() const {
1126     return params_->cancellation_manager;
1127   }
1128 
1129   // Other accessors.
1130 
1131   // For control flow.
frame_iter()1132   FrameAndIter frame_iter() const { return params_->frame_iter; }
is_input_dead()1133   bool is_input_dead() const { return params_->is_input_dead; }
1134 
1135   // May be used, e.g., to get GPU handles, etc.
1136   // TODO(tucker): Add example usage.
device()1137   DeviceBase* device() const { return params_->device; }
1138 
1139   // Per-step container for use by white-listed internal ops.
step_container()1140   ScopedStepContainer* step_container() const {
1141     return params_->step_container;
1142   }
1143 
1144   // Access to distributed coordination service.
coordination_service_agent()1145   CoordinationServiceAgent* coordination_service_agent() const {
1146     return params_->coordination_service_agent;
1147   }
1148 
1149   // Helper routines for the OP_REQUIRES macros
1150   void CtxFailure(const Status& s);
1151   void CtxFailureWithWarning(const Status& s);
1152   void CtxFailure(const char* file, int line, const Status& s);
1153   void CtxFailureWithWarning(const char* file, int line, const Status& s);
1154 
1155   // Unrecommended functions: these are functions that have some
1156   // current uses but are not recommended for use, and may go away at
1157   // some future major version release.
1158   //
1159   // The following functions all have versions that return Status
1160   // to capture error conditions, and are strongly preferred.
1161   Tensor* mutable_output(int index);
1162   mutex* input_ref_mutex(int index);
1163   void set_output_ref(int index, mutex* mu, Tensor* tensor_for_ref);
1164   TensorValue release_output(int index);
1165 
track_allocations()1166   bool track_allocations() const { return params_->track_allocations; }
1167 
1168   // Records temp memory allocation. Tensor object is recorded to identify the
1169   // case where temp memory is used as output memory.
1170   void record_temp_memory_allocation(int64_t size, const Tensor& t)
1171       TF_LOCKS_EXCLUDED(tracking_state_->stats_mu);
1172 
1173   // Returns recorded size of temporary memory;
1174   int64 temp_memory_allocated() const
1175       TF_LOCKS_EXCLUDED(tracking_state_->stats_mu);
1176 
1177   // Records persistent memory allocation, size can be negative indicating
1178   // deallocation.
1179   void record_persistent_memory_allocation(int64_t size, int64_t alloc_id = -1)
1180       TF_LOCKS_EXCLUDED(tracking_state_->stats_mu);
1181 
1182   // Returns recorded size and ids of persistent memory.
1183   int64 persistent_memory_allocated() const
1184       TF_LOCKS_EXCLUDED(tracking_state_->stats_mu);
1185 
1186   std::vector<int64> persistent_alloc_ids() const
1187       TF_LOCKS_EXCLUDED(tracking_state_->stats_mu);
1188 
1189   // Resets counters for temp and persistent memory and recorded ids.
1190   void clear_recorded_memory() TF_LOCKS_EXCLUDED(tracking_state_->stats_mu);
1191 
1192   bool input_is_ref(int index) const;
1193 
1194   void set_record_memory_consumption(bool v);
1195 
1196   // Used by OpKernel implementations to track actively running deferred ops.
1197   //
1198   // A deferred op is one whose Compute method returns (or whose ComputeAsync
1199   // method invokes the callback) when work is scheduled onto a device. At that
1200   // point, we don't know when the work will actually complete (or if it has
1201   // already completed) on the device. These functions allow the executor to
1202   // track the status of deferred ops and act accordingly.
1203   //
1204   // Deferred OpKernel implementations must use these methods to get two
1205   // functions. It then must call these two functions in pairs, before and after
1206   // device execution, respectively.
inc_num_deferred_ops_function()1207   TF_MUST_USE_RESULT std::function<void()> inc_num_deferred_ops_function() {
1208     DCHECK(params_->op_kernel->is_deferred());
1209     return params_->inc_num_deferred_ops_function
1210                ? params_->inc_num_deferred_ops_function
1211                : []() {};
1212   }
dec_num_deferred_ops_function()1213   TF_MUST_USE_RESULT std::function<void()> dec_num_deferred_ops_function() {
1214     DCHECK(params_->op_kernel->is_deferred());
1215     return params_->dec_num_deferred_ops_function
1216                ? params_->dec_num_deferred_ops_function
1217                : []() {};
1218   }
1219 
1220   Allocator* get_allocator(AllocatorAttributes attr);
1221 
1222  private:
1223   bool record_memory_consumption_ = false;
1224 
1225   // Internal common method used when allocating tensor memory
allocate_tensor(DataType type,const TensorShape & shape,Tensor * out_tensor,AllocatorAttributes allocator_attr)1226   Status allocate_tensor(DataType type, const TensorShape& shape,
1227                          Tensor* out_tensor,
1228                          AllocatorAttributes allocator_attr) {
1229     return allocate_tensor(type, shape, out_tensor, allocator_attr,
1230                            AllocationAttributes());
1231   }
1232 
1233   Status allocate_tensor(DataType type, const TensorShape& shape,
1234                          Tensor* out_tensor, AllocatorAttributes allocator_attr,
1235                          const AllocationAttributes& allocation_attr);
1236 
1237   // Helpers for `set_output()`.
1238 
1239   // Returns `true` if the tensor was copied into an allocated output.
1240   bool maybe_set_output_by_allocate_and_copy(int index, const Tensor& tensor);
1241 
1242   void maybe_track_allocations_for_set_output(const Tensor& tensor);
1243 
1244   Status get_input_index(StringPiece name, int* out_index) const;
1245   Status get_output_index(StringPiece name, int* out_index) const;
1246 
1247   // Initialize the allocated_scope_ids_ set the first time this method is
1248   // called.
1249   void maybe_initialize_scope_id_set();
1250 
1251   Status status_;
1252   friend class CollectiveExecutor;  // for access to params_
1253   Params* params_;                  // not owned
1254   gtl::InlinedVector<TensorValue, 4> outputs_;
1255 
1256   // Keep track of calls to ScopedAllocator.
1257   // TODO(ayushd): change to absl::flat_hash_set.
1258   std::unique_ptr<std::unordered_set<int32>> allocated_scope_ids_;
1259 
1260   // The following data members are only used when allocation tracking is
1261   // enabled, memory consumption is being recorded, or tensor access is being
1262   // recorded.
1263   struct TrackingState {
1264     mutable mutex mu;
1265     gtl::InlinedVector<WrappedAllocator, 4> wrapped_allocators
1266         TF_GUARDED_BY(mu);
1267 
1268     mutable mutex stats_mu;
1269     int64 temp_memory_allocated TF_GUARDED_BY(stats_mu) = 0;
1270 
1271     int64 persistent_memory_allocated TF_GUARDED_BY(stats_mu) = 0;
1272     gtl::InlinedVector<std::pair<const void*, int64>, 2>
1273         temp_tensor_buffer_and_size TF_GUARDED_BY(stats_mu);
1274     gtl::InlinedVector<int64, 2> persistent_alloc_ids TF_GUARDED_BY(stats_mu);
1275   };
1276   std::unique_ptr<TrackingState> tracking_state_;
1277 
1278   // For access to `params_->op_kernel`.
1279   friend void CheckNotInComputeAsync(OpKernelContext* ctx,
1280                                      const char* correct_macro_name);
1281 
1282   TF_DISALLOW_COPY_AND_ASSIGN(OpKernelContext);
1283 };
1284 
1285 template <>
1286 const Eigen::ThreadPoolDevice& OpKernelContext::eigen_device() const;
1287 
1288 template <>
1289 const Eigen::GpuDevice& OpKernelContext::eigen_device() const;
1290 
1291 // Register your OpKernel by specifying the Op's name, the device the
1292 // kernel runs on, any type attr constraints for this kernel, any
1293 // host-memory args, and the class to instantiate.  Examples:
1294 //
1295 //  // A kernel that supports all types.
1296 //  REGISTER_KERNEL_BUILDER(Name("Save").Device(DEVICE_CPU), SaveOp);
1297 //
1298 //  // The following are equivalent ways of specifying that the kernel only
1299 //  // works if the "T" type attr is set to DT_FLOAT.
1300 //  REGISTER_KERNEL_BUILDER(
1301 //      Name("Sub").Device(DEVICE_CPU).TypeConstraint<float>("T"),
1302 //      SubOp<float>);
1303 //  // (You would then repeat this for every type supported by "Sub".)
1304 //
1305 //  // This form allows you to specify a list of types as the constraint.
1306 //  REGISTER_KERNEL_BUILDER(Name("Sub")
1307 //                              .Device(DEVICE_CPU)
1308 //                              .TypeConstraint("T", {DT_FLOAT}),
1309 //                          SubOp<float>);
1310 //
1311 //  // A kernel that expects one of the input tensors in host memory.
1312 //  REGISTER_KERNEL_BUILDER(
1313 //      Name("Reshape").Device(DEVICE_GPU).HostMemory("shape"), ReshapeOp);
1314 //
1315 // See kernel_def_builder for details.
1316 
1317 // Instantiate an OpKernel that has been registered.  Returns nullptr
1318 // if no operation for that type of device / input signature combination
1319 // (and a NOT_FOUND *status), or there is an error in construction (and
1320 // an INVALID_ARGUMENT *status).  Otherwise, the caller takes ownership
1321 // of the returned pointer.
1322 // EXPECTED USAGE: unique_ptr<OpKernel> op = CreateOpKernel(...);
1323 // REQUIRES: def has all attrs specified (e.g. using AddDefaultsToNodeDef()).
1324 std::unique_ptr<OpKernel> CreateOpKernel(DeviceType device_type,
1325                                          DeviceBase* device,
1326                                          Allocator* allocator,
1327                                          const NodeDef& node_def,
1328                                          int graph_def_version, Status* status);
1329 
1330 std::unique_ptr<OpKernel> CreateOpKernel(
1331     DeviceType device_type, DeviceBase* device, Allocator* allocator,
1332     const std::shared_ptr<const NodeProperties>& props, int graph_def_version,
1333     Status* status);
1334 
1335 Status CreateOpKernel(DeviceType device_type, DeviceBase* device,
1336                       Allocator* allocator, FunctionLibraryRuntime* flib,
1337                       const std::shared_ptr<const NodeProperties>& props,
1338                       int graph_def_version, OpKernel** kernel);
1339 
1340 Status CreateOpKernel(DeviceType device_type, DeviceBase* device,
1341                       Allocator* allocator, FunctionLibraryRuntime* flib,
1342                       ResourceMgr* resource_mgr,
1343                       const std::shared_ptr<const NodeProperties>& props,
1344                       int graph_def_version, OpKernel** kernel);
1345 
1346 // Returns into 'device_types' the subset of prioritized_types that this
1347 // binary has registered for the given NodeDef.
1348 //
1349 // REQUIRES: * 'device_types' is not nullptr.
1350 //           * def has all attrs specified (e.g. using AddDefaultsToNodeDef()).
1351 Status SupportedDeviceTypesForNode(
1352     const std::vector<DeviceType>& prioritized_types, const NodeDef& def,
1353     PrioritizedDeviceTypeVector* device_types,
1354     const DeviceNameUtils::ParsedName* local_address_spec = nullptr);
1355 
1356 // Returns a message with a description of the kernels registered for op
1357 // `op_name`.
1358 std::string KernelsRegisteredForOp(StringPiece op_name);
1359 
1360 // Call once after Op registration has completed.
1361 Status ValidateKernelRegistrations(const OpRegistryInterface& op_registry);
1362 
1363 // -----------------------------------------------------------------------------
1364 // OpKernel registration implementation follows, please ignore.
1365 
1366 // Allow the REGISTER_KERNEL_BUILDER(Name("op_name").Device(...)...) syntax.
1367 namespace register_kernel {
1368 
1369 class Name : public KernelDefBuilder {
1370  public:
Name(const char * op)1371   explicit Name(const char* op) : KernelDefBuilder(op) {}
1372 };
1373 
1374 }  // namespace register_kernel
1375 
1376 // Kernel registration appears as:
1377 //   REGISTER_KERNEL_BUILDER(Name("OpName").Device(DEVICE_CPU)..., OpImpl)
1378 // We'd like to have "OpName" as a constant-expression, without requiring that
1379 // of the overall KernelDefBuilder expression (beginning with the
1380 // register_kernel::Name constructor above).
1381 //
1382 // So, we pull the "OpName" part to a separate macro-level argument. This
1383 // involves treating Name("OpName") as a macro call, via token-pasting (e.g.
1384 // M_## =>  M_Name("OpName")), and having it expand to '"OpName",
1385 // Name("OpName")' which is then usable as two arguments.
1386 #define TF_EXTRACT_KERNEL_NAME_Name(name_str) \
1387   name_str, ::tensorflow::register_kernel::Name(name_str)
1388 #define TF_EXTRACT_KERNEL_NAME_IMPL(m, ...) m(__VA_ARGS__)
1389 #define TF_EXTRACT_KERNEL_NAME(m, kernel_builder, ...)                    \
1390   TF_EXTRACT_KERNEL_NAME_IMPL(m, TF_EXTRACT_KERNEL_NAME_##kernel_builder, \
1391                               __VA_ARGS__)
1392 
1393 // REGISTER_KERNEL_BUILDER_IMPL_2, with a unique 'ctr' as the first argument.
1394 // TODO(dodgen): There are some uses of this macro inside functions, where
1395 // kernel_builder refers to (non-const) locals (they should be fixed). To
1396 // accommodate those, kernel_builder.Build() appears as an argument to an
1397 // immediately-called lambda (not in the lambda itself).
1398 #define REGISTER_KERNEL_BUILDER_IMPL_3(ctr, op_name, kernel_builder_expr,   \
1399                                        is_system_kernel, ...)               \
1400   static ::tensorflow::InitOnStartupMarker const register_kernel_##ctr      \
1401       TF_ATTRIBUTE_UNUSED =                                                 \
1402           TF_INIT_ON_STARTUP_IF(is_system_kernel ||                         \
1403                                 (SHOULD_REGISTER_OP_KERNEL(#__VA_ARGS__) && \
1404                                  SHOULD_REGISTER_OP(op_name)))              \
1405           << ([](::tensorflow::KernelDef const* kernel_def) {               \
1406                ::tensorflow::kernel_factory::OpKernelRegistrar registrar(   \
1407                    kernel_def, #__VA_ARGS__,                                \
1408                    [](::tensorflow::OpKernelConstruction* context)          \
1409                        -> ::tensorflow::OpKernel* {                         \
1410                      return new __VA_ARGS__(context);                       \
1411                    });                                                      \
1412                (void)registrar;                                             \
1413                return ::tensorflow::InitOnStartupMarker{};                  \
1414              })(kernel_builder_expr.Build());
1415 
1416 // REGISTER_KERNEL_BUILDER_IMPL, but with kernel_builder split to op_name,
1417 // kernel_builder_expr.
1418 #define REGISTER_KERNEL_BUILDER_IMPL_2(op_name, kernel_builder_expr, \
1419                                        is_system_kernel, ...)        \
1420   TF_NEW_ID_FOR_INIT(REGISTER_KERNEL_BUILDER_IMPL_3, op_name,        \
1421                      kernel_builder_expr, is_system_kernel, __VA_ARGS__)
1422 
1423 // REGISTER_KERNEL_BUILDER, but with is_system_kernel bound.
1424 #define REGISTER_KERNEL_BUILDER_IMPL(kernel_builder, is_system_kernel, ...) \
1425   TF_EXTRACT_KERNEL_NAME(REGISTER_KERNEL_BUILDER_IMPL_2, kernel_builder,    \
1426                          is_system_kernel, __VA_ARGS__)
1427 
1428 #define REGISTER_KERNEL_BUILDER(kernel_builder, ...) \
1429   TF_ATTRIBUTE_ANNOTATE("tf:kernel")                 \
1430   REGISTER_KERNEL_BUILDER_IMPL(kernel_builder, false, __VA_ARGS__)
1431 
1432 // The `REGISTER_SYSTEM_KERNEL_BUILDER()` macro acts as
1433 // `REGISTER_KERNEL_BUILDER()` except that the kernel is registered
1434 // unconditionally even when selective registration is used.
1435 #define REGISTER_SYSTEM_KERNEL_BUILDER(kernel_builder, ...) \
1436   TF_ATTRIBUTE_ANNOTATE("tf:kernel")                        \
1437   TF_ATTRIBUTE_ANNOTATE("tf:kernel:system")                 \
1438   REGISTER_KERNEL_BUILDER_IMPL(kernel_builder, true, __VA_ARGS__)
1439 
1440 // Checks whether a given kernel is registered on device_type.
1441 bool KernelDefAvailable(const DeviceType& device_type, const NodeDef& node_def);
1442 
1443 // If node of node_name, experimental_debug_info, node_op, node_device and
1444 // node_attrs has a corresponding kernel registered on device_type, returns OK
1445 // and fill in the kernel def and kernel_class_name. <def> and
1446 // <kernel_class_name> may be null.
1447 Status FindKernelDef(
1448     const DeviceType& device_type, StringPiece node_name,
1449     bool has_experimental_debug_info,
1450     const NodeDef_ExperimentalDebugInfo& experimental_debug_info,
1451     StringPiece node_op, StringPiece node_device, AttrSlice node_attrs,
1452     const KernelDef** def, std::string* kernel_class_name);
1453 
1454 // If node_def has a corresponding kernel registered on device_type,
1455 // returns OK and fill in the kernel def and kernel_class_name. <def> and
1456 // <kernel_class_name> may be null.
1457 Status FindKernelDef(const DeviceType& device_type, const NodeDef& node_def,
1458                      const KernelDef** def, std::string* kernel_class_name);
1459 
1460 // Writes a list of all registered kernels to LOG(INFO), to help users debug
1461 // missing kernel errors.
1462 void LogAllRegisteredKernels();
1463 
1464 // Gets a list of all registered kernels.
1465 KernelList GetAllRegisteredKernels();
1466 
1467 // Gets a list of all registered kernels for which predicate returns true
1468 KernelList GetFilteredRegisteredKernels(
1469     const std::function<bool(const KernelDef&)>& predicate);
1470 
1471 // Gets a list of all registered kernels for a given op
1472 KernelList GetRegisteredKernelsForOp(StringPiece op_name);
1473 
1474 namespace kernel_factory {
1475 
1476 // OpKernelFactory is responsible for creating OpKernels when TensorFlow needs
1477 // them. You register factories with the TensorFlow core by constructing an
1478 // OpKernelRegistrar and passing the factory as a constructor parameter.
1479 class OpKernelFactory {
1480  public:
1481   virtual OpKernel* Create(OpKernelConstruction* context) = 0;
1482   virtual ~OpKernelFactory() = default;
1483 };
1484 
1485 class OpKernelRegistrar {
1486  public:
1487   // Registers the given kernel factory with TensorFlow. TF will call the
1488   // factory Create() method when it determines that a kernel matching the given
1489   // KernelDef is required.
OpKernelRegistrar(const KernelDef * kernel_def,StringPiece kernel_class_name,std::unique_ptr<OpKernelFactory> factory)1490   OpKernelRegistrar(const KernelDef* kernel_def, StringPiece kernel_class_name,
1491                     std::unique_ptr<OpKernelFactory> factory) {
1492     InitInternal(kernel_def, kernel_class_name, std::move(factory));
1493   }
1494 
1495   // Registers the given factory function with TensorFlow. This is equivalent
1496   // to registering a factory whose Create function invokes `create_fn`.
OpKernelRegistrar(const KernelDef * kernel_def,StringPiece kernel_class_name,OpKernel * (* create_fn)(OpKernelConstruction *))1497   OpKernelRegistrar(const KernelDef* kernel_def, StringPiece kernel_class_name,
1498                     OpKernel* (*create_fn)(OpKernelConstruction*)) {
1499     InitInternal(kernel_def, kernel_class_name,
1500                  absl::make_unique<PtrOpKernelFactory>(create_fn));
1501   }
1502 
1503  private:
1504   struct PtrOpKernelFactory : public OpKernelFactory {
PtrOpKernelFactoryPtrOpKernelFactory1505     explicit PtrOpKernelFactory(OpKernel* (*create_func)(OpKernelConstruction*))
1506         : create_func_(create_func) {}
1507 
1508     OpKernel* Create(OpKernelConstruction* context) override;
1509 
1510     OpKernel* (*create_func_)(OpKernelConstruction*);
1511   };
1512 
1513   void InitInternal(const KernelDef* kernel_def, StringPiece kernel_class_name,
1514                     std::unique_ptr<OpKernelFactory> factory);
1515 };
1516 
1517 }  // namespace kernel_factory
1518 
1519 // -----------------------------------------------------------------------------
1520 // Template and inline method implementations, please ignore
1521 
1522 template <class T>
GetAttr(StringPiece attr_name,T * value)1523 Status OpKernelConstruction::GetAttr(StringPiece attr_name, T* value) const {
1524   return GetNodeAttr(def(), attr_name, value);
1525 }
1526 
input_dtype(int index)1527 inline DataType OpKernelContext::input_dtype(int index) const {
1528   DCHECK_GE(index, 0);
1529   DCHECK_LT(index, num_inputs());
1530   const TensorValue& value((*params_->inputs)[index]);
1531   return value.dtype();
1532 }
1533 
input_memory_type(int index)1534 inline MemoryType OpKernelContext::input_memory_type(int index) const {
1535   DCHECK_GE(index, 0);
1536   DCHECK_LT(index, num_inputs());
1537   return op_kernel().input_memory_types()[index];
1538 }
1539 
expected_output_dtype(int index)1540 inline DataType OpKernelContext::expected_output_dtype(int index) const {
1541   DCHECK_GE(index, 0);
1542   DCHECK_LT(index, num_outputs());
1543   return params_->op_kernel->output_type(index);
1544 }
1545 
output_memory_type(int index)1546 inline MemoryType OpKernelContext::output_memory_type(int index) const {
1547   DCHECK_GE(index, 0);
1548   DCHECK_LT(index, num_outputs());
1549   return op_kernel().output_memory_types()[index];
1550 }
1551 
input_is_ref(int index)1552 inline bool OpKernelContext::input_is_ref(int index) const {
1553   const TensorValue& value((*params_->inputs)[index]);
1554   return value.is_ref();
1555 }
1556 
1557 // no input if tensor == nullptr.
has_input(int index)1558 inline bool OpKernelContext::has_input(int index) const {
1559   DCHECK_GE(index, 0);
1560   DCHECK_LT(index, num_inputs());
1561   return (*params_->inputs)[index].tensor != nullptr;
1562 }
1563 
input_ref_mutex(int index)1564 inline mutex* OpKernelContext::input_ref_mutex(int index) {
1565   DCHECK_GE(index, 0);
1566   DCHECK_LT(index, num_inputs());
1567   DCHECK(input_is_ref(index));
1568   return (*params_->inputs)[index].mutex_if_ref;
1569 }
1570 
mutable_output(int index)1571 inline Tensor* OpKernelContext::mutable_output(int index) {
1572   DCHECK_GE(index, 0);
1573   DCHECK_LT(index, num_outputs());
1574   return outputs_[index].tensor;
1575 }
1576 
release_output(int index)1577 inline TensorValue OpKernelContext::release_output(int index) {
1578   DCHECK_GE(index, 0);
1579   DCHECK_LT(index, num_outputs());
1580   TensorValue value = outputs_[index];
1581   outputs_[index] = TensorValue();
1582   return value;
1583 }
1584 
forward_input_or_allocate_output(gtl::ArraySlice<int> candidate_input_indices,int output_index,const TensorShape & output_shape,Tensor ** output,int * forwarded_input)1585 inline Status OpKernelContext::forward_input_or_allocate_output(
1586     gtl::ArraySlice<int> candidate_input_indices, int output_index,
1587     const TensorShape& output_shape, Tensor** output, int* forwarded_input) {
1588   for (int input_index : candidate_input_indices) {
1589     if (forward_input_to_output_with_shape(input_index, output_index,
1590                                            output_shape, output)) {
1591       if (forwarded_input != nullptr) {
1592         *forwarded_input = input_index;
1593       }
1594       return Status::OK();
1595     }
1596   }
1597   if (forwarded_input != nullptr) {
1598     *forwarded_input = -1;
1599   }
1600   return allocate_output(output_index, output_shape, output);
1601 }
1602 
forward_input_or_allocate_output(gtl::ArraySlice<StringPiece> candidate_input_names,StringPiece output_name,const TensorShape & output_shape,Tensor ** output)1603 inline Status OpKernelContext::forward_input_or_allocate_output(
1604     gtl::ArraySlice<StringPiece> candidate_input_names, StringPiece output_name,
1605     const TensorShape& output_shape, Tensor** output) {
1606   for (const StringPiece& input_name : candidate_input_names) {
1607     if (forward_input_to_output_with_shape(input_name, output_name,
1608                                            output_shape, output)
1609             .ok()) {
1610       return Status::OK();
1611     }
1612   }
1613   return allocate_output(output_name, output_shape, output);
1614 }
1615 
1616 template <typename T>
op_device_context()1617 T* OpKernelContext::op_device_context() {
1618   static_assert(std::is_base_of<DeviceContext, T>::value,
1619                 "T is not a subclass of DeviceContext");
1620   return static_cast<T*>(op_device_context());
1621 }
1622 
1623 inline const Tensor& OpInputList::operator[](int i) const {
1624   DCHECK_GE(i, 0);
1625   DCHECK_LT(i, stop_ - start_);
1626   return ctx_->input(start_ + i);
1627 }
1628 
ref_mutex(int i)1629 inline mutex* OpMutableInputList::ref_mutex(int i) {
1630   DCHECK_GE(i, 0);
1631   DCHECK_LT(i, stop_ - start_);
1632   return ctx_->input_ref_mutex(start_ + i);
1633 }
1634 
at(int i,bool lock_held)1635 inline Tensor OpMutableInputList::at(int i, bool lock_held) {
1636   DCHECK_GE(i, 0);
1637   DCHECK_LT(i, stop_ - start_);
1638   return ctx_->mutable_input(start_ + i, lock_held);
1639 }
1640 
1641 inline Tensor* OpOutputList::operator[](int i) {
1642   DCHECK_GE(i, 0);
1643   DCHECK_LT(i, stop_ - start_);
1644   return ctx_->mutable_output(start_ + i);
1645 }
1646 
required(int i)1647 inline bool OpOutputList::required(int i) const {
1648   DCHECK_GE(i, 0);
1649   DCHECK_LT(i, stop_ - start_);
1650   return ctx_->output_required(start_ + i);
1651 }
1652 
expected_output_dtype(int i)1653 inline DataType OpOutputList::expected_output_dtype(int i) const {
1654   DCHECK_GE(i, 0);
1655   DCHECK_LT(i, stop_ - start_);
1656   return ctx_->expected_output_dtype(start_ + i);
1657 }
1658 
allocate(int i,const TensorShape & shape,Tensor ** output)1659 inline Status OpOutputList::allocate(int i, const TensorShape& shape,
1660                                      Tensor** output) {
1661   DCHECK_GE(i, 0);
1662   DCHECK_LT(i, stop_ - start_);
1663   return ctx_->allocate_output(start_ + i, shape, output);
1664 }
1665 
set(int i,const Tensor & tensor)1666 inline void OpOutputList::set(int i, const Tensor& tensor) {
1667   DCHECK_GE(i, 0);
1668   DCHECK_LT(i, stop_ - start_);
1669   ctx_->set_output(start_ + i, tensor);
1670 }
1671 
set(int i,Tensor && tensor)1672 inline void OpOutputList::set(int i, Tensor&& tensor) {
1673   DCHECK_GE(i, 0);
1674   DCHECK_LT(i, stop_ - start_);
1675   ctx_->set_output(start_ + i, std::move(tensor));
1676 }
1677 
set_ref(int i,mutex * mu,Tensor * tensor_for_ref)1678 inline void OpOutputList::set_ref(int i, mutex* mu, Tensor* tensor_for_ref) {
1679   DCHECK_GE(i, 0);
1680   DCHECK_LT(i, stop_ - start_);
1681   ctx_->set_output_ref(i, mu, tensor_for_ref);
1682 }
1683 
1684 // Generate a fatal error if OP_REQUIRES or OP_REQUIRES_OK are used in
1685 // AsyncOpKernel implementations. If these macros are used and the condition
1686 // does not hold, the `done` callback will never be called and the system will
1687 // deadlock, so a crash failure is preferable. Since the OP_REQUIRES[_OK] macros
1688 // are legal to use in AsyncOpKernel constructors, we use overload resolution
1689 // to distinguish between OpKernelConstruction* and OpKernelContext* context
1690 // types.
1691 class XlaOpKernelContext;
CheckNotInComputeAsync(XlaOpKernelContext *,const char *)1692 inline void CheckNotInComputeAsync(XlaOpKernelContext*, const char*) {}
CheckNotInComputeAsync(OpKernelConstruction *,const char *)1693 inline void CheckNotInComputeAsync(OpKernelConstruction*, const char*) {}
1694 void CheckNotInComputeAsync(OpKernelContext* ctx,
1695                             const char* correct_macro_name);
1696 
1697 }  // namespace tensorflow
1698 
1699 #endif  // TENSORFLOW_CORE_FRAMEWORK_OP_KERNEL_H_
1700