• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_FRAMEWORK_OP_KERNEL_H_
17 #define TENSORFLOW_CORE_FRAMEWORK_OP_KERNEL_H_
18 
19 #include <atomic>
20 #include <functional>
21 
22 #include <utility>
23 #include <vector>
24 #include "tensorflow/core/framework/allocator.h"
25 #include "tensorflow/core/framework/cancellation.h"
26 #include "tensorflow/core/framework/control_flow.h"
27 #include "tensorflow/core/framework/device_base.h"
28 #include "tensorflow/core/framework/graph.pb.h"
29 #include "tensorflow/core/framework/kernel_def.pb.h"
30 #include "tensorflow/core/framework/kernel_def_builder.h"
31 #include "tensorflow/core/framework/node_def_util.h"
32 #include "tensorflow/core/framework/op.h"  // TODO(b/62899350): Remove
33 #include "tensorflow/core/framework/rendezvous.h"
34 #include "tensorflow/core/framework/selective_registration.h"
35 #include "tensorflow/core/framework/session_state.h"
36 #include "tensorflow/core/framework/tensor.h"
37 #include "tensorflow/core/framework/tensor_shape.h"
38 #include "tensorflow/core/framework/tensor_shape.pb.h"  // TODO(b/62899350): Remove
39 #include "tensorflow/core/framework/tracking_allocator.h"
40 #include "tensorflow/core/framework/types.h"
41 #include "tensorflow/core/framework/types.pb.h"
42 #include "tensorflow/core/framework/unique_tensor_references.h"
43 #include "tensorflow/core/lib/core/errors.h"
44 #include "tensorflow/core/lib/core/status.h"
45 #include "tensorflow/core/lib/gtl/array_slice.h"
46 #include "tensorflow/core/lib/gtl/manual_constructor.h"
47 #include "tensorflow/core/platform/env.h"
48 #include "tensorflow/core/platform/logging.h"
49 #include "tensorflow/core/platform/macros.h"
50 #include "tensorflow/core/platform/mutex.h"
51 #include "tensorflow/core/platform/profile_utils/cpu_utils.h"
52 #include "tensorflow/core/platform/thread_annotations.h"
53 #include "tensorflow/core/platform/types.h"
54 
55 namespace Eigen {
56 struct ThreadPoolDevice;
57 struct GpuDevice;
58 struct SyclDevice;
59 }  // end namespace Eigen
60 
61 namespace tensorflow {
62 
63 namespace checkpoint {
64 class TensorSliceReaderCacheWrapper;
65 }  // namespace checkpoint
66 
67 class AsyncOpKernel;
68 class CallFrameInterface;
69 class FunctionLibraryRuntime;
70 class OpKernelConstruction;  // declared below
71 class OpKernelContext;       // declared below,
72 class OpRegistryInterface;
73 class ResourceMgr;
74 class ScopedStepContainer;
75 class CollectiveExecutor;
76 class StepStatsCollectorInterface;
77 
78 class OpKernel {
79  public:
80   // OpKernel won't be instantiated by the scheduler, so you may perform
81   // expensive initialization in the descendant's constructor.
82   explicit OpKernel(OpKernelConstruction* context);
83 
84   // Specialized constructor that enables the descendant to provide a different
85   // `NodeDef` value. For example, this constructor can be used to provide a
86   // stripped-down `NodeDef` that does not contain the full set of attrs (such
87   // as tensor values) if the descendant stores them in a different form.
88   explicit OpKernel(OpKernelConstruction* context,
89                     std::unique_ptr<const NodeDef> node_def);
90 
91   virtual ~OpKernel();
92 
93   // An OpKernel's computation can be either synchronous or
94   // asynchronous. All OpKernel Compute() methods must be thread-safe as they
95   // may be called concurrently (e.g. by multiple executions of the same graph
96   // concurrently).
97   //
98   // Most OpKernels should compute synchronously.  They should
99   // subclass OpKernel and override the Compute() method and have it
100   // return after completing the supplied work.
101   //
102   // A few special kernels might need to be asynchronous to bound the
103   // number of threads (e.g., network receive operations). These
104   // kernels must subclass AsyncOpKernel and override
105   // AsyncOpKernel::ComputeAsync().
106   //
107   // In both cases, implementations of Compute() and ComputeAsync()
108   // get inputs and write outputs through the given OpKernelContext
109   // and returns a status via context->SetStatus(). They must be
110   // thread-safe.
111 
112   // Synchronous compute.
113   //
114   // "context" is guaranteed to be alive until Compute() returns.
115   virtual void Compute(OpKernelContext* context) = 0;
116 
117   // Returns nullptr iff this op kernel is synchronous.
AsAsync()118   virtual AsyncOpKernel* AsAsync() { return nullptr; }
AsAsync()119   virtual const AsyncOpKernel* AsAsync() const { return nullptr; }
120 
121   // Initial time (in CPU cycles) we expect an operation to take.  Used to
122   // determine whether an operation should be place in a threadpool.  Operations
123   // start out "expensive".
124   static const uint64 kInitialCostEstimateCycles = 100 * 1000 * 1000;
125   static const uint64 kOpIsExpensiveThresholdCycles = 5000;
126   static const uint64 kCostDecay = 10;
127 
128   // Returns true iff this op kernel is considered "expensive". The
129   // runtime may use this flag to optimize graph execution for example
130   // to "inline" inexpensive kernels.
IsExpensive()131   virtual bool IsExpensive() {
132     return expensive_ && (cost_estimate_.load(std::memory_order_relaxed) >
133                           kOpIsExpensiveThresholdCycles);
134   }
135 
136   // Updates the dynamic cost estimate, which is used to determine whether this
137   // op is expensive. The new cost estimate is a weighted average of the old
138   // cost estimate and the latest cost.
UpdateCostEstimate(uint64 elapsed_cycles)139   void UpdateCostEstimate(uint64 elapsed_cycles) {
140     // N.B. Updates to `cost_estimate_` are atomic but unlocked.  Simulataneous
141     // updates may result in one or more updates being ignored.  This does not
142     // affect correctness but may slow down the update frequency.
143     cost_estimate_.store(
144         (kCostDecay - 1) * cost_estimate_.load(std::memory_order_relaxed) /
145                 kCostDecay +
146             (elapsed_cycles / kCostDecay),
147         std::memory_order_relaxed);
148   }
149 
150   // Accessors.
def()151   const NodeDef& def() const { return *def_; }
152   const string& name() const;              // Same as def().name()
153   const string& type_string() const;       // Same as def().op()
154   const string& requested_device() const;  // Same as def().device()
is_internal()155   bool is_internal() const { return is_internal_; }
156 
num_inputs()157   int num_inputs() const { return input_types_.size(); }
input_type(int i)158   DataType input_type(int i) const { return input_types_[i]; }
input_types()159   const DataTypeVector& input_types() const { return input_types_; }
input_memory_types()160   const MemoryTypeVector& input_memory_types() const {
161     return input_memory_types_;
162   }
163   const string& requested_input(int i) const;  // Same as def().input(i)
164 
num_outputs()165   int num_outputs() const { return output_types_.size(); }
output_type(int o)166   DataType output_type(int o) const { return output_types_[o]; }
output_types()167   const DataTypeVector& output_types() const { return output_types_; }
output_memory_types()168   const MemoryTypeVector& output_memory_types() const {
169     return output_memory_types_;
170   }
171 
172   Status InputRange(StringPiece input_name, int* start, int* stop) const;
173   Status OutputRange(StringPiece output_name, int* start, int* stop) const;
174 
175   // We allow legacy scalars within Google up until GraphDef version 6.
176   // TODO(irving): Remove when we can drop support for GraphDef version 5.
allow_legacy_scalars()177   bool allow_legacy_scalars() const {
178 #if defined(PLATFORM_GOOGLE) || defined(PLATFORM_GOOGLE_ANDROID)
179     return graph_def_version_ < 6;
180 #else
181     return false;
182 #endif
183   }
184 
185   // Allow either scalars or (if allowing legacy scalars) shape (1,).
IsLegacyScalar(const TensorShape & shape)186   bool IsLegacyScalar(const TensorShape& shape) const {
187     return shape.dims() == 0 || (allow_legacy_scalars() && shape.dims() == 1 &&
188                                  shape.dim_size(0) == 1);
189   }
190 
191   // Allow rank 1 or (if allowing legacy scalars) rank 0.
IsLegacyVector(const TensorShape & shape)192   bool IsLegacyVector(const TensorShape& shape) const {
193     return shape.dims() == 1 || (allow_legacy_scalars() && shape.dims() == 0);
194   }
195 
196   // Turn a shape Tensor into a TensorShape
197   // TODO(irving): Move to TensorShapeUtils once !allow_legacy_scalars
198   Status MakeShape(const Tensor& shape, TensorShape* out) const;
199 
200   static int DeviceNumaNode(const DeviceBase* device);
201 
202  private:
203   const std::unique_ptr<const NodeDef> def_;
204   const DataTypeVector input_types_;
205   const MemoryTypeVector input_memory_types_;
206   const DataTypeVector output_types_;
207   const MemoryTypeVector output_memory_types_;
208   const int graph_def_version_;
209   const bool is_internal_;  // True if this is an internal operation
210   NameRangeMap input_name_map_;
211   NameRangeMap output_name_map_;
212   bool expensive_;
213   std::atomic_uint_fast64_t cost_estimate_;
214 
215   TF_DISALLOW_COPY_AND_ASSIGN(OpKernel);
216 };
217 
218 class AsyncOpKernel : public OpKernel {
219  public:
220   using OpKernel::OpKernel;  // Lift OpKernel constructors.
221 
222   // Asynchronous compute.
223   //
224   // Implementations of ComputeAsync() must run "done" to signal the
225   // completion of the computation. "context" is guaranteed to be
226   // alive until the "done" callback starts.
227   typedef std::function<void()> DoneCallback;
228   virtual void ComputeAsync(OpKernelContext* context, DoneCallback done) = 0;
229 
AsAsync()230   AsyncOpKernel* AsAsync() final { return this; }
AsAsync()231   const AsyncOpKernel* AsAsync() const final { return this; }
232 
233   void Compute(OpKernelContext* context) final;
234 };
235 
236 // Wraps a tensor that is held by an Op across calls to Compute(). For
237 // memory safety when using asynchronous devices like GPUs, the system
238 // must be notified when a Tensor is used inside an Op execution. The
239 // wrapper ensures that all uses of the Tensor are tracked, because in
240 // order to retrieve the Tensor the caller must use AccessTensor which
241 // notifies the context.
242 class PersistentTensor {
243  public:
PersistentTensor()244   PersistentTensor() {}
PersistentTensor(const Tensor & tensor)245   explicit PersistentTensor(const Tensor& tensor) : tensor_(tensor) {}
246 
247   // Caller does not own the returned Tensor*.
248   Tensor* AccessTensor(OpKernelConstruction* context);
249   // Caller does not own the returned Tensor*.
250   Tensor* AccessTensor(OpKernelContext* context);
251 
252   // The check for initialization does not need to access the
253   // underlying tensor buffer.
IsInitialized()254   bool IsInitialized() const { return tensor_.IsInitialized(); }
255 
NumElements()256   int64 NumElements() const { return tensor_.NumElements(); }
257 
AllocatedBytes()258   int64 AllocatedBytes() const { return tensor_.AllocatedBytes(); }
259 
260  private:
261   Tensor tensor_;
262 };
263 
264 class OpKernelConstruction {
265  public:
266   OpKernelConstruction(DeviceType device_type, DeviceBase* device,
267                        Allocator* allocator, const NodeDef* node_def,
268                        const OpDef* op_def, FunctionLibraryRuntime* flib,
269                        const DataTypeSlice& input_types,
270                        const MemoryTypeSlice& input_memory_types,
271                        const DataTypeSlice& output_types,
272                        const MemoryTypeSlice& output_memory_types,
273                        int graph_def_version, Status* status);
274 
env()275   Env* env() const { return device_->env(); }
276 
277   // Allocation of tensors during kernel construction:
278   //
279   // It is legal to temporarily allocate scratch tensor storage during
280   // Op kernel construction. Scratch tensors should be allocated using
281   // allocate_temp below. Some kernels need to keep tensors in between
282   // invocations. If such a Tensor is allocated during kernel
283   // construction this must be done using allocate_persistent, and the
284   // Op may only store the returned PersistentTensor object. When the
285   // Tensor is needed in a subsequent invocation, it can be retrieved
286   // from the PersistentTensor using the AccessTensor method. This
287   // ensures that the system is made aware of any use of the tensor's
288   // allocated memory, which is needed for correctness on asynchronous
289   // devices such as GPUs.
290 
291   // Allocates a temporary Tensor of the specified type and shape. The
292   // Tensor must not be used after kernel construction is
293   // complete. See comment above.
294   Status allocate_temp(DataType type, const TensorShape& shape,
295                        Tensor* out_temp);
296 
297   // Allocates a Tensor of the specified type and shape which the Op
298   // plans to maintain as persistent state. out_persistent holds the
299   // PersistentTensor which is the object the caller should store. For
300   // convenience, if out_tensor is non-null then it will be filled in
301   // with a Tensor* pointing to the newly-allocated tensor which the
302   // caller can use instead of calling
303   // out_persistent->AccessTensor. The caller does not own out_tensor
304   // and should not keep a copy of it. See comment above.
305   Status allocate_persistent(DataType type, const TensorShape& shape,
306                              PersistentTensor* out_persistent,
307                              Tensor** out_tensor);
308 
309   // User-supplied configuration of this operation.
def()310   const NodeDef& def() const { return *def_; }
311 
312   // For inspecting the inputs to this operation.
num_inputs()313   int num_inputs() const { return input_types_.size(); }
input_type(int i)314   DataType input_type(int i) const { return input_types_[i]; }
input_types()315   const DataTypeSlice& input_types() const { return input_types_; }
input_memory_types()316   const MemoryTypeSlice& input_memory_types() const {
317     return input_memory_types_;
318   }
319 
320   // For inspecting the outputs expected from this operation.
num_outputs()321   int num_outputs() const { return output_types_.size(); }
output_type(int i)322   DataType output_type(int i) const { return output_types_[i]; }
output_types()323   const DataTypeSlice& output_types() const { return output_types_; }
output_memory_types()324   const MemoryTypeSlice& output_memory_types() const {
325     return output_memory_types_;
326   }
327 
328   // If expected_inputs == inputs() and expected_outputs == output_types(),
329   // returns OK, else returns INVALID_ARGUMENT with an error message.
330   // Recommended for Ops with dynamic signatures.
331   Status MatchSignature(const DataTypeSlice expected_inputs,
332                         const DataTypeSlice expected_outputs);
333 
334   // For recording configuration errors during construction.
335   void SetStatus(const Status& status);
status()336   const Status& status() const { return *status_; }
337 
338   // Look up the attr with name attr_name and set *value to its value.  If no
339   // attr with attr_name is found in def(), or the attr does not have
340   // a matching type, a non-ok status will be returned.
341   template <class T>
342   Status GetAttr(StringPiece attr_name, T* value) const;
343 
344   // Return true if the attr_name is defined in def().
345   bool HasAttr(StringPiece attr_name) const;
346 
347   // Return the device type.
device_type()348   const DeviceType& device_type() const { return device_type_; }
349 
350   // If not nullptr, the kernel can instantiate functions defined in
351   // the library. E.g.,
352   // CHECK_NOTNULL(function_library())->Instantiate("Foo", ...).
function_library()353   FunctionLibraryRuntime* function_library() const { return flib_; }
354 
355   // The GraphDef version whose behavior we should follow.
graph_def_version()356   int graph_def_version() const { return graph_def_version_; }
357 
358   // Helper routines for the OP_REQUIRES macros
359   void CtxFailure(const Status& s);
360   void CtxFailureWithWarning(const Status& s);
361   void CtxFailure(const char* file, int line, const Status& s);
362   void CtxFailureWithWarning(const char* file, int line, const Status& s);
363 
364   // Unrecommended functions: these are functions that have some
365   // current uses but are not recommended for use, and may go away at
366   // some future major version release.
367 
368   // May be used, e.g., to get GPU handles, etc.
369   //
370   // Currently only used to call MakeTensorFromProto() for
371   // implementing ConstantOp for every device.  See comments
372   // on Device::MakeTensorFromProto for longer-term replacement
373   // ideas.
device()374   DeviceBase* device() const { return device_; }
375 
376  private:
377   const DeviceType device_type_;
378   DeviceBase* const device_;
379   Allocator* allocator_;
380   const NodeDef* def_;
381   const OpDef* op_def_;
382   FunctionLibraryRuntime* flib_;
383   DataTypeSlice input_types_;
384   MemoryTypeSlice input_memory_types_;
385   DataTypeSlice output_types_;
386   MemoryTypeSlice output_memory_types_;
387   const int graph_def_version_;
388   Status* status_;
389 
390   // Allow op_def_ across from OpKernel, but not from subclasses.
391   // TODO(irving): Remove protos from this header entirely.
392   friend class OpKernel;
393 
394   TF_DISALLOW_COPY_AND_ASSIGN(OpKernelConstruction);
395 };
396 
397 // TODO(mrry): Consider converting to a random_access_iterator, and upgrading
398 // tensorflow::gtl::iterator_range to make the below container classes
399 // unnecessary.
400 template <typename ListType, typename ElementType>
401 class OpArgIterator {
402  public:
403   using iterator_category = std::forward_iterator_tag;
404   using value_type = ElementType;
405   using pointer = ElementType*;
406   using const_pointer = const ElementType*;
407   using reference = ElementType&;
408   using const_reference = const ElementType&;
409   using difference_type = ptrdiff_t;
410 
OpArgIterator(const ListType * list,int i)411   OpArgIterator(const ListType* list, int i) : list_(list), i_(i) {}
412 
413   bool operator==(const OpArgIterator& rhs) {
414     DCHECK(list_ == rhs.list_);
415     return i_ == rhs.i_;
416   }
417 
418   bool operator!=(const OpArgIterator& rhs) {
419     DCHECK(list_ == rhs.list_);
420     return i_ != rhs.i_;
421   }
422 
423   OpArgIterator operator++() {  // prefix ++it
424     ++i_;
425     return *this;
426   }
427 
428   OpArgIterator operator++(int) {  // postfix it++
429     OpArgIterator old_value = *this;
430     ++i_;
431     return old_value;
432   }
433 
434   reference operator*() { return (*list_)[i_]; }
435   pointer operator->() { return &(*list_)[i_]; }
436 
437   const_reference operator*() const { return (*list_)[i_]; }
438   const_pointer operator->() const { return &(*list_)[i_]; }
439 
440  private:
441   const ListType* const list_;
442   int i_;
443 };
444 
445 // Utility class for representing a list of immutable input tensors
446 // that are passed to the op as a single named argument.
447 class OpInputList {
448  public:
449   typedef OpArgIterator<OpInputList, const Tensor> Iterator;
OpInputList()450   OpInputList() : ctx_(nullptr), start_(0), stop_(0) {}
OpInputList(OpKernelContext * ctx,int start,int stop)451   OpInputList(OpKernelContext* ctx, int start, int stop)
452       : ctx_(ctx), start_(start), stop_(stop) {}
453   OpInputList& operator=(const OpInputList& other) = default;
454   const Tensor& operator[](int i) const;
size()455   int size() const { return stop_ - start_; }
begin()456   Iterator begin() const { return Iterator(this, 0); }
end()457   Iterator end() const { return Iterator(this, size()); }
458 
459  private:
460   OpKernelContext* ctx_;  // not owned
461   int start_;
462   int stop_;
463 };
464 
465 // Utility class for representing a list of mutable ("ref") input tensors
466 // that are passed to the op as a single named argument.
467 class OpMutableInputList {
468  public:
469   typedef OpArgIterator<OpMutableInputList, Tensor*> Iterator;
OpMutableInputList(OpKernelContext * ctx,int start,int stop)470   OpMutableInputList(OpKernelContext* ctx, int start, int stop)
471       : ctx_(ctx), start_(start), stop_(stop) {}
OpMutableInputList()472   OpMutableInputList() : ctx_(nullptr), start_(0), stop_(0) {}
473   OpMutableInputList& operator=(const OpMutableInputList& other) = default;
474   Tensor at(int i, bool lock_held);
475   mutex* ref_mutex(int i);
size()476   int size() const { return stop_ - start_; }
begin()477   Iterator begin() const { return Iterator(this, 0); }
end()478   Iterator end() const { return Iterator(this, size()); }
479 
480  private:
481   OpKernelContext* ctx_;  // not owned
482   int start_;
483   int stop_;
484 };
485 
486 // Utility class for representing a list of output tensors that are
487 // grouped as a single named output.
488 class OpOutputList {
489  public:
490   typedef OpArgIterator<OpOutputList, const Tensor*> Iterator;
OpOutputList()491   OpOutputList() : ctx_(nullptr), start_(0), stop_(0) {}
OpOutputList(OpKernelContext * ctx,int start,int stop)492   OpOutputList(OpKernelContext* ctx, int start, int stop)
493       : ctx_(ctx), start_(start), stop_(stop) {}
494   OpOutputList& operator=(const OpOutputList& other) = default;
495   Tensor* operator[](int i);
496   bool required(int i) const;
497   DataType expected_output_dtype(int i) const;
498   Status allocate(int i, const TensorShape& shape, Tensor** output);
499   void set(int i, const Tensor& tensor);
500   void set_ref(int i, mutex* mu, Tensor* tensor_for_ref);
size()501   int size() const { return stop_ - start_; }
begin()502   Iterator begin() const { return Iterator(this, 0); }
end()503   Iterator end() const { return Iterator(this, size()); }
504 
505  private:
506   OpKernelContext* ctx_;  // not owned
507   int start_;
508   int stop_;
509 };
510 
511 // Holds a tensor or tensor reference. For tensor references, we need
512 // a mutex to prevent concurrent access to the tensor.
513 struct TensorValue {
TensorValueTensorValue514   TensorValue() : mutex_if_ref(nullptr), tensor(nullptr) {}
TensorValueTensorValue515   TensorValue(Tensor* t)  // NOLINT(runtime/explicit)
516       : mutex_if_ref(nullptr), tensor(t) {}
TensorValueTensorValue517   TensorValue(mutex* mu, Tensor* t) : mutex_if_ref(mu), tensor(t) {}
518   Tensor* operator->() const { return tensor; }
is_refTensorValue519   bool is_ref() const { return mutex_if_ref != nullptr; }
520 
521   mutex* mutex_if_ref;  // nullptr if not a ref, != nullptr if a ref
522   Tensor* tensor;
523 };
524 
525 // Used to store partitioned graphs from function-calling ops.
526 struct GraphCollector {
527   mutex mu;
528   std::vector<GraphDef> partitioned_graphs GUARDED_BY(mu);
529   GraphDef raw_graph GUARDED_BY(mu);
530   GraphDef optimized_graph GUARDED_BY(mu);
531 
532   bool dirty GUARDED_BY(mu);
533 
GraphCollectorGraphCollector534   GraphCollector() : dirty(false) {}
535 
CollectRawGraphGraphCollector536   void CollectRawGraph(const GraphDef& graph) {
537     mutex_lock ml(mu);
538     raw_graph.MergeFrom(graph);
539     dirty = true;
540   }
541 
CollectOptimizedGraphGraphCollector542   void CollectOptimizedGraph(const GraphDef& graph) {
543     mutex_lock ml(mu);
544     optimized_graph.MergeFrom(graph);
545     dirty = true;
546   }
547 
CollectPartitionedGraphGraphCollector548   void CollectPartitionedGraph(const GraphDef& graph) {
549     mutex_lock ml(mu);
550     partitioned_graphs.push_back(graph);
551     dirty = true;
552   }
553 
ClearGraphsGraphCollector554   void ClearGraphs() EXCLUSIVE_LOCKS_REQUIRED(mu) {
555     raw_graph.Clear();
556     optimized_graph.Clear();
557     partitioned_graphs.clear();
558     dirty = false;
559   }
560 
HasUpdatedGraphsGraphCollector561   bool HasUpdatedGraphs() {
562     mutex_lock ml(mu);
563     return dirty;
564   }
565 };
566 
567 class OpKernelContext {
568  public:
569   // The first element of a WrappedAllocator is a "base" Allocator and
570   // the second element is that Allocator wrapped by a
571   // TrackingAllocator
572   typedef std::pair<Allocator*, TrackingAllocator*> WrappedAllocator;
573 
574   // TODO(zhifengc): Do some cleanup of Params.
575   // The Params struct is passed in to initialize an OpKernelContext,
576   // and must outlive the OpKernelContext.
577   struct Params {
~ParamsParams578     ~Params() { delete eigen_gpu_device; }
579 
580     // The step being executed.
581     int64 step_id = 0;
582 
583     // The op kernel being computed.
584     OpKernel* op_kernel = nullptr;
585 
586     // The device on which the kernel is running.
587     DeviceBase* device = nullptr;
588 
589     // The Eigen GPU device wrapper, which may include a per-op
590     // wrapped allocator. The concrete type of this object depends on
591     // the type of this->device, so eigen_gpu_device can't be an
592     // inline member and must be heap allocated. However, we don't
593     // want to allocate a new eigen_gpu_device for every Op that is
594     // executed. Instead this member is allocated on first use using
595     // ensure_eigen_gpu_device, and then if the Params structure is
596     // re-used for subsequent Ops, the eigen_gpu_device is
597     // ReInitialized in the OpKernelContext constructor. Unlike the
598     // other pointers in Params, this one is owned by Params.
599     PerOpGpuDevice* eigen_gpu_device = nullptr;
600 
ensure_eigen_gpu_deviceParams601     inline void ensure_eigen_gpu_device() {
602       DCHECK(device);
603       if (nullptr == eigen_gpu_device) {
604         // Surprisingly, MakeGpuDevice will return nullptr if the
605         // device is not a GPU device. This is ok, since those devices
606         // will never use eigen_gpu_device. It seems better to have
607         // ensure_eigen_gpu_device fall through and regenerate the
608         // nullptr every time an OpKernelContext is instantiated, than
609         // to do an unnecessary allocation of a dummy eigen GPU
610         // device for CPU device Ops.
611         eigen_gpu_device = device->MakeGpuDevice();
612       }
613     }
614 
615     bool track_allocations = false;
616     bool log_memory = false;
617     bool record_tensor_accesses = false;
618 
619     // Array indexed by output number for this node
620     const AllocatorAttributes* output_attr_array = nullptr;
621 
622     // Shared resources accessible by this op kernel invocation.
623     ResourceMgr* resource_manager = nullptr;
624 
625     // Per-step resources accessible by this op kernel invocation should be
626     // stored in this container..
627     ScopedStepContainer* step_container = nullptr;
628 
629     // Mechanism used by this op kernel invocation to communicate with
630     // computations running on other devices.
631     Rendezvous* rendezvous = nullptr;
632 
633     // Mechanism for executing a collective op that needs to coordinate
634     // with parallel instances running on other devices.
635     CollectiveExecutor* collective_executor = nullptr;
636 
637     // The session state for this op.
638     SessionState* session_state = nullptr;
639 
640     // Unique session identifier. Can be empty.
641     string session_handle;
642 
643     // The tensor store for this op.
644     TensorStore* tensor_store = nullptr;
645 
646     // Mechanism used by this op kernel invocation to register a callback
647     // for its cancellation.
648     CancellationManager* cancellation_manager = nullptr;
649 
650     // Inputs to this op kernel.
651     const gtl::InlinedVector<TensorValue, 4>* inputs = nullptr;
652     bool is_input_dead = false;
653 
654     const gtl::InlinedVector<AllocatorAttributes, 4>* input_alloc_attrs =
655         nullptr;
656 
657     // Device contexts.
658     const gtl::InlinedVector<DeviceContext*, 4>* input_device_contexts =
659         nullptr;
660     DeviceContext* op_device_context = nullptr;
661 
662     // Control-flow op supports.
663     FrameAndIter frame_iter;
664 
665     // Function call supports.
666     CallFrameInterface* call_frame = nullptr;
667     FunctionLibraryRuntime* function_library = nullptr;
668     std::function<void(std::function<void()>)>* runner = nullptr;
669     StepStatsCollectorInterface* stats_collector = nullptr;
670     GraphCollector* graph_collector = nullptr;
671 
672     // TensorSliceReaderCache support.
673     checkpoint::TensorSliceReaderCacheWrapper* slice_reader_cache = nullptr;
674 
675     // Support for forwarding reservations (used by ScopedAllocator).
676     static const int kNeverForward = -2;
677     static const int kNoReservation = -1;
678     // Values in [0,...) represent reservations for the indexed output.
679     const int* forward_from_array = nullptr;
680 
681     // For tracking actively running deferred ops.
682     std::function<void()> inc_num_deferred_ops_function = []() {};
683     std::function<void()> dec_num_deferred_ops_function = []() {};
684   };
685 
686   // params must outlive the OpKernelContext.
687   explicit OpKernelContext(Params* params);
688   OpKernelContext(Params* params, int noutputs);
689   ~OpKernelContext();
690 
env()691   Env* env() const { return params_->device->env(); }
692 
step_id()693   int64 step_id() const { return params_->step_id; }
694 
op_kernel()695   const OpKernel& op_kernel() const { return *params_->op_kernel; }
696 
697   // Input/output signature.
698 
num_inputs()699   int num_inputs() const { return params_->inputs->size(); }
700   DataType input_dtype(int index) const;
701   Status input_dtype(StringPiece name, DataType* dtype) const;
702   MemoryType input_memory_type(int index) const;
703 
num_outputs()704   int num_outputs() const { return outputs_.size(); }
705   DataType expected_output_dtype(int index) const;
706   MemoryType output_memory_type(int index) const;
707 
708   // Input
709 
710   // Returns an immutable input tensor. May only be used for non-Ref
711   // inputs. For Ref inputs use mutable_input below.
712   // REQUIRES: !IsRefType(input_dtype(index))
713   // TODO(mrry): Convert this to return Status.
714   const Tensor& input(int index);
715 
716   // Returns the named immutable input tensor in "tensor", as defined
717   // in the OpDef. May only be used for non-Ref inputs. For Ref inputs
718   // use mutable_input below.
719   // REQUIRES: !IsRefType(input_dtype(index))
720   // REQUIRES: the named input must not be a list.
721   Status input(StringPiece name, const Tensor** tensor);
722 
723   // Returns the named list-valued immutable input in "list", as
724   // defined in the OpDef.  If the named output is not list-valued,
725   // returns a one-element list. May only be used for non-Ref
726   // inputs. For Ref inputs use mutable_input below.
727   // REQUIRES: !IsRefType(input_dtype(index))
728   Status input_list(StringPiece name, OpInputList* list);
729 
730   // For mutable inputs, use the following together to make sure there
731   // is no concurrent access to mutable_input(), e.g.:
732   // {
733   //   Tensor& t = context->mutable_input(index);
734   //   mutex_lock lock(*context->input_ref_mutex(index));
735   //   // modify the values in t
736   // }
737   // REQUIRES: IsRefType(input_dtype(index))
738   Status input_ref_mutex(StringPiece name, mutex** out_mutex);
739 
740   // Returns a mutable input tensor. Must be used to access Ref
741   // inputs.  REQUIRES: IsRefType(input_dtype(index)). The caller may
742   // modify the values stored in the Tensor buffer, and modifications
743   // will be visible to other Ops reading the same ref tensor. If
744   // !lock_held the input mutex will be acquired before returning the
745   // Tensor.
746   // TODO(mrry): Convert this to return Status.
747   Tensor mutable_input(int index, bool lock_held);
748 
749   // Returns the named mutable input tensor in "tensor", as defined in
750   // the OpDef. Must be used to access Ref inputs. The values stored
751   // in the Tensor buffer may be modified, and modifications will be
752   // visible to other Ops reading the same ref tensor. If !lock_held
753   // the input mutex will be acquired before returning the Tensor.
754   // REQUIRES: the named input must not be a list.
755   // REQUIRES: the named input must be a ref tensor.
756   Status mutable_input(StringPiece name, Tensor* tensor, bool lock_held);
757 
758   // Returns the named list-valued mutable input in "list", as defined
759   // in the OpDef.  If the named input is not list-valued, returns a
760   // one-element list. Must be used to access Ref inputs. The values
761   // stored in the Tensor buffer may be modified, and modifications
762   // will be visible to other Ops reading the same ref tensor.
763   // REQUIRES: the named input must be a ref tensor.
764   Status mutable_input_list(StringPiece name, OpMutableInputList* list);
765 
766   // Replace the corresponding Ref Input to use the storage buffer
767   // used by tensor. If !lock_held the input mutex will be acquired
768   // before returning the Tensor.
769   // REQUIRES: IsRefType(input_dtype(index)).
770   void replace_ref_input(int index, const Tensor& tensor, bool lock_held);
771 
772   // Replace the corresponding named Ref Input to use the storage
773   // buffer used by tensor. If !lock_held the input mutex will be
774   // acquired before returning the Tensor.
775   // REQUIRES: IsRefType(input_dtype(index)).
776   Status replace_ref_input(StringPiece name, const Tensor& tensor,
777                            bool lock_held);
778 
779   // Deletes the Tensor object used as the Ref Input at
780   // input_index. This is not usually necessary and should be used
781   // with caution. If !lock_held the input mutex will be acquired
782   // before returning the Tensor.
783   // REQUIRES: IsRefType(input_dtype(input_index)).
784   void delete_ref_input(int input_index, bool lock_held);
785 
786   // Return true if there is input at the given index. An operator has no
787   // input at index if its tensor is null. This is primarily used by the
788   // merge operator.
789   // TODO(mrry): Convert this to return Status.
790   bool has_input(int index) const;
791 
792   // Returns true if all inputs are the same shape, otherwise sets the
793   // status to a non-OK value and returns false.
794   // Usage: if (!context->ValidateInputsAreSameShape(this)) return;
795   bool ValidateInputsAreSameShape(OpKernel* op);
796 
797   // If non-null, kernels should populate with any partition subgraphs created.
graph_collector()798   GraphCollector* graph_collector() { return params_->graph_collector; }
799 
800   // Input to output forwarding.
801 
802   // Set the output Ref Tensor at output_index to be an alias of the
803   // input Ref Tensor at input_index.
804   // REQUIRES: IsRefType(input_dtype(input_index)).
805   // REQUIRES: IsRefType(output_dtype(output_index)).
806   void forward_ref_input_to_ref_output(int input_index, int output_index);
807 
808   // Returns true when an alias to input[input_index], reshaped to output_shape,
809   // which is safe to use for in-place computation was written to *output.
810   // Returns false if input[input_index] has a refcount greater than one, or if
811   // its type does not match the expected output type of output[output_index],
812   // or the number of elements in input[input_index] does not equal the number
813   // of elements in output_shape.
814   bool forward_input_to_output_with_shape(int input_index, int output_index,
815                                           const TensorShape& output_shape,
816                                           Tensor** output) TF_MUST_USE_RESULT;
817   Status forward_input_to_output_with_shape(StringPiece input_name,
818                                             StringPiece output_name,
819                                             const TensorShape& output_shape,
820                                             Tensor** output) TF_MUST_USE_RESULT;
821 
822   // Returns a pointer to a Tensor aliasing the underlying buffer backing
823   // input[input_index] iff
824   //   * input[input_index] is not a ref,
825   //   * the data type, shape, memory type, and allocator attributes of
826   //     input[input_index] are compatible with those given in dtype, shape,
827   //     memory_type, and attr,
828   //   * refcount on the underlying buffer is one.
829   //   * Either there is no forwarding reservation for either input_index
830   //     or output_index or the specified input is reserved for the specified
831   //     output. More precisely:
832   //
833   //     These cases mean neither input nor output has a reservation:
834   //        forward_from_array = nullptr
835   //     OR (input_index is not in forward_from_array AND
836   //         (output_index == kNoReservation OR
837   //          forward_from_array[output_index] == kNoReservation))
838   //
839   //     This case means that input_index is reserved for output_index:
840   //        forward_from_array[output_index] == input_index
841   //
842   //     This case means the output is reserved to always be allocated,
843   //     never assigned a forwarded input:
844   //        forward_from_array[output_index] == kNeverForward
845   //
846   // Otherwise returns nullptr.
847   // NOTE: For Cuda kernels that read inputs using the __ldg() intrinsic,
848   // forwarding is only safe if there are no reads via __ldg() after writes
849   // to the same address.
850   std::unique_ptr<Tensor> forward_input(
851       int input_index, int output_index, DataType output_dtype,
852       const TensorShape& output_shape, MemoryType output_memory_type,
853       const AllocatorAttributes& output_attr) TF_MUST_USE_RESULT;
854 
855   // Tries to forward one of the inputs given in input_indices to
856   // output[output_index]. If none of the given inputs can be forwarded, calls
857   // allocate_output() to allocate a new output buffer.
858   Status forward_input_or_allocate_output(
859       gtl::ArraySlice<int> candidate_input_indices, int output_index,
860       const TensorShape& output_shape, Tensor** output) TF_MUST_USE_RESULT;
861   Status forward_input_or_allocate_output(
862       gtl::ArraySlice<StringPiece> candidate_input_names,
863       StringPiece output_name, const TensorShape& output_shape,
864       Tensor** output) TF_MUST_USE_RESULT;
865 
866   // Tries to reuse one of the inputs given in input_indices as a temporary.
867   // If none of the given inputs can be forwarded, calls
868   // allocate_temp() to allocate a new temporary buffer.
869   Status forward_input_or_allocate_temp(
870       gtl::ArraySlice<int> candidate_input_indices, DataType type,
871       const TensorShape& shape, const AllocatorAttributes& allocator_attr,
872       Tensor* out_temp) TF_MUST_USE_RESULT;
873 
forward_input_or_allocate_temp(gtl::ArraySlice<int> candidate_input_indices,DataType type,const TensorShape & shape,Tensor * out_temp)874   Status forward_input_or_allocate_temp(
875       gtl::ArraySlice<int> candidate_input_indices, DataType type,
876       const TensorShape& shape, Tensor* out_temp) TF_MUST_USE_RESULT {
877     return forward_input_or_allocate_temp(candidate_input_indices, type, shape,
878                                           AllocatorAttributes(), out_temp);
879   }
880 
881   // Output
882 
883   // Returns the named list-valued output in "list", as defined in the OpDef.
884   // If the named output is not list-valued, returns a one-element list.
885   Status output_list(StringPiece name, OpOutputList* list);
886 
887   // If output_required(index) returns true, the OpKernel's Compute() method
888   // should call allocate_output(index, ...), set_output(index, ...),
889   // set_output_ref(index, ...), or set the status to a non-ok value.
890   // If it returns false, it may output, but is not required to do so.
891   // TODO(mrry): Convert this to return Status, and implement a string
892   // name version.
output_required(int index)893   bool output_required(int index) const {
894     return true;  // TODO(josh11b): implement
895   }
896 
897   // Allocation of tensors during kernel execution inside the Compute
898   // method:
899   //
900   // There are three methods to allocate Tensors when an Op kernel
901   // executes.
902   //
903   // 1) allocate_persistent. This is only needed for Tensors that will
904   // be stored by the Op between invocations, and it *must* be used
905   // for those Tensors. The call returns a PersistentTensor, and that
906   // is the only object the Op is allowed to hold on to between
907   // invocations. When the Tensor is needed in a subsequent
908   // invocation, it can be retrieved from the PersistentTensor using
909   // the AccessTensor method. This ensures that the system is made
910   // aware of any use of the tensor's allocated memory, which is
911   // needed for correctness on asynchronous devices such as GPUs.
912   //
913   // 2) allocate_output. This should be used to allocate any tensor
914   // that is going to be used as an output from the Op at the end of
915   // the current execution. The caller indicates which output the
916   // Tensor will be assigned to, and the call returns the
917   // newly-allocated Tensor. The Tensor can subsequently be assigned
918   // to during kernel execution, and will be used as the designated
919   // output when the kernel execution completes.
920   //
921   // 3) allocate_temp. This should be used to allocate any scratch
922   // storage that is needed while the kernel is executing, and will
923   // not be retained by the Op.
924   //
925   // In some cases a Tensor needs to be used as an output even though
926   // it was previously allocated elsewhere. The Tensor may have been
927   // passed as an input, or stored in a PersistentTensor during a
928   // previous kernel execution, or allocated earlier in the kernel
929   // execution at a time when it was not known which output it would
930   // be assigned to. In this case the kernel can use set_output or
931   // set_output_ref to indicate that the tensor should be used as the
932   // designated output. It is legal to use any previously-allocated
933   // Tensor as an argument to set_output or set_output_ref, including
934   // Tensors allocated via allocate_temp. There may be a performance
935   // penalty to using a Tensor that was not allocated using
936   // allocate_output. This is because allocate_output uses the
937   // AllocatorAttributes stored in output_attr_array for the
938   // designated output. In some cases, using the wrong attributes may
939   // cause an extra copy of the Tensor's buffer.
940 
941   // Allocates output for the specified output index with shape.
942   // OpKernelContext retains ownership of the returned pointer. See
943   // comment above.
944   //
945   // If memory allocation fails, returns an error status.
946   //
947   // REQUIRES: !IsRefType(expected_output_dtype(index))
948   Status allocate_output(int index, const TensorShape& shape,
949                          Tensor** tensor) TF_MUST_USE_RESULT;
950   Status allocate_output(StringPiece name, const TensorShape& shape,
951                          Tensor** tensor) TF_MUST_USE_RESULT;
952   // The following methods use the supplied attributes instead of
953   // those in output_attr_array. The caller is responsible for
954   // ensuring that the attributes are "compatible" with the
955   // output_attr_array, e.g. the tensor is allocated on the correct
956   // device. See comment above.
957   Status allocate_output(int index, const TensorShape& shape, Tensor** tensor,
958                          AllocatorAttributes attr) TF_MUST_USE_RESULT;
959   Status allocate_output(StringPiece name, const TensorShape& shape,
960                          Tensor** tensor,
961                          AllocatorAttributes attr) TF_MUST_USE_RESULT;
962 
963   // Allocates a temporary Tensor of the specified type and
964   // shape. Devices such as GPUs that enqueue Ops for lazy execution
965   // may retain references to the temporary tensors after the Op's
966   // Compute method has run. See comment above.
967   Status allocate_temp(DataType type, const TensorShape& shape,
968                        Tensor* out_temp, AllocatorAttributes allocator_attr,
969                        const AllocationAttributes& allocation_attr);
allocate_temp(DataType type,const TensorShape & shape,Tensor * out_temp,AllocatorAttributes allocator_attr)970   Status allocate_temp(DataType type, const TensorShape& shape,
971                        Tensor* out_temp, AllocatorAttributes allocator_attr) {
972     return allocate_temp(type, shape, out_temp, allocator_attr,
973                          AllocationAttributes());
974   }
allocate_temp(DataType type,const TensorShape & shape,Tensor * out_temp)975   Status allocate_temp(DataType type, const TensorShape& shape,
976                        Tensor* out_temp) {
977     return allocate_temp(type, shape, out_temp, AllocatorAttributes());
978   }
979 
980   // Allocates a Tensor of the specified type and shape which the Op
981   // plans to maintain as persistent state. out_persistent holds the
982   // PersistentTensor which is the object the caller should store. For
983   // convenience, if out_tensor is non-null then it will be filled in
984   // with a Tensor* pointing to the newly-allocated tensor which the
985   // caller can use instead of calling
986   // out_persistent->AccessTensor. The caller does not own out_tensor
987   // and should not keep a copy of it. See comment above.
988   Status allocate_persistent(DataType type, const TensorShape& shape,
989                              PersistentTensor* out_persistent,
990                              Tensor** out_tensor, AllocatorAttributes attr);
allocate_persistent(DataType type,const TensorShape & shape,PersistentTensor * out_persistent,Tensor ** out_tensor)991   Status allocate_persistent(DataType type, const TensorShape& shape,
992                              PersistentTensor* out_persistent,
993                              Tensor** out_tensor) {
994     return allocate_persistent(type, shape, out_persistent, out_tensor,
995                                AllocatorAttributes());
996   }
997 
998   // Copies a tensor (allocated by the caller) to the specified output
999   // index.  REQUIRES: !IsRefType(expected_output_dtype(index))
1000   // REQUIRES: 'tensor' must have the same MemoryType as
1001   // output_memory_types[index]. See comment above.
1002   Status set_output(StringPiece name, const Tensor& tensor);
1003 
1004   // To output a reference.  Caller retains ownership of mu and tensor_for_ref,
1005   // and they must outlive all uses within the step. See comment above.
1006   // REQUIRES: IsRefType(expected_output_dtype(index))
1007   Status set_output_ref(StringPiece name, mutex* mu, Tensor* tensor_for_ref);
1008 
1009   // Returns nullptr if allocate_output() or set_output() have not been called.
1010   Status mutable_output(StringPiece name, Tensor** tensor);
1011 
1012   // Records device specific state about how the input tensors were
1013   // computed.
1014   //
1015   // If using the templated function, the type must be a subclass
1016   // of DeviceContext.
1017   //
1018   // Get the DeviceContext used for the index input.  Returns nullptr
1019   // if no DeviceContext was provided.
1020   template <typename T>
1021   T* input_device_context(int index);
1022   DeviceContext* input_device_context(int index);
1023 
1024   // Return the DeviceContext that should be used for this Op.
1025   //
1026   // If using the templated function, the type must be a subclass
1027   // of DeviceContext.
1028   //
1029   // Returns nullptr if the device did not provide one.
1030   template <typename T>
1031   T* op_device_context();
op_device_context()1032   DeviceContext* op_device_context() {
1033     DeviceContext* ret = params_->op_device_context;
1034     if (ret == nullptr) {
1035       auto* dev_info = device()->tensorflow_gpu_device_info();
1036       if (dev_info) ret = dev_info->default_context;
1037     }
1038     return ret;
1039   }
1040 
input_alloc_attr(int index)1041   AllocatorAttributes input_alloc_attr(int index) const {
1042     if (params_->input_alloc_attrs == nullptr) {
1043       return AllocatorAttributes();
1044     } else {
1045       DCHECK_GE(index, 0);
1046       DCHECK_LT(index, params_->input_alloc_attrs->size());
1047       return (*params_->input_alloc_attrs)[index];
1048     }
1049   }
1050 
output_alloc_attr(int index)1051   AllocatorAttributes output_alloc_attr(int index) const {
1052     return params_->output_attr_array[index];
1053   }
1054 
ConsumeWrappedAllocators()1055   gtl::InlinedVector<WrappedAllocator, 4> ConsumeWrappedAllocators() {
1056     mutex_lock lock(mu_);
1057     gtl::InlinedVector<WrappedAllocator, 4> retrieved;
1058     retrieved.swap(wrapped_allocators_);
1059     return retrieved;
1060   }
1061 
1062   // Communication.
1063   //
1064   // An op kernel communicates with outside environment through
1065   // Rendezvous Send() and Recv().
rendezvous()1066   Rendezvous* rendezvous() const { return params_->rendezvous; }
1067 
collective_executor()1068   CollectiveExecutor* collective_executor() const {
1069     return params_->collective_executor;
1070   }
1071 
1072   // An op kernel can access the session state it belongs to.
session_state()1073   SessionState* session_state() const { return params_->session_state; }
1074 
1075   // Unique identifier of the session it belongs to. Can be empty.
session_handle()1076   string session_handle() const { return params_->session_handle; }
1077 
1078   // An op kernel can access the tensor store of the run it belongs to.
tensor_store()1079   TensorStore* tensor_store() const { return params_->tensor_store; }
1080 
1081   // Function call support.
1082   //
1083   // If this kernel invocation is within a function execution,
1084   // call_frame() returns the call frame for the function call.
call_frame()1085   CallFrameInterface* call_frame() const { return params_->call_frame; }
1086 
1087   // If not nullptr, the kernel invoke functions defined in the
1088   // library. E.g., CHECK_NOTNULL(function_library())->Run("Foo", ...).
function_library()1089   FunctionLibraryRuntime* function_library() const {
1090     return params_->function_library;
1091   }
1092 
runner()1093   std::function<void(std::function<void()>)>* runner() const {
1094     return params_->runner;
1095   }
stats_collector()1096   StepStatsCollectorInterface* stats_collector() const {
1097     return params_->stats_collector;
1098   }
1099 
1100   // Shared resources accessible to this kernel.
resource_manager()1101   ResourceMgr* resource_manager() const { return params_->resource_manager; }
1102 
slice_reader_cache()1103   checkpoint::TensorSliceReaderCacheWrapper* slice_reader_cache() const {
1104     return params_->slice_reader_cache;
1105   }
1106 
1107   // Execution.
1108   //
1109   // OpKernels can use these eigen devices to carry out their
1110   // numerical computation.
eigen_cpu_device()1111   const Eigen::ThreadPoolDevice& eigen_cpu_device() const {
1112     return *device()->eigen_cpu_device();
1113   }
eigen_gpu_device()1114   const Eigen::GpuDevice& eigen_gpu_device() const {
1115     return params_->eigen_gpu_device->device();
1116   }
1117 #ifdef TENSORFLOW_USE_SYCL
eigen_sycl_device()1118   const Eigen::SyclDevice& eigen_sycl_device() const {
1119     return *device()->eigen_sycl_device();
1120   }
1121 #endif
1122   template <typename EigenDeviceType>
1123   const EigenDeviceType& eigen_device() const;
1124 
1125   // Error handling.
1126 
1127   // If expected_inputs == inputs() and expected_outputs == output_types(),
1128   // returns OK, else returns INVALID_ARGUMENT with an error message.
1129   // Recommended for Ops with dynamic signatures, where validation can only
1130   // be performed at runtime.
1131   Status MatchSignature(const DataTypeSlice expected_inputs,
1132                         const DataTypeSlice expected_outputs);
1133 
1134   // An OpKernel should call SetStatus() if Compute() encounters an
1135   // error.
1136   void SetStatus(const Status& status);
status()1137   const Status& status() const { return status_; }
1138 
1139   // Cancellation.
1140   //
1141   // EXPERIMENTAL. See the implementation in tensorflow::FIFOQueue for an
1142   // example of how to use this API.
cancellation_manager()1143   CancellationManager* cancellation_manager() const {
1144     return params_->cancellation_manager;
1145   }
1146 
1147   // Other accessors.
1148 
1149   // For control flow.
frame_iter()1150   FrameAndIter frame_iter() const { return params_->frame_iter; }
is_input_dead()1151   bool is_input_dead() const { return params_->is_input_dead; }
1152 
1153   // May be used, e.g., to get GPU handles, etc.
1154   // TODO(tucker): Add example usage.
device()1155   DeviceBase* device() const { return params_->device; }
1156 
1157   // Retrieve list of referenced tensors in out_vector. Once this is
1158   // called, it is not legal to reference any more tensors.  Should
1159   // not be called from Op kernels.
1160   void retrieve_accessed_tensors(TensorReferenceVector* out_vector);
1161 
1162   // Per-step container for use by white-listed internal ops.
step_container()1163   ScopedStepContainer* step_container() const {
1164     return params_->step_container;
1165   }
1166 
1167   // Helper routines for the OP_REQUIRES macros
1168   void CtxFailure(const Status& s);
1169   void CtxFailureWithWarning(const Status& s);
1170   void CtxFailure(const char* file, int line, const Status& s);
1171   void CtxFailureWithWarning(const char* file, int line, const Status& s);
1172 
1173   // Unrecommended functions: these are functions that have some
1174   // current uses but are not recommended for use, and may go away at
1175   // some future major version release.
1176   //
1177   // The following functions all have versions that return Status
1178   // to capture error conditions, and are strongly preferred.
1179   Tensor* mutable_output(int index);
1180   void set_output(int index, const Tensor& tensor);
1181   mutex* input_ref_mutex(int index);
1182   void set_output_ref(int index, mutex* mu, Tensor* tensor_for_ref);
1183   TensorValue release_output(int index);
1184 
track_allocations()1185   bool track_allocations() const { return params_->track_allocations; }
1186 
1187   // Records temp memory allocation. Tensor object is recorded to identify the
1188   // case where temp memory is used as output memory.
1189   void record_temp_memory_allocation(int64 size, const Tensor& t)
1190       LOCKS_EXCLUDED(stats_mu_);
1191 
1192   // Returns recorded size of temporary memory;
1193   int64 temp_memory_allocated() const LOCKS_EXCLUDED(stats_mu_);
1194 
1195   // Records persistent memory allocation, size can be negative indicating
1196   // deallocation.
1197   void record_persistent_memory_allocation(int64 size, int64 alloc_id = -1)
1198       LOCKS_EXCLUDED(stats_mu_);
1199 
1200   // Returns recorded size and ids of persistent memory.
1201   int64 persistent_memory_allocated() const LOCKS_EXCLUDED(stats_mu_);
1202 
1203   std::vector<int64> persistent_alloc_ids() const LOCKS_EXCLUDED(stats_mu_);
1204 
1205   // Resets counters for temp and persistent memory and recorded ids.
1206   void clear_recorded_memory() LOCKS_EXCLUDED(stats_mu_);
1207 
1208   bool input_is_ref(int index) const;
1209 
1210   // Used by OpKernel implementations to track actively running deferred ops.
1211   //
1212   // A deferred op is one whose Compute method returns (or whose ComputeAsync
1213   // method invokes the callback) when work is scheduled onto a device. At that
1214   // point, we don't know when the work will actually complete (or if it has
1215   // already completed) on the device. These functions allow the executor to
1216   // track the status of deferred ops and act accordingly.
1217   //
1218   // Deferred OpKernel implementations must use these methods to get two
1219   // functions. It then must call these two functions in pairs, before and after
1220   // device execution, respectively.
inc_num_deferred_ops_function()1221   TF_MUST_USE_RESULT std::function<void()> inc_num_deferred_ops_function() {
1222     return params_->inc_num_deferred_ops_function;
1223   }
dec_num_deferred_ops_function()1224   TF_MUST_USE_RESULT std::function<void()> dec_num_deferred_ops_function() {
1225     return params_->dec_num_deferred_ops_function;
1226   }
1227 
1228  private:
1229   Allocator* get_allocator(AllocatorAttributes attr);
1230 
1231   // Internal method to add a tensor's buffer to the list of buffers
1232   // referenced during the execution of the Op, so that GPUs may
1233   // accurately track the memory that may not be reused until the Op
1234   // execution completes.
1235   void record_tensor_reference(const Tensor& tensor);
1236   void really_record_tensor_reference(const Tensor& tensor);
1237 
1238   // Internal common method used when allocating tensor memory
allocate_tensor(DataType type,const TensorShape & shape,Tensor * out_tensor,AllocatorAttributes allocator_attr)1239   Status allocate_tensor(DataType type, const TensorShape& shape,
1240                          Tensor* out_tensor,
1241                          AllocatorAttributes allocator_attr) {
1242     return allocate_tensor(type, shape, out_tensor, allocator_attr,
1243                            AllocationAttributes());
1244   }
1245 
1246   Status allocate_tensor(DataType type, const TensorShape& shape,
1247                          Tensor* out_tensor, AllocatorAttributes allocator_attr,
1248                          const AllocationAttributes& allocation_attr);
1249 
1250   // This is called by PersistentTensor::AccessTensor whenever the
1251   // wrapped tensor is retrieved, to ensure the runtime knows that the
1252   // Tensor is being accessed within an Op. This is necessary for
1253   // memory safety of devices like GPUs that queue Ops for
1254   // asynchronous execution after the Compute() method completes.
1255   friend class PersistentTensor;
1256   void NotifyUseOfPersistentTensor(const Tensor& tensor);
1257 
1258   Status status_;
1259   friend class CollectiveExecutor;  // for access to params_
1260   Params* params_;                  // not owned
1261   mutable mutex mu_;  // mutable so const accessors can acquire the lock
1262   gtl::InlinedVector<WrappedAllocator, 4> wrapped_allocators_ GUARDED_BY(mu_);
1263   gtl::InlinedVector<TensorValue, 4> outputs_;
1264 
1265   // Constructed only if <params->record_tensor_accesses>.
1266   ManualConstructor<UniqueTensorReferences> referenced_tensors_ GUARDED_BY(mu_);
1267 
1268   // The following data members are only used when allocation tracking is
1269   // enabled.
1270   mutable mutex stats_mu_;
1271   int64 temp_memory_allocated_ GUARDED_BY(stats_mu_);
1272   int64 persistent_memory_allocated_ GUARDED_BY(stats_mu_);
1273   std::unique_ptr<gtl::InlinedVector<std::pair<const void*, int64>, 2>>
1274       temp_tensor_buffer_and_size_ GUARDED_BY(stats_mu_);
1275   std::unique_ptr<gtl::InlinedVector<int64, 2>> persistent_alloc_ids_
1276       GUARDED_BY(stats_mu_);
1277 
1278   TF_DISALLOW_COPY_AND_ASSIGN(OpKernelContext);
1279 };
1280 
1281 // Register your OpKernel by specifying the Op's name, the device the
1282 // kernel runs on, any type attr constraints for this kernel, any
1283 // host-memory args, and the class to instantiate.  Examples:
1284 //
1285 //  // A kernel that supports all types.
1286 //  REGISTER_KERNEL_BUILDER(Name("Save").Device(DEVICE_CPU), SaveOp);
1287 //
1288 //  // The following are equivalent ways of specifying that the kernel only
1289 //  // works if the "T" type attr is set to DT_FLOAT.
1290 //  REGISTER_KERNEL_BUILDER(
1291 //      Name("Sub").Device(DEVICE_CPU).TypeConstraint<float>("T"),
1292 //      SubOp<float>);
1293 //  // (You would then repeat this for every type supported by "Sub".)
1294 //
1295 //  // This form allows you to specify a list of types as the constraint.
1296 //  REGISTER_KERNEL_BUILDER(Name("Sub")
1297 //                              .Device(DEVICE_CPU)
1298 //                              .TypeConstraint("T", {DT_FLOAT}),
1299 //                          SubOp<float>);
1300 //
1301 //  // A kernel that expects one of the input tensors in host memory.
1302 //  REGISTER_KERNEL_BUILDER(
1303 //      Name("Reshape").Device(DEVICE_GPU).HostMemory("shape"), ReshapeOp);
1304 //
1305 // See kernel_def_builder for details.
1306 
1307 // Instantiate an OpKernel that has been registered.  Returns nullptr
1308 // if no operation for that type of device / input signature combination
1309 // (and a NOT_FOUND *status), or there is an error in construction (and
1310 // an INVALID_ARGUMENT *status).  Otherwise, the caller takes ownership
1311 // of the returned pointer.
1312 // EXPECTED USAGE: unique_ptr<OpKernel> op = CreateOpKernel(...);
1313 // REQUIRES: def has all attrs specified (e.g. using AddDefaultsToNodeDef()).
1314 std::unique_ptr<OpKernel> CreateOpKernel(DeviceType device_type,
1315                                          DeviceBase* device,
1316                                          Allocator* allocator,
1317                                          const NodeDef& def,
1318                                          int graph_def_version, Status* status);
1319 Status CreateOpKernel(DeviceType device_type, DeviceBase* device,
1320                       Allocator* allocator, FunctionLibraryRuntime* flib,
1321                       const NodeDef& def, int graph_def_version,
1322                       OpKernel** kernel);
1323 
1324 // Returns into 'device_types' the subset of prioritized_types that this
1325 // binary has registered for the given NodeDef.
1326 //
1327 // REQUIRES: * 'device_types' is not nullptr.
1328 //           * def has all attrs specified (e.g. using AddDefaultsToNodeDef()).
1329 Status SupportedDeviceTypesForNode(
1330     const std::vector<DeviceType>& prioritized_types, const NodeDef& def,
1331     PrioritizedDeviceTypeVector* device_types);
1332 
1333 // Returns a message with a description of the kernels registered for op
1334 // `op_name`.
1335 string KernelsRegisteredForOp(StringPiece op_name);
1336 
1337 // Call once after Op registration has completed.
1338 Status ValidateKernelRegistrations(const OpRegistryInterface& op_registry);
1339 
1340 // -----------------------------------------------------------------------------
1341 // OpKernel registration implementation follows, please ignore.
1342 
1343 // Allow the REGISTER_KERNEL_BUILDER(Name("op_name").Device(...)...) syntax.
1344 namespace register_kernel {
1345 
1346 class Name : public KernelDefBuilder {
1347  public:
1348   // With selective registration, kernels whose implementation class is not used
1349   // by any kernel are disabled with the SHOULD_REGISTER_OP_KERNEL call in
1350   // REGISTER_KERNEL_BUILDER_UNIQ. However, an unused kernel that shares an
1351   // implementation class with a used kernel would get through that mechanism.
1352   //
1353   // This mechanism stops that registration by changing the name of the kernel
1354   // for the unused op to one that is ignored by
1355   // OpKernelRegistrar::InitInternal.  Note that this method alone is
1356   // not sufficient - the compiler can't evaluate the entire KernelDefBuilder at
1357   // compilation time, so this method doesn't actually reduce code size.
Name(const char * op)1358   explicit Name(const char* op)
1359       : KernelDefBuilder(SHOULD_REGISTER_OP(op) ? op : "_no_register") {}
1360 };
1361 
1362 namespace system {
1363 
1364 class Name : public KernelDefBuilder {
1365  public:
1366   // For system kernels, we ignore selective registration and
1367   // unconditionally register the kernel.
Name(const char * op)1368   explicit Name(const char* op) : KernelDefBuilder(op) {}
1369 };
1370 
1371 }  // namespace system
1372 
1373 }  // namespace register_kernel
1374 
1375 #define REGISTER_KERNEL_BUILDER(kernel_builder, ...) \
1376   REGISTER_KERNEL_BUILDER_UNIQ_HELPER(__COUNTER__, kernel_builder, __VA_ARGS__)
1377 
1378 #define REGISTER_KERNEL_BUILDER_UNIQ_HELPER(ctr, kernel_builder, ...) \
1379   REGISTER_KERNEL_BUILDER_UNIQ(ctr, kernel_builder, __VA_ARGS__)
1380 
1381 #define REGISTER_KERNEL_BUILDER_UNIQ(ctr, kernel_builder, ...)        \
1382   constexpr bool should_register_##ctr##__flag =                      \
1383       SHOULD_REGISTER_OP_KERNEL(#__VA_ARGS__);                        \
1384   static ::tensorflow::kernel_factory::OpKernelRegistrar              \
1385       registrar__body__##ctr##__object(                               \
1386           should_register_##ctr##__flag                               \
1387               ? ::tensorflow::register_kernel::kernel_builder.Build() \
1388               : nullptr,                                              \
1389           #__VA_ARGS__,                                               \
1390           [](::tensorflow::OpKernelConstruction* context)             \
1391               -> ::tensorflow::OpKernel* {                            \
1392             return new __VA_ARGS__(context);                          \
1393           });
1394 
1395 // The `REGISTER_SYSTEM_KERNEL_BUILDER()` macro acts as
1396 // `REGISTER_KERNEL_BUILDER()` except that the kernel is registered
1397 // unconditionally even when selective registration is used.
1398 #define REGISTER_SYSTEM_KERNEL_BUILDER(kernel_builder, ...)               \
1399   REGISTER_SYSTEM_KERNEL_BUILDER_UNIQ_HELPER(__COUNTER__, kernel_builder, \
1400                                              __VA_ARGS__)
1401 
1402 #define REGISTER_SYSTEM_KERNEL_BUILDER_UNIQ_HELPER(ctr, kernel_builder, ...) \
1403   REGISTER_SYSTEM_KERNEL_BUILDER_UNIQ(ctr, kernel_builder, __VA_ARGS__)
1404 
1405 #define REGISTER_SYSTEM_KERNEL_BUILDER_UNIQ(ctr, kernel_builder, ...)    \
1406   static ::tensorflow::kernel_factory::OpKernelRegistrar                 \
1407       registrar__body__##ctr##__object(                                  \
1408           ::tensorflow::register_kernel::system::kernel_builder.Build(), \
1409           #__VA_ARGS__,                                                  \
1410           [](::tensorflow::OpKernelConstruction* context)                \
1411               -> ::tensorflow::OpKernel* {                               \
1412             return new __VA_ARGS__(context);                             \
1413           });
1414 
1415 // Checks whether a given kernel is registered on device_type.
1416 bool KernelDefAvailable(const DeviceType& device_type, const NodeDef& node_def);
1417 
1418 // If node_def has a corresponding kernel registered on device_type,
1419 // returns OK and fill in the kernel def and kernel_class_name. <def> and
1420 // <kernel_class_name> may be null.
1421 Status FindKernelDef(const DeviceType& device_type, const NodeDef& node_def,
1422                      const KernelDef** def, string* kernel_class_name);
1423 
1424 // Writes a list of all registered kernels to LOG(INFO), to help users debug
1425 // missing kernel errors.
1426 void LogAllRegisteredKernels();
1427 
1428 // Gets a list of all registered kernels.
1429 KernelList GetAllRegisteredKernels();
1430 
1431 // Gets a list of all registered kernels for which predicate returns true
1432 KernelList GetFilteredRegisteredKernels(
1433     const std::function<bool(const KernelDef&)>& predicate);
1434 
1435 // Gets a list of all registered kernels for a given op
1436 KernelList GetRegisteredKernelsForOp(StringPiece op_name);
1437 
1438 namespace kernel_factory {
1439 
1440 // OpKernelFactory is responsible for creating OpKernels when TensorFlow needs
1441 // them. You register factories with the TensorFlow core by constructing an
1442 // OpKernelRegistrar and passing the factory as a constructor parameter.
1443 class OpKernelFactory {
1444  public:
1445   virtual OpKernel* Create(OpKernelConstruction* context) = 0;
1446   virtual ~OpKernelFactory() = default;
1447 };
1448 
1449 class OpKernelRegistrar {
1450  public:
1451   // Registers the given kernel factory with TensorFlow. TF will call the
1452   // factory Create() method when it determines that a kernel matching the given
1453   // KernelDef is required.
OpKernelRegistrar(const KernelDef * kernel_def,StringPiece kernel_class_name,std::unique_ptr<OpKernelFactory> factory)1454   OpKernelRegistrar(const KernelDef* kernel_def, StringPiece kernel_class_name,
1455                     std::unique_ptr<OpKernelFactory> factory) {
1456     // Perform the check in the header to allow compile-time optimization
1457     // to a no-op, allowing the linker to remove the kernel symbols.
1458     if (kernel_def != nullptr) {
1459       InitInternal(kernel_def, kernel_class_name, std::move(factory));
1460     }
1461   }
1462 
1463   // Registers the given factory function with TensorFlow. This is equivalent
1464   // to registering a factory whose Create function invokes `create_fn`.
OpKernelRegistrar(const KernelDef * kernel_def,StringPiece kernel_class_name,OpKernel * (* create_fn)(OpKernelConstruction *))1465   OpKernelRegistrar(const KernelDef* kernel_def, StringPiece kernel_class_name,
1466                     OpKernel* (*create_fn)(OpKernelConstruction*)) {
1467     // Perform the check in the header to allow compile-time optimization
1468     // to a no-op, allowing the linker to remove the kernel symbols.
1469     if (kernel_def != nullptr) {
1470       InitInternal(kernel_def, kernel_class_name,
1471                    absl::make_unique<PtrOpKernelFactory>(create_fn));
1472     }
1473   }
1474 
1475  private:
1476   struct PtrOpKernelFactory : public OpKernelFactory {
PtrOpKernelFactoryPtrOpKernelFactory1477     explicit PtrOpKernelFactory(OpKernel* (*create_func)(OpKernelConstruction*))
1478         : create_func_(create_func) {}
1479 
1480     OpKernel* Create(OpKernelConstruction* context) override;
1481 
1482     OpKernel* (*create_func_)(OpKernelConstruction*);
1483   };
1484 
1485   void InitInternal(const KernelDef* kernel_def, StringPiece kernel_class_name,
1486                     std::unique_ptr<OpKernelFactory> factory);
1487 };
1488 
1489 }  // namespace kernel_factory
1490 
1491 // -----------------------------------------------------------------------------
1492 // Template and inline method implementations, please ignore
1493 
1494 template <class T>
GetAttr(StringPiece attr_name,T * value)1495 Status OpKernelConstruction::GetAttr(StringPiece attr_name, T* value) const {
1496   return GetNodeAttr(def(), attr_name, value);
1497 }
1498 
input_dtype(int index)1499 inline DataType OpKernelContext::input_dtype(int index) const {
1500   DCHECK_GE(index, 0);
1501   DCHECK_LT(index, num_inputs());
1502   const TensorValue& value((*params_->inputs)[index]);
1503   if (value.is_ref()) {
1504     return MakeRefType(value->dtype());
1505   } else {
1506     return value->dtype();
1507   }
1508 }
1509 
input_memory_type(int index)1510 inline MemoryType OpKernelContext::input_memory_type(int index) const {
1511   DCHECK_GE(index, 0);
1512   DCHECK_LT(index, num_inputs());
1513   return op_kernel().input_memory_types()[index];
1514 }
1515 
expected_output_dtype(int index)1516 inline DataType OpKernelContext::expected_output_dtype(int index) const {
1517   DCHECK_GE(index, 0);
1518   DCHECK_LT(index, num_outputs());
1519   return params_->op_kernel->output_type(index);
1520 }
1521 
output_memory_type(int index)1522 inline MemoryType OpKernelContext::output_memory_type(int index) const {
1523   DCHECK_GE(index, 0);
1524   DCHECK_LT(index, num_outputs());
1525   return op_kernel().output_memory_types()[index];
1526 }
1527 
input_is_ref(int index)1528 inline bool OpKernelContext::input_is_ref(int index) const {
1529   const TensorValue& value((*params_->inputs)[index]);
1530   return value.is_ref();
1531 }
1532 
record_tensor_reference(const Tensor & tensor)1533 inline void OpKernelContext::record_tensor_reference(const Tensor& tensor) {
1534   DCHECK_EQ(params_->device->RequiresRecordingAccessedTensors(),
1535             params_->record_tensor_accesses);
1536   if (params_->record_tensor_accesses) {
1537     really_record_tensor_reference(tensor);
1538   }
1539 }
1540 
retrieve_accessed_tensors(TensorReferenceVector * out_vector)1541 inline void OpKernelContext::retrieve_accessed_tensors(
1542     TensorReferenceVector* out_vector) {
1543   if (params_->record_tensor_accesses) {
1544     mutex_lock l(mu_);
1545     referenced_tensors_->FreezeAndReturnReferences(out_vector);
1546   }
1547 }
1548 
1549 // no input if tensor == nullptr.
has_input(int index)1550 inline bool OpKernelContext::has_input(int index) const {
1551   DCHECK_GE(index, 0);
1552   DCHECK_LT(index, num_inputs());
1553   return (*params_->inputs)[index].tensor != nullptr;
1554 }
1555 
input_ref_mutex(int index)1556 inline mutex* OpKernelContext::input_ref_mutex(int index) {
1557   DCHECK_GE(index, 0);
1558   DCHECK_LT(index, num_inputs());
1559   DCHECK(input_is_ref(index));
1560   return (*params_->inputs)[index].mutex_if_ref;
1561 }
1562 
NotifyUseOfPersistentTensor(const Tensor & t)1563 inline void OpKernelContext::NotifyUseOfPersistentTensor(const Tensor& t) {
1564   if (t.IsInitialized()) {
1565     record_tensor_reference(t);
1566   }
1567 }
1568 
mutable_output(int index)1569 inline Tensor* OpKernelContext::mutable_output(int index) {
1570   DCHECK_GE(index, 0);
1571   DCHECK_LT(index, num_outputs());
1572   // No need to record_tensor_reference since the output must already
1573   // have been set by a call that did so.
1574   return outputs_[index].tensor;
1575 }
1576 
release_output(int index)1577 inline TensorValue OpKernelContext::release_output(int index) {
1578   DCHECK_GE(index, 0);
1579   DCHECK_LT(index, num_outputs());
1580   TensorValue value = outputs_[index];
1581   outputs_[index] = TensorValue();
1582   return value;
1583 }
1584 
forward_input_or_allocate_output(gtl::ArraySlice<int> candidate_input_indices,int output_index,const TensorShape & output_shape,Tensor ** output)1585 inline Status OpKernelContext::forward_input_or_allocate_output(
1586     gtl::ArraySlice<int> candidate_input_indices, int output_index,
1587     const TensorShape& output_shape, Tensor** output) {
1588   for (int input_index : candidate_input_indices) {
1589     if (forward_input_to_output_with_shape(input_index, output_index,
1590                                            output_shape, output)) {
1591       return Status::OK();
1592     }
1593   }
1594   return allocate_output(output_index, output_shape, output);
1595 }
1596 
forward_input_or_allocate_output(gtl::ArraySlice<StringPiece> candidate_input_names,StringPiece output_name,const TensorShape & output_shape,Tensor ** output)1597 inline Status OpKernelContext::forward_input_or_allocate_output(
1598     gtl::ArraySlice<StringPiece> candidate_input_names, StringPiece output_name,
1599     const TensorShape& output_shape, Tensor** output) {
1600   for (const StringPiece& input_name : candidate_input_names) {
1601     if (forward_input_to_output_with_shape(input_name, output_name,
1602                                            output_shape, output)
1603             .ok()) {
1604       return Status::OK();
1605     }
1606   }
1607   return allocate_output(output_name, output_shape, output);
1608 }
1609 
1610 template <typename T>
op_device_context()1611 T* OpKernelContext::op_device_context() {
1612   static_assert(std::is_base_of<DeviceContext, T>::value,
1613                 "T is not a subclass of DeviceContext");
1614   return static_cast<T*>(op_device_context());
1615 }
1616 
1617 template <typename T>
input_device_context(int index)1618 T* OpKernelContext::input_device_context(int index) {
1619   DCHECK_NE(params_->input_device_contexts, nullptr);
1620   DCHECK_GE(index, 0);
1621   DCHECK_LT(index, params_->input_device_contexts->size());
1622   static_assert(std::is_base_of<DeviceContext, T>::value,
1623                 "T is not a subclass of DeviceContext");
1624   return static_cast<T*>((*params_->input_device_contexts)[index]);
1625 }
1626 
input_device_context(int index)1627 inline DeviceContext* OpKernelContext::input_device_context(int index) {
1628   DCHECK_NE(params_->input_device_contexts, nullptr);
1629   DCHECK_GE(index, 0);
1630   DCHECK_LT(index, params_->input_device_contexts->size());
1631   return (*params_->input_device_contexts)[index];
1632 }
1633 
1634 inline const Tensor& OpInputList::operator[](int i) const {
1635   DCHECK_GE(i, 0);
1636   DCHECK_LT(i, stop_ - start_);
1637   return ctx_->input(start_ + i);
1638 }
1639 
ref_mutex(int i)1640 inline mutex* OpMutableInputList::ref_mutex(int i) {
1641   DCHECK_GE(i, 0);
1642   DCHECK_LT(i, stop_ - start_);
1643   return ctx_->input_ref_mutex(start_ + i);
1644 }
1645 
at(int i,bool lock_held)1646 inline Tensor OpMutableInputList::at(int i, bool lock_held) {
1647   DCHECK_GE(i, 0);
1648   DCHECK_LT(i, stop_ - start_);
1649   return ctx_->mutable_input(start_ + i, lock_held);
1650 }
1651 
1652 inline Tensor* OpOutputList::operator[](int i) {
1653   DCHECK_GE(i, 0);
1654   DCHECK_LT(i, stop_ - start_);
1655   return ctx_->mutable_output(start_ + i);
1656 }
1657 
required(int i)1658 inline bool OpOutputList::required(int i) const {
1659   DCHECK_GE(i, 0);
1660   DCHECK_LT(i, stop_ - start_);
1661   return ctx_->output_required(start_ + i);
1662 }
1663 
expected_output_dtype(int i)1664 inline DataType OpOutputList::expected_output_dtype(int i) const {
1665   DCHECK_GE(i, 0);
1666   DCHECK_LT(i, stop_ - start_);
1667   return ctx_->expected_output_dtype(start_ + i);
1668 }
1669 
allocate(int i,const TensorShape & shape,Tensor ** output)1670 inline Status OpOutputList::allocate(int i, const TensorShape& shape,
1671                                      Tensor** output) {
1672   DCHECK_GE(i, 0);
1673   DCHECK_LT(i, stop_ - start_);
1674   return ctx_->allocate_output(start_ + i, shape, output);
1675 }
1676 
set(int i,const Tensor & tensor)1677 inline void OpOutputList::set(int i, const Tensor& tensor) {
1678   DCHECK_GE(i, 0);
1679   DCHECK_LT(i, stop_ - start_);
1680   ctx_->set_output(start_ + i, tensor);
1681 }
1682 
set_ref(int i,mutex * mu,Tensor * tensor_for_ref)1683 inline void OpOutputList::set_ref(int i, mutex* mu, Tensor* tensor_for_ref) {
1684   DCHECK_GE(i, 0);
1685   DCHECK_LT(i, stop_ - start_);
1686   ctx_->set_output_ref(i, mu, tensor_for_ref);
1687 }
1688 
1689 // Convenience macros for asserting and handling exceptional conditions.
1690 // Analogous to the CHECK* macros provided by logging.h.
1691 //
1692 // Example use:
1693 // void Compute(OperationContext* context) {
1694 //   OP_REQUIRES(context, context->num_inputs() == 2,
1695 //               errors::InvalidArgument("FooOp requires 2 arguments"));
1696 //   ...
1697 //   Status status = SomeUncertainMethod();
1698 //   OP_REQUIRES_OK(context, status);
1699 //   ...
1700 // }
1701 
1702 // Generate a fatal error if OP_REQUIRES or OP_REQUIRES_OK are used in
1703 // AsyncOpKernel implementations. If these macros are used and the condition
1704 // does not hold, the `done` callback will never be called and the system will
1705 // deadlock, so a crash failure is preferable. Since the OP_REQUIRES[_OK] macros
1706 // are legal to use in AsyncOpKernel constructors, we use overload resolution
1707 // to distinguish between OpKernelConstruction* and OpKernelContext* context
1708 // types.
1709 class XlaOpKernelContext;
CheckNotInComputeAsync(XlaOpKernelContext *,const char *)1710 inline void CheckNotInComputeAsync(XlaOpKernelContext*, const char*) {}
CheckNotInComputeAsync(OpKernelConstruction *,const char *)1711 inline void CheckNotInComputeAsync(OpKernelConstruction*, const char*) {}
1712 void CheckNotInComputeAsync(OpKernelContext* ctx,
1713                             const char* correct_macro_name);
1714 
1715 #define OP_REQUIRES(CTX, EXP, STATUS)                     \
1716   do {                                                    \
1717     if (!TF_PREDICT_TRUE(EXP)) {                          \
1718       CheckNotInComputeAsync((CTX), "OP_REQUIRES_ASYNC"); \
1719       (CTX)->CtxFailure(__FILE__, __LINE__, (STATUS));    \
1720       return;                                             \
1721     }                                                     \
1722   } while (0)
1723 
1724 #define OP_REQUIRES_OK(CTX, ...)                             \
1725   do {                                                       \
1726     ::tensorflow::Status _s(__VA_ARGS__);                    \
1727     if (!TF_PREDICT_TRUE(_s.ok())) {                         \
1728       CheckNotInComputeAsync((CTX), "OP_REQUIRES_OK_ASYNC"); \
1729       (CTX)->CtxFailureWithWarning(__FILE__, __LINE__, _s);  \
1730       return;                                                \
1731     }                                                        \
1732   } while (0)
1733 
1734 #define OP_REQUIRES_ASYNC(CTX, EXP, STATUS, CALLBACK)  \
1735   do {                                                 \
1736     if (!TF_PREDICT_TRUE(EXP)) {                       \
1737       (CTX)->CtxFailure(__FILE__, __LINE__, (STATUS)); \
1738       (CALLBACK)();                                    \
1739       return;                                          \
1740     }                                                  \
1741   } while (0)
1742 
1743 #define OP_REQUIRES_OK_ASYNC(CTX, STATUS, CALLBACK)         \
1744   do {                                                      \
1745     ::tensorflow::Status _s(STATUS);                        \
1746     if (!TF_PREDICT_TRUE(_s.ok())) {                        \
1747       (CTX)->CtxFailureWithWarning(__FILE__, __LINE__, _s); \
1748       (CALLBACK)();                                         \
1749       return;                                               \
1750     }                                                       \
1751   } while (0)
1752 
1753 }  // namespace tensorflow
1754 
1755 #endif  // TENSORFLOW_CORE_FRAMEWORK_OP_KERNEL_H_
1756