• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_FRAMEWORK_OP_KERNEL_H_
17 #define TENSORFLOW_CORE_FRAMEWORK_OP_KERNEL_H_
18 
19 #include <functional>
20 #include <unordered_set>
21 #include <utility>
22 #include <vector>
23 
24 #include "absl/time/time.h"
25 #include "absl/types/optional.h"
26 #include "absl/types/span.h"
27 #include "tensorflow/core/framework/allocator.h"
28 #include "tensorflow/core/framework/cancellation.h"
29 #include "tensorflow/core/framework/control_flow.h"
30 #include "tensorflow/core/framework/device_base.h"
31 #include "tensorflow/core/framework/graph.pb.h"
32 #include "tensorflow/core/framework/kernel_def.pb.h"
33 #include "tensorflow/core/framework/kernel_def_builder.h"
34 #include "tensorflow/core/framework/node_def.pb.h"
35 #include "tensorflow/core/framework/node_def_util.h"
36 #include "tensorflow/core/framework/node_properties.h"
37 #include "tensorflow/core/framework/op.h"  // TODO(b/62899350): Remove
38 #include "tensorflow/core/framework/op_requires.h"
39 #include "tensorflow/core/framework/registration/registration.h"
40 #include "tensorflow/core/framework/rendezvous.h"
41 #include "tensorflow/core/framework/session_state.h"
42 #include "tensorflow/core/framework/tensor.h"
43 #include "tensorflow/core/framework/tensor_shape.h"
44 #include "tensorflow/core/framework/tensor_shape.pb.h"  // TODO(b/62899350): Remove
45 #include "tensorflow/core/framework/tracking_allocator.h"
46 #include "tensorflow/core/framework/types.h"
47 #include "tensorflow/core/framework/types.pb.h"
48 #include "tensorflow/core/lib/core/errors.h"
49 #include "tensorflow/core/lib/core/status.h"
50 #include "tensorflow/core/lib/gtl/array_slice.h"
51 #include "tensorflow/core/lib/gtl/manual_constructor.h"
52 #include "tensorflow/core/platform/env.h"
53 #include "tensorflow/core/platform/logging.h"
54 #include "tensorflow/core/platform/macros.h"
55 #include "tensorflow/core/platform/mutex.h"
56 #include "tensorflow/core/platform/profile_utils/cpu_utils.h"
57 #include "tensorflow/core/platform/thread_annotations.h"
58 #include "tensorflow/core/platform/types.h"
59 #include "tensorflow/core/protobuf/config.pb.h"
60 #include "tensorflow/core/util/managed_stack_trace.h"
61 
62 namespace Eigen {
63 struct ThreadPoolDevice;
64 struct GpuDevice;
65 }  // end namespace Eigen
66 
67 namespace tensorflow {
68 
69 namespace checkpoint {
70 class TensorSliceReaderCacheWrapper;
71 }  // namespace checkpoint
72 
73 class AsyncOpKernel;
74 class CallFrameInterface;
75 class DeviceMgr;
76 class FunctionLibraryRuntime;
77 class OpKernelConstruction;  // declared below
78 class OpKernelContext;       // declared below,
79 class OpRegistryInterface;
80 class ResourceMgr;
81 class ScopedStepContainer;
82 class CollectiveExecutor;
83 class StepStatsCollectorInterface;
84 class CoordinationServiceAgent;
85 
86 // A label that is added to kernels that are JIT compiled. These labels will be
87 // removed before kernels are looked up, so they can be used without specifying
88 // the label. This label is a temporary measure to allow JIT kernels to be
89 // disabled if needed.
90 extern const char* kJitKernelLabel;
91 extern const char* kDisableJitKernelsEnvVar;
92 
93 class OpKernel {
94  public:
95   // OpKernel won't be instantiated by the scheduler, so you may perform
96   // expensive initialization in the descendant's constructor.
97   explicit OpKernel(OpKernelConstruction* context);
98 
99   // Specialized constructor that allows a kernel implementation to mark itself
100   // as a "deferred" op. If true, the executor will provide access to the
101   // `OpKernelContext::inc_num_deferred_ops_function()` and
102   // `OpKernelContext::dec_num_deferred_ops_function()` methods at run-time.
103   OpKernel(OpKernelConstruction* context, bool is_deferred);
104 
105   // Specialized constructor that enables the descendant to provide a custom
106   // `NodeDef` value. For example, this constructor can be used to provide a
107   // stripped-down `NodeDef` that does not contain the full set of attrs (such
108   // as tensor values) if the descendant stores them in a different form.
109   OpKernel(OpKernelConstruction* context, NodeDef&& custom_def,
110            bool is_deferred);
111 
112   virtual ~OpKernel();
113 
114   // An OpKernel's computation can be either synchronous or
115   // asynchronous. All OpKernel Compute() methods must be thread-safe as they
116   // may be called concurrently (e.g. by multiple executions of the same graph
117   // concurrently).
118   //
119   // Most OpKernels should compute synchronously. They should
120   // subclass OpKernel and override the Compute() method and have it
121   // return after completing the supplied work.
122   //
123   // A synchronous OpKernel *MUST NOT* block the calling thread on a
124   // synchronization mechanism (condition variable, Notification, etc.) that
125   // will be unblocked by the execution of another OpKernel. Execution may
126   // deadlock in that case, because the executor may use a bounded number of
127   // threads.
128   //
129   // If an OpKernel must block on the execution of another OpKernel (e.g. a
130   // RecvOp, or a DequeueOp), the implementation *MUST* subclass AsyncOpKernel,
131   // and override `AsyncOpKernel::ComputeAsync()`. In addition, because the
132   // unblocking kernel may never run (due to an error or cancellation), in most
133   // cases the AsyncOpKernel should implement cancellation support via
134   // `ctx->cancellation_manager()`.
135   //
136   // In both cases, implementations of Compute() and ComputeAsync()
137   // get inputs and write outputs through the given OpKernelContext
138   // and returns a status via context->SetStatus(). They must be
139   // thread-safe.
140 
141   // Synchronous compute.
142   //
143   // "context" is guaranteed to be alive until Compute() returns.
144   virtual void Compute(OpKernelContext* context) = 0;
145 
146   // Returns nullptr iff this op kernel is synchronous.
AsAsync()147   virtual AsyncOpKernel* AsAsync() { return nullptr; }
148 
149   // Returns true iff this op kernel is considered "expensive". The
150   // runtime may use this flag to optimize graph execution for example
151   // to "inline" inexpensive kernels.
IsExpensive()152   virtual bool IsExpensive() { return expensive_; }
153 
154   // Returns a pointer to the tensor stored inside constant ops.
const_tensor()155   virtual const Tensor* const_tensor() const { return nullptr; }
156 
157   // Accessors.
def()158   const NodeDef& def() const { return props_->node_def; }
name()159   const std::string& name() const { return props_->node_def.name(); }
name_view()160   absl::string_view name_view() const { return name_view_; }
type_string()161   const std::string& type_string() const { return props_->node_def.op(); }
type_string_view()162   absl::string_view type_string_view() const { return type_string_view_; }
requested_input(int i)163   const std::string& requested_input(int i) const {
164     return props_->node_def.input(i);
165   }
requested_device()166   const std::string& requested_device() const {
167     return props_->node_def.device();
168   }
169 
num_inputs()170   int num_inputs() const { return props_->input_types.size(); }
input_type(int i)171   DataType input_type(int i) const { return props_->input_types[i]; }
input_types()172   const DataTypeVector& input_types() const { return props_->input_types; }
input_memory_types()173   const MemoryTypeVector& input_memory_types() const {
174     return input_memory_types_;
175   }
176 
num_outputs()177   int num_outputs() const { return props_->output_types.size(); }
output_type(int o)178   DataType output_type(int o) const { return props_->output_types[o]; }
output_types()179   const DataTypeVector& output_types() const { return props_->output_types; }
output_memory_types()180   const MemoryTypeVector& output_memory_types() const {
181     return output_memory_types_;
182   }
183 
184   Status InputRange(StringPiece input_name, int* start, int* stop) const;
185   Status OutputRange(StringPiece output_name, int* start, int* stop) const;
186 
187   // Returns `true` if and only if this kernel uses deferred execution.
is_deferred()188   bool is_deferred() const { return is_deferred_; }
189 
190   // Returns a trace string for current computation, op name/type and input
191   // tensor shape/dtype are encoded for profiler cost analysis. Most OpKernel
192   // should use the default implementation.
193   virtual std::string TraceString(const OpKernelContext& ctx,
194                                   bool verbose) const;
195 
196  protected:
197   std::string ShapeTraceString(const OpKernelContext& ctx) const;
198 
199  private:
200   const std::shared_ptr<const NodeProperties> props_;
201   const MemoryTypeVector input_memory_types_;
202   const MemoryTypeVector output_memory_types_;
203   NameRangeMap input_name_map_;
204   NameRangeMap output_name_map_;
205   const absl::string_view name_view_;
206   const absl::string_view type_string_view_;
207   const int graph_def_version_;
208   const bool is_deferred_;
209   bool expensive_;
210 
211   TF_DISALLOW_COPY_AND_ASSIGN(OpKernel);
212 };
213 
214 class AsyncOpKernel : public OpKernel {
215  public:
216   using OpKernel::OpKernel;  // Lift OpKernel constructors.
217 
218   // Asynchronous compute.
219   //
220   // Implementations of ComputeAsync() must ensure that `done` is (eventually)
221   // called exactly once to signal the completion of the computation. The
222   // implementation of ComputeAsync() must not block on the execution of another
223   // OpKernel. `done` may be called by the current thread, or by another thread.
224   // `context` is guaranteed to stay alive until the `done` callback starts.
225   //
226   // Since it is possible that the unblocking kernel may never run (due to an
227   // error or cancellation), in most cases the AsyncOpKernel should implement
228   // cancellation support via `context->cancellation_manager()`.
229   //
230   // WARNING: As soon as the `done` callback starts, `context` and `this` may be
231   // deleted. No code depending on these objects should execute after the call
232   // to `done`.
233   typedef std::function<void()> DoneCallback;
234   virtual void ComputeAsync(OpKernelContext* context, DoneCallback done) = 0;
235 
AsAsync()236   AsyncOpKernel* AsAsync() override { return this; }
237 
238   void Compute(OpKernelContext* context) override;
239 };
240 
241 class OpKernelConstruction {
242  public:
243   OpKernelConstruction(DeviceType device_type, DeviceBase* device,
244                        Allocator* allocator, FunctionLibraryRuntime* flib,
245                        ResourceMgr* resource_mgr,
246                        const std::shared_ptr<const NodeProperties>& props,
247                        const MemoryTypeSlice& input_memory_types,
248                        const MemoryTypeSlice& output_memory_types,
249                        int graph_def_version, Status* status);
250 
env()251   Env* env() const { return device_->env(); }
252 
253   // Allocation of tensors during kernel construction:
254   //
255   // It is legal to temporarily allocate scratch tensor storage during
256   // Op kernel construction. Scratch tensors should be allocated using
257   // allocate_temp below. Some kernels need to keep tensors in between
258   // invocations. If such a Tensor is allocated during kernel
259   // construction this also must be done using allocate_temp, and the
260   // Op may only store the returned Tensor object.
261 
262   // Allocates a temporary Tensor of the specified type and shape. The
263   // Tensor must not be used after kernel construction is
264   // complete. See comment above.
265   Status allocate_temp(DataType type, const TensorShape& shape,
266                        Tensor* out_temp);
267   Status allocate_temp(DataType type, const TensorShape& shape,
268                        Tensor* out_temp, AllocatorAttributes allocator_attr);
269 
270   // User-supplied configuration of this operation.
def()271   const NodeDef& def() const { return props_->node_def; }
272 
273   // For inspecting the inputs to this operation.
num_inputs()274   int num_inputs() const { return props_->input_types.size(); }
input_type(int i)275   DataType input_type(int i) const { return props_->input_types[i]; }
input_types()276   const DataTypeSlice& input_types() const { return props_->input_types_slice; }
input_memory_types()277   const MemoryTypeSlice& input_memory_types() const {
278     return input_memory_types_;
279   }
280 
281   // For inspecting the outputs expected from this operation.
num_outputs()282   int num_outputs() const { return props_->output_types.size(); }
output_type(int i)283   DataType output_type(int i) const { return props_->output_types[i]; }
output_types()284   const DataTypeSlice& output_types() const {
285     return props_->output_types_slice;
286   }
output_memory_types()287   const MemoryTypeSlice& output_memory_types() const {
288     return output_memory_types_;
289   }
290 
291   // If expected_inputs == inputs() and expected_outputs == output_types(),
292   // returns OK, else returns INVALID_ARGUMENT with an error message.
293   // Recommended for Ops with dynamic signatures.
294   Status MatchSignature(const DataTypeSlice expected_inputs,
295                         const DataTypeSlice expected_outputs);
296 
297   // For recording configuration errors during construction.
298   void SetStatus(const Status& status);
status()299   const Status& status() const { return *status_; }
300 
301   // Look up the attr with name attr_name and set *value to its value.  If no
302   // attr with attr_name is found in def(), or the attr does not have
303   // a matching type, a non-ok status will be returned.
304   template <class T>
305   Status GetAttr(StringPiece attr_name, T* value) const TF_ATTRIBUTE_NOINLINE;
306 
307   // Return true if the attr_name is defined in def().
308   bool HasAttr(StringPiece attr_name) const;
309 
310   // Return the device type.
device_type()311   const DeviceType& device_type() const { return device_type_; }
312 
313   // If not nullptr, the kernel can instantiate functions defined in
314   // the library. E.g.,
315   // CHECK_NOTNULL(function_library())->Instantiate("Foo", ...).
function_library()316   FunctionLibraryRuntime* function_library() const { return flib_; }
317 
318   // Shared resources accessible to this kernel.
resource_manager()319   ResourceMgr* resource_manager() const { return resource_mgr_; }
320 
321   // The GraphDef version whose behavior we should follow.
graph_def_version()322   int graph_def_version() const { return graph_def_version_; }
323 
324   // Helper routines for the OP_REQUIRES macros
325   void CtxFailure(const Status& s);
326   void CtxFailureWithWarning(const Status& s);
327   void CtxFailure(const char* file, int line, const Status& s);
328   void CtxFailureWithWarning(const char* file, int line, const Status& s);
329 
330   // Unrecommended functions: these are functions that have some
331   // current uses but are not recommended for use, and may go away at
332   // some future major version release.
333 
334   // May be used, e.g., to get GPU handles, etc.
335   //
336   // Currently only used to call MakeTensorFromProto() for
337   // implementing ConstantOp for every device.  See comments
338   // on Device::MakeTensorFromProto for longer-term replacement
339   // ideas.
device()340   DeviceBase* device() const { return device_; }
341 
342  private:
343   const DeviceType device_type_;
344   DeviceBase* const device_;
345   Allocator* allocator_;
346   FunctionLibraryRuntime* flib_;
347   ResourceMgr* const resource_mgr_;
348   std::shared_ptr<const NodeProperties> props_;
349   MemoryTypeSlice input_memory_types_;
350   MemoryTypeSlice output_memory_types_;
351   const int graph_def_version_;
352   Status* status_;
353 
354   // Allow access from OpKernel ctor.
355   friend class OpKernel;
356 
357   TF_DISALLOW_COPY_AND_ASSIGN(OpKernelConstruction);
358 };
359 
360 // TODO(mrry): Consider converting to a random_access_iterator, and upgrading
361 // tensorflow::gtl::iterator_range to make the below container classes
362 // unnecessary.
363 template <typename ListType, typename ElementType>
364 class OpArgIterator {
365  public:
366   using iterator_category = std::forward_iterator_tag;
367   using value_type = ElementType;
368   using pointer = ElementType*;
369   using const_pointer = const ElementType*;
370   using reference = ElementType&;
371   using const_reference = const ElementType&;
372   using difference_type = ptrdiff_t;
373 
OpArgIterator(const ListType * list,int i)374   OpArgIterator(const ListType* list, int i) : list_(list), i_(i) {}
375 
376   bool operator==(const OpArgIterator& rhs) {
377     DCHECK(list_ == rhs.list_);
378     return i_ == rhs.i_;
379   }
380 
381   bool operator!=(const OpArgIterator& rhs) {
382     DCHECK(list_ == rhs.list_);
383     return i_ != rhs.i_;
384   }
385 
386   OpArgIterator operator++() {  // prefix ++it
387     ++i_;
388     return *this;
389   }
390 
391   OpArgIterator operator++(int) {  // postfix it++
392     OpArgIterator old_value = *this;
393     ++i_;
394     return old_value;
395   }
396 
397   reference operator*() { return (*list_)[i_]; }
398   pointer operator->() { return &(*list_)[i_]; }
399 
400   const_reference operator*() const { return (*list_)[i_]; }
401   const_pointer operator->() const { return &(*list_)[i_]; }
402 
403  private:
404   const ListType* const list_;
405   int i_;
406 };
407 
408 // Utility class for representing a list of immutable input tensors
409 // that are passed to the op as a single named argument.
410 class OpInputList {
411  public:
412   typedef OpArgIterator<OpInputList, const Tensor> Iterator;
OpInputList()413   OpInputList() : ctx_(nullptr), start_(0), stop_(0) {}
OpInputList(OpKernelContext * ctx,int start,int stop)414   OpInputList(OpKernelContext* ctx, int start, int stop)
415       : ctx_(ctx), start_(start), stop_(stop) {}
416   OpInputList& operator=(const OpInputList& other) = default;
417   const Tensor& operator[](int i) const;
size()418   int size() const { return stop_ - start_; }
begin()419   Iterator begin() const { return Iterator(this, 0); }
end()420   Iterator end() const { return Iterator(this, size()); }
421 
422  private:
423   OpKernelContext* ctx_;  // not owned
424   int start_;
425   int stop_;
426 };
427 
428 // Utility class for representing a list of mutable ("ref") input tensors
429 // that are passed to the op as a single named argument.
430 class OpMutableInputList {
431  public:
432   typedef OpArgIterator<OpMutableInputList, Tensor*> Iterator;
OpMutableInputList(OpKernelContext * ctx,int start,int stop)433   OpMutableInputList(OpKernelContext* ctx, int start, int stop)
434       : ctx_(ctx), start_(start), stop_(stop) {}
OpMutableInputList()435   OpMutableInputList() : ctx_(nullptr), start_(0), stop_(0) {}
436   OpMutableInputList& operator=(const OpMutableInputList& other) = default;
437   Tensor at(int i, bool lock_held);
438   mutex* ref_mutex(int i);
size()439   int size() const { return stop_ - start_; }
begin()440   Iterator begin() const { return Iterator(this, 0); }
end()441   Iterator end() const { return Iterator(this, size()); }
442 
443  private:
444   OpKernelContext* ctx_;  // not owned
445   int start_;
446   int stop_;
447 };
448 
449 // Utility class for representing a list of output tensors that are
450 // grouped as a single named output.
451 class OpOutputList {
452  public:
453   typedef OpArgIterator<OpOutputList, const Tensor*> Iterator;
OpOutputList()454   OpOutputList() : ctx_(nullptr), start_(0), stop_(0) {}
OpOutputList(OpKernelContext * ctx,int start,int stop)455   OpOutputList(OpKernelContext* ctx, int start, int stop)
456       : ctx_(ctx), start_(start), stop_(stop) {}
457   OpOutputList& operator=(const OpOutputList& other) = default;
458   Tensor* operator[](int i);
459   bool required(int i) const;
460   DataType expected_output_dtype(int i) const;
461   Status allocate(int i, const TensorShape& shape, Tensor** output);
462   void set(int i, const Tensor& tensor);
463   void set(int i, Tensor&& tensor);
464   void set_ref(int i, mutex* mu, Tensor* tensor_for_ref);
size()465   int size() const { return stop_ - start_; }
begin()466   Iterator begin() const { return Iterator(this, 0); }
end()467   Iterator end() const { return Iterator(this, size()); }
468 
469  private:
470   OpKernelContext* ctx_;  // not owned
471   int start_;
472   int stop_;
473 };
474 
475 // Holds a tensor or tensor reference. For tensor references, we need
476 // a mutex to prevent concurrent access to the tensor.
477 struct TensorValue {
TensorValueTensorValue478   TensorValue() : mutex_if_ref(nullptr), tensor(nullptr) {}
TensorValueTensorValue479   explicit TensorValue(Tensor* t) : mutex_if_ref(nullptr), tensor(t) {}
TensorValueTensorValue480   TensorValue(mutex* mu, Tensor* t) : mutex_if_ref(mu), tensor(t) {}
481   Tensor* operator->() const { return tensor; }
is_refTensorValue482   bool is_ref() const { return mutex_if_ref != nullptr; }
483 
484   // Return the dtype of the Tensor. For references, return the underlying type.
dtypeTensorValue485   DataType dtype() const {
486     if (is_ref()) {
487       return MakeRefType(tensor->dtype());
488     } else {
489       return tensor->dtype();
490     }
491   }
492 
493   // Return the dtype of the Tensor. For references, return the underlying type.
494   // This variation on the dtype() acquires the lock for references.
495   //
496   // TODO(b/133843385): Disallow dtype modifications
dtype_safeTensorValue497   DataType dtype_safe() const {
498     if (is_ref()) {
499       tf_shared_lock ml(*mutex_if_ref);
500       return MakeRefType(tensor->dtype());
501     } else {
502       return tensor->dtype();
503     }
504   }
505 
506   mutex* mutex_if_ref;  // nullptr if not a ref, != nullptr if a ref
507   Tensor* tensor;
508 };
509 
510 // Used to store partitioned graphs from function-calling ops.
511 struct GraphCollector {
512   mutex mu;
513   std::vector<GraphDef> partitioned_graphs TF_GUARDED_BY(mu);
514   GraphDef raw_graph TF_GUARDED_BY(mu);
515   GraphDef optimized_graph TF_GUARDED_BY(mu);
516 
517   bool dirty TF_GUARDED_BY(mu);
518 
GraphCollectorGraphCollector519   GraphCollector() : dirty(false) {}
520 
CollectRawGraphGraphCollector521   void CollectRawGraph(const GraphDef& graph) {
522     mutex_lock ml(mu);
523     raw_graph.MergeFrom(graph);
524     dirty = true;
525   }
526 
CollectOptimizedGraphGraphCollector527   void CollectOptimizedGraph(const GraphDef& graph) {
528     mutex_lock ml(mu);
529     optimized_graph.MergeFrom(graph);
530     dirty = true;
531   }
532 
CollectPartitionedGraphGraphCollector533   void CollectPartitionedGraph(const GraphDef& graph) {
534     mutex_lock ml(mu);
535     partitioned_graphs.push_back(graph);
536     dirty = true;
537   }
538 
ClearGraphsGraphCollector539   void ClearGraphs() TF_EXCLUSIVE_LOCKS_REQUIRED(mu) {
540     raw_graph.Clear();
541     optimized_graph.Clear();
542     partitioned_graphs.clear();
543     dirty = false;
544   }
545 
HasUpdatedGraphsGraphCollector546   bool HasUpdatedGraphs() {
547     mutex_lock ml(mu);
548     return dirty;
549   }
550 };
551 
552 class OpKernelContext {
553  public:
554   // The first element of a WrappedAllocator is a "base" Allocator and
555   // the second element is that Allocator wrapped by a
556   // TrackingAllocator
557   typedef std::pair<Allocator*, TrackingAllocator*> WrappedAllocator;
558 
559   // TODO(zhifengc): Do some cleanup of Params.
560   // The Params struct is passed in to initialize an OpKernelContext,
561   // and must outlive the OpKernelContext.
562   struct Params {
~ParamsParams563     ~Params() { delete eigen_gpu_device; }
564 
565     // The step being executed.
566     int64_t step_id = 0;
567 
568     // Timestamp for the start of graph execution. Used for latency metrics.
569     int64_t start_time_usecs = 0;
570 
571     // The deadline for the session to complete by. Empty if unspecified.
572     absl::optional<absl::Time> deadline;
573 
574     // The op kernel being computed.
575     OpKernel* op_kernel = nullptr;
576 
577     // The device on which the kernel is running.
578     DeviceBase* device = nullptr;
579 
580     // The Eigen GPU device wrapper, which may include a per-op
581     // wrapped allocator. The concrete type of this object depends on
582     // the type of this->device, so eigen_gpu_device can't be an
583     // inline member and must be heap allocated. However, we don't
584     // want to allocate a new eigen_gpu_device for every Op that is
585     // executed. Instead this member is allocated on first use using
586     // ensure_eigen_gpu_device, and then if the Params structure is
587     // re-used for subsequent Ops, the eigen_gpu_device is
588     // ReInitialized in the OpKernelContext constructor. Unlike the
589     // other pointers in Params, this one is owned by Params.
590     PerOpGpuDevice* eigen_gpu_device = nullptr;
591 
ensure_eigen_gpu_deviceParams592     inline void ensure_eigen_gpu_device() {
593       DCHECK(device);
594       if (nullptr == eigen_gpu_device) {
595         // Surprisingly, MakeGpuDevice will return nullptr if the
596         // device is not a GPU device. This is ok, since those devices
597         // will never use eigen_gpu_device. It seems better to have
598         // ensure_eigen_gpu_device fall through and regenerate the
599         // nullptr every time an OpKernelContext is instantiated, than
600         // to do an unnecessary allocation of a dummy eigen GPU
601         // device for CPU device Ops.
602         eigen_gpu_device = device->MakeGpuDevice();
603       }
604     }
605 
606     bool track_allocations = false;
607     bool log_memory = false;
608 
609     // Array indexed by output number for this node
610     const AllocatorAttributes* output_attr_array = nullptr;
611 
612     // Shared resources accessible by this op kernel invocation.
613     ResourceMgr* resource_manager = nullptr;
614 
615     // Per-step resources accessible by this op kernel invocation should be
616     // stored in this container..
617     ScopedStepContainer* step_container = nullptr;
618 
619     // Mechanism used by this op kernel invocation to communicate with
620     // computations running on other devices.
621     RendezvousInterface* rendezvous = nullptr;
622 
623     // Mechanism for executing a collective op that needs to coordinate
624     // with parallel instances running on other devices.
625     CollectiveExecutor* collective_executor = nullptr;
626 
627     // The session state for this op.
628     SessionState* session_state = nullptr;
629 
630     // Unique session identifier. Can be empty.
631     std::string session_handle;
632 
633     // Metadata about the session. Can be nullptr.
634     const SessionMetadata* session_metadata = nullptr;
635 
636     // The tensor store for this op.
637     TensorStore* tensor_store = nullptr;
638 
639     // Mechanism used by this op kernel invocation to register a callback
640     // for its cancellation.
641     CancellationManager* cancellation_manager = nullptr;
642 
643     // Inputs to this op kernel.
644     absl::Span<const TensorValue> inputs;
645     bool is_input_dead = false;
646 
647     absl::Span<const AllocatorAttributes> input_alloc_attrs;
648 
649     // Device context.
650     DeviceContext* op_device_context = nullptr;
651 
652     // Control-flow op supports.
653     FrameAndIter frame_iter;
654 
655     // Function call supports.
656     CallFrameInterface* call_frame = nullptr;
657     FunctionLibraryRuntime* function_library = nullptr;
658     std::function<void(std::function<void()>)>* runner = nullptr;
659     StepStatsCollectorInterface* stats_collector = nullptr;
660     GraphCollector* graph_collector = nullptr;
661     bool run_all_kernels_inline = false;
662     const std::string* executor_type = nullptr;
663 
664     // TensorSliceReaderCache support.
665     checkpoint::TensorSliceReaderCacheWrapper* slice_reader_cache = nullptr;
666 
667     // Support for forwarding reservations (used by ScopedAllocator).
668     static constexpr int kNeverForward = -2;
669     static constexpr int kNoReservation = -1;
670     // Values in [0,...) represent reservations for the indexed output.
671     const int* forward_from_array = nullptr;
672 
673     // For tracking actively running deferred ops.
674     std::function<void()> inc_num_deferred_ops_function;
675     std::function<void()> dec_num_deferred_ops_function;
676 
677     absl::optional<ManagedStackTrace> stack_trace = {};
678 
679     // For implementing `OpKernelContext::output_required()`. If null, all
680     // outputs are required.
681     bool* outputs_required_array = nullptr;
682 
683     // For access to distributed coordination service.
684     CoordinationServiceAgent* coordination_service_agent = nullptr;
685   };
686 
687   // params must outlive the OpKernelContext.
688   explicit OpKernelContext(Params* params);
689   OpKernelContext(Params* params, int num_outputs);
690   ~OpKernelContext();
691 
env()692   Env* env() const { return params_->device->env(); }
693 
step_id()694   int64_t step_id() const { return params_->step_id; }
695 
start_time_usecs()696   int64_t start_time_usecs() const { return params_->start_time_usecs; }
697 
698   // The deadline for the session to complete by. Empty if unspecified in
699   // RunOptions.
deadline()700   absl::optional<absl::Time> deadline() const { return params_->deadline; }
701 
op_kernel()702   const OpKernel& op_kernel() const { return *params_->op_kernel; }
703 
704   // Stack trace of where the op was defined (if defined in eager mode).
stack_trace()705   const absl::optional<ManagedStackTrace>& stack_trace() const {
706     return params_->stack_trace;
707   }
708 
709   // Input/output signature.
710 
num_inputs()711   int num_inputs() const { return params_->inputs.size(); }
712   DataType input_dtype(int index) const;
713   Status input_dtype(StringPiece name, DataType* dtype) const;
714   MemoryType input_memory_type(int index) const;
715 
num_outputs()716   int num_outputs() const { return outputs_.size(); }
717   DataType expected_output_dtype(int index) const;
718   MemoryType output_memory_type(int index) const;
719 
720   // Input
721 
722   // Returns an immutable input tensor. May only be used for non-Ref
723   // inputs. For Ref inputs use mutable_input below.
724   // REQUIRES: !IsRefType(input_dtype(index))
725   // TODO(mrry): Convert this to return Status.
726   const Tensor& input(int index) const;
727 
728   // Returns the named immutable input tensor in "tensor", as defined
729   // in the OpDef. May only be used for non-Ref inputs. For Ref inputs
730   // use mutable_input below.
731   // REQUIRES: !IsRefType(input_dtype(index))
732   // REQUIRES: the named input must not be a list.
733   Status input(StringPiece name, const Tensor** tensor);
734 
735   // Returns the named list-valued immutable input in "list", as
736   // defined in the OpDef.  If the named output is not list-valued,
737   // returns a one-element list. May only be used for non-Ref
738   // inputs. For Ref inputs use mutable_input below.
739   // REQUIRES: !IsRefType(input_dtype(index))
740   Status input_list(StringPiece name, OpInputList* list);
741 
742   // For mutable inputs, use the following together to make sure there
743   // is no concurrent access to mutable_input(), e.g.:
744   // {
745   //   Tensor& t = context->mutable_input(index);
746   //   mutex_lock lock(*context->input_ref_mutex(index));
747   //   // modify the values in t
748   // }
749   // REQUIRES: IsRefType(input_dtype(index))
750   Status input_ref_mutex(StringPiece name, mutex** out_mutex);
751 
752   // Returns a mutable input tensor. Must be used to access Ref
753   // inputs.  REQUIRES: IsRefType(input_dtype(index)). The caller may
754   // modify the values stored in the Tensor buffer, and modifications
755   // will be visible to other Ops reading the same ref tensor. If
756   // !lock_held the input mutex will be acquired before returning the
757   // Tensor.
758   // TODO(mrry): Convert this to return Status.
759   Tensor mutable_input(int index, bool lock_held);
760 
761   // Returns the named mutable input tensor in "tensor", as defined in
762   // the OpDef. Must be used to access Ref inputs. The values stored
763   // in the Tensor buffer may be modified, and modifications will be
764   // visible to other Ops reading the same ref tensor. If !lock_held
765   // the input mutex will be acquired before returning the Tensor.
766   // REQUIRES: the named input must not be a list.
767   // REQUIRES: the named input must be a ref tensor.
768   Status mutable_input(StringPiece name, Tensor* tensor, bool lock_held);
769 
770   // Returns the named list-valued mutable input in "list", as defined
771   // in the OpDef.  If the named input is not list-valued, returns a
772   // one-element list. Must be used to access Ref inputs. The values
773   // stored in the Tensor buffer may be modified, and modifications
774   // will be visible to other Ops reading the same ref tensor.
775   // REQUIRES: the named input must be a ref tensor.
776   Status mutable_input_list(StringPiece name, OpMutableInputList* list);
777 
778   // Replace the corresponding Ref Input to use the storage buffer
779   // used by tensor. If !lock_held the input mutex will be acquired
780   // before returning the Tensor.
781   // REQUIRES: IsRefType(input_dtype(index)).
782   void replace_ref_input(int index, const Tensor& tensor, bool lock_held);
783 
784   // Replace the corresponding named Ref Input to use the storage
785   // buffer used by tensor. If !lock_held the input mutex will be
786   // acquired before returning the Tensor.
787   // REQUIRES: IsRefType(input_dtype(index)).
788   Status replace_ref_input(StringPiece name, const Tensor& tensor,
789                            bool lock_held);
790 
791   // Deletes the Tensor object used as the Ref Input at
792   // input_index. This is not usually necessary and should be used
793   // with caution. If !lock_held the input mutex will be acquired
794   // before returning the Tensor.
795   // REQUIRES: IsRefType(input_dtype(input_index)).
796   void delete_ref_input(int input_index, bool lock_held);
797 
798   // Return true if there is input at the given index. An operator has no
799   // input at index if its tensor is null. This is primarily used by the
800   // merge operator.
801   // TODO(mrry): Convert this to return Status.
802   bool has_input(int index) const;
803 
804   // Returns true if all inputs are the same shape, otherwise sets the
805   // status to a non-OK value and returns false.
806   // Usage: if (!context->ValidateInputsAreSameShape(this)) return;
807   bool ValidateInputsAreSameShape(OpKernel* op);
808 
809   // If non-null, kernels should populate with any partition subgraphs created.
graph_collector()810   GraphCollector* graph_collector() { return params_->graph_collector; }
811 
812   // If True, hint that all kernels in functions called by this kernel, should
813   // be treated as "inexpensive", and hence executed on the scheduling thread.
run_all_kernels_inline()814   bool run_all_kernels_inline() const {
815     return params_->run_all_kernels_inline;
816   }
817 
818   // Returns the registered name for the executor type that is executing the
819   // current kernel. If empty, the default executor is used.
820   const std::string& executor_type() const;
821 
822   // Input to output forwarding.
823 
824   // Set the output Ref Tensor at output_index to be an alias of the
825   // input Ref Tensor at input_index.
826   // REQUIRES: IsRefType(input_dtype(input_index)).
827   // REQUIRES: IsRefType(output_dtype(output_index)).
828   void forward_ref_input_to_ref_output(int input_index, int output_index);
829 
830   // Returns true when an alias to input[input_index], reshaped to output_shape,
831   // which is safe to use for in-place computation was written to *output.
832   // Returns false if input[input_index] has a refcount greater than one, or if
833   // its type does not match the expected output type of output[output_index],
834   // or the number of elements in input[input_index] does not equal the number
835   // of elements in output_shape.
836   bool forward_input_to_output_with_shape(int input_index, int output_index,
837                                           const TensorShape& output_shape,
838                                           Tensor** output) TF_MUST_USE_RESULT;
839   Status forward_input_to_output_with_shape(StringPiece input_name,
840                                             StringPiece output_name,
841                                             const TensorShape& output_shape,
842                                             Tensor** output) TF_MUST_USE_RESULT;
843 
844   // Returns a pointer to a Tensor aliasing the underlying buffer backing
845   // input[input_index] iff
846   //   * input[input_index] is not a ref,
847   //   * the data type, shape, memory type, and allocator attributes of
848   //     input[input_index] are compatible with those given in dtype, shape,
849   //     memory_type, and attr,
850   //   * refcount on the underlying buffer is one.
851   //   * Either there is no forwarding reservation for either input_index
852   //     or output_index or the specified input is reserved for the specified
853   //     output. More precisely:
854   //
855   //     These cases mean neither input nor output has a reservation:
856   //        forward_from_array = nullptr
857   //     OR (input_index is not in forward_from_array AND
858   //         (output_index == kNoReservation OR
859   //          forward_from_array[output_index] == kNoReservation))
860   //
861   //     This case means that input_index is reserved for output_index:
862   //        forward_from_array[output_index] == input_index
863   //
864   //     This case means the output is reserved to always be allocated,
865   //     never assigned a forwarded input:
866   //        forward_from_array[output_index] == kNeverForward
867   //
868   // Otherwise returns nullptr.
869   // NOTE: For Cuda kernels that read inputs using the __ldg() intrinsic,
870   // forwarding is only safe if there are no reads via __ldg() after writes
871   // to the same address.
872   std::unique_ptr<Tensor> forward_input(
873       int input_index, int output_index, DataType output_dtype,
874       const TensorShape& output_shape, MemoryType output_memory_type,
875       const AllocatorAttributes& output_attr) TF_MUST_USE_RESULT;
876 
877   // Tries to forward one of the inputs given in input_indices to
878   // output[output_index]. If none of the given inputs can be forwarded, calls
879   // allocate_output() to allocate a new output buffer. The index of the
880   // forwarded input will be assign to output argument forwarded_input (if it's
881   // not nullptr). If no inputs are forwarded, forwarded_input will be assigned
882   // -1.
883   Status forward_input_or_allocate_output(
884       gtl::ArraySlice<int> candidate_input_indices, int output_index,
885       const TensorShape& output_shape, Tensor** output,
886       int* forwarded_input = nullptr) TF_MUST_USE_RESULT;
887   Status forward_input_or_allocate_output(
888       gtl::ArraySlice<StringPiece> candidate_input_names,
889       StringPiece output_name, const TensorShape& output_shape,
890       Tensor** output) TF_MUST_USE_RESULT;
891 
892   // Tries to reuse one of the inputs given in input_indices as a temporary.
893   // If none of the given inputs can be forwarded, calls
894   // allocate_temp() to allocate a new temporary buffer.
895   Status forward_input_or_allocate_temp(
896       gtl::ArraySlice<int> candidate_input_indices, DataType type,
897       const TensorShape& shape, const AllocatorAttributes& allocator_attr,
898       Tensor* out_temp) TF_MUST_USE_RESULT;
899 
forward_input_or_allocate_temp(gtl::ArraySlice<int> candidate_input_indices,DataType type,const TensorShape & shape,Tensor * out_temp)900   Status forward_input_or_allocate_temp(
901       gtl::ArraySlice<int> candidate_input_indices, DataType type,
902       const TensorShape& shape, Tensor* out_temp) TF_MUST_USE_RESULT {
903     return forward_input_or_allocate_temp(candidate_input_indices, type, shape,
904                                           AllocatorAttributes(), out_temp);
905   }
906 
907   // Output
908 
909   // Returns the named list-valued output in "list", as defined in the OpDef.
910   // If the named output is not list-valued, returns a one-element list.
911   Status output_list(StringPiece name, OpOutputList* list);
912 
913   // If output_required(index) returns true, the OpKernel's Compute() method
914   // should call allocate_output(index, ...), set_output(index, ...),
915   // set_output_ref(index, ...), or set the status to a non-ok value.
916   // If it returns false, it may output, but is not required to do so.
output_required(int index)917   bool output_required(int index) const {
918     return !params_->outputs_required_array ||
919            params_->outputs_required_array[index];
920   }
921 
922   // If output_expects_forwarding returns true, the OpKernel's Compute() method
923   // should not allocate the output with allocate_output but instead needs to
924   // use forward_input.
output_expects_forwarding(int index)925   bool output_expects_forwarding(int index) const {
926     return params_->forward_from_array != nullptr &&
927            params_->forward_from_array[index] >= 0;
928   }
929 
930   // Allocation of tensors during kernel execution inside the Compute
931   // method:
932   //
933   // There are two methods to allocate Tensors when an Op kernel
934   // executes.
935   //
936   // 1) allocate_output. This should be used to allocate any tensor
937   // that is going to be used as an output from the Op at the end of
938   // the current execution. The caller indicates which output the
939   // Tensor will be assigned to, and the call returns the
940   // newly-allocated Tensor. The Tensor can subsequently be assigned
941   // to during kernel execution, and will be used as the designated
942   // output when the kernel execution completes.
943   //
944   // 2) allocate_temp. This should be used to allocate any scratch
945   // storage that is needed while the kernel is executing, and will
946   // not be retained by the Op.
947   //
948   // In some cases a Tensor needs to be used as an output even though
949   // it was previously allocated elsewhere. The Tensor may have been
950   // passed as an input, or stored in a Tensor during a
951   // previous kernel execution, or allocated earlier in the kernel
952   // execution at a time when it was not known which output it would
953   // be assigned to. In this case the kernel can use set_output or
954   // set_output_ref to indicate that the tensor should be used as the
955   // designated output. It is legal to use any previously-allocated
956   // Tensor as an argument to set_output or set_output_ref, including
957   // Tensors allocated via allocate_temp. There may be a performance
958   // penalty to using a Tensor that was not allocated using
959   // allocate_output. This is because allocate_output uses the
960   // AllocatorAttributes stored in output_attr_array for the
961   // designated output. In some cases, using the wrong attributes may
962   // cause an extra copy of the Tensor's buffer.
963 
964   // Allocates output for the specified output index with shape.
965   // OpKernelContext retains ownership of the returned pointer. See
966   // comment above.
967   //
968   // If memory allocation fails, returns an error status.
969   //
970   // REQUIRES: !IsRefType(expected_output_dtype(index))
971   Status allocate_output(int index, const TensorShape& shape,
972                          Tensor** tensor) TF_MUST_USE_RESULT;
973   Status allocate_output(StringPiece name, const TensorShape& shape,
974                          Tensor** tensor) TF_MUST_USE_RESULT;
975   // The following methods use the supplied attributes instead of
976   // those in output_attr_array. The caller is responsible for
977   // ensuring that the attributes are "compatible" with the
978   // output_attr_array, e.g. the tensor is allocated on the correct
979   // device. See comment above.
980   Status allocate_output(int index, const TensorShape& shape, Tensor** tensor,
981                          AllocatorAttributes attr) TF_MUST_USE_RESULT;
982   Status allocate_output(StringPiece name, const TensorShape& shape,
983                          Tensor** tensor,
984                          AllocatorAttributes attr) TF_MUST_USE_RESULT;
985 
986   // Allocates a temporary Tensor of the specified type and
987   // shape. Devices such as GPUs that enqueue Ops for lazy execution
988   // may retain references to the temporary tensors after the Op's
989   // Compute method has run. See comment above.
990   Status allocate_temp(DataType type, const TensorShape& shape,
991                        Tensor* out_temp, AllocatorAttributes allocator_attr,
992                        const AllocationAttributes& allocation_attr);
993   Status allocate_temp(DataType type, const TensorShape& shape,
994                        Tensor* out_temp, AllocatorAttributes allocator_attr);
995   Status allocate_temp(DataType type, const TensorShape& shape,
996                        Tensor* out_temp);
997 
998   // Copies a tensor (allocated by the caller) to the specified output
999   // index.  REQUIRES: !IsRefType(expected_output_dtype(index))
1000   // REQUIRES: 'tensor' must have the same MemoryType as
1001   // output_memory_types[index]. See comment above.
1002   Status set_output(StringPiece name, const Tensor& tensor);
1003   Status set_output(StringPiece name, Tensor&& tensor);
1004   void set_output(int index, const Tensor& tensor);
1005   void set_output(int index, Tensor&& tensor);
1006 
1007   // To output a reference.  Caller retains ownership of mu and tensor_for_ref,
1008   // and they must outlive all uses within the step. See comment above.
1009   // REQUIRES: IsRefType(expected_output_dtype(index))
1010   Status set_output_ref(StringPiece name, mutex* mu, Tensor* tensor_for_ref);
1011 
1012   // Returns nullptr if allocate_output() or set_output() have not been called.
1013   Status mutable_output(StringPiece name, Tensor** tensor);
1014 
1015   // Return the DeviceContext that should be used for this Op.
1016   //
1017   // If using the templated function, the type must be a subclass
1018   // of DeviceContext.
1019   //
1020   // Returns nullptr if the device did not provide one.
1021   template <typename T>
1022   T* op_device_context();
op_device_context()1023   DeviceContext* op_device_context() {
1024     DeviceContext* ret = params_->op_device_context;
1025     if (ret == nullptr) {
1026       auto* dev_info = device()->tensorflow_accelerator_device_info();
1027       if (dev_info) ret = dev_info->default_context;
1028     }
1029     return ret;
1030   }
1031 
input_alloc_attr(int index)1032   AllocatorAttributes input_alloc_attr(int index) const {
1033     if (params_->input_alloc_attrs.empty()) {
1034       return AllocatorAttributes();
1035     } else {
1036       DCHECK_GE(index, 0);
1037       DCHECK_LT(index, params_->input_alloc_attrs.size());
1038       return params_->input_alloc_attrs[index];
1039     }
1040   }
1041 
output_alloc_attr(int index)1042   AllocatorAttributes output_alloc_attr(int index) const {
1043     return params_->output_attr_array[index];
1044   }
1045 
ConsumeWrappedAllocators()1046   gtl::InlinedVector<WrappedAllocator, 4> ConsumeWrappedAllocators() {
1047     gtl::InlinedVector<WrappedAllocator, 4> retrieved;
1048     if (tracking_state_) {
1049       mutex_lock lock(tracking_state_->mu);
1050       retrieved.swap(tracking_state_->wrapped_allocators);
1051     }
1052     return retrieved;
1053   }
1054 
1055   // Communication.
1056   //
1057   // An op kernel communicates with outside environment through
1058   // Rendezvous Send() and Recv().
rendezvous()1059   RendezvousInterface* rendezvous() const { return params_->rendezvous; }
1060 
collective_executor()1061   CollectiveExecutor* collective_executor() const {
1062     return params_->collective_executor;
1063   }
1064 
1065   // An op kernel can access the session state it belongs to.
session_state()1066   SessionState* session_state() const { return params_->session_state; }
1067 
1068   // Unique identifier of the session it belongs to. Can be empty.
session_handle()1069   std::string session_handle() const { return params_->session_handle; }
1070 
1071   // Metadata about the session. Can be nullptr.
session_metadata()1072   const SessionMetadata* session_metadata() const {
1073     return params_->session_metadata;
1074   }
1075 
1076   // An op kernel can access the tensor store of the run it belongs to.
tensor_store()1077   TensorStore* tensor_store() const { return params_->tensor_store; }
1078 
1079   // Function call support.
1080   //
1081   // If this kernel invocation is within a function execution,
1082   // call_frame() returns the call frame for the function call.
call_frame()1083   CallFrameInterface* call_frame() const { return params_->call_frame; }
1084 
1085   // If not nullptr, the kernel invoke functions defined in the
1086   // library. E.g., CHECK_NOTNULL(function_library())->Run("Foo", ...).
function_library()1087   FunctionLibraryRuntime* function_library() const {
1088     return params_->function_library;
1089   }
1090 
runner()1091   std::function<void(std::function<void()>)>* runner() const {
1092     return params_->runner;
1093   }
stats_collector()1094   StepStatsCollectorInterface* stats_collector() const {
1095     return params_->stats_collector;
1096   }
1097 
1098   // Shared resources accessible to this kernel.
resource_manager()1099   ResourceMgr* resource_manager() const { return params_->resource_manager; }
1100 
slice_reader_cache()1101   checkpoint::TensorSliceReaderCacheWrapper* slice_reader_cache() const {
1102     return params_->slice_reader_cache;
1103   }
1104 
1105   // Execution.
1106   //
1107   // OpKernels can use these eigen devices to carry out their
1108   // numerical computation.
eigen_cpu_device()1109   const Eigen::ThreadPoolDevice& eigen_cpu_device() const {
1110     return *device()->eigen_cpu_device();
1111   }
eigen_gpu_device()1112   const Eigen::GpuDevice& eigen_gpu_device() const {
1113     return params_->eigen_gpu_device->device();
1114   }
1115   template <typename EigenDeviceType>
1116   const EigenDeviceType& eigen_device() const;
1117 
1118   // Error handling.
1119 
1120   // If expected_inputs == inputs() and expected_outputs == output_types(),
1121   // returns OK, else returns INVALID_ARGUMENT with an error message.
1122   // Recommended for Ops with dynamic signatures, where validation can only
1123   // be performed at runtime.
1124   Status MatchSignature(const DataTypeSlice expected_inputs,
1125                         const DataTypeSlice expected_outputs);
1126 
1127   // An OpKernel should call SetStatus() if Compute() encounters an
1128   // error.
1129   void SetStatus(const Status& status);
status()1130   const Status& status() const { return status_; }
1131 
1132   // Cancellation.
1133   //
1134   // EXPERIMENTAL. See the implementation in tensorflow::FIFOQueue for an
1135   // example of how to use this API.
cancellation_manager()1136   CancellationManager* cancellation_manager() const {
1137     return params_->cancellation_manager;
1138   }
1139 
1140   // Other accessors.
1141 
1142   // For control flow.
frame_iter()1143   FrameAndIter frame_iter() const { return params_->frame_iter; }
is_input_dead()1144   bool is_input_dead() const { return params_->is_input_dead; }
1145 
1146   // May be used, e.g., to get GPU handles, etc.
1147   // TODO(tucker): Add example usage.
device()1148   DeviceBase* device() const { return params_->device; }
1149 
1150   // Per-step container for use by white-listed internal ops.
step_container()1151   ScopedStepContainer* step_container() const {
1152     return params_->step_container;
1153   }
1154 
1155   // Access to distributed coordination service.
coordination_service_agent()1156   CoordinationServiceAgent* coordination_service_agent() const {
1157     return params_->coordination_service_agent;
1158   }
1159 
1160   // Helper routines for the OP_REQUIRES macros
1161   void CtxFailure(const Status& s);
1162   void CtxFailureWithWarning(const Status& s);
1163   void CtxFailure(const char* file, int line, const Status& s);
1164   void CtxFailureWithWarning(const char* file, int line, const Status& s);
1165 
1166   // Unrecommended functions: these are functions that have some
1167   // current uses but are not recommended for use, and may go away at
1168   // some future major version release.
1169   //
1170   // The following functions all have versions that return Status
1171   // to capture error conditions, and are strongly preferred.
1172   Tensor* mutable_output(int index);
1173   mutex* input_ref_mutex(int index);
1174   void set_output_ref(int index, mutex* mu, Tensor* tensor_for_ref);
1175   TensorValue release_output(int index);
1176 
track_allocations()1177   bool track_allocations() const { return params_->track_allocations; }
1178 
1179   // Records temp memory allocation. Tensor object is recorded to identify the
1180   // case where temp memory is used as output memory.
1181   void record_temp_memory_allocation(int64_t size, const Tensor& t)
1182       TF_LOCKS_EXCLUDED(tracking_state_->stats_mu);
1183 
1184   // Returns recorded size of temporary memory;
1185   int64_t temp_memory_allocated() const
1186       TF_LOCKS_EXCLUDED(tracking_state_->stats_mu);
1187 
1188   // Records persistent memory allocation, size can be negative indicating
1189   // deallocation.
1190   void record_persistent_memory_allocation(int64_t size, int64_t alloc_id = -1)
1191       TF_LOCKS_EXCLUDED(tracking_state_->stats_mu);
1192 
1193   // Returns recorded size and ids of persistent memory.
1194   int64_t persistent_memory_allocated() const
1195       TF_LOCKS_EXCLUDED(tracking_state_->stats_mu);
1196 
1197   std::vector<int64_t> persistent_alloc_ids() const
1198       TF_LOCKS_EXCLUDED(tracking_state_->stats_mu);
1199 
1200   // Resets counters for temp and persistent memory and recorded ids.
1201   void clear_recorded_memory() TF_LOCKS_EXCLUDED(tracking_state_->stats_mu);
1202 
1203   bool input_is_ref(int index) const;
1204 
1205   void set_record_memory_consumption(bool v);
1206 
1207   // Used by OpKernel implementations to track actively running deferred ops.
1208   //
1209   // A deferred op is one whose Compute method returns (or whose ComputeAsync
1210   // method invokes the callback) when work is scheduled onto a device. At that
1211   // point, we don't know when the work will actually complete (or if it has
1212   // already completed) on the device. These functions allow the executor to
1213   // track the status of deferred ops and act accordingly.
1214   //
1215   // Deferred OpKernel implementations must use these methods to get two
1216   // functions. It then must call these two functions in pairs, before and after
1217   // device execution, respectively.
inc_num_deferred_ops_function()1218   TF_MUST_USE_RESULT std::function<void()> inc_num_deferred_ops_function() {
1219     DCHECK(params_->op_kernel->is_deferred());
1220     return params_->inc_num_deferred_ops_function
1221                ? params_->inc_num_deferred_ops_function
1222                : []() {};
1223   }
dec_num_deferred_ops_function()1224   TF_MUST_USE_RESULT std::function<void()> dec_num_deferred_ops_function() {
1225     DCHECK(params_->op_kernel->is_deferred());
1226     return params_->dec_num_deferred_ops_function
1227                ? params_->dec_num_deferred_ops_function
1228                : []() {};
1229   }
1230 
1231   Allocator* get_allocator(AllocatorAttributes attr);
1232 
1233  private:
1234   bool record_memory_consumption_ = false;
1235 
1236   // Internal common method used when allocating tensor memory
allocate_tensor(DataType type,const TensorShape & shape,Tensor * out_tensor,AllocatorAttributes allocator_attr)1237   Status allocate_tensor(DataType type, const TensorShape& shape,
1238                          Tensor* out_tensor,
1239                          AllocatorAttributes allocator_attr) {
1240     return allocate_tensor(type, shape, out_tensor, allocator_attr,
1241                            AllocationAttributes());
1242   }
1243 
1244   Status allocate_tensor(DataType type, const TensorShape& shape,
1245                          Tensor* out_tensor, AllocatorAttributes allocator_attr,
1246                          const AllocationAttributes& allocation_attr);
1247 
1248   // Helpers for `set_output()`.
1249 
1250   // Returns `true` if the tensor was copied into an allocated output.
1251   bool maybe_set_output_by_allocate_and_copy(int index, const Tensor& tensor);
1252 
1253   void maybe_track_allocations_for_set_output(const Tensor& tensor);
1254 
1255   Status get_input_index(StringPiece name, int* out_index) const;
1256   Status get_output_index(StringPiece name, int* out_index) const;
1257 
1258   // Initialize the allocated_scope_ids_ set the first time this method is
1259   // called.
1260   void maybe_initialize_scope_id_set();
1261 
1262   Status status_;
1263   friend class CollectiveExecutor;  // for access to params_
1264   Params* params_;                  // not owned
1265   gtl::InlinedVector<TensorValue, 4> outputs_;
1266 
1267   // Keep track of calls to ScopedAllocator.
1268   // TODO(ayushd): change to absl::flat_hash_set.
1269   std::unique_ptr<std::unordered_set<int32>> allocated_scope_ids_;
1270 
1271   // The following data members are only used when allocation tracking is
1272   // enabled, memory consumption is being recorded, or tensor access is being
1273   // recorded.
1274   struct TrackingState {
1275     mutable mutex mu;
1276     gtl::InlinedVector<WrappedAllocator, 4> wrapped_allocators
1277         TF_GUARDED_BY(mu);
1278 
1279     mutable mutex stats_mu;
1280     int64_t temp_memory_allocated TF_GUARDED_BY(stats_mu) = 0;
1281 
1282     int64_t persistent_memory_allocated TF_GUARDED_BY(stats_mu) = 0;
1283     gtl::InlinedVector<std::pair<const void*, int64_t>, 2>
1284         temp_tensor_buffer_and_size TF_GUARDED_BY(stats_mu);
1285     gtl::InlinedVector<int64_t, 2> persistent_alloc_ids TF_GUARDED_BY(stats_mu);
1286   };
1287   std::unique_ptr<TrackingState> tracking_state_;
1288 
1289   // For access to `params_->op_kernel`.
1290   friend void CheckNotInComputeAsync(OpKernelContext* ctx,
1291                                      const char* correct_macro_name);
1292 
1293   TF_DISALLOW_COPY_AND_ASSIGN(OpKernelContext);
1294 };
1295 
1296 template <>
1297 const Eigen::ThreadPoolDevice& OpKernelContext::eigen_device() const;
1298 
1299 template <>
1300 const Eigen::GpuDevice& OpKernelContext::eigen_device() const;
1301 
1302 // Register your OpKernel by specifying the Op's name, the device the
1303 // kernel runs on, any type attr constraints for this kernel, any
1304 // host-memory args, and the class to instantiate.  Examples:
1305 //
1306 //  // A kernel that supports all types.
1307 //  REGISTER_KERNEL_BUILDER(Name("Save").Device(DEVICE_CPU), SaveOp);
1308 //
1309 //  // The following are equivalent ways of specifying that the kernel only
1310 //  // works if the "T" type attr is set to DT_FLOAT.
1311 //  REGISTER_KERNEL_BUILDER(
1312 //      Name("Sub").Device(DEVICE_CPU).TypeConstraint<float>("T"),
1313 //      SubOp<float>);
1314 //  // (You would then repeat this for every type supported by "Sub".)
1315 //
1316 //  // This form allows you to specify a list of types as the constraint.
1317 //  REGISTER_KERNEL_BUILDER(Name("Sub")
1318 //                              .Device(DEVICE_CPU)
1319 //                              .TypeConstraint("T", {DT_FLOAT}),
1320 //                          SubOp<float>);
1321 //
1322 //  // A kernel that expects one of the input tensors in host memory.
1323 //  REGISTER_KERNEL_BUILDER(
1324 //      Name("Reshape").Device(DEVICE_GPU).HostMemory("shape"), ReshapeOp);
1325 //
1326 //  // A kernel that works on any device. Kernels using DEVICE_DEFAULT
1327 //  // must aways run on host and all inputs and outputs must use `HostMemory`.
1328 //  // Kernels for data management, control-flow primitives or working with
1329 //  // tensor shapes for various devices (including `PluggableDevices`) are
1330 //  // typical uses.
1331 //  REGISTER_KERNEL_BUILDER(
1332 //     Name("TensorListLength").Device(DEVICE_DEFAULT).HostMemory("length"),
1333 //     TensorListLength);
1334 //
1335 // See kernel_def_builder for details.
1336 
1337 // Instantiate an OpKernel that has been registered.  Returns nullptr
1338 // if no operation for that type of device / input signature combination
1339 // (and a NOT_FOUND *status), or there is an error in construction (and
1340 // an INVALID_ARGUMENT *status).  Otherwise, the caller takes ownership
1341 // of the returned pointer.
1342 // EXPECTED USAGE: unique_ptr<OpKernel> op = CreateOpKernel(...);
1343 // REQUIRES: def has all attrs specified (e.g. using AddDefaultsToNodeDef()).
1344 std::unique_ptr<OpKernel> CreateOpKernel(DeviceType device_type,
1345                                          DeviceBase* device,
1346                                          Allocator* allocator,
1347                                          const NodeDef& node_def,
1348                                          int graph_def_version, Status* status);
1349 
1350 std::unique_ptr<OpKernel> CreateOpKernel(
1351     DeviceType device_type, DeviceBase* device, Allocator* allocator,
1352     const std::shared_ptr<const NodeProperties>& props, int graph_def_version,
1353     Status* status);
1354 
1355 Status CreateOpKernel(DeviceType device_type, DeviceBase* device,
1356                       Allocator* allocator, FunctionLibraryRuntime* flib,
1357                       const std::shared_ptr<const NodeProperties>& props,
1358                       int graph_def_version, OpKernel** kernel);
1359 
1360 Status CreateOpKernel(DeviceType device_type, DeviceBase* device,
1361                       Allocator* allocator, FunctionLibraryRuntime* flib,
1362                       ResourceMgr* resource_mgr,
1363                       const std::shared_ptr<const NodeProperties>& props,
1364                       int graph_def_version, OpKernel** kernel);
1365 
1366 // Returns into 'device_types' the subset of prioritized_types that this
1367 // binary has registered for the given NodeDef.
1368 //
1369 // REQUIRES: * 'device_types' is not nullptr.
1370 //           * def has all attrs specified (e.g. using AddDefaultsToNodeDef()).
1371 Status SupportedDeviceTypesForNode(
1372     const std::vector<DeviceType>& prioritized_types, const NodeDef& def,
1373     PrioritizedDeviceTypeVector* device_types,
1374     const DeviceNameUtils::ParsedName* local_address_spec = nullptr);
1375 
1376 // Returns a message with a description of the kernels registered for op
1377 // `op_name`.
1378 std::string KernelsRegisteredForOp(StringPiece op_name);
1379 
1380 // Call once after Op registration has completed.
1381 Status ValidateKernelRegistrations(const OpRegistryInterface& op_registry);
1382 
1383 // -----------------------------------------------------------------------------
1384 // OpKernel registration implementation follows, please ignore.
1385 
1386 // Allow the REGISTER_KERNEL_BUILDER(Name("op_name").Device(...)...) syntax.
1387 namespace register_kernel {
1388 
1389 class Name : public KernelDefBuilder {
1390  public:
1391   explicit Name(const char* op);
1392 };
1393 
1394 }  // namespace register_kernel
1395 
1396 // Kernel registration appears as:
1397 //   REGISTER_KERNEL_BUILDER(Name("OpName").Device(DEVICE_CPU)..., OpImpl)
1398 // We'd like to have "OpName" as a constant-expression, without requiring that
1399 // of the overall KernelDefBuilder expression (beginning with the
1400 // register_kernel::Name constructor above).
1401 //
1402 // So, we pull the "OpName" part to a separate macro-level argument. This
1403 // involves treating Name("OpName") as a macro call, via token-pasting (e.g.
1404 // M_## =>  M_Name("OpName")), and having it expand to '"OpName",
1405 // Name("OpName")' which is then usable as two arguments.
1406 #define TF_EXTRACT_KERNEL_NAME_Name(name_str) \
1407   name_str, ::tensorflow::register_kernel::Name(name_str)
1408 #define TF_EXTRACT_KERNEL_NAME_IMPL(m, ...) m(__VA_ARGS__)
1409 #define TF_EXTRACT_KERNEL_NAME(m, kernel_builder, ...)                    \
1410   TF_EXTRACT_KERNEL_NAME_IMPL(m, TF_EXTRACT_KERNEL_NAME_##kernel_builder, \
1411                               __VA_ARGS__)
1412 
1413 // REGISTER_KERNEL_BUILDER_IMPL_2, with a unique 'ctr' as the first argument.
1414 // TODO(dodgen): There are some uses of this macro inside functions, where
1415 // kernel_builder refers to (non-const) locals (they should be fixed). To
1416 // accommodate those, kernel_builder.Build() appears as an argument to an
1417 // immediately-called lambda (not in the lambda itself).
1418 #define REGISTER_KERNEL_BUILDER_IMPL_3(ctr, op_name, kernel_builder_expr,   \
1419                                        is_system_kernel, ...)               \
1420   static ::tensorflow::InitOnStartupMarker const register_kernel_##ctr      \
1421       TF_ATTRIBUTE_UNUSED =                                                 \
1422           TF_INIT_ON_STARTUP_IF(is_system_kernel ||                         \
1423                                 (SHOULD_REGISTER_OP_KERNEL(#__VA_ARGS__) && \
1424                                  SHOULD_REGISTER_OP(op_name)))              \
1425           << ([](::tensorflow::KernelDef const* kernel_def) {               \
1426                ::tensorflow::kernel_factory::OpKernelRegistrar registrar(   \
1427                    kernel_def, #__VA_ARGS__,                                \
1428                    [](::tensorflow::OpKernelConstruction* context)          \
1429                        -> ::tensorflow::OpKernel* {                         \
1430                      return new __VA_ARGS__(context);                       \
1431                    });                                                      \
1432                (void)registrar;                                             \
1433                return ::tensorflow::InitOnStartupMarker{};                  \
1434              })(kernel_builder_expr.Build());
1435 
1436 // REGISTER_KERNEL_BUILDER_IMPL, but with kernel_builder split to op_name,
1437 // kernel_builder_expr.
1438 #define REGISTER_KERNEL_BUILDER_IMPL_2(op_name, kernel_builder_expr, \
1439                                        is_system_kernel, ...)        \
1440   TF_NEW_ID_FOR_INIT(REGISTER_KERNEL_BUILDER_IMPL_3, op_name,        \
1441                      kernel_builder_expr, is_system_kernel, __VA_ARGS__)
1442 
1443 // REGISTER_KERNEL_BUILDER, but with is_system_kernel bound.
1444 #define REGISTER_KERNEL_BUILDER_IMPL(kernel_builder, is_system_kernel, ...) \
1445   TF_EXTRACT_KERNEL_NAME(REGISTER_KERNEL_BUILDER_IMPL_2, kernel_builder,    \
1446                          is_system_kernel, __VA_ARGS__)
1447 
1448 #define REGISTER_KERNEL_BUILDER(kernel_builder, ...) \
1449   TF_ATTRIBUTE_ANNOTATE("tf:kernel")                 \
1450   REGISTER_KERNEL_BUILDER_IMPL(kernel_builder, false, __VA_ARGS__)
1451 
1452 // The `REGISTER_SYSTEM_KERNEL_BUILDER()` macro acts as
1453 // `REGISTER_KERNEL_BUILDER()` except that the kernel is registered
1454 // unconditionally even when selective registration is used.
1455 #define REGISTER_SYSTEM_KERNEL_BUILDER(kernel_builder, ...) \
1456   TF_ATTRIBUTE_ANNOTATE("tf:kernel")                        \
1457   TF_ATTRIBUTE_ANNOTATE("tf:kernel:system")                 \
1458   REGISTER_KERNEL_BUILDER_IMPL(kernel_builder, true, __VA_ARGS__)
1459 
1460 // Checks whether a given kernel is registered on device_type.
1461 bool KernelDefAvailable(const DeviceType& device_type, const NodeDef& node_def);
1462 
1463 // If node of node_name, experimental_debug_info, node_op, node_device and
1464 // node_attrs has a corresponding kernel registered on device_type, returns OK
1465 // and fill in the kernel def and kernel_class_name. <def> and
1466 // <kernel_class_name> may be null.
1467 Status FindKernelDef(
1468     const DeviceType& device_type, StringPiece node_name,
1469     bool has_experimental_debug_info,
1470     const NodeDef_ExperimentalDebugInfo& experimental_debug_info,
1471     StringPiece node_op, StringPiece node_device, AttrSlice node_attrs,
1472     const KernelDef** def, std::string* kernel_class_name);
1473 
1474 // If node_def has a corresponding kernel registered on device_type,
1475 // returns OK and fill in the kernel def and kernel_class_name. <def> and
1476 // <kernel_class_name> may be null.
1477 Status FindKernelDef(const DeviceType& device_type, const NodeDef& node_def,
1478                      const KernelDef** def, std::string* kernel_class_name);
1479 
1480 // Writes a list of all registered kernels to LOG(INFO), to help users debug
1481 // missing kernel errors.
1482 void LogAllRegisteredKernels();
1483 
1484 // Gets a list of all registered kernels.
1485 KernelList GetAllRegisteredKernels();
1486 
1487 // Gets a list of all registered kernels for which predicate returns true
1488 KernelList GetFilteredRegisteredKernels(
1489     const std::function<bool(const KernelDef&)>& predicate);
1490 
1491 // Gets a list of all registered kernels for a given op
1492 KernelList GetRegisteredKernelsForOp(StringPiece op_name);
1493 
1494 namespace kernel_factory {
1495 
1496 // OpKernelFactory is responsible for creating OpKernels when TensorFlow needs
1497 // them. You register factories with the TensorFlow core by constructing an
1498 // OpKernelRegistrar and passing the factory as a constructor parameter.
1499 class OpKernelFactory {
1500  public:
1501   virtual OpKernel* Create(OpKernelConstruction* context) = 0;
1502   virtual ~OpKernelFactory() = default;
1503 };
1504 
1505 class OpKernelRegistrar {
1506  public:
1507   // Registers the given kernel factory with TensorFlow. TF will call the
1508   // factory Create() method when it determines that a kernel matching the given
1509   // KernelDef is required.
OpKernelRegistrar(const KernelDef * kernel_def,StringPiece kernel_class_name,std::unique_ptr<OpKernelFactory> factory)1510   OpKernelRegistrar(const KernelDef* kernel_def, StringPiece kernel_class_name,
1511                     std::unique_ptr<OpKernelFactory> factory)
1512       TF_ATTRIBUTE_NOINLINE {
1513     InitInternal(kernel_def, kernel_class_name, std::move(factory));
1514   }
1515 
1516   // Registers the given factory function with TensorFlow. This is equivalent
1517   // to registering a factory whose Create function invokes `create_fn`.
OpKernelRegistrar(const KernelDef * kernel_def,StringPiece kernel_class_name,OpKernel * (* create_fn)(OpKernelConstruction *))1518   OpKernelRegistrar(const KernelDef* kernel_def, StringPiece kernel_class_name,
1519                     OpKernel* (*create_fn)(OpKernelConstruction*))
1520       TF_ATTRIBUTE_NOINLINE {
1521     InitInternal(kernel_def, kernel_class_name,
1522                  absl::make_unique<PtrOpKernelFactory>(create_fn));
1523   }
1524 
1525  private:
1526   struct PtrOpKernelFactory : public OpKernelFactory {
PtrOpKernelFactoryPtrOpKernelFactory1527     explicit PtrOpKernelFactory(OpKernel* (*create_func)(OpKernelConstruction*))
1528         : create_func_(create_func) {}
1529 
1530     OpKernel* Create(OpKernelConstruction* context) override;
1531 
1532     OpKernel* (*create_func_)(OpKernelConstruction*);
1533   };
1534 
1535   void InitInternal(const KernelDef* kernel_def, StringPiece kernel_class_name,
1536                     std::unique_ptr<OpKernelFactory> factory);
1537 };
1538 
1539 }  // namespace kernel_factory
1540 
1541 // -----------------------------------------------------------------------------
1542 // Template and inline method implementations, please ignore
1543 
1544 template <class T>
GetAttr(StringPiece attr_name,T * value)1545 Status OpKernelConstruction::GetAttr(StringPiece attr_name, T* value) const {
1546   return GetNodeAttr(def(), attr_name, value);
1547 }
1548 
input_dtype(int index)1549 inline DataType OpKernelContext::input_dtype(int index) const {
1550   DCHECK_GE(index, 0);
1551   DCHECK_LT(index, num_inputs());
1552   const TensorValue& value(params_->inputs[index]);
1553   return value.dtype();
1554 }
1555 
input_memory_type(int index)1556 inline MemoryType OpKernelContext::input_memory_type(int index) const {
1557   DCHECK_GE(index, 0);
1558   DCHECK_LT(index, num_inputs());
1559   return op_kernel().input_memory_types()[index];
1560 }
1561 
expected_output_dtype(int index)1562 inline DataType OpKernelContext::expected_output_dtype(int index) const {
1563   DCHECK_GE(index, 0);
1564   DCHECK_LT(index, num_outputs());
1565   return params_->op_kernel->output_type(index);
1566 }
1567 
output_memory_type(int index)1568 inline MemoryType OpKernelContext::output_memory_type(int index) const {
1569   DCHECK_GE(index, 0);
1570   DCHECK_LT(index, num_outputs());
1571   return op_kernel().output_memory_types()[index];
1572 }
1573 
input_is_ref(int index)1574 inline bool OpKernelContext::input_is_ref(int index) const {
1575   const TensorValue& value(params_->inputs[index]);
1576   return value.is_ref();
1577 }
1578 
1579 // no input if tensor == nullptr.
has_input(int index)1580 inline bool OpKernelContext::has_input(int index) const {
1581   DCHECK_GE(index, 0);
1582   DCHECK_LT(index, num_inputs());
1583   return params_->inputs[index].tensor != nullptr;
1584 }
1585 
input_ref_mutex(int index)1586 inline mutex* OpKernelContext::input_ref_mutex(int index) {
1587   DCHECK_GE(index, 0);
1588   DCHECK_LT(index, num_inputs());
1589   DCHECK(input_is_ref(index));
1590   return params_->inputs[index].mutex_if_ref;
1591 }
1592 
mutable_output(int index)1593 inline Tensor* OpKernelContext::mutable_output(int index) {
1594   DCHECK_GE(index, 0);
1595   DCHECK_LT(index, num_outputs());
1596   return outputs_[index].tensor;
1597 }
1598 
release_output(int index)1599 inline TensorValue OpKernelContext::release_output(int index) {
1600   DCHECK_GE(index, 0);
1601   DCHECK_LT(index, num_outputs());
1602   TensorValue value = outputs_[index];
1603   outputs_[index] = TensorValue();
1604   return value;
1605 }
1606 
1607 template <typename T>
op_device_context()1608 T* OpKernelContext::op_device_context() {
1609   static_assert(std::is_base_of<DeviceContext, T>::value,
1610                 "T is not a subclass of DeviceContext");
1611   return static_cast<T*>(op_device_context());
1612 }
1613 
1614 inline const Tensor& OpInputList::operator[](int i) const {
1615   DCHECK_GE(i, 0);
1616   DCHECK_LT(i, stop_ - start_);
1617   return ctx_->input(start_ + i);
1618 }
1619 
ref_mutex(int i)1620 inline mutex* OpMutableInputList::ref_mutex(int i) {
1621   DCHECK_GE(i, 0);
1622   DCHECK_LT(i, stop_ - start_);
1623   return ctx_->input_ref_mutex(start_ + i);
1624 }
1625 
at(int i,bool lock_held)1626 inline Tensor OpMutableInputList::at(int i, bool lock_held) {
1627   DCHECK_GE(i, 0);
1628   DCHECK_LT(i, stop_ - start_);
1629   return ctx_->mutable_input(start_ + i, lock_held);
1630 }
1631 
1632 inline Tensor* OpOutputList::operator[](int i) {
1633   DCHECK_GE(i, 0);
1634   DCHECK_LT(i, stop_ - start_);
1635   return ctx_->mutable_output(start_ + i);
1636 }
1637 
required(int i)1638 inline bool OpOutputList::required(int i) const {
1639   DCHECK_GE(i, 0);
1640   DCHECK_LT(i, stop_ - start_);
1641   return ctx_->output_required(start_ + i);
1642 }
1643 
expected_output_dtype(int i)1644 inline DataType OpOutputList::expected_output_dtype(int i) const {
1645   DCHECK_GE(i, 0);
1646   DCHECK_LT(i, stop_ - start_);
1647   return ctx_->expected_output_dtype(start_ + i);
1648 }
1649 
allocate(int i,const TensorShape & shape,Tensor ** output)1650 inline Status OpOutputList::allocate(int i, const TensorShape& shape,
1651                                      Tensor** output) {
1652   DCHECK_GE(i, 0);
1653   DCHECK_LT(i, stop_ - start_);
1654   return ctx_->allocate_output(start_ + i, shape, output);
1655 }
1656 
set(int i,const Tensor & tensor)1657 inline void OpOutputList::set(int i, const Tensor& tensor) {
1658   DCHECK_GE(i, 0);
1659   DCHECK_LT(i, stop_ - start_);
1660   ctx_->set_output(start_ + i, tensor);
1661 }
1662 
set(int i,Tensor && tensor)1663 inline void OpOutputList::set(int i, Tensor&& tensor) {
1664   DCHECK_GE(i, 0);
1665   DCHECK_LT(i, stop_ - start_);
1666   ctx_->set_output(start_ + i, std::move(tensor));
1667 }
1668 
set_ref(int i,mutex * mu,Tensor * tensor_for_ref)1669 inline void OpOutputList::set_ref(int i, mutex* mu, Tensor* tensor_for_ref) {
1670   DCHECK_GE(i, 0);
1671   DCHECK_LT(i, stop_ - start_);
1672   ctx_->set_output_ref(i, mu, tensor_for_ref);
1673 }
1674 
1675 // Generate a fatal error if OP_REQUIRES or OP_REQUIRES_OK are used in
1676 // AsyncOpKernel implementations. If these macros are used and the condition
1677 // does not hold, the `done` callback will never be called and the system will
1678 // deadlock, so a crash failure is preferable. Since the OP_REQUIRES[_OK] macros
1679 // are legal to use in AsyncOpKernel constructors, we use overload resolution
1680 // to distinguish between OpKernelConstruction* and OpKernelContext* context
1681 // types.
1682 class XlaOpKernelContext;
CheckNotInComputeAsync(XlaOpKernelContext *,const char *)1683 inline void CheckNotInComputeAsync(XlaOpKernelContext*, const char*) {}
CheckNotInComputeAsync(OpKernelConstruction *,const char *)1684 inline void CheckNotInComputeAsync(OpKernelConstruction*, const char*) {}
1685 void CheckNotInComputeAsync(OpKernelContext* ctx,
1686                             const char* correct_macro_name);
1687 
1688 }  // namespace tensorflow
1689 
1690 #endif  // TENSORFLOW_CORE_FRAMEWORK_OP_KERNEL_H_
1691