• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_FRAMEWORK_OP_KERNEL_H_
17 #define TENSORFLOW_CORE_FRAMEWORK_OP_KERNEL_H_
18 
19 #include <functional>
20 #include <unordered_set>
21 #include <utility>
22 #include <vector>
23 
24 #include "tensorflow/core/framework/allocator.h"
25 #include "tensorflow/core/framework/cancellation.h"
26 #include "tensorflow/core/framework/control_flow.h"
27 #include "tensorflow/core/framework/device_base.h"
28 #include "tensorflow/core/framework/graph.pb.h"
29 #include "tensorflow/core/framework/kernel_def.pb.h"
30 #include "tensorflow/core/framework/kernel_def_builder.h"
31 #include "tensorflow/core/framework/node_def.pb.h"
32 #include "tensorflow/core/framework/node_def_util.h"
33 #include "tensorflow/core/framework/node_properties.h"
34 #include "tensorflow/core/framework/op.h"  // TODO(b/62899350): Remove
35 #include "tensorflow/core/framework/op_requires.h"
36 #include "tensorflow/core/framework/rendezvous.h"
37 #include "tensorflow/core/framework/selective_registration.h"
38 #include "tensorflow/core/framework/session_state.h"
39 #include "tensorflow/core/framework/tensor.h"
40 #include "tensorflow/core/framework/tensor_shape.h"
41 #include "tensorflow/core/framework/tensor_shape.pb.h"  // TODO(b/62899350): Remove
42 #include "tensorflow/core/framework/tracking_allocator.h"
43 #include "tensorflow/core/framework/types.h"
44 #include "tensorflow/core/framework/types.pb.h"
45 #include "tensorflow/core/lib/core/errors.h"
46 #include "tensorflow/core/lib/core/status.h"
47 #include "tensorflow/core/lib/gtl/array_slice.h"
48 #include "tensorflow/core/lib/gtl/manual_constructor.h"
49 #include "tensorflow/core/platform/env.h"
50 #include "tensorflow/core/platform/logging.h"
51 #include "tensorflow/core/platform/macros.h"
52 #include "tensorflow/core/platform/mutex.h"
53 #include "tensorflow/core/platform/profile_utils/cpu_utils.h"
54 #include "tensorflow/core/platform/thread_annotations.h"
55 #include "tensorflow/core/platform/types.h"
56 #include "tensorflow/core/protobuf/config.pb.h"
57 #include "tensorflow/core/util/managed_stack_trace.h"
58 
59 namespace Eigen {
60 struct ThreadPoolDevice;
61 struct GpuDevice;
62 }  // end namespace Eigen
63 
64 namespace tensorflow {
65 
66 namespace checkpoint {
67 class TensorSliceReaderCacheWrapper;
68 }  // namespace checkpoint
69 
70 class AsyncOpKernel;
71 class CallFrameInterface;
72 class DeviceMgr;
73 class FunctionLibraryRuntime;
74 class OpKernelConstruction;  // declared below
75 class OpKernelContext;       // declared below,
76 class OpRegistryInterface;
77 class ResourceMgr;
78 class ScopedStepContainer;
79 class CollectiveExecutor;
80 class StepStatsCollectorInterface;
81 
82 class OpKernel {
83  public:
84   // OpKernel won't be instantiated by the scheduler, so you may perform
85   // expensive initialization in the descendant's constructor.
86   explicit OpKernel(OpKernelConstruction* context);
87 
88   // Specialized constructor that allows a kernel implementation to mark itself
89   // as a "deferred" op. If true, the executor will provide access to the
90   // `OpKernelContext::inc_num_deferred_ops_function()` and
91   // `OpKernelContext::dec_num_deferred_ops_function()` methods at run-time.
92   OpKernel(OpKernelConstruction* context, bool is_deferred);
93 
94   // Specialized constructor that enables the descendant to provide a custom
95   // `NodeDef` value. For example, this constructor can be used to provide a
96   // stripped-down `NodeDef` that does not contain the full set of attrs (such
97   // as tensor values) if the descendant stores them in a different form.
98   OpKernel(OpKernelConstruction* context, NodeDef&& custom_def,
99            bool is_deferred);
100 
101   virtual ~OpKernel();
102 
103   // An OpKernel's computation can be either synchronous or
104   // asynchronous. All OpKernel Compute() methods must be thread-safe as they
105   // may be called concurrently (e.g. by multiple executions of the same graph
106   // concurrently).
107   //
108   // Most OpKernels should compute synchronously. They should
109   // subclass OpKernel and override the Compute() method and have it
110   // return after completing the supplied work.
111   //
112   // A synchronous OpKernel *MUST NOT* block the calling thread on a
113   // synchronization mechanism (condition variable, Notification, etc.) that
114   // will be unblocked by the execution of another OpKernel. Execution may
115   // deadlock in that case, because the executor may use a bounded number of
116   // threads.
117   //
118   // If an OpKernel must block on the execution of another OpKernel (e.g. a
119   // RecvOp, or a DequeueOp), the implementation *MUST* subclass AsyncOpKernel,
120   // and override `AsyncOpKernel::ComputeAsync()`. In addition, because the
121   // unblocking kernel may never run (due to an error or cancellation), in most
122   // cases the AsyncOpKernel should implement cancellation support via
123   // `ctx->cancellation_manager()`.
124   //
125   // In both cases, implementations of Compute() and ComputeAsync()
126   // get inputs and write outputs through the given OpKernelContext
127   // and returns a status via context->SetStatus(). They must be
128   // thread-safe.
129 
130   // Synchronous compute.
131   //
132   // "context" is guaranteed to be alive until Compute() returns.
133   virtual void Compute(OpKernelContext* context) = 0;
134 
135   // Returns nullptr iff this op kernel is synchronous.
AsAsync()136   virtual AsyncOpKernel* AsAsync() { return nullptr; }
137 
138   // Returns true iff this op kernel is considered "expensive". The
139   // runtime may use this flag to optimize graph execution for example
140   // to "inline" inexpensive kernels.
IsExpensive()141   virtual bool IsExpensive() { return expensive_; }
142 
143   // Returns a pointer to the tensor stored inside constant ops.
const_tensor()144   virtual const Tensor* const_tensor() const { return nullptr; }
145 
146   // Accessors.
def()147   const NodeDef& def() const { return props_->node_def; }
name()148   const std::string& name() const { return props_->node_def.name(); }
name_view()149   absl::string_view name_view() const { return name_view_; }
type_string()150   const std::string& type_string() const { return props_->node_def.op(); }
type_string_view()151   absl::string_view type_string_view() const { return type_string_view_; }
requested_input(int i)152   const std::string& requested_input(int i) const {
153     return props_->node_def.input(i);
154   }
requested_device()155   const std::string& requested_device() const {
156     return props_->node_def.device();
157   }
158 
num_inputs()159   int num_inputs() const { return props_->input_types.size(); }
input_type(int i)160   DataType input_type(int i) const { return props_->input_types[i]; }
input_types()161   const DataTypeVector& input_types() const { return props_->input_types; }
input_memory_types()162   const MemoryTypeVector& input_memory_types() const {
163     return input_memory_types_;
164   }
165 
num_outputs()166   int num_outputs() const { return props_->output_types.size(); }
output_type(int o)167   DataType output_type(int o) const { return props_->output_types[o]; }
output_types()168   const DataTypeVector& output_types() const { return props_->output_types; }
output_memory_types()169   const MemoryTypeVector& output_memory_types() const {
170     return output_memory_types_;
171   }
172 
173   Status InputRange(StringPiece input_name, int* start, int* stop) const;
174   Status OutputRange(StringPiece output_name, int* start, int* stop) const;
175 
176   // Returns `true` if and only if this kernel uses deferred execution.
is_deferred()177   bool is_deferred() const { return is_deferred_; }
178 
179   // Returns a trace string for current computation, op name/type and input
180   // tensor shape/dtype are encoded for profiler cost analysis. Most OpKernel
181   // should use the default implementation.
182   virtual std::string TraceString(const OpKernelContext& ctx,
183                                   bool verbose) const;
184 
185  protected:
186   std::string ShapeTraceString(const OpKernelContext& ctx) const;
187 
188  private:
189   const std::shared_ptr<const NodeProperties> props_;
190   const MemoryTypeVector input_memory_types_;
191   const MemoryTypeVector output_memory_types_;
192   NameRangeMap input_name_map_;
193   NameRangeMap output_name_map_;
194   const absl::string_view name_view_;
195   const absl::string_view type_string_view_;
196   const int graph_def_version_;
197   const bool is_deferred_;
198   bool expensive_;
199 
200   TF_DISALLOW_COPY_AND_ASSIGN(OpKernel);
201 };
202 
203 class AsyncOpKernel : public OpKernel {
204  public:
205   using OpKernel::OpKernel;  // Lift OpKernel constructors.
206 
207   // Asynchronous compute.
208   //
209   // Implementations of ComputeAsync() must ensure that `done` is (eventually)
210   // called exactly once to signal the completion of the computation. The
211   // implementation of ComputeAsync() must not block on the execution of another
212   // OpKernel. `done` may be called by the current thread, or by another thread.
213   // `context` is guaranteed to stay alive until the `done` callback starts.
214   //
215   // Since it is possible that the unblocking kernel may never run (due to an
216   // error or cancellation), in most cases the AsyncOpKernel should implement
217   // cancellation support via `context->cancellation_manager()`.
218   //
219   // WARNING: As soon as the `done` callback starts, `context` and `this` may be
220   // deleted. No code depending on these objects should execute after the call
221   // to `done`.
222   typedef std::function<void()> DoneCallback;
223   virtual void ComputeAsync(OpKernelContext* context, DoneCallback done) = 0;
224 
AsAsync()225   AsyncOpKernel* AsAsync() override { return this; }
226 
227   void Compute(OpKernelContext* context) override;
228 };
229 
230 // Wraps a tensor that is held by an Op across calls to Compute(). For memory
231 // safety when using asynchronous devices like GPUs, the system must be notified
232 // when a Tensor is used inside an Op execution. The wrapper ensures that all
233 // uses of the Tensor are tracked, because in order to retrieve the Tensor the
234 // caller must use AccessTensor which notifies the context.
235 class PersistentTensor {
236  public:
PersistentTensor()237   PersistentTensor() {}
PersistentTensor(const Tensor & tensor)238   explicit PersistentTensor(const Tensor& tensor) : tensor_(tensor) {}
239 
240   // Caller does not own the returned Tensor*.
241   Tensor* AccessTensor(OpKernelConstruction* context);
242   // Caller does not own the returned Tensor*.
243   Tensor* AccessTensor(OpKernelContext* context);
244 
245   // The check for initialization does not need to access the
246   // underlying tensor buffer.
IsInitialized()247   bool IsInitialized() const { return tensor_.IsInitialized(); }
248 
NumElements()249   int64 NumElements() const { return tensor_.NumElements(); }
250 
AllocatedBytes()251   int64 AllocatedBytes() const { return tensor_.AllocatedBytes(); }
252 
253  private:
254   Tensor tensor_;
255 };
256 
257 class OpKernelConstruction {
258  public:
259   OpKernelConstruction(DeviceType device_type, DeviceBase* device,
260                        Allocator* allocator, FunctionLibraryRuntime* flib,
261                        ResourceMgr* resource_mgr,
262                        const std::shared_ptr<const NodeProperties>& props,
263                        const MemoryTypeSlice& input_memory_types,
264                        const MemoryTypeSlice& output_memory_types,
265                        int graph_def_version, Status* status);
266 
env()267   Env* env() const { return device_->env(); }
268 
269   // Allocation of tensors during kernel construction:
270   //
271   // It is legal to temporarily allocate scratch tensor storage during
272   // Op kernel construction. Scratch tensors should be allocated using
273   // allocate_temp below. Some kernels need to keep tensors in between
274   // invocations. If such a Tensor is allocated during kernel
275   // construction this must be done using allocate_persistent, and the
276   // Op may only store the returned PersistentTensor object. When the
277   // Tensor is needed in a subsequent invocation, it can be retrieved
278   // from the PersistentTensor using the AccessTensor method. This
279   // ensures that the system is made aware of any use of the tensor's
280   // allocated memory, which is needed for correctness on asynchronous
281   // devices such as GPUs.
282 
283   // Allocates a temporary Tensor of the specified type and shape. The
284   // Tensor must not be used after kernel construction is
285   // complete. See comment above.
286   Status allocate_temp(DataType type, const TensorShape& shape,
287                        Tensor* out_temp);
288   Status allocate_temp(DataType type, const TensorShape& shape,
289                        Tensor* out_temp, AllocatorAttributes allocator_attr);
290 
291   // Allocates a Tensor of the specified type and shape which the Op
292   // plans to maintain as persistent state. out_persistent holds the
293   // PersistentTensor which is the object the caller should store. For
294   // convenience, if out_tensor is non-null then it will be filled in
295   // with a Tensor* pointing to the newly-allocated tensor which the
296   // caller can use instead of calling
297   // out_persistent->AccessTensor. The caller does not own out_tensor
298   // and should not keep a copy of it. See comment above.
299   Status allocate_persistent(DataType type, const TensorShape& shape,
300                              PersistentTensor* out_persistent,
301                              Tensor** out_tensor);
302 
303   // User-supplied configuration of this operation.
def()304   const NodeDef& def() const { return props_->node_def; }
305 
306   // For inspecting the inputs to this operation.
num_inputs()307   int num_inputs() const { return props_->input_types.size(); }
input_type(int i)308   DataType input_type(int i) const { return props_->input_types[i]; }
input_types()309   const DataTypeSlice& input_types() const { return props_->input_types_slice; }
input_memory_types()310   const MemoryTypeSlice& input_memory_types() const {
311     return input_memory_types_;
312   }
313 
314   // For inspecting the outputs expected from this operation.
num_outputs()315   int num_outputs() const { return props_->output_types.size(); }
output_type(int i)316   DataType output_type(int i) const { return props_->output_types[i]; }
output_types()317   const DataTypeSlice& output_types() const {
318     return props_->output_types_slice;
319   }
output_memory_types()320   const MemoryTypeSlice& output_memory_types() const {
321     return output_memory_types_;
322   }
323 
324   // If expected_inputs == inputs() and expected_outputs == output_types(),
325   // returns OK, else returns INVALID_ARGUMENT with an error message.
326   // Recommended for Ops with dynamic signatures.
327   Status MatchSignature(const DataTypeSlice expected_inputs,
328                         const DataTypeSlice expected_outputs);
329 
330   // For recording configuration errors during construction.
331   void SetStatus(const Status& status);
status()332   const Status& status() const { return *status_; }
333 
334   // Look up the attr with name attr_name and set *value to its value.  If no
335   // attr with attr_name is found in def(), or the attr does not have
336   // a matching type, a non-ok status will be returned.
337   template <class T>
338   Status GetAttr(StringPiece attr_name, T* value) const;
339 
340   // Return true if the attr_name is defined in def().
341   bool HasAttr(StringPiece attr_name) const;
342 
343   // Return the device type.
device_type()344   const DeviceType& device_type() const { return device_type_; }
345 
346   // If not nullptr, the kernel can instantiate functions defined in
347   // the library. E.g.,
348   // CHECK_NOTNULL(function_library())->Instantiate("Foo", ...).
function_library()349   FunctionLibraryRuntime* function_library() const { return flib_; }
350 
351   // Shared resources accessible to this kernel.
resource_manager()352   ResourceMgr* resource_manager() const { return resource_mgr_; }
353 
354   // The GraphDef version whose behavior we should follow.
graph_def_version()355   int graph_def_version() const { return graph_def_version_; }
356 
357   // Helper routines for the OP_REQUIRES macros
358   void CtxFailure(const Status& s);
359   void CtxFailureWithWarning(const Status& s);
360   void CtxFailure(const char* file, int line, const Status& s);
361   void CtxFailureWithWarning(const char* file, int line, const Status& s);
362 
363   // Unrecommended functions: these are functions that have some
364   // current uses but are not recommended for use, and may go away at
365   // some future major version release.
366 
367   // May be used, e.g., to get GPU handles, etc.
368   //
369   // Currently only used to call MakeTensorFromProto() for
370   // implementing ConstantOp for every device.  See comments
371   // on Device::MakeTensorFromProto for longer-term replacement
372   // ideas.
device()373   DeviceBase* device() const { return device_; }
374 
375  private:
376   const DeviceType device_type_;
377   DeviceBase* const device_;
378   Allocator* allocator_;
379   FunctionLibraryRuntime* flib_;
380   ResourceMgr* const resource_mgr_;
381   std::shared_ptr<const NodeProperties> props_;
382   MemoryTypeSlice input_memory_types_;
383   MemoryTypeSlice output_memory_types_;
384   const int graph_def_version_;
385   Status* status_;
386 
387   // Allow access from OpKernel ctor.
388   friend class OpKernel;
389 
390   TF_DISALLOW_COPY_AND_ASSIGN(OpKernelConstruction);
391 };
392 
393 // TODO(mrry): Consider converting to a random_access_iterator, and upgrading
394 // tensorflow::gtl::iterator_range to make the below container classes
395 // unnecessary.
396 template <typename ListType, typename ElementType>
397 class OpArgIterator {
398  public:
399   using iterator_category = std::forward_iterator_tag;
400   using value_type = ElementType;
401   using pointer = ElementType*;
402   using const_pointer = const ElementType*;
403   using reference = ElementType&;
404   using const_reference = const ElementType&;
405   using difference_type = ptrdiff_t;
406 
OpArgIterator(const ListType * list,int i)407   OpArgIterator(const ListType* list, int i) : list_(list), i_(i) {}
408 
409   bool operator==(const OpArgIterator& rhs) {
410     DCHECK(list_ == rhs.list_);
411     return i_ == rhs.i_;
412   }
413 
414   bool operator!=(const OpArgIterator& rhs) {
415     DCHECK(list_ == rhs.list_);
416     return i_ != rhs.i_;
417   }
418 
419   OpArgIterator operator++() {  // prefix ++it
420     ++i_;
421     return *this;
422   }
423 
424   OpArgIterator operator++(int) {  // postfix it++
425     OpArgIterator old_value = *this;
426     ++i_;
427     return old_value;
428   }
429 
430   reference operator*() { return (*list_)[i_]; }
431   pointer operator->() { return &(*list_)[i_]; }
432 
433   const_reference operator*() const { return (*list_)[i_]; }
434   const_pointer operator->() const { return &(*list_)[i_]; }
435 
436  private:
437   const ListType* const list_;
438   int i_;
439 };
440 
441 // Utility class for representing a list of immutable input tensors
442 // that are passed to the op as a single named argument.
443 class OpInputList {
444  public:
445   typedef OpArgIterator<OpInputList, const Tensor> Iterator;
OpInputList()446   OpInputList() : ctx_(nullptr), start_(0), stop_(0) {}
OpInputList(OpKernelContext * ctx,int start,int stop)447   OpInputList(OpKernelContext* ctx, int start, int stop)
448       : ctx_(ctx), start_(start), stop_(stop) {}
449   OpInputList& operator=(const OpInputList& other) = default;
450   const Tensor& operator[](int i) const;
size()451   int size() const { return stop_ - start_; }
begin()452   Iterator begin() const { return Iterator(this, 0); }
end()453   Iterator end() const { return Iterator(this, size()); }
454 
455  private:
456   OpKernelContext* ctx_;  // not owned
457   int start_;
458   int stop_;
459 };
460 
461 // Utility class for representing a list of mutable ("ref") input tensors
462 // that are passed to the op as a single named argument.
463 class OpMutableInputList {
464  public:
465   typedef OpArgIterator<OpMutableInputList, Tensor*> Iterator;
OpMutableInputList(OpKernelContext * ctx,int start,int stop)466   OpMutableInputList(OpKernelContext* ctx, int start, int stop)
467       : ctx_(ctx), start_(start), stop_(stop) {}
OpMutableInputList()468   OpMutableInputList() : ctx_(nullptr), start_(0), stop_(0) {}
469   OpMutableInputList& operator=(const OpMutableInputList& other) = default;
470   Tensor at(int i, bool lock_held);
471   mutex* ref_mutex(int i);
size()472   int size() const { return stop_ - start_; }
begin()473   Iterator begin() const { return Iterator(this, 0); }
end()474   Iterator end() const { return Iterator(this, size()); }
475 
476  private:
477   OpKernelContext* ctx_;  // not owned
478   int start_;
479   int stop_;
480 };
481 
482 // Utility class for representing a list of output tensors that are
483 // grouped as a single named output.
484 class OpOutputList {
485  public:
486   typedef OpArgIterator<OpOutputList, const Tensor*> Iterator;
OpOutputList()487   OpOutputList() : ctx_(nullptr), start_(0), stop_(0) {}
OpOutputList(OpKernelContext * ctx,int start,int stop)488   OpOutputList(OpKernelContext* ctx, int start, int stop)
489       : ctx_(ctx), start_(start), stop_(stop) {}
490   OpOutputList& operator=(const OpOutputList& other) = default;
491   Tensor* operator[](int i);
492   bool required(int i) const;
493   DataType expected_output_dtype(int i) const;
494   Status allocate(int i, const TensorShape& shape, Tensor** output);
495   void set(int i, const Tensor& tensor);
496   void set(int i, Tensor&& tensor);
497   void set_ref(int i, mutex* mu, Tensor* tensor_for_ref);
size()498   int size() const { return stop_ - start_; }
begin()499   Iterator begin() const { return Iterator(this, 0); }
end()500   Iterator end() const { return Iterator(this, size()); }
501 
502  private:
503   OpKernelContext* ctx_;  // not owned
504   int start_;
505   int stop_;
506 };
507 
508 // Holds a tensor or tensor reference. For tensor references, we need
509 // a mutex to prevent concurrent access to the tensor.
510 struct TensorValue {
TensorValueTensorValue511   TensorValue() : mutex_if_ref(nullptr), tensor(nullptr) {}
TensorValueTensorValue512   explicit TensorValue(Tensor* t) : mutex_if_ref(nullptr), tensor(t) {}
TensorValueTensorValue513   TensorValue(mutex* mu, Tensor* t) : mutex_if_ref(mu), tensor(t) {}
514   Tensor* operator->() const { return tensor; }
is_refTensorValue515   bool is_ref() const { return mutex_if_ref != nullptr; }
516 
517   // Return the dtype of the Tensor. For references, return the underlying type.
dtypeTensorValue518   DataType dtype() const {
519     if (is_ref()) {
520       return MakeRefType(tensor->dtype());
521     } else {
522       return tensor->dtype();
523     }
524   }
525 
526   // Return the dtype of the Tensor. For references, return the underlying type.
527   // This variation on the dtype() acquires the lock for references.
528   //
529   // TODO(b/133843385): Disallow dtype modifications
dtype_safeTensorValue530   DataType dtype_safe() const {
531     if (is_ref()) {
532       tf_shared_lock ml(*mutex_if_ref);
533       return MakeRefType(tensor->dtype());
534     } else {
535       return tensor->dtype();
536     }
537   }
538 
539   mutex* mutex_if_ref;  // nullptr if not a ref, != nullptr if a ref
540   Tensor* tensor;
541 };
542 
543 // Used to store partitioned graphs from function-calling ops.
544 struct GraphCollector {
545   mutex mu;
546   std::vector<GraphDef> partitioned_graphs TF_GUARDED_BY(mu);
547   GraphDef raw_graph TF_GUARDED_BY(mu);
548   GraphDef optimized_graph TF_GUARDED_BY(mu);
549 
550   bool dirty TF_GUARDED_BY(mu);
551 
GraphCollectorGraphCollector552   GraphCollector() : dirty(false) {}
553 
CollectRawGraphGraphCollector554   void CollectRawGraph(const GraphDef& graph) {
555     mutex_lock ml(mu);
556     raw_graph.MergeFrom(graph);
557     dirty = true;
558   }
559 
CollectOptimizedGraphGraphCollector560   void CollectOptimizedGraph(const GraphDef& graph) {
561     mutex_lock ml(mu);
562     optimized_graph.MergeFrom(graph);
563     dirty = true;
564   }
565 
CollectPartitionedGraphGraphCollector566   void CollectPartitionedGraph(const GraphDef& graph) {
567     mutex_lock ml(mu);
568     partitioned_graphs.push_back(graph);
569     dirty = true;
570   }
571 
ClearGraphsGraphCollector572   void ClearGraphs() TF_EXCLUSIVE_LOCKS_REQUIRED(mu) {
573     raw_graph.Clear();
574     optimized_graph.Clear();
575     partitioned_graphs.clear();
576     dirty = false;
577   }
578 
HasUpdatedGraphsGraphCollector579   bool HasUpdatedGraphs() {
580     mutex_lock ml(mu);
581     return dirty;
582   }
583 };
584 
585 class OpKernelContext {
586  public:
587   // The first element of a WrappedAllocator is a "base" Allocator and
588   // the second element is that Allocator wrapped by a
589   // TrackingAllocator
590   typedef std::pair<Allocator*, TrackingAllocator*> WrappedAllocator;
591 
592   // TODO(zhifengc): Do some cleanup of Params.
593   // The Params struct is passed in to initialize an OpKernelContext,
594   // and must outlive the OpKernelContext.
595   struct Params {
~ParamsParams596     ~Params() { delete eigen_gpu_device; }
597 
598     // The step being executed.
599     int64 step_id = 0;
600 
601     // The op kernel being computed.
602     OpKernel* op_kernel = nullptr;
603 
604     // The device on which the kernel is running.
605     DeviceBase* device = nullptr;
606 
607     // The Eigen GPU device wrapper, which may include a per-op
608     // wrapped allocator. The concrete type of this object depends on
609     // the type of this->device, so eigen_gpu_device can't be an
610     // inline member and must be heap allocated. However, we don't
611     // want to allocate a new eigen_gpu_device for every Op that is
612     // executed. Instead this member is allocated on first use using
613     // ensure_eigen_gpu_device, and then if the Params structure is
614     // re-used for subsequent Ops, the eigen_gpu_device is
615     // ReInitialized in the OpKernelContext constructor. Unlike the
616     // other pointers in Params, this one is owned by Params.
617     PerOpGpuDevice* eigen_gpu_device = nullptr;
618 
ensure_eigen_gpu_deviceParams619     inline void ensure_eigen_gpu_device() {
620       DCHECK(device);
621       if (nullptr == eigen_gpu_device) {
622         // Surprisingly, MakeGpuDevice will return nullptr if the
623         // device is not a GPU device. This is ok, since those devices
624         // will never use eigen_gpu_device. It seems better to have
625         // ensure_eigen_gpu_device fall through and regenerate the
626         // nullptr every time an OpKernelContext is instantiated, than
627         // to do an unnecessary allocation of a dummy eigen GPU
628         // device for CPU device Ops.
629         eigen_gpu_device = device->MakeGpuDevice();
630       }
631     }
632 
633     bool track_allocations = false;
634     bool log_memory = false;
635 
636     // Array indexed by output number for this node
637     const AllocatorAttributes* output_attr_array = nullptr;
638 
639     // Shared resources accessible by this op kernel invocation.
640     ResourceMgr* resource_manager = nullptr;
641 
642     // Per-step resources accessible by this op kernel invocation should be
643     // stored in this container..
644     ScopedStepContainer* step_container = nullptr;
645 
646     // Mechanism used by this op kernel invocation to communicate with
647     // computations running on other devices.
648     RendezvousInterface* rendezvous = nullptr;
649 
650     // Mechanism for executing a collective op that needs to coordinate
651     // with parallel instances running on other devices.
652     CollectiveExecutor* collective_executor = nullptr;
653 
654     // The session state for this op.
655     SessionState* session_state = nullptr;
656 
657     // Unique session identifier. Can be empty.
658     std::string session_handle;
659 
660     // Metadata about the session. Can be nullptr.
661     const SessionMetadata* session_metadata = nullptr;
662 
663     // The tensor store for this op.
664     TensorStore* tensor_store = nullptr;
665 
666     // Mechanism used by this op kernel invocation to register a callback
667     // for its cancellation.
668     CancellationManager* cancellation_manager = nullptr;
669 
670     // Inputs to this op kernel.
671     const gtl::InlinedVector<TensorValue, 4>* inputs = nullptr;
672     bool is_input_dead = false;
673 
674     const gtl::InlinedVector<AllocatorAttributes, 4>* input_alloc_attrs =
675         nullptr;
676 
677     // Device context.
678     DeviceContext* op_device_context = nullptr;
679 
680     // Control-flow op supports.
681     FrameAndIter frame_iter;
682 
683     // Function call supports.
684     CallFrameInterface* call_frame = nullptr;
685     FunctionLibraryRuntime* function_library = nullptr;
686     std::function<void(std::function<void()>)>* runner = nullptr;
687     StepStatsCollectorInterface* stats_collector = nullptr;
688     GraphCollector* graph_collector = nullptr;
689     bool run_all_kernels_inline = false;
690     const std::string* executor_type = nullptr;
691 
692     // TensorSliceReaderCache support.
693     checkpoint::TensorSliceReaderCacheWrapper* slice_reader_cache = nullptr;
694 
695     // Support for forwarding reservations (used by ScopedAllocator).
696     static constexpr int kNeverForward = -2;
697     static constexpr int kNoReservation = -1;
698     // Values in [0,...) represent reservations for the indexed output.
699     const int* forward_from_array = nullptr;
700 
701     // For tracking actively running deferred ops.
702     std::function<void()> inc_num_deferred_ops_function;
703     std::function<void()> dec_num_deferred_ops_function;
704 
705     absl::optional<ManagedStackTrace> stack_trace = {};
706 
707     // For implementing `OpKernelContext::output_required()`. If null, all
708     // outputs are required.
709     bool* outputs_required_array = nullptr;
710   };
711 
712   // params must outlive the OpKernelContext.
713   explicit OpKernelContext(Params* params);
714   OpKernelContext(Params* params, int num_outputs);
715   ~OpKernelContext();
716 
env()717   Env* env() const { return params_->device->env(); }
718 
step_id()719   int64 step_id() const { return params_->step_id; }
720 
op_kernel()721   const OpKernel& op_kernel() const { return *params_->op_kernel; }
722 
723   // Stack trace of where the op was defined (if defined in eager mode).
stack_trace()724   const absl::optional<ManagedStackTrace>& stack_trace() const {
725     return params_->stack_trace;
726   }
727 
728   // Input/output signature.
729 
num_inputs()730   int num_inputs() const { return params_->inputs->size(); }
731   DataType input_dtype(int index) const;
732   Status input_dtype(StringPiece name, DataType* dtype) const;
733   MemoryType input_memory_type(int index) const;
734 
num_outputs()735   int num_outputs() const { return outputs_.size(); }
736   DataType expected_output_dtype(int index) const;
737   MemoryType output_memory_type(int index) const;
738 
739   // Input
740 
741   // Returns an immutable input tensor. May only be used for non-Ref
742   // inputs. For Ref inputs use mutable_input below.
743   // REQUIRES: !IsRefType(input_dtype(index))
744   // TODO(mrry): Convert this to return Status.
745   const Tensor& input(int index) const;
746 
747   // Returns the named immutable input tensor in "tensor", as defined
748   // in the OpDef. May only be used for non-Ref inputs. For Ref inputs
749   // use mutable_input below.
750   // REQUIRES: !IsRefType(input_dtype(index))
751   // REQUIRES: the named input must not be a list.
752   Status input(StringPiece name, const Tensor** tensor);
753 
754   // Returns the named list-valued immutable input in "list", as
755   // defined in the OpDef.  If the named output is not list-valued,
756   // returns a one-element list. May only be used for non-Ref
757   // inputs. For Ref inputs use mutable_input below.
758   // REQUIRES: !IsRefType(input_dtype(index))
759   Status input_list(StringPiece name, OpInputList* list);
760 
761   // For mutable inputs, use the following together to make sure there
762   // is no concurrent access to mutable_input(), e.g.:
763   // {
764   //   Tensor& t = context->mutable_input(index);
765   //   mutex_lock lock(*context->input_ref_mutex(index));
766   //   // modify the values in t
767   // }
768   // REQUIRES: IsRefType(input_dtype(index))
769   Status input_ref_mutex(StringPiece name, mutex** out_mutex);
770 
771   // Returns a mutable input tensor. Must be used to access Ref
772   // inputs.  REQUIRES: IsRefType(input_dtype(index)). The caller may
773   // modify the values stored in the Tensor buffer, and modifications
774   // will be visible to other Ops reading the same ref tensor. If
775   // !lock_held the input mutex will be acquired before returning the
776   // Tensor.
777   // TODO(mrry): Convert this to return Status.
778   Tensor mutable_input(int index, bool lock_held);
779 
780   // Returns the named mutable input tensor in "tensor", as defined in
781   // the OpDef. Must be used to access Ref inputs. The values stored
782   // in the Tensor buffer may be modified, and modifications will be
783   // visible to other Ops reading the same ref tensor. If !lock_held
784   // the input mutex will be acquired before returning the Tensor.
785   // REQUIRES: the named input must not be a list.
786   // REQUIRES: the named input must be a ref tensor.
787   Status mutable_input(StringPiece name, Tensor* tensor, bool lock_held);
788 
789   // Returns the named list-valued mutable input in "list", as defined
790   // in the OpDef.  If the named input is not list-valued, returns a
791   // one-element list. Must be used to access Ref inputs. The values
792   // stored in the Tensor buffer may be modified, and modifications
793   // will be visible to other Ops reading the same ref tensor.
794   // REQUIRES: the named input must be a ref tensor.
795   Status mutable_input_list(StringPiece name, OpMutableInputList* list);
796 
797   // Replace the corresponding Ref Input to use the storage buffer
798   // used by tensor. If !lock_held the input mutex will be acquired
799   // before returning the Tensor.
800   // REQUIRES: IsRefType(input_dtype(index)).
801   void replace_ref_input(int index, const Tensor& tensor, bool lock_held);
802 
803   // Replace the corresponding named Ref Input to use the storage
804   // buffer used by tensor. If !lock_held the input mutex will be
805   // acquired before returning the Tensor.
806   // REQUIRES: IsRefType(input_dtype(index)).
807   Status replace_ref_input(StringPiece name, const Tensor& tensor,
808                            bool lock_held);
809 
810   // Deletes the Tensor object used as the Ref Input at
811   // input_index. This is not usually necessary and should be used
812   // with caution. If !lock_held the input mutex will be acquired
813   // before returning the Tensor.
814   // REQUIRES: IsRefType(input_dtype(input_index)).
815   void delete_ref_input(int input_index, bool lock_held);
816 
817   // Return true if there is input at the given index. An operator has no
818   // input at index if its tensor is null. This is primarily used by the
819   // merge operator.
820   // TODO(mrry): Convert this to return Status.
821   bool has_input(int index) const;
822 
823   // Returns true if all inputs are the same shape, otherwise sets the
824   // status to a non-OK value and returns false.
825   // Usage: if (!context->ValidateInputsAreSameShape(this)) return;
826   bool ValidateInputsAreSameShape(OpKernel* op);
827 
828   // If non-null, kernels should populate with any partition subgraphs created.
graph_collector()829   GraphCollector* graph_collector() { return params_->graph_collector; }
830 
831   // If True, hint that all kernels in functions called by this kernel, should
832   // be treated as "inexpensive", and hence executed on the scheduling thread.
run_all_kernels_inline()833   bool run_all_kernels_inline() const {
834     return params_->run_all_kernels_inline;
835   }
836 
837   // Returns the registered name for the executor type that is executing the
838   // current kernel. If empty, the default executor is used.
839   const std::string& executor_type() const;
840 
841   // Input to output forwarding.
842 
843   // Set the output Ref Tensor at output_index to be an alias of the
844   // input Ref Tensor at input_index.
845   // REQUIRES: IsRefType(input_dtype(input_index)).
846   // REQUIRES: IsRefType(output_dtype(output_index)).
847   void forward_ref_input_to_ref_output(int input_index, int output_index);
848 
849   // Returns true when an alias to input[input_index], reshaped to output_shape,
850   // which is safe to use for in-place computation was written to *output.
851   // Returns false if input[input_index] has a refcount greater than one, or if
852   // its type does not match the expected output type of output[output_index],
853   // or the number of elements in input[input_index] does not equal the number
854   // of elements in output_shape.
855   bool forward_input_to_output_with_shape(int input_index, int output_index,
856                                           const TensorShape& output_shape,
857                                           Tensor** output) TF_MUST_USE_RESULT;
858   Status forward_input_to_output_with_shape(StringPiece input_name,
859                                             StringPiece output_name,
860                                             const TensorShape& output_shape,
861                                             Tensor** output) TF_MUST_USE_RESULT;
862 
863   // Returns a pointer to a Tensor aliasing the underlying buffer backing
864   // input[input_index] iff
865   //   * input[input_index] is not a ref,
866   //   * the data type, shape, memory type, and allocator attributes of
867   //     input[input_index] are compatible with those given in dtype, shape,
868   //     memory_type, and attr,
869   //   * refcount on the underlying buffer is one.
870   //   * Either there is no forwarding reservation for either input_index
871   //     or output_index or the specified input is reserved for the specified
872   //     output. More precisely:
873   //
874   //     These cases mean neither input nor output has a reservation:
875   //        forward_from_array = nullptr
876   //     OR (input_index is not in forward_from_array AND
877   //         (output_index == kNoReservation OR
878   //          forward_from_array[output_index] == kNoReservation))
879   //
880   //     This case means that input_index is reserved for output_index:
881   //        forward_from_array[output_index] == input_index
882   //
883   //     This case means the output is reserved to always be allocated,
884   //     never assigned a forwarded input:
885   //        forward_from_array[output_index] == kNeverForward
886   //
887   // Otherwise returns nullptr.
888   // NOTE: For Cuda kernels that read inputs using the __ldg() intrinsic,
889   // forwarding is only safe if there are no reads via __ldg() after writes
890   // to the same address.
891   std::unique_ptr<Tensor> forward_input(
892       int input_index, int output_index, DataType output_dtype,
893       const TensorShape& output_shape, MemoryType output_memory_type,
894       const AllocatorAttributes& output_attr) TF_MUST_USE_RESULT;
895 
896   // Tries to forward one of the inputs given in input_indices to
897   // output[output_index]. If none of the given inputs can be forwarded, calls
898   // allocate_output() to allocate a new output buffer. The index of the
899   // forwarded input will be assign to output argument forwarded_input (if it's
900   // not nullptr). If no inputs are forwarded, forwarded_input will be assigned
901   // -1.
902   Status forward_input_or_allocate_output(
903       gtl::ArraySlice<int> candidate_input_indices, int output_index,
904       const TensorShape& output_shape, Tensor** output,
905       int* forwarded_input = nullptr) TF_MUST_USE_RESULT;
906   Status forward_input_or_allocate_output(
907       gtl::ArraySlice<StringPiece> candidate_input_names,
908       StringPiece output_name, const TensorShape& output_shape,
909       Tensor** output) TF_MUST_USE_RESULT;
910 
911   // Tries to reuse one of the inputs given in input_indices as a temporary.
912   // If none of the given inputs can be forwarded, calls
913   // allocate_temp() to allocate a new temporary buffer.
914   Status forward_input_or_allocate_temp(
915       gtl::ArraySlice<int> candidate_input_indices, DataType type,
916       const TensorShape& shape, const AllocatorAttributes& allocator_attr,
917       Tensor* out_temp) TF_MUST_USE_RESULT;
918 
forward_input_or_allocate_temp(gtl::ArraySlice<int> candidate_input_indices,DataType type,const TensorShape & shape,Tensor * out_temp)919   Status forward_input_or_allocate_temp(
920       gtl::ArraySlice<int> candidate_input_indices, DataType type,
921       const TensorShape& shape, Tensor* out_temp) TF_MUST_USE_RESULT {
922     return forward_input_or_allocate_temp(candidate_input_indices, type, shape,
923                                           AllocatorAttributes(), out_temp);
924   }
925 
926   // Output
927 
928   // Returns the named list-valued output in "list", as defined in the OpDef.
929   // If the named output is not list-valued, returns a one-element list.
930   Status output_list(StringPiece name, OpOutputList* list);
931 
932   // If output_required(index) returns true, the OpKernel's Compute() method
933   // should call allocate_output(index, ...), set_output(index, ...),
934   // set_output_ref(index, ...), or set the status to a non-ok value.
935   // If it returns false, it may output, but is not required to do so.
output_required(int index)936   bool output_required(int index) const {
937     return !params_->outputs_required_array ||
938            params_->outputs_required_array[index];
939   }
940 
941   // Allocation of tensors during kernel execution inside the Compute
942   // method:
943   //
944   // There are three methods to allocate Tensors when an Op kernel
945   // executes.
946   //
947   // 1) allocate_persistent. This is only needed for Tensors that will
948   // be stored by the Op between invocations, and it *must* be used
949   // for those Tensors. The call returns a PersistentTensor, and that
950   // is the only object the Op is allowed to hold on to between
951   // invocations. When the Tensor is needed in a subsequent
952   // invocation, it can be retrieved from the PersistentTensor using
953   // the AccessTensor method. This ensures that the system is made
954   // aware of any use of the tensor's allocated memory, which is
955   // needed for correctness on asynchronous devices such as GPUs.
956   //
957   // 2) allocate_output. This should be used to allocate any tensor
958   // that is going to be used as an output from the Op at the end of
959   // the current execution. The caller indicates which output the
960   // Tensor will be assigned to, and the call returns the
961   // newly-allocated Tensor. The Tensor can subsequently be assigned
962   // to during kernel execution, and will be used as the designated
963   // output when the kernel execution completes.
964   //
965   // 3) allocate_temp. This should be used to allocate any scratch
966   // storage that is needed while the kernel is executing, and will
967   // not be retained by the Op.
968   //
969   // In some cases a Tensor needs to be used as an output even though
970   // it was previously allocated elsewhere. The Tensor may have been
971   // passed as an input, or stored in a PersistentTensor during a
972   // previous kernel execution, or allocated earlier in the kernel
973   // execution at a time when it was not known which output it would
974   // be assigned to. In this case the kernel can use set_output or
975   // set_output_ref to indicate that the tensor should be used as the
976   // designated output. It is legal to use any previously-allocated
977   // Tensor as an argument to set_output or set_output_ref, including
978   // Tensors allocated via allocate_temp. There may be a performance
979   // penalty to using a Tensor that was not allocated using
980   // allocate_output. This is because allocate_output uses the
981   // AllocatorAttributes stored in output_attr_array for the
982   // designated output. In some cases, using the wrong attributes may
983   // cause an extra copy of the Tensor's buffer.
984 
985   // Allocates output for the specified output index with shape.
986   // OpKernelContext retains ownership of the returned pointer. See
987   // comment above.
988   //
989   // If memory allocation fails, returns an error status.
990   //
991   // REQUIRES: !IsRefType(expected_output_dtype(index))
992   Status allocate_output(int index, const TensorShape& shape,
993                          Tensor** tensor) TF_MUST_USE_RESULT;
994   Status allocate_output(StringPiece name, const TensorShape& shape,
995                          Tensor** tensor) TF_MUST_USE_RESULT;
996   // The following methods use the supplied attributes instead of
997   // those in output_attr_array. The caller is responsible for
998   // ensuring that the attributes are "compatible" with the
999   // output_attr_array, e.g. the tensor is allocated on the correct
1000   // device. See comment above.
1001   Status allocate_output(int index, const TensorShape& shape, Tensor** tensor,
1002                          AllocatorAttributes attr) TF_MUST_USE_RESULT;
1003   Status allocate_output(StringPiece name, const TensorShape& shape,
1004                          Tensor** tensor,
1005                          AllocatorAttributes attr) TF_MUST_USE_RESULT;
1006 
1007   // Allocates a temporary Tensor of the specified type and
1008   // shape. Devices such as GPUs that enqueue Ops for lazy execution
1009   // may retain references to the temporary tensors after the Op's
1010   // Compute method has run. See comment above.
1011   Status allocate_temp(DataType type, const TensorShape& shape,
1012                        Tensor* out_temp, AllocatorAttributes allocator_attr,
1013                        const AllocationAttributes& allocation_attr);
allocate_temp(DataType type,const TensorShape & shape,Tensor * out_temp,AllocatorAttributes allocator_attr)1014   Status allocate_temp(DataType type, const TensorShape& shape,
1015                        Tensor* out_temp, AllocatorAttributes allocator_attr) {
1016     return allocate_temp(type, shape, out_temp, allocator_attr,
1017                          AllocationAttributes());
1018   }
allocate_temp(DataType type,const TensorShape & shape,Tensor * out_temp)1019   Status allocate_temp(DataType type, const TensorShape& shape,
1020                        Tensor* out_temp) {
1021     return allocate_temp(type, shape, out_temp, AllocatorAttributes());
1022   }
1023 
1024   // Allocates a Tensor of the specified type and shape which the Op
1025   // plans to maintain as persistent state. out_persistent holds the
1026   // PersistentTensor which is the object the caller should store. For
1027   // convenience, if out_tensor is non-null then it will be filled in
1028   // with a Tensor* pointing to the newly-allocated tensor which the
1029   // caller can use instead of calling
1030   // out_persistent->AccessTensor. The caller does not own out_tensor
1031   // and should not keep a copy of it. See comment above.
1032   Status allocate_persistent(DataType type, const TensorShape& shape,
1033                              PersistentTensor* out_persistent,
1034                              Tensor** out_tensor, AllocatorAttributes attr);
allocate_persistent(DataType type,const TensorShape & shape,PersistentTensor * out_persistent,Tensor ** out_tensor)1035   Status allocate_persistent(DataType type, const TensorShape& shape,
1036                              PersistentTensor* out_persistent,
1037                              Tensor** out_tensor) {
1038     return allocate_persistent(type, shape, out_persistent, out_tensor,
1039                                AllocatorAttributes());
1040   }
1041 
1042   // Copies a tensor (allocated by the caller) to the specified output
1043   // index.  REQUIRES: !IsRefType(expected_output_dtype(index))
1044   // REQUIRES: 'tensor' must have the same MemoryType as
1045   // output_memory_types[index]. See comment above.
1046   Status set_output(StringPiece name, const Tensor& tensor);
1047   Status set_output(StringPiece name, Tensor&& tensor);
1048   void set_output(int index, const Tensor& tensor);
1049   void set_output(int index, Tensor&& tensor);
1050 
1051   // To output a reference.  Caller retains ownership of mu and tensor_for_ref,
1052   // and they must outlive all uses within the step. See comment above.
1053   // REQUIRES: IsRefType(expected_output_dtype(index))
1054   Status set_output_ref(StringPiece name, mutex* mu, Tensor* tensor_for_ref);
1055 
1056   // Returns nullptr if allocate_output() or set_output() have not been called.
1057   Status mutable_output(StringPiece name, Tensor** tensor);
1058 
1059   // Return the DeviceContext that should be used for this Op.
1060   //
1061   // If using the templated function, the type must be a subclass
1062   // of DeviceContext.
1063   //
1064   // Returns nullptr if the device did not provide one.
1065   template <typename T>
1066   T* op_device_context();
op_device_context()1067   DeviceContext* op_device_context() {
1068     DeviceContext* ret = params_->op_device_context;
1069     if (ret == nullptr) {
1070       auto* dev_info = device()->tensorflow_gpu_device_info();
1071       if (dev_info) ret = dev_info->default_context;
1072     }
1073     return ret;
1074   }
1075 
input_alloc_attr(int index)1076   AllocatorAttributes input_alloc_attr(int index) const {
1077     if (params_->input_alloc_attrs == nullptr) {
1078       return AllocatorAttributes();
1079     } else {
1080       DCHECK_GE(index, 0);
1081       DCHECK_LT(index, params_->input_alloc_attrs->size());
1082       return (*params_->input_alloc_attrs)[index];
1083     }
1084   }
1085 
output_alloc_attr(int index)1086   AllocatorAttributes output_alloc_attr(int index) const {
1087     return params_->output_attr_array[index];
1088   }
1089 
ConsumeWrappedAllocators()1090   gtl::InlinedVector<WrappedAllocator, 4> ConsumeWrappedAllocators() {
1091     gtl::InlinedVector<WrappedAllocator, 4> retrieved;
1092     if (tracking_state_) {
1093       mutex_lock lock(tracking_state_->mu);
1094       retrieved.swap(tracking_state_->wrapped_allocators);
1095     }
1096     return retrieved;
1097   }
1098 
1099   // Communication.
1100   //
1101   // An op kernel communicates with outside environment through
1102   // Rendezvous Send() and Recv().
rendezvous()1103   RendezvousInterface* rendezvous() const { return params_->rendezvous; }
1104 
collective_executor()1105   CollectiveExecutor* collective_executor() const {
1106     return params_->collective_executor;
1107   }
1108 
1109   // An op kernel can access the session state it belongs to.
session_state()1110   SessionState* session_state() const { return params_->session_state; }
1111 
1112   // Unique identifier of the session it belongs to. Can be empty.
session_handle()1113   std::string session_handle() const { return params_->session_handle; }
1114 
1115   // Metadata about the session. Can be nullptr.
session_metadata()1116   const SessionMetadata* session_metadata() const {
1117     return params_->session_metadata;
1118   }
1119 
1120   // An op kernel can access the tensor store of the run it belongs to.
tensor_store()1121   TensorStore* tensor_store() const { return params_->tensor_store; }
1122 
1123   // Function call support.
1124   //
1125   // If this kernel invocation is within a function execution,
1126   // call_frame() returns the call frame for the function call.
call_frame()1127   CallFrameInterface* call_frame() const { return params_->call_frame; }
1128 
1129   // If not nullptr, the kernel invoke functions defined in the
1130   // library. E.g., CHECK_NOTNULL(function_library())->Run("Foo", ...).
function_library()1131   FunctionLibraryRuntime* function_library() const {
1132     return params_->function_library;
1133   }
1134 
runner()1135   std::function<void(std::function<void()>)>* runner() const {
1136     return params_->runner;
1137   }
stats_collector()1138   StepStatsCollectorInterface* stats_collector() const {
1139     return params_->stats_collector;
1140   }
1141 
1142   // Shared resources accessible to this kernel.
resource_manager()1143   ResourceMgr* resource_manager() const { return params_->resource_manager; }
1144 
slice_reader_cache()1145   checkpoint::TensorSliceReaderCacheWrapper* slice_reader_cache() const {
1146     return params_->slice_reader_cache;
1147   }
1148 
1149   // Execution.
1150   //
1151   // OpKernels can use these eigen devices to carry out their
1152   // numerical computation.
eigen_cpu_device()1153   const Eigen::ThreadPoolDevice& eigen_cpu_device() const {
1154     return *device()->eigen_cpu_device();
1155   }
eigen_gpu_device()1156   const Eigen::GpuDevice& eigen_gpu_device() const {
1157     return params_->eigen_gpu_device->device();
1158   }
1159   template <typename EigenDeviceType>
1160   const EigenDeviceType& eigen_device() const;
1161 
1162   // Error handling.
1163 
1164   // If expected_inputs == inputs() and expected_outputs == output_types(),
1165   // returns OK, else returns INVALID_ARGUMENT with an error message.
1166   // Recommended for Ops with dynamic signatures, where validation can only
1167   // be performed at runtime.
1168   Status MatchSignature(const DataTypeSlice expected_inputs,
1169                         const DataTypeSlice expected_outputs);
1170 
1171   // An OpKernel should call SetStatus() if Compute() encounters an
1172   // error.
1173   void SetStatus(const Status& status);
status()1174   const Status& status() const { return status_; }
1175 
1176   // Cancellation.
1177   //
1178   // EXPERIMENTAL. See the implementation in tensorflow::FIFOQueue for an
1179   // example of how to use this API.
cancellation_manager()1180   CancellationManager* cancellation_manager() const {
1181     return params_->cancellation_manager;
1182   }
1183 
1184   // Other accessors.
1185 
1186   // For control flow.
frame_iter()1187   FrameAndIter frame_iter() const { return params_->frame_iter; }
is_input_dead()1188   bool is_input_dead() const { return params_->is_input_dead; }
1189 
1190   // May be used, e.g., to get GPU handles, etc.
1191   // TODO(tucker): Add example usage.
device()1192   DeviceBase* device() const { return params_->device; }
1193 
1194   // Per-step container for use by white-listed internal ops.
step_container()1195   ScopedStepContainer* step_container() const {
1196     return params_->step_container;
1197   }
1198 
1199   // Helper routines for the OP_REQUIRES macros
1200   void CtxFailure(const Status& s);
1201   void CtxFailureWithWarning(const Status& s);
1202   void CtxFailure(const char* file, int line, const Status& s);
1203   void CtxFailureWithWarning(const char* file, int line, const Status& s);
1204 
1205   // Unrecommended functions: these are functions that have some
1206   // current uses but are not recommended for use, and may go away at
1207   // some future major version release.
1208   //
1209   // The following functions all have versions that return Status
1210   // to capture error conditions, and are strongly preferred.
1211   Tensor* mutable_output(int index);
1212   mutex* input_ref_mutex(int index);
1213   void set_output_ref(int index, mutex* mu, Tensor* tensor_for_ref);
1214   TensorValue release_output(int index);
1215 
track_allocations()1216   bool track_allocations() const { return params_->track_allocations; }
1217 
1218   // Records temp memory allocation. Tensor object is recorded to identify the
1219   // case where temp memory is used as output memory.
1220   void record_temp_memory_allocation(int64 size, const Tensor& t)
1221       TF_LOCKS_EXCLUDED(tracking_state_->stats_mu);
1222 
1223   // Returns recorded size of temporary memory;
1224   int64 temp_memory_allocated() const
1225       TF_LOCKS_EXCLUDED(tracking_state_->stats_mu);
1226 
1227   // Records persistent memory allocation, size can be negative indicating
1228   // deallocation.
1229   void record_persistent_memory_allocation(int64 size, int64 alloc_id = -1)
1230       TF_LOCKS_EXCLUDED(tracking_state_->stats_mu);
1231 
1232   // Returns recorded size and ids of persistent memory.
1233   int64 persistent_memory_allocated() const
1234       TF_LOCKS_EXCLUDED(tracking_state_->stats_mu);
1235 
1236   std::vector<int64> persistent_alloc_ids() const
1237       TF_LOCKS_EXCLUDED(tracking_state_->stats_mu);
1238 
1239   // Resets counters for temp and persistent memory and recorded ids.
1240   void clear_recorded_memory() TF_LOCKS_EXCLUDED(tracking_state_->stats_mu);
1241 
1242   bool input_is_ref(int index) const;
1243 
1244   void set_record_memory_consumption(bool v);
1245 
1246   // Used by OpKernel implementations to track actively running deferred ops.
1247   //
1248   // A deferred op is one whose Compute method returns (or whose ComputeAsync
1249   // method invokes the callback) when work is scheduled onto a device. At that
1250   // point, we don't know when the work will actually complete (or if it has
1251   // already completed) on the device. These functions allow the executor to
1252   // track the status of deferred ops and act accordingly.
1253   //
1254   // Deferred OpKernel implementations must use these methods to get two
1255   // functions. It then must call these two functions in pairs, before and after
1256   // device execution, respectively.
inc_num_deferred_ops_function()1257   TF_MUST_USE_RESULT std::function<void()> inc_num_deferred_ops_function() {
1258     DCHECK(params_->op_kernel->is_deferred());
1259     return params_->inc_num_deferred_ops_function
1260                ? params_->inc_num_deferred_ops_function
1261                : []() {};
1262   }
dec_num_deferred_ops_function()1263   TF_MUST_USE_RESULT std::function<void()> dec_num_deferred_ops_function() {
1264     DCHECK(params_->op_kernel->is_deferred());
1265     return params_->dec_num_deferred_ops_function
1266                ? params_->dec_num_deferred_ops_function
1267                : []() {};
1268   }
1269 
1270   Allocator* get_allocator(AllocatorAttributes attr);
1271 
1272  private:
1273   bool record_memory_consumption_ = false;
1274 
1275   // Internal common method used when allocating tensor memory
allocate_tensor(DataType type,const TensorShape & shape,Tensor * out_tensor,AllocatorAttributes allocator_attr)1276   Status allocate_tensor(DataType type, const TensorShape& shape,
1277                          Tensor* out_tensor,
1278                          AllocatorAttributes allocator_attr) {
1279     return allocate_tensor(type, shape, out_tensor, allocator_attr,
1280                            AllocationAttributes());
1281   }
1282 
1283   Status allocate_tensor(DataType type, const TensorShape& shape,
1284                          Tensor* out_tensor, AllocatorAttributes allocator_attr,
1285                          const AllocationAttributes& allocation_attr);
1286 
1287   // Helpers for `set_output()`.
1288 
1289   // Returns `true` if the tensor was copied into an allocated output.
1290   bool maybe_set_output_by_allocate_and_copy(int index, const Tensor& tensor);
1291 
1292   void maybe_track_allocations_for_set_output(const Tensor& tensor);
1293 
1294   Status get_input_index(StringPiece name, int* out_index) const;
1295   Status get_output_index(StringPiece name, int* out_index) const;
1296 
1297   // Initialize the allocated_scope_ids_ set the first time this method is
1298   // called.
1299   void maybe_initialize_scope_id_set();
1300 
1301   Status status_;
1302   friend class CollectiveExecutor;  // for access to params_
1303   Params* params_;                  // not owned
1304   gtl::InlinedVector<TensorValue, 4> outputs_;
1305 
1306   // Keep track of calls to ScopedAllocator.
1307   // TODO(ayushd): change to absl::flat_hash_set.
1308   std::unique_ptr<std::unordered_set<int32>> allocated_scope_ids_;
1309 
1310   // The following data members are only used when allocation tracking is
1311   // enabled, memory consumption is being recorded, or tensor access is being
1312   // recorded.
1313   struct TrackingState {
1314     mutable mutex mu;
1315     gtl::InlinedVector<WrappedAllocator, 4> wrapped_allocators
1316         TF_GUARDED_BY(mu);
1317 
1318     mutable mutex stats_mu;
1319     int64 temp_memory_allocated TF_GUARDED_BY(stats_mu) = 0;
1320 
1321     int64 persistent_memory_allocated TF_GUARDED_BY(stats_mu) = 0;
1322     gtl::InlinedVector<std::pair<const void*, int64>, 2>
1323         temp_tensor_buffer_and_size TF_GUARDED_BY(stats_mu);
1324     gtl::InlinedVector<int64, 2> persistent_alloc_ids TF_GUARDED_BY(stats_mu);
1325   };
1326   std::unique_ptr<TrackingState> tracking_state_;
1327 
1328   // For access to `params_->op_kernel`.
1329   friend void CheckNotInComputeAsync(OpKernelContext* ctx,
1330                                      const char* correct_macro_name);
1331 
1332   TF_DISALLOW_COPY_AND_ASSIGN(OpKernelContext);
1333 };
1334 
1335 template <>
1336 const Eigen::ThreadPoolDevice& OpKernelContext::eigen_device() const;
1337 
1338 template <>
1339 const Eigen::GpuDevice& OpKernelContext::eigen_device() const;
1340 
1341 
1342 // Register your OpKernel by specifying the Op's name, the device the
1343 // kernel runs on, any type attr constraints for this kernel, any
1344 // host-memory args, and the class to instantiate.  Examples:
1345 //
1346 //  // A kernel that supports all types.
1347 //  REGISTER_KERNEL_BUILDER(Name("Save").Device(DEVICE_CPU), SaveOp);
1348 //
1349 //  // The following are equivalent ways of specifying that the kernel only
1350 //  // works if the "T" type attr is set to DT_FLOAT.
1351 //  REGISTER_KERNEL_BUILDER(
1352 //      Name("Sub").Device(DEVICE_CPU).TypeConstraint<float>("T"),
1353 //      SubOp<float>);
1354 //  // (You would then repeat this for every type supported by "Sub".)
1355 //
1356 //  // This form allows you to specify a list of types as the constraint.
1357 //  REGISTER_KERNEL_BUILDER(Name("Sub")
1358 //                              .Device(DEVICE_CPU)
1359 //                              .TypeConstraint("T", {DT_FLOAT}),
1360 //                          SubOp<float>);
1361 //
1362 //  // A kernel that expects one of the input tensors in host memory.
1363 //  REGISTER_KERNEL_BUILDER(
1364 //      Name("Reshape").Device(DEVICE_GPU).HostMemory("shape"), ReshapeOp);
1365 //
1366 // See kernel_def_builder for details.
1367 
1368 // Instantiate an OpKernel that has been registered.  Returns nullptr
1369 // if no operation for that type of device / input signature combination
1370 // (and a NOT_FOUND *status), or there is an error in construction (and
1371 // an INVALID_ARGUMENT *status).  Otherwise, the caller takes ownership
1372 // of the returned pointer.
1373 // EXPECTED USAGE: unique_ptr<OpKernel> op = CreateOpKernel(...);
1374 // REQUIRES: def has all attrs specified (e.g. using AddDefaultsToNodeDef()).
1375 std::unique_ptr<OpKernel> CreateOpKernel(DeviceType device_type,
1376                                          DeviceBase* device,
1377                                          Allocator* allocator,
1378                                          const NodeDef& node_def,
1379                                          int graph_def_version, Status* status);
1380 
1381 std::unique_ptr<OpKernel> CreateOpKernel(
1382     DeviceType device_type, DeviceBase* device, Allocator* allocator,
1383     const std::shared_ptr<const NodeProperties>& props, int graph_def_version,
1384     Status* status);
1385 
1386 Status CreateOpKernel(DeviceType device_type, DeviceBase* device,
1387                       Allocator* allocator, FunctionLibraryRuntime* flib,
1388                       const std::shared_ptr<const NodeProperties>& props,
1389                       int graph_def_version, OpKernel** kernel);
1390 
1391 Status CreateOpKernel(DeviceType device_type, DeviceBase* device,
1392                       Allocator* allocator, FunctionLibraryRuntime* flib,
1393                       ResourceMgr* resource_mgr,
1394                       const std::shared_ptr<const NodeProperties>& props,
1395                       int graph_def_version, OpKernel** kernel);
1396 
1397 // Returns into 'device_types' the subset of prioritized_types that this
1398 // binary has registered for the given NodeDef.
1399 //
1400 // REQUIRES: * 'device_types' is not nullptr.
1401 //           * def has all attrs specified (e.g. using AddDefaultsToNodeDef()).
1402 Status SupportedDeviceTypesForNode(
1403     const std::vector<DeviceType>& prioritized_types, const NodeDef& def,
1404     PrioritizedDeviceTypeVector* device_types,
1405     const DeviceNameUtils::ParsedName* local_address_spec = nullptr);
1406 
1407 // Returns a message with a description of the kernels registered for op
1408 // `op_name`.
1409 std::string KernelsRegisteredForOp(StringPiece op_name);
1410 
1411 // Call once after Op registration has completed.
1412 Status ValidateKernelRegistrations(const OpRegistryInterface& op_registry);
1413 
1414 // -----------------------------------------------------------------------------
1415 // OpKernel registration implementation follows, please ignore.
1416 
1417 // Allow the REGISTER_KERNEL_BUILDER(Name("op_name").Device(...)...) syntax.
1418 namespace register_kernel {
1419 
1420 class Name : public KernelDefBuilder {
1421  public:
Name(const char * op)1422   explicit Name(const char* op) : KernelDefBuilder(op) {}
1423 };
1424 
1425 }  // namespace register_kernel
1426 
1427 // Kernel registration appears as:
1428 //   REGISTER_KERNEL_BUILDER(Name("OpName").Device(DEVICE_CPU)..., OpImpl)
1429 // We'd like to have "OpName" as a constant-expression, without requiring that
1430 // of the overall KernelDefBuilder expression (beginning with the
1431 // register_kernel::Name constructor above).
1432 //
1433 // So, we pull the "OpName" part to a separate macro-level argument. This
1434 // involves treating Name("OpName") as a macro call, via token-pasting (e.g.
1435 // M_## =>  M_Name("OpName")), and having it expand to '"OpName",
1436 // Name("OpName")' which is then usable as two arguments.
1437 #define TF_EXTRACT_KERNEL_NAME_Name(name_str) \
1438   name_str, ::tensorflow::register_kernel::Name(name_str)
1439 #define TF_EXTRACT_KERNEL_NAME_IMPL(m, ...) m(__VA_ARGS__)
1440 #define TF_EXTRACT_KERNEL_NAME(m, kernel_builder, ...)                    \
1441   TF_EXTRACT_KERNEL_NAME_IMPL(m, TF_EXTRACT_KERNEL_NAME_##kernel_builder, \
1442                               __VA_ARGS__)
1443 
1444 // REGISTER_KERNEL_BUILDER_IMPL_2, with a unique 'ctr' as the first argument.
1445 // TODO(dodgen): There are some uses of this macro inside functions, where
1446 // kernel_builder refers to (non-const) locals (they should be fixed). To
1447 // accommodate those, kernel_builder.Build() appears as an argument to an
1448 // immediately-called lambda (not in the lambda itself).
1449 #define REGISTER_KERNEL_BUILDER_IMPL_3(ctr, op_name, kernel_builder_expr,   \
1450                                        is_system_kernel, ...)               \
1451   static ::tensorflow::InitOnStartupMarker const register_kernel_##ctr      \
1452       TF_ATTRIBUTE_UNUSED =                                                 \
1453           TF_INIT_ON_STARTUP_IF(is_system_kernel ||                         \
1454                                 (SHOULD_REGISTER_OP_KERNEL(#__VA_ARGS__) && \
1455                                  SHOULD_REGISTER_OP(op_name)))              \
1456           << ([](::tensorflow::KernelDef const* kernel_def) {               \
1457                ::tensorflow::kernel_factory::OpKernelRegistrar registrar(   \
1458                    kernel_def, #__VA_ARGS__,                                \
1459                    [](::tensorflow::OpKernelConstruction* context)          \
1460                        -> ::tensorflow::OpKernel* {                         \
1461                      return new __VA_ARGS__(context);                       \
1462                    });                                                      \
1463                (void)registrar;                                             \
1464                return ::tensorflow::InitOnStartupMarker{};                  \
1465              })(kernel_builder_expr.Build());
1466 
1467 // REGISTER_KERNEL_BUILDER_IMPL, but with kernel_builder split to op_name,
1468 // kernel_builder_expr.
1469 #define REGISTER_KERNEL_BUILDER_IMPL_2(op_name, kernel_builder_expr, \
1470                                        is_system_kernel, ...)        \
1471   TF_NEW_ID_FOR_INIT(REGISTER_KERNEL_BUILDER_IMPL_3, op_name,        \
1472                      kernel_builder_expr, is_system_kernel, __VA_ARGS__)
1473 
1474 // REGISTER_KERNEL_BUILDER, but with is_system_kernel bound.
1475 #define REGISTER_KERNEL_BUILDER_IMPL(kernel_builder, is_system_kernel, ...) \
1476   TF_EXTRACT_KERNEL_NAME(REGISTER_KERNEL_BUILDER_IMPL_2, kernel_builder,    \
1477                          is_system_kernel, __VA_ARGS__)
1478 
1479 #define REGISTER_KERNEL_BUILDER(kernel_builder, ...) \
1480   TF_ATTRIBUTE_ANNOTATE("tf:kernel")                 \
1481   REGISTER_KERNEL_BUILDER_IMPL(kernel_builder, false, __VA_ARGS__)
1482 
1483 // The `REGISTER_SYSTEM_KERNEL_BUILDER()` macro acts as
1484 // `REGISTER_KERNEL_BUILDER()` except that the kernel is registered
1485 // unconditionally even when selective registration is used.
1486 #define REGISTER_SYSTEM_KERNEL_BUILDER(kernel_builder, ...) \
1487   TF_ATTRIBUTE_ANNOTATE("tf:kernel")                        \
1488   TF_ATTRIBUTE_ANNOTATE("tf:kernel:system")                 \
1489   REGISTER_KERNEL_BUILDER_IMPL(kernel_builder, true, __VA_ARGS__)
1490 
1491 // Checks whether a given kernel is registered on device_type.
1492 bool KernelDefAvailable(const DeviceType& device_type, const NodeDef& node_def);
1493 
1494 // If node of node_name, experimental_debug_info, node_op, node_device and
1495 // node_attrs has a corresponding kernel registered on device_type, returns OK
1496 // and fill in the kernel def and kernel_class_name. <def> and
1497 // <kernel_class_name> may be null.
1498 Status FindKernelDef(
1499     const DeviceType& device_type, StringPiece node_name,
1500     bool has_experimental_debug_info,
1501     const NodeDef_ExperimentalDebugInfo& experimental_debug_info,
1502     StringPiece node_op, StringPiece node_device, AttrSlice node_attrs,
1503     const KernelDef** def, std::string* kernel_class_name);
1504 
1505 // If node_def has a corresponding kernel registered on device_type,
1506 // returns OK and fill in the kernel def and kernel_class_name. <def> and
1507 // <kernel_class_name> may be null.
1508 Status FindKernelDef(const DeviceType& device_type, const NodeDef& node_def,
1509                      const KernelDef** def, std::string* kernel_class_name);
1510 
1511 // Writes a list of all registered kernels to LOG(INFO), to help users debug
1512 // missing kernel errors.
1513 void LogAllRegisteredKernels();
1514 
1515 // Gets a list of all registered kernels.
1516 KernelList GetAllRegisteredKernels();
1517 
1518 // Gets a list of all registered kernels for which predicate returns true
1519 KernelList GetFilteredRegisteredKernels(
1520     const std::function<bool(const KernelDef&)>& predicate);
1521 
1522 // Gets a list of all registered kernels for a given op
1523 KernelList GetRegisteredKernelsForOp(StringPiece op_name);
1524 
1525 namespace kernel_factory {
1526 
1527 // OpKernelFactory is responsible for creating OpKernels when TensorFlow needs
1528 // them. You register factories with the TensorFlow core by constructing an
1529 // OpKernelRegistrar and passing the factory as a constructor parameter.
1530 class OpKernelFactory {
1531  public:
1532   virtual OpKernel* Create(OpKernelConstruction* context) = 0;
1533   virtual ~OpKernelFactory() = default;
1534 };
1535 
1536 class OpKernelRegistrar {
1537  public:
1538   // Registers the given kernel factory with TensorFlow. TF will call the
1539   // factory Create() method when it determines that a kernel matching the given
1540   // KernelDef is required.
OpKernelRegistrar(const KernelDef * kernel_def,StringPiece kernel_class_name,std::unique_ptr<OpKernelFactory> factory)1541   OpKernelRegistrar(const KernelDef* kernel_def, StringPiece kernel_class_name,
1542                     std::unique_ptr<OpKernelFactory> factory) {
1543     InitInternal(kernel_def, kernel_class_name, std::move(factory));
1544   }
1545 
1546   // Registers the given factory function with TensorFlow. This is equivalent
1547   // to registering a factory whose Create function invokes `create_fn`.
OpKernelRegistrar(const KernelDef * kernel_def,StringPiece kernel_class_name,OpKernel * (* create_fn)(OpKernelConstruction *))1548   OpKernelRegistrar(const KernelDef* kernel_def, StringPiece kernel_class_name,
1549                     OpKernel* (*create_fn)(OpKernelConstruction*)) {
1550     InitInternal(kernel_def, kernel_class_name,
1551                  absl::make_unique<PtrOpKernelFactory>(create_fn));
1552   }
1553 
1554  private:
1555   struct PtrOpKernelFactory : public OpKernelFactory {
PtrOpKernelFactoryPtrOpKernelFactory1556     explicit PtrOpKernelFactory(OpKernel* (*create_func)(OpKernelConstruction*))
1557         : create_func_(create_func) {}
1558 
1559     OpKernel* Create(OpKernelConstruction* context) override;
1560 
1561     OpKernel* (*create_func_)(OpKernelConstruction*);
1562   };
1563 
1564   void InitInternal(const KernelDef* kernel_def, StringPiece kernel_class_name,
1565                     std::unique_ptr<OpKernelFactory> factory);
1566 };
1567 
1568 }  // namespace kernel_factory
1569 
1570 // -----------------------------------------------------------------------------
1571 // Template and inline method implementations, please ignore
1572 
1573 template <class T>
GetAttr(StringPiece attr_name,T * value)1574 Status OpKernelConstruction::GetAttr(StringPiece attr_name, T* value) const {
1575   return GetNodeAttr(def(), attr_name, value);
1576 }
1577 
input_dtype(int index)1578 inline DataType OpKernelContext::input_dtype(int index) const {
1579   DCHECK_GE(index, 0);
1580   DCHECK_LT(index, num_inputs());
1581   const TensorValue& value((*params_->inputs)[index]);
1582   return value.dtype();
1583 }
1584 
input_memory_type(int index)1585 inline MemoryType OpKernelContext::input_memory_type(int index) const {
1586   DCHECK_GE(index, 0);
1587   DCHECK_LT(index, num_inputs());
1588   return op_kernel().input_memory_types()[index];
1589 }
1590 
expected_output_dtype(int index)1591 inline DataType OpKernelContext::expected_output_dtype(int index) const {
1592   DCHECK_GE(index, 0);
1593   DCHECK_LT(index, num_outputs());
1594   return params_->op_kernel->output_type(index);
1595 }
1596 
output_memory_type(int index)1597 inline MemoryType OpKernelContext::output_memory_type(int index) const {
1598   DCHECK_GE(index, 0);
1599   DCHECK_LT(index, num_outputs());
1600   return op_kernel().output_memory_types()[index];
1601 }
1602 
input_is_ref(int index)1603 inline bool OpKernelContext::input_is_ref(int index) const {
1604   const TensorValue& value((*params_->inputs)[index]);
1605   return value.is_ref();
1606 }
1607 
1608 // no input if tensor == nullptr.
has_input(int index)1609 inline bool OpKernelContext::has_input(int index) const {
1610   DCHECK_GE(index, 0);
1611   DCHECK_LT(index, num_inputs());
1612   return (*params_->inputs)[index].tensor != nullptr;
1613 }
1614 
input_ref_mutex(int index)1615 inline mutex* OpKernelContext::input_ref_mutex(int index) {
1616   DCHECK_GE(index, 0);
1617   DCHECK_LT(index, num_inputs());
1618   DCHECK(input_is_ref(index));
1619   return (*params_->inputs)[index].mutex_if_ref;
1620 }
1621 
mutable_output(int index)1622 inline Tensor* OpKernelContext::mutable_output(int index) {
1623   DCHECK_GE(index, 0);
1624   DCHECK_LT(index, num_outputs());
1625   return outputs_[index].tensor;
1626 }
1627 
release_output(int index)1628 inline TensorValue OpKernelContext::release_output(int index) {
1629   DCHECK_GE(index, 0);
1630   DCHECK_LT(index, num_outputs());
1631   TensorValue value = outputs_[index];
1632   outputs_[index] = TensorValue();
1633   return value;
1634 }
1635 
forward_input_or_allocate_output(gtl::ArraySlice<int> candidate_input_indices,int output_index,const TensorShape & output_shape,Tensor ** output,int * forwarded_input)1636 inline Status OpKernelContext::forward_input_or_allocate_output(
1637     gtl::ArraySlice<int> candidate_input_indices, int output_index,
1638     const TensorShape& output_shape, Tensor** output, int* forwarded_input) {
1639   for (int input_index : candidate_input_indices) {
1640     if (forward_input_to_output_with_shape(input_index, output_index,
1641                                            output_shape, output)) {
1642       if (forwarded_input != nullptr) {
1643         *forwarded_input = input_index;
1644       }
1645       return Status::OK();
1646     }
1647   }
1648   if (forwarded_input != nullptr) {
1649     *forwarded_input = -1;
1650   }
1651   return allocate_output(output_index, output_shape, output);
1652 }
1653 
forward_input_or_allocate_output(gtl::ArraySlice<StringPiece> candidate_input_names,StringPiece output_name,const TensorShape & output_shape,Tensor ** output)1654 inline Status OpKernelContext::forward_input_or_allocate_output(
1655     gtl::ArraySlice<StringPiece> candidate_input_names, StringPiece output_name,
1656     const TensorShape& output_shape, Tensor** output) {
1657   for (const StringPiece& input_name : candidate_input_names) {
1658     if (forward_input_to_output_with_shape(input_name, output_name,
1659                                            output_shape, output)
1660             .ok()) {
1661       return Status::OK();
1662     }
1663   }
1664   return allocate_output(output_name, output_shape, output);
1665 }
1666 
1667 template <typename T>
op_device_context()1668 T* OpKernelContext::op_device_context() {
1669   static_assert(std::is_base_of<DeviceContext, T>::value,
1670                 "T is not a subclass of DeviceContext");
1671   return static_cast<T*>(op_device_context());
1672 }
1673 
1674 inline const Tensor& OpInputList::operator[](int i) const {
1675   DCHECK_GE(i, 0);
1676   DCHECK_LT(i, stop_ - start_);
1677   return ctx_->input(start_ + i);
1678 }
1679 
ref_mutex(int i)1680 inline mutex* OpMutableInputList::ref_mutex(int i) {
1681   DCHECK_GE(i, 0);
1682   DCHECK_LT(i, stop_ - start_);
1683   return ctx_->input_ref_mutex(start_ + i);
1684 }
1685 
at(int i,bool lock_held)1686 inline Tensor OpMutableInputList::at(int i, bool lock_held) {
1687   DCHECK_GE(i, 0);
1688   DCHECK_LT(i, stop_ - start_);
1689   return ctx_->mutable_input(start_ + i, lock_held);
1690 }
1691 
1692 inline Tensor* OpOutputList::operator[](int i) {
1693   DCHECK_GE(i, 0);
1694   DCHECK_LT(i, stop_ - start_);
1695   return ctx_->mutable_output(start_ + i);
1696 }
1697 
required(int i)1698 inline bool OpOutputList::required(int i) const {
1699   DCHECK_GE(i, 0);
1700   DCHECK_LT(i, stop_ - start_);
1701   return ctx_->output_required(start_ + i);
1702 }
1703 
expected_output_dtype(int i)1704 inline DataType OpOutputList::expected_output_dtype(int i) const {
1705   DCHECK_GE(i, 0);
1706   DCHECK_LT(i, stop_ - start_);
1707   return ctx_->expected_output_dtype(start_ + i);
1708 }
1709 
allocate(int i,const TensorShape & shape,Tensor ** output)1710 inline Status OpOutputList::allocate(int i, const TensorShape& shape,
1711                                      Tensor** output) {
1712   DCHECK_GE(i, 0);
1713   DCHECK_LT(i, stop_ - start_);
1714   return ctx_->allocate_output(start_ + i, shape, output);
1715 }
1716 
set(int i,const Tensor & tensor)1717 inline void OpOutputList::set(int i, const Tensor& tensor) {
1718   DCHECK_GE(i, 0);
1719   DCHECK_LT(i, stop_ - start_);
1720   ctx_->set_output(start_ + i, tensor);
1721 }
1722 
set(int i,Tensor && tensor)1723 inline void OpOutputList::set(int i, Tensor&& tensor) {
1724   DCHECK_GE(i, 0);
1725   DCHECK_LT(i, stop_ - start_);
1726   ctx_->set_output(start_ + i, std::move(tensor));
1727 }
1728 
set_ref(int i,mutex * mu,Tensor * tensor_for_ref)1729 inline void OpOutputList::set_ref(int i, mutex* mu, Tensor* tensor_for_ref) {
1730   DCHECK_GE(i, 0);
1731   DCHECK_LT(i, stop_ - start_);
1732   ctx_->set_output_ref(i, mu, tensor_for_ref);
1733 }
1734 
1735 // Generate a fatal error if OP_REQUIRES or OP_REQUIRES_OK are used in
1736 // AsyncOpKernel implementations. If these macros are used and the condition
1737 // does not hold, the `done` callback will never be called and the system will
1738 // deadlock, so a crash failure is preferable. Since the OP_REQUIRES[_OK] macros
1739 // are legal to use in AsyncOpKernel constructors, we use overload resolution
1740 // to distinguish between OpKernelConstruction* and OpKernelContext* context
1741 // types.
1742 class XlaOpKernelContext;
CheckNotInComputeAsync(XlaOpKernelContext *,const char *)1743 inline void CheckNotInComputeAsync(XlaOpKernelContext*, const char*) {}
CheckNotInComputeAsync(OpKernelConstruction *,const char *)1744 inline void CheckNotInComputeAsync(OpKernelConstruction*, const char*) {}
1745 void CheckNotInComputeAsync(OpKernelContext* ctx,
1746                             const char* correct_macro_name);
1747 
1748 }  // namespace tensorflow
1749 
1750 #endif  // TENSORFLOW_CORE_FRAMEWORK_OP_KERNEL_H_
1751