• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_FRAMEWORK_OP_KERNEL_H_
17 #define TENSORFLOW_CORE_FRAMEWORK_OP_KERNEL_H_
18 
19 #include <atomic>
20 #include <functional>
21 #include <unordered_set>
22 #include <utility>
23 #include <vector>
24 
25 #include "tensorflow/core/framework/allocator.h"
26 #include "tensorflow/core/framework/cancellation.h"
27 #include "tensorflow/core/framework/control_flow.h"
28 #include "tensorflow/core/framework/device_base.h"
29 #include "tensorflow/core/framework/graph.pb.h"
30 #include "tensorflow/core/framework/kernel_def.pb.h"
31 #include "tensorflow/core/framework/kernel_def_builder.h"
32 #include "tensorflow/core/framework/node_def.pb.h"
33 #include "tensorflow/core/framework/node_def_util.h"
34 #include "tensorflow/core/framework/op.h"  // TODO(b/62899350): Remove
35 #include "tensorflow/core/framework/rendezvous.h"
36 #include "tensorflow/core/framework/selective_registration.h"
37 #include "tensorflow/core/framework/session_state.h"
38 #include "tensorflow/core/framework/tensor.h"
39 #include "tensorflow/core/framework/tensor_shape.h"
40 #include "tensorflow/core/framework/tensor_shape.pb.h"  // TODO(b/62899350): Remove
41 #include "tensorflow/core/framework/tracking_allocator.h"
42 #include "tensorflow/core/framework/types.h"
43 #include "tensorflow/core/framework/types.pb.h"
44 #include "tensorflow/core/framework/unique_tensor_references.h"
45 #include "tensorflow/core/lib/core/errors.h"
46 #include "tensorflow/core/lib/core/status.h"
47 #include "tensorflow/core/lib/gtl/array_slice.h"
48 #include "tensorflow/core/lib/gtl/manual_constructor.h"
49 #include "tensorflow/core/platform/env.h"
50 #include "tensorflow/core/platform/logging.h"
51 #include "tensorflow/core/platform/macros.h"
52 #include "tensorflow/core/platform/mutex.h"
53 #include "tensorflow/core/platform/profile_utils/cpu_utils.h"
54 #include "tensorflow/core/platform/thread_annotations.h"
55 #include "tensorflow/core/platform/types.h"
56 #include "tensorflow/core/protobuf/config.pb.h"
57 
58 namespace Eigen {
59 struct ThreadPoolDevice;
60 struct GpuDevice;
61 struct SyclDevice;
62 }  // end namespace Eigen
63 
64 namespace tensorflow {
65 
66 namespace checkpoint {
67 class TensorSliceReaderCacheWrapper;
68 }  // namespace checkpoint
69 
70 class AsyncOpKernel;
71 class CallFrameInterface;
72 class DeviceMgr;
73 class FunctionLibraryRuntime;
74 class OpKernelConstruction;  // declared below
75 class OpKernelContext;       // declared below,
76 class OpRegistryInterface;
77 class ResourceMgr;
78 class ScopedStepContainer;
79 class CollectiveExecutor;
80 class StepStatsCollectorInterface;
81 
82 class OpKernel {
83  public:
84   // OpKernel won't be instantiated by the scheduler, so you may perform
85   // expensive initialization in the descendant's constructor.
86   explicit OpKernel(OpKernelConstruction* context);
87 
88   // Specialized constructor that enables the descendant to provide a different
89   // `NodeDef` value. For example, this constructor can be used to provide a
90   // stripped-down `NodeDef` that does not contain the full set of attrs (such
91   // as tensor values) if the descendant stores them in a different form.
92   explicit OpKernel(OpKernelConstruction* context,
93                     std::unique_ptr<const NodeDef> node_def,
94                     bool is_deferred = false);
95 
96   // Specialized constructor that allows a kernel implementation to mark itself
97   // as a "deferred" op. If true, the executor will provide access to the
98   // `OpKernelContext::inc_num_deferred_ops_function()` and
99   // `OpKernelContext::dec_num_deferred_ops_function()` methods at run-time.
100   explicit OpKernel(OpKernelConstruction* context, bool is_deferred);
101 
102   virtual ~OpKernel();
103 
104   // An OpKernel's computation can be either synchronous or
105   // asynchronous. All OpKernel Compute() methods must be thread-safe as they
106   // may be called concurrently (e.g. by multiple executions of the same graph
107   // concurrently).
108   //
109   // Most OpKernels should compute synchronously. They should
110   // subclass OpKernel and override the Compute() method and have it
111   // return after completing the supplied work.
112   //
113   // A synchronous OpKernel *MUST NOT* block the calling thread on a
114   // synchronization mechanism (condition variable, Notification, etc.) that
115   // will be unblocked by the execution of another OpKernel. Execution may
116   // deadlock in that case, because the executor may use a bounded number of
117   // threads.
118   //
119   // If an OpKernel must block on the execution of another OpKernel (e.g. a
120   // RecvOp, or a DequeueOp), the implementation *MUST* subclass AsyncOpKernel,
121   // and override `AsyncOpKernel::ComputeAsync()`. In addition, because the
122   // unblocking kernel may never run (due to an error or cancellation), in most
123   // cases the AsyncOpKernel should implement cancellation support via
124   // `ctx->cancellation_manager()`.
125   //
126   // In both cases, implementations of Compute() and ComputeAsync()
127   // get inputs and write outputs through the given OpKernelContext
128   // and returns a status via context->SetStatus(). They must be
129   // thread-safe.
130 
131   // Synchronous compute.
132   //
133   // "context" is guaranteed to be alive until Compute() returns.
134   virtual void Compute(OpKernelContext* context) = 0;
135 
136   // Returns nullptr iff this op kernel is synchronous.
AsAsync()137   virtual AsyncOpKernel* AsAsync() { return nullptr; }
AsAsync()138   virtual const AsyncOpKernel* AsAsync() const { return nullptr; }
139 
140   // Initial time (in CPU cycles) we expect an operation to take.  Used to
141   // determine whether an operation should be place in a threadpool.  Operations
142   // start out "expensive".
143   static const uint64 kInitialCostEstimateCycles = 100 * 1000 * 1000;
144   static const uint64 kOpIsExpensiveThresholdCycles = 5000;
145   static const uint64 kCostDecay = 10;
146 
147   // Returns true iff this op kernel is considered "expensive". The
148   // runtime may use this flag to optimize graph execution for example
149   // to "inline" inexpensive kernels.
IsExpensive()150   virtual bool IsExpensive() {
151     return expensive_ && (cost_estimate_.load(std::memory_order_relaxed) >
152                           kOpIsExpensiveThresholdCycles);
153   }
154 
155   // Returns a pointer to the tensor stored inside constant ops.
const_tensor()156   virtual const Tensor* const_tensor() const { return nullptr; }
157 
158   // Updates the dynamic cost estimate, which is used to determine whether this
159   // op is expensive. The new cost estimate is a weighted average of the old
160   // cost estimate and the latest cost.
UpdateCostEstimate(uint64 elapsed_cycles)161   void UpdateCostEstimate(uint64 elapsed_cycles) {
162     // N.B. Updates to `cost_estimate_` are atomic but unlocked.  Simultaneous
163     // updates may result in one or more updates being ignored.  This does not
164     // affect correctness but may slow down the update frequency.
165     cost_estimate_.store(
166         (kCostDecay - 1) * cost_estimate_.load(std::memory_order_relaxed) /
167                 kCostDecay +
168             (elapsed_cycles / kCostDecay),
169         std::memory_order_relaxed);
170   }
171 
172   // Accessors.
def()173   const NodeDef& def() const { return *def_; }
174   const string& name() const;              // Same as def().name()
name_view()175   absl::string_view name_view() const { return name_view_; }
176   const string& type_string() const;       // Same as def().op()
type_string_view()177   absl::string_view type_string_view() const { return type_string_view_; }
178   const string& requested_device() const;  // Same as def().device()
179 
num_inputs()180   int num_inputs() const { return input_types_.size(); }
input_type(int i)181   DataType input_type(int i) const { return input_types_[i]; }
input_types()182   const DataTypeVector& input_types() const { return input_types_; }
input_memory_types()183   const MemoryTypeVector& input_memory_types() const {
184     return input_memory_types_;
185   }
186   const string& requested_input(int i) const;  // Same as def().input(i)
187 
num_outputs()188   int num_outputs() const { return output_types_.size(); }
output_type(int o)189   DataType output_type(int o) const { return output_types_[o]; }
output_types()190   const DataTypeVector& output_types() const { return output_types_; }
output_memory_types()191   const MemoryTypeVector& output_memory_types() const {
192     return output_memory_types_;
193   }
194 
195   Status InputRange(StringPiece input_name, int* start, int* stop) const;
196   Status OutputRange(StringPiece output_name, int* start, int* stop) const;
197 
198   // Returns `true` if and only if this kernel uses deferred execution.
is_deferred()199   bool is_deferred() const { return is_deferred_; }
200 
201   // Returns a trace string for current computation, op name/type and input
202   // tensor shape/dtype are encoded for profiler cost analysis. Most OpKernel
203   // should use the default implementation.
204   // Override this function to add OpKernel specific attributes that are
205   // necessary for cost analysis.
206   virtual string TraceString(OpKernelContext* ctx, bool verbose);
207 
208  private:
209   const std::unique_ptr<const NodeDef> def_;
210   const DataTypeVector input_types_;
211   const MemoryTypeVector input_memory_types_;
212   const DataTypeVector output_types_;
213   const MemoryTypeVector output_memory_types_;
214   NameRangeMap input_name_map_;
215   NameRangeMap output_name_map_;
216   const absl::string_view name_view_;
217   const absl::string_view type_string_view_;
218   const int graph_def_version_;
219   const bool is_deferred_;
220   bool expensive_;
221   std::atomic_uint_fast64_t cost_estimate_;
222 
223   TF_DISALLOW_COPY_AND_ASSIGN(OpKernel);
224 };
225 
226 class AsyncOpKernel : public OpKernel {
227  public:
228   using OpKernel::OpKernel;  // Lift OpKernel constructors.
229 
230   // Asynchronous compute.
231   //
232   // Implementations of ComputeAsync() must ensure that `done` is (eventually)
233   // called exactly once to signal the completion of the computation. The
234   // implementation of ComputeAsync() must not block on the execution of another
235   // OpKernel. `done` may be called by the current thread, or by another thread.
236   // `context` is guaranteed to stay alive until the `done` callback starts.
237   //
238   // Since it is possible that the unblocking kernel may never run (due to an
239   // error or cancellation), in most cases the AsyncOpKernel should implement
240   // cancellation support via `context->cancellation_manager()`.
241   //
242   // WARNING: As soon as the `done` callback starts, `context` and `this` may be
243   // deleted. No code depending on these objects should execute after the call
244   // to `done`.
245   typedef std::function<void()> DoneCallback;
246   virtual void ComputeAsync(OpKernelContext* context, DoneCallback done) = 0;
247 
AsAsync()248   AsyncOpKernel* AsAsync() override { return this; }
AsAsync()249   const AsyncOpKernel* AsAsync() const override { return this; }
250 
251   void Compute(OpKernelContext* context) override;
252 };
253 
254 // Wraps a tensor that is held by an Op across calls to Compute(). For memory
255 // safety when using asynchronous devices like GPUs, the system must be notified
256 // when a Tensor is used inside an Op execution. The wrapper ensures that all
257 // uses of the Tensor are tracked, because in order to retrieve the Tensor the
258 // caller must use AccessTensor which notifies the context.
259 class PersistentTensor {
260  public:
PersistentTensor()261   PersistentTensor() {}
PersistentTensor(const Tensor & tensor)262   explicit PersistentTensor(const Tensor& tensor) : tensor_(tensor) {}
263 
264   // Caller does not own the returned Tensor*.
265   Tensor* AccessTensor(OpKernelConstruction* context);
266   // Caller does not own the returned Tensor*.
267   Tensor* AccessTensor(OpKernelContext* context);
268 
269   // The check for initialization does not need to access the
270   // underlying tensor buffer.
IsInitialized()271   bool IsInitialized() const { return tensor_.IsInitialized(); }
272 
NumElements()273   int64 NumElements() const { return tensor_.NumElements(); }
274 
AllocatedBytes()275   int64 AllocatedBytes() const { return tensor_.AllocatedBytes(); }
276 
277  private:
278   Tensor tensor_;
279 };
280 
281 class OpKernelConstruction {
282  public:
283   OpKernelConstruction(DeviceType device_type, DeviceBase* device,
284                        Allocator* allocator, const NodeDef* node_def,
285                        const OpDef* op_def, FunctionLibraryRuntime* flib,
286                        ResourceMgr* resource_mgr,
287                        const DataTypeSlice& input_types,
288                        const MemoryTypeSlice& input_memory_types,
289                        const DataTypeSlice& output_types,
290                        const MemoryTypeSlice& output_memory_types,
291                        int graph_def_version, Status* status);
292 
env()293   Env* env() const { return device_->env(); }
294 
295   // Allocation of tensors during kernel construction:
296   //
297   // It is legal to temporarily allocate scratch tensor storage during
298   // Op kernel construction. Scratch tensors should be allocated using
299   // allocate_temp below. Some kernels need to keep tensors in between
300   // invocations. If such a Tensor is allocated during kernel
301   // construction this must be done using allocate_persistent, and the
302   // Op may only store the returned PersistentTensor object. When the
303   // Tensor is needed in a subsequent invocation, it can be retrieved
304   // from the PersistentTensor using the AccessTensor method. This
305   // ensures that the system is made aware of any use of the tensor's
306   // allocated memory, which is needed for correctness on asynchronous
307   // devices such as GPUs.
308 
309   // Allocates a temporary Tensor of the specified type and shape. The
310   // Tensor must not be used after kernel construction is
311   // complete. See comment above.
312   Status allocate_temp(DataType type, const TensorShape& shape,
313                        Tensor* out_temp);
314   Status allocate_temp(DataType type, const TensorShape& shape,
315                        Tensor* out_temp, AllocatorAttributes allocator_attr);
316 
317   // Allocates a Tensor of the specified type and shape which the Op
318   // plans to maintain as persistent state. out_persistent holds the
319   // PersistentTensor which is the object the caller should store. For
320   // convenience, if out_tensor is non-null then it will be filled in
321   // with a Tensor* pointing to the newly-allocated tensor which the
322   // caller can use instead of calling
323   // out_persistent->AccessTensor. The caller does not own out_tensor
324   // and should not keep a copy of it. See comment above.
325   Status allocate_persistent(DataType type, const TensorShape& shape,
326                              PersistentTensor* out_persistent,
327                              Tensor** out_tensor);
328 
329   // User-supplied configuration of this operation.
def()330   const NodeDef& def() const { return *def_; }
331 
332   // For inspecting the inputs to this operation.
num_inputs()333   int num_inputs() const { return input_types_.size(); }
input_type(int i)334   DataType input_type(int i) const { return input_types_[i]; }
input_types()335   const DataTypeSlice& input_types() const { return input_types_; }
input_memory_types()336   const MemoryTypeSlice& input_memory_types() const {
337     return input_memory_types_;
338   }
339 
340   // For inspecting the outputs expected from this operation.
num_outputs()341   int num_outputs() const { return output_types_.size(); }
output_type(int i)342   DataType output_type(int i) const { return output_types_[i]; }
output_types()343   const DataTypeSlice& output_types() const { return output_types_; }
output_memory_types()344   const MemoryTypeSlice& output_memory_types() const {
345     return output_memory_types_;
346   }
347 
348   // If expected_inputs == inputs() and expected_outputs == output_types(),
349   // returns OK, else returns INVALID_ARGUMENT with an error message.
350   // Recommended for Ops with dynamic signatures.
351   Status MatchSignature(const DataTypeSlice expected_inputs,
352                         const DataTypeSlice expected_outputs);
353 
354   // For recording configuration errors during construction.
355   void SetStatus(const Status& status);
status()356   const Status& status() const { return *status_; }
357 
358   // Look up the attr with name attr_name and set *value to its value.  If no
359   // attr with attr_name is found in def(), or the attr does not have
360   // a matching type, a non-ok status will be returned.
361   template <class T>
362   Status GetAttr(StringPiece attr_name, T* value) const;
363 
364   // Return true if the attr_name is defined in def().
365   bool HasAttr(StringPiece attr_name) const;
366 
367   // Return the device type.
device_type()368   const DeviceType& device_type() const { return device_type_; }
369 
370   // If not nullptr, the kernel can instantiate functions defined in
371   // the library. E.g.,
372   // CHECK_NOTNULL(function_library())->Instantiate("Foo", ...).
function_library()373   FunctionLibraryRuntime* function_library() const { return flib_; }
374 
375   // Shared resources accessible to this kernel.
resource_manager()376   ResourceMgr* resource_manager() const { return resource_mgr_; }
377 
378   // The GraphDef version whose behavior we should follow.
graph_def_version()379   int graph_def_version() const { return graph_def_version_; }
380 
381   // Helper routines for the OP_REQUIRES macros
382   void CtxFailure(const Status& s);
383   void CtxFailureWithWarning(const Status& s);
384   void CtxFailure(const char* file, int line, const Status& s);
385   void CtxFailureWithWarning(const char* file, int line, const Status& s);
386 
387   // Unrecommended functions: these are functions that have some
388   // current uses but are not recommended for use, and may go away at
389   // some future major version release.
390 
391   // May be used, e.g., to get GPU handles, etc.
392   //
393   // Currently only used to call MakeTensorFromProto() for
394   // implementing ConstantOp for every device.  See comments
395   // on Device::MakeTensorFromProto for longer-term replacement
396   // ideas.
device()397   DeviceBase* device() const { return device_; }
398 
399  private:
400   const DeviceType device_type_;
401   DeviceBase* const device_;
402   Allocator* allocator_;
403   const NodeDef* def_;
404   const OpDef* op_def_;
405   FunctionLibraryRuntime* flib_;
406   ResourceMgr* const resource_mgr_;
407   DataTypeSlice input_types_;
408   MemoryTypeSlice input_memory_types_;
409   DataTypeSlice output_types_;
410   MemoryTypeSlice output_memory_types_;
411   const int graph_def_version_;
412   Status* status_;
413 
414   // Allow op_def_ across from OpKernel, but not from subclasses.
415   // TODO(irving): Remove protos from this header entirely.
416   friend class OpKernel;
417 
418   TF_DISALLOW_COPY_AND_ASSIGN(OpKernelConstruction);
419 };
420 
421 // TODO(mrry): Consider converting to a random_access_iterator, and upgrading
422 // tensorflow::gtl::iterator_range to make the below container classes
423 // unnecessary.
424 template <typename ListType, typename ElementType>
425 class OpArgIterator {
426  public:
427   using iterator_category = std::forward_iterator_tag;
428   using value_type = ElementType;
429   using pointer = ElementType*;
430   using const_pointer = const ElementType*;
431   using reference = ElementType&;
432   using const_reference = const ElementType&;
433   using difference_type = ptrdiff_t;
434 
OpArgIterator(const ListType * list,int i)435   OpArgIterator(const ListType* list, int i) : list_(list), i_(i) {}
436 
437   bool operator==(const OpArgIterator& rhs) {
438     DCHECK(list_ == rhs.list_);
439     return i_ == rhs.i_;
440   }
441 
442   bool operator!=(const OpArgIterator& rhs) {
443     DCHECK(list_ == rhs.list_);
444     return i_ != rhs.i_;
445   }
446 
447   OpArgIterator operator++() {  // prefix ++it
448     ++i_;
449     return *this;
450   }
451 
452   OpArgIterator operator++(int) {  // postfix it++
453     OpArgIterator old_value = *this;
454     ++i_;
455     return old_value;
456   }
457 
458   reference operator*() { return (*list_)[i_]; }
459   pointer operator->() { return &(*list_)[i_]; }
460 
461   const_reference operator*() const { return (*list_)[i_]; }
462   const_pointer operator->() const { return &(*list_)[i_]; }
463 
464  private:
465   const ListType* const list_;
466   int i_;
467 };
468 
469 // Utility class for representing a list of immutable input tensors
470 // that are passed to the op as a single named argument.
471 class OpInputList {
472  public:
473   typedef OpArgIterator<OpInputList, const Tensor> Iterator;
OpInputList()474   OpInputList() : ctx_(nullptr), start_(0), stop_(0) {}
OpInputList(OpKernelContext * ctx,int start,int stop)475   OpInputList(OpKernelContext* ctx, int start, int stop)
476       : ctx_(ctx), start_(start), stop_(stop) {}
477   OpInputList& operator=(const OpInputList& other) = default;
478   const Tensor& operator[](int i) const;
size()479   int size() const { return stop_ - start_; }
begin()480   Iterator begin() const { return Iterator(this, 0); }
end()481   Iterator end() const { return Iterator(this, size()); }
482 
483  private:
484   OpKernelContext* ctx_;  // not owned
485   int start_;
486   int stop_;
487 };
488 
489 // Utility class for representing a list of mutable ("ref") input tensors
490 // that are passed to the op as a single named argument.
491 class OpMutableInputList {
492  public:
493   typedef OpArgIterator<OpMutableInputList, Tensor*> Iterator;
OpMutableInputList(OpKernelContext * ctx,int start,int stop)494   OpMutableInputList(OpKernelContext* ctx, int start, int stop)
495       : ctx_(ctx), start_(start), stop_(stop) {}
OpMutableInputList()496   OpMutableInputList() : ctx_(nullptr), start_(0), stop_(0) {}
497   OpMutableInputList& operator=(const OpMutableInputList& other) = default;
498   Tensor at(int i, bool lock_held);
499   mutex* ref_mutex(int i);
size()500   int size() const { return stop_ - start_; }
begin()501   Iterator begin() const { return Iterator(this, 0); }
end()502   Iterator end() const { return Iterator(this, size()); }
503 
504  private:
505   OpKernelContext* ctx_;  // not owned
506   int start_;
507   int stop_;
508 };
509 
510 // Utility class for representing a list of output tensors that are
511 // grouped as a single named output.
512 class OpOutputList {
513  public:
514   typedef OpArgIterator<OpOutputList, const Tensor*> Iterator;
OpOutputList()515   OpOutputList() : ctx_(nullptr), start_(0), stop_(0) {}
OpOutputList(OpKernelContext * ctx,int start,int stop)516   OpOutputList(OpKernelContext* ctx, int start, int stop)
517       : ctx_(ctx), start_(start), stop_(stop) {}
518   OpOutputList& operator=(const OpOutputList& other) = default;
519   Tensor* operator[](int i);
520   bool required(int i) const;
521   DataType expected_output_dtype(int i) const;
522   Status allocate(int i, const TensorShape& shape, Tensor** output);
523   void set(int i, const Tensor& tensor);
524   void set_ref(int i, mutex* mu, Tensor* tensor_for_ref);
size()525   int size() const { return stop_ - start_; }
begin()526   Iterator begin() const { return Iterator(this, 0); }
end()527   Iterator end() const { return Iterator(this, size()); }
528 
529  private:
530   OpKernelContext* ctx_;  // not owned
531   int start_;
532   int stop_;
533 };
534 
535 // Holds a tensor or tensor reference. For tensor references, we need
536 // a mutex to prevent concurrent access to the tensor.
537 struct TensorValue {
TensorValueTensorValue538   TensorValue() : mutex_if_ref(nullptr), tensor(nullptr) {}
TensorValueTensorValue539   explicit TensorValue(Tensor* t) : mutex_if_ref(nullptr), tensor(t) {}
TensorValueTensorValue540   TensorValue(mutex* mu, Tensor* t) : mutex_if_ref(mu), tensor(t) {}
541   Tensor* operator->() const { return tensor; }
is_refTensorValue542   bool is_ref() const { return mutex_if_ref != nullptr; }
543 
544   // Return the dtype of the Tensor. For references, return the underlying type.
dtypeTensorValue545   DataType dtype() const {
546     if (is_ref()) {
547       return MakeRefType(tensor->dtype());
548     } else {
549       return tensor->dtype();
550     }
551   }
552 
553   // Return the dtype of the Tensor. For references, return the underlying type.
554   // This variation on the dtype() acquires the lock for references.
555   //
556   // TODO(b/133843385): Disallow dtype modifications
dtype_safeTensorValue557   DataType dtype_safe() const {
558     if (is_ref()) {
559       tf_shared_lock ml(*mutex_if_ref);
560       return MakeRefType(tensor->dtype());
561     } else {
562       return tensor->dtype();
563     }
564   }
565 
566   mutex* mutex_if_ref;  // nullptr if not a ref, != nullptr if a ref
567   Tensor* tensor;
568 };
569 
570 // Used to store partitioned graphs from function-calling ops.
571 struct GraphCollector {
572   mutex mu;
573   std::vector<GraphDef> partitioned_graphs GUARDED_BY(mu);
574   GraphDef raw_graph GUARDED_BY(mu);
575   GraphDef optimized_graph GUARDED_BY(mu);
576 
577   bool dirty GUARDED_BY(mu);
578 
GraphCollectorGraphCollector579   GraphCollector() : dirty(false) {}
580 
CollectRawGraphGraphCollector581   void CollectRawGraph(const GraphDef& graph) {
582     mutex_lock ml(mu);
583     raw_graph.MergeFrom(graph);
584     dirty = true;
585   }
586 
CollectOptimizedGraphGraphCollector587   void CollectOptimizedGraph(const GraphDef& graph) {
588     mutex_lock ml(mu);
589     optimized_graph.MergeFrom(graph);
590     dirty = true;
591   }
592 
CollectPartitionedGraphGraphCollector593   void CollectPartitionedGraph(const GraphDef& graph) {
594     mutex_lock ml(mu);
595     partitioned_graphs.push_back(graph);
596     dirty = true;
597   }
598 
ClearGraphsGraphCollector599   void ClearGraphs() EXCLUSIVE_LOCKS_REQUIRED(mu) {
600     raw_graph.Clear();
601     optimized_graph.Clear();
602     partitioned_graphs.clear();
603     dirty = false;
604   }
605 
HasUpdatedGraphsGraphCollector606   bool HasUpdatedGraphs() {
607     mutex_lock ml(mu);
608     return dirty;
609   }
610 };
611 
612 class OpKernelContext {
613  public:
614   // The first element of a WrappedAllocator is a "base" Allocator and
615   // the second element is that Allocator wrapped by a
616   // TrackingAllocator
617   typedef std::pair<Allocator*, TrackingAllocator*> WrappedAllocator;
618 
619   // TODO(zhifengc): Do some cleanup of Params.
620   // The Params struct is passed in to initialize an OpKernelContext,
621   // and must outlive the OpKernelContext.
622   struct Params {
~ParamsParams623     ~Params() { delete eigen_gpu_device; }
624 
625     // The step being executed.
626     int64 step_id = 0;
627 
628     // True if the op is created by eager runtime.
629     bool is_eager = false;
630 
631     // The op kernel being computed.
632     OpKernel* op_kernel = nullptr;
633 
634     // The device on which the kernel is running.
635     DeviceBase* device = nullptr;
636 
637     // The Eigen GPU device wrapper, which may include a per-op
638     // wrapped allocator. The concrete type of this object depends on
639     // the type of this->device, so eigen_gpu_device can't be an
640     // inline member and must be heap allocated. However, we don't
641     // want to allocate a new eigen_gpu_device for every Op that is
642     // executed. Instead this member is allocated on first use using
643     // ensure_eigen_gpu_device, and then if the Params structure is
644     // re-used for subsequent Ops, the eigen_gpu_device is
645     // ReInitialized in the OpKernelContext constructor. Unlike the
646     // other pointers in Params, this one is owned by Params.
647     PerOpGpuDevice* eigen_gpu_device = nullptr;
648 
ensure_eigen_gpu_deviceParams649     inline void ensure_eigen_gpu_device() {
650       DCHECK(device);
651       if (nullptr == eigen_gpu_device) {
652         // Surprisingly, MakeGpuDevice will return nullptr if the
653         // device is not a GPU device. This is ok, since those devices
654         // will never use eigen_gpu_device. It seems better to have
655         // ensure_eigen_gpu_device fall through and regenerate the
656         // nullptr every time an OpKernelContext is instantiated, than
657         // to do an unnecessary allocation of a dummy eigen GPU
658         // device for CPU device Ops.
659         eigen_gpu_device = device->MakeGpuDevice();
660       }
661     }
662 
663     bool track_allocations = false;
664     bool log_memory = false;
665     bool record_tensor_accesses = false;
666 
667     // Array indexed by output number for this node
668     const AllocatorAttributes* output_attr_array = nullptr;
669 
670     // Shared resources accessible by this op kernel invocation.
671     ResourceMgr* resource_manager = nullptr;
672 
673     // Per-step resources accessible by this op kernel invocation should be
674     // stored in this container..
675     ScopedStepContainer* step_container = nullptr;
676 
677     // Mechanism used by this op kernel invocation to communicate with
678     // computations running on other devices.
679     RendezvousInterface* rendezvous = nullptr;
680     const std::function<Status(const int64, const DeviceMgr*, Rendezvous** r)>*
681         create_rendezvous;
682 
683     // Mechanism for executing a collective op that needs to coordinate
684     // with parallel instances running on other devices.
685     CollectiveExecutor* collective_executor = nullptr;
686 
687     // The session state for this op.
688     SessionState* session_state = nullptr;
689 
690     // Unique session identifier. Can be empty.
691     string session_handle;
692 
693     // Metadata about the session. Can be nullptr.
694     const SessionMetadata* session_metadata = nullptr;
695 
696     // The tensor store for this op.
697     TensorStore* tensor_store = nullptr;
698 
699     // Mechanism used by this op kernel invocation to register a callback
700     // for its cancellation.
701     CancellationManager* cancellation_manager = nullptr;
702 
703     // Inputs to this op kernel.
704     const gtl::InlinedVector<TensorValue, 4>* inputs = nullptr;
705     bool is_input_dead = false;
706 
707     const gtl::InlinedVector<AllocatorAttributes, 4>* input_alloc_attrs =
708         nullptr;
709 
710     // Device context.
711     DeviceContext* op_device_context = nullptr;
712 
713     // Control-flow op supports.
714     FrameAndIter frame_iter;
715 
716     // Function call supports.
717     CallFrameInterface* call_frame = nullptr;
718     FunctionLibraryRuntime* function_library = nullptr;
719     std::function<void(std::function<void()>)>* runner = nullptr;
720     StepStatsCollectorInterface* stats_collector = nullptr;
721     GraphCollector* graph_collector = nullptr;
722 
723     // TensorSliceReaderCache support.
724     checkpoint::TensorSliceReaderCacheWrapper* slice_reader_cache = nullptr;
725 
726     // Support for forwarding reservations (used by ScopedAllocator).
727     static const int kNeverForward = -2;
728     static const int kNoReservation = -1;
729     // Values in [0,...) represent reservations for the indexed output.
730     const int* forward_from_array = nullptr;
731 
732     // For tracking actively running deferred ops.
733     std::function<void()> inc_num_deferred_ops_function;
734     std::function<void()> dec_num_deferred_ops_function;
735   };
736 
737   // params must outlive the OpKernelContext.
738   explicit OpKernelContext(Params* params);
739   OpKernelContext(Params* params, int num_outputs);
740   ~OpKernelContext();
741 
env()742   Env* env() const { return params_->device->env(); }
743 
step_id()744   int64 step_id() const { return params_->step_id; }
745 
is_eager()746   bool is_eager() const { return params_->is_eager; }
747 
op_kernel()748   const OpKernel& op_kernel() const { return *params_->op_kernel; }
749 
750   // Input/output signature.
751 
num_inputs()752   int num_inputs() const { return params_->inputs->size(); }
753   DataType input_dtype(int index) const;
754   Status input_dtype(StringPiece name, DataType* dtype) const;
755   MemoryType input_memory_type(int index) const;
756 
num_outputs()757   int num_outputs() const { return outputs_.size(); }
758   DataType expected_output_dtype(int index) const;
759   MemoryType output_memory_type(int index) const;
760 
761   // Input
762 
763   // Returns an immutable input tensor. May only be used for non-Ref
764   // inputs. For Ref inputs use mutable_input below.
765   // REQUIRES: !IsRefType(input_dtype(index))
766   // TODO(mrry): Convert this to return Status.
767   const Tensor& input(int index);
768 
769   // Returns the named immutable input tensor in "tensor", as defined
770   // in the OpDef. May only be used for non-Ref inputs. For Ref inputs
771   // use mutable_input below.
772   // REQUIRES: !IsRefType(input_dtype(index))
773   // REQUIRES: the named input must not be a list.
774   Status input(StringPiece name, const Tensor** tensor);
775 
776   // Returns the named list-valued immutable input in "list", as
777   // defined in the OpDef.  If the named output is not list-valued,
778   // returns a one-element list. May only be used for non-Ref
779   // inputs. For Ref inputs use mutable_input below.
780   // REQUIRES: !IsRefType(input_dtype(index))
781   Status input_list(StringPiece name, OpInputList* list);
782 
783   // For mutable inputs, use the following together to make sure there
784   // is no concurrent access to mutable_input(), e.g.:
785   // {
786   //   Tensor& t = context->mutable_input(index);
787   //   mutex_lock lock(*context->input_ref_mutex(index));
788   //   // modify the values in t
789   // }
790   // REQUIRES: IsRefType(input_dtype(index))
791   Status input_ref_mutex(StringPiece name, mutex** out_mutex);
792 
793   // Returns a mutable input tensor. Must be used to access Ref
794   // inputs.  REQUIRES: IsRefType(input_dtype(index)). The caller may
795   // modify the values stored in the Tensor buffer, and modifications
796   // will be visible to other Ops reading the same ref tensor. If
797   // !lock_held the input mutex will be acquired before returning the
798   // Tensor.
799   // TODO(mrry): Convert this to return Status.
800   Tensor mutable_input(int index, bool lock_held);
801 
802   // Returns the named mutable input tensor in "tensor", as defined in
803   // the OpDef. Must be used to access Ref inputs. The values stored
804   // in the Tensor buffer may be modified, and modifications will be
805   // visible to other Ops reading the same ref tensor. If !lock_held
806   // the input mutex will be acquired before returning the Tensor.
807   // REQUIRES: the named input must not be a list.
808   // REQUIRES: the named input must be a ref tensor.
809   Status mutable_input(StringPiece name, Tensor* tensor, bool lock_held);
810 
811   // Returns the named list-valued mutable input in "list", as defined
812   // in the OpDef.  If the named input is not list-valued, returns a
813   // one-element list. Must be used to access Ref inputs. The values
814   // stored in the Tensor buffer may be modified, and modifications
815   // will be visible to other Ops reading the same ref tensor.
816   // REQUIRES: the named input must be a ref tensor.
817   Status mutable_input_list(StringPiece name, OpMutableInputList* list);
818 
819   // Replace the corresponding Ref Input to use the storage buffer
820   // used by tensor. If !lock_held the input mutex will be acquired
821   // before returning the Tensor.
822   // REQUIRES: IsRefType(input_dtype(index)).
823   void replace_ref_input(int index, const Tensor& tensor, bool lock_held);
824 
825   // Replace the corresponding named Ref Input to use the storage
826   // buffer used by tensor. If !lock_held the input mutex will be
827   // acquired before returning the Tensor.
828   // REQUIRES: IsRefType(input_dtype(index)).
829   Status replace_ref_input(StringPiece name, const Tensor& tensor,
830                            bool lock_held);
831 
832   // Deletes the Tensor object used as the Ref Input at
833   // input_index. This is not usually necessary and should be used
834   // with caution. If !lock_held the input mutex will be acquired
835   // before returning the Tensor.
836   // REQUIRES: IsRefType(input_dtype(input_index)).
837   void delete_ref_input(int input_index, bool lock_held);
838 
839   // Return true if there is input at the given index. An operator has no
840   // input at index if its tensor is null. This is primarily used by the
841   // merge operator.
842   // TODO(mrry): Convert this to return Status.
843   bool has_input(int index) const;
844 
845   // Returns true if all inputs are the same shape, otherwise sets the
846   // status to a non-OK value and returns false.
847   // Usage: if (!context->ValidateInputsAreSameShape(this)) return;
848   bool ValidateInputsAreSameShape(OpKernel* op);
849 
850   // If non-null, kernels should populate with any partition subgraphs created.
graph_collector()851   GraphCollector* graph_collector() { return params_->graph_collector; }
852 
853   // Input to output forwarding.
854 
855   // Set the output Ref Tensor at output_index to be an alias of the
856   // input Ref Tensor at input_index.
857   // REQUIRES: IsRefType(input_dtype(input_index)).
858   // REQUIRES: IsRefType(output_dtype(output_index)).
859   void forward_ref_input_to_ref_output(int input_index, int output_index);
860 
861   // Returns true when an alias to input[input_index], reshaped to output_shape,
862   // which is safe to use for in-place computation was written to *output.
863   // Returns false if input[input_index] has a refcount greater than one, or if
864   // its type does not match the expected output type of output[output_index],
865   // or the number of elements in input[input_index] does not equal the number
866   // of elements in output_shape.
867   bool forward_input_to_output_with_shape(int input_index, int output_index,
868                                           const TensorShape& output_shape,
869                                           Tensor** output) TF_MUST_USE_RESULT;
870   Status forward_input_to_output_with_shape(StringPiece input_name,
871                                             StringPiece output_name,
872                                             const TensorShape& output_shape,
873                                             Tensor** output) TF_MUST_USE_RESULT;
874 
875   // Returns a pointer to a Tensor aliasing the underlying buffer backing
876   // input[input_index] iff
877   //   * input[input_index] is not a ref,
878   //   * the data type, shape, memory type, and allocator attributes of
879   //     input[input_index] are compatible with those given in dtype, shape,
880   //     memory_type, and attr,
881   //   * refcount on the underlying buffer is one.
882   //   * Either there is no forwarding reservation for either input_index
883   //     or output_index or the specified input is reserved for the specified
884   //     output. More precisely:
885   //
886   //     These cases mean neither input nor output has a reservation:
887   //        forward_from_array = nullptr
888   //     OR (input_index is not in forward_from_array AND
889   //         (output_index == kNoReservation OR
890   //          forward_from_array[output_index] == kNoReservation))
891   //
892   //     This case means that input_index is reserved for output_index:
893   //        forward_from_array[output_index] == input_index
894   //
895   //     This case means the output is reserved to always be allocated,
896   //     never assigned a forwarded input:
897   //        forward_from_array[output_index] == kNeverForward
898   //
899   // Otherwise returns nullptr.
900   // NOTE: For Cuda kernels that read inputs using the __ldg() intrinsic,
901   // forwarding is only safe if there are no reads via __ldg() after writes
902   // to the same address.
903   std::unique_ptr<Tensor> forward_input(
904       int input_index, int output_index, DataType output_dtype,
905       const TensorShape& output_shape, MemoryType output_memory_type,
906       const AllocatorAttributes& output_attr) TF_MUST_USE_RESULT;
907 
908   // Tries to forward one of the inputs given in input_indices to
909   // output[output_index]. If none of the given inputs can be forwarded, calls
910   // allocate_output() to allocate a new output buffer.
911   Status forward_input_or_allocate_output(
912       gtl::ArraySlice<int> candidate_input_indices, int output_index,
913       const TensorShape& output_shape, Tensor** output) TF_MUST_USE_RESULT;
914   Status forward_input_or_allocate_output(
915       gtl::ArraySlice<StringPiece> candidate_input_names,
916       StringPiece output_name, const TensorShape& output_shape,
917       Tensor** output) TF_MUST_USE_RESULT;
918 
919   // Tries to reuse one of the inputs given in input_indices as a temporary.
920   // If none of the given inputs can be forwarded, calls
921   // allocate_temp() to allocate a new temporary buffer.
922   Status forward_input_or_allocate_temp(
923       gtl::ArraySlice<int> candidate_input_indices, DataType type,
924       const TensorShape& shape, const AllocatorAttributes& allocator_attr,
925       Tensor* out_temp) TF_MUST_USE_RESULT;
926 
forward_input_or_allocate_temp(gtl::ArraySlice<int> candidate_input_indices,DataType type,const TensorShape & shape,Tensor * out_temp)927   Status forward_input_or_allocate_temp(
928       gtl::ArraySlice<int> candidate_input_indices, DataType type,
929       const TensorShape& shape, Tensor* out_temp) TF_MUST_USE_RESULT {
930     return forward_input_or_allocate_temp(candidate_input_indices, type, shape,
931                                           AllocatorAttributes(), out_temp);
932   }
933 
934   // Output
935 
936   // Returns the named list-valued output in "list", as defined in the OpDef.
937   // If the named output is not list-valued, returns a one-element list.
938   Status output_list(StringPiece name, OpOutputList* list);
939 
940   // If output_required(index) returns true, the OpKernel's Compute() method
941   // should call allocate_output(index, ...), set_output(index, ...),
942   // set_output_ref(index, ...), or set the status to a non-ok value.
943   // If it returns false, it may output, but is not required to do so.
944   // TODO(mrry): Convert this to return Status, and implement a string
945   // name version.
output_required(int index)946   bool output_required(int index) const {
947     return true;  // TODO(josh11b): implement
948   }
949 
950   // Allocation of tensors during kernel execution inside the Compute
951   // method:
952   //
953   // There are three methods to allocate Tensors when an Op kernel
954   // executes.
955   //
956   // 1) allocate_persistent. This is only needed for Tensors that will
957   // be stored by the Op between invocations, and it *must* be used
958   // for those Tensors. The call returns a PersistentTensor, and that
959   // is the only object the Op is allowed to hold on to between
960   // invocations. When the Tensor is needed in a subsequent
961   // invocation, it can be retrieved from the PersistentTensor using
962   // the AccessTensor method. This ensures that the system is made
963   // aware of any use of the tensor's allocated memory, which is
964   // needed for correctness on asynchronous devices such as GPUs.
965   //
966   // 2) allocate_output. This should be used to allocate any tensor
967   // that is going to be used as an output from the Op at the end of
968   // the current execution. The caller indicates which output the
969   // Tensor will be assigned to, and the call returns the
970   // newly-allocated Tensor. The Tensor can subsequently be assigned
971   // to during kernel execution, and will be used as the designated
972   // output when the kernel execution completes.
973   //
974   // 3) allocate_temp. This should be used to allocate any scratch
975   // storage that is needed while the kernel is executing, and will
976   // not be retained by the Op.
977   //
978   // In some cases a Tensor needs to be used as an output even though
979   // it was previously allocated elsewhere. The Tensor may have been
980   // passed as an input, or stored in a PersistentTensor during a
981   // previous kernel execution, or allocated earlier in the kernel
982   // execution at a time when it was not known which output it would
983   // be assigned to. In this case the kernel can use set_output or
984   // set_output_ref to indicate that the tensor should be used as the
985   // designated output. It is legal to use any previously-allocated
986   // Tensor as an argument to set_output or set_output_ref, including
987   // Tensors allocated via allocate_temp. There may be a performance
988   // penalty to using a Tensor that was not allocated using
989   // allocate_output. This is because allocate_output uses the
990   // AllocatorAttributes stored in output_attr_array for the
991   // designated output. In some cases, using the wrong attributes may
992   // cause an extra copy of the Tensor's buffer.
993 
994   // Allocates output for the specified output index with shape.
995   // OpKernelContext retains ownership of the returned pointer. See
996   // comment above.
997   //
998   // If memory allocation fails, returns an error status.
999   //
1000   // REQUIRES: !IsRefType(expected_output_dtype(index))
1001   Status allocate_output(int index, const TensorShape& shape,
1002                          Tensor** tensor) TF_MUST_USE_RESULT;
1003   Status allocate_output(StringPiece name, const TensorShape& shape,
1004                          Tensor** tensor) TF_MUST_USE_RESULT;
1005   // The following methods use the supplied attributes instead of
1006   // those in output_attr_array. The caller is responsible for
1007   // ensuring that the attributes are "compatible" with the
1008   // output_attr_array, e.g. the tensor is allocated on the correct
1009   // device. See comment above.
1010   Status allocate_output(int index, const TensorShape& shape, Tensor** tensor,
1011                          AllocatorAttributes attr) TF_MUST_USE_RESULT;
1012   Status allocate_output(StringPiece name, const TensorShape& shape,
1013                          Tensor** tensor,
1014                          AllocatorAttributes attr) TF_MUST_USE_RESULT;
1015 
1016   // Allocates a temporary Tensor of the specified type and
1017   // shape. Devices such as GPUs that enqueue Ops for lazy execution
1018   // may retain references to the temporary tensors after the Op's
1019   // Compute method has run. See comment above.
1020   Status allocate_temp(DataType type, const TensorShape& shape,
1021                        Tensor* out_temp, AllocatorAttributes allocator_attr,
1022                        const AllocationAttributes& allocation_attr);
allocate_temp(DataType type,const TensorShape & shape,Tensor * out_temp,AllocatorAttributes allocator_attr)1023   Status allocate_temp(DataType type, const TensorShape& shape,
1024                        Tensor* out_temp, AllocatorAttributes allocator_attr) {
1025     return allocate_temp(type, shape, out_temp, allocator_attr,
1026                          AllocationAttributes());
1027   }
allocate_temp(DataType type,const TensorShape & shape,Tensor * out_temp)1028   Status allocate_temp(DataType type, const TensorShape& shape,
1029                        Tensor* out_temp) {
1030     return allocate_temp(type, shape, out_temp, AllocatorAttributes());
1031   }
1032 
1033   // Allocates a Tensor of the specified type and shape which the Op
1034   // plans to maintain as persistent state. out_persistent holds the
1035   // PersistentTensor which is the object the caller should store. For
1036   // convenience, if out_tensor is non-null then it will be filled in
1037   // with a Tensor* pointing to the newly-allocated tensor which the
1038   // caller can use instead of calling
1039   // out_persistent->AccessTensor. The caller does not own out_tensor
1040   // and should not keep a copy of it. See comment above.
1041   Status allocate_persistent(DataType type, const TensorShape& shape,
1042                              PersistentTensor* out_persistent,
1043                              Tensor** out_tensor, AllocatorAttributes attr);
allocate_persistent(DataType type,const TensorShape & shape,PersistentTensor * out_persistent,Tensor ** out_tensor)1044   Status allocate_persistent(DataType type, const TensorShape& shape,
1045                              PersistentTensor* out_persistent,
1046                              Tensor** out_tensor) {
1047     return allocate_persistent(type, shape, out_persistent, out_tensor,
1048                                AllocatorAttributes());
1049   }
1050 
1051   // Copies a tensor (allocated by the caller) to the specified output
1052   // index.  REQUIRES: !IsRefType(expected_output_dtype(index))
1053   // REQUIRES: 'tensor' must have the same MemoryType as
1054   // output_memory_types[index]. See comment above.
1055   Status set_output(StringPiece name, const Tensor& tensor);
1056 
1057   // To output a reference.  Caller retains ownership of mu and tensor_for_ref,
1058   // and they must outlive all uses within the step. See comment above.
1059   // REQUIRES: IsRefType(expected_output_dtype(index))
1060   Status set_output_ref(StringPiece name, mutex* mu, Tensor* tensor_for_ref);
1061 
1062   // Returns nullptr if allocate_output() or set_output() have not been called.
1063   Status mutable_output(StringPiece name, Tensor** tensor);
1064 
1065   // Return the DeviceContext that should be used for this Op.
1066   //
1067   // If using the templated function, the type must be a subclass
1068   // of DeviceContext.
1069   //
1070   // Returns nullptr if the device did not provide one.
1071   template <typename T>
1072   T* op_device_context();
op_device_context()1073   DeviceContext* op_device_context() {
1074     DeviceContext* ret = params_->op_device_context;
1075     if (ret == nullptr) {
1076       auto* dev_info = device()->tensorflow_gpu_device_info();
1077       if (dev_info) ret = dev_info->default_context;
1078     }
1079     return ret;
1080   }
1081 
input_alloc_attr(int index)1082   AllocatorAttributes input_alloc_attr(int index) const {
1083     if (params_->input_alloc_attrs == nullptr) {
1084       return AllocatorAttributes();
1085     } else {
1086       DCHECK_GE(index, 0);
1087       DCHECK_LT(index, params_->input_alloc_attrs->size());
1088       return (*params_->input_alloc_attrs)[index];
1089     }
1090   }
1091 
output_alloc_attr(int index)1092   AllocatorAttributes output_alloc_attr(int index) const {
1093     return params_->output_attr_array[index];
1094   }
1095 
ConsumeWrappedAllocators()1096   gtl::InlinedVector<WrappedAllocator, 4> ConsumeWrappedAllocators() {
1097     gtl::InlinedVector<WrappedAllocator, 4> retrieved;
1098     if (tracking_state_) {
1099       mutex_lock lock(tracking_state_->mu);
1100       retrieved.swap(tracking_state_->wrapped_allocators);
1101     }
1102     return retrieved;
1103   }
1104 
1105   // Communication.
1106   //
1107   // An op kernel communicates with outside environment through
1108   // Rendezvous Send() and Recv().
rendezvous()1109   RendezvousInterface* rendezvous() const { return params_->rendezvous; }
create_rendezvous(const int64 step_id,const DeviceMgr * device_mgr,Rendezvous ** r)1110   Status create_rendezvous(const int64 step_id, const DeviceMgr* device_mgr,
1111                            Rendezvous** r) const {
1112     return (*params_->create_rendezvous)(step_id, device_mgr, r);
1113   }
1114 
collective_executor()1115   CollectiveExecutor* collective_executor() const {
1116     return params_->collective_executor;
1117   }
1118 
1119   // An op kernel can access the session state it belongs to.
session_state()1120   SessionState* session_state() const { return params_->session_state; }
1121 
1122   // Unique identifier of the session it belongs to. Can be empty.
session_handle()1123   string session_handle() const { return params_->session_handle; }
1124 
1125   // Metadata about the session. Can be nullptr.
session_metadata()1126   const SessionMetadata* session_metadata() const {
1127     return params_->session_metadata;
1128   }
1129 
1130   // An op kernel can access the tensor store of the run it belongs to.
tensor_store()1131   TensorStore* tensor_store() const { return params_->tensor_store; }
1132 
1133   // Function call support.
1134   //
1135   // If this kernel invocation is within a function execution,
1136   // call_frame() returns the call frame for the function call.
call_frame()1137   CallFrameInterface* call_frame() const { return params_->call_frame; }
1138 
1139   // If not nullptr, the kernel invoke functions defined in the
1140   // library. E.g., CHECK_NOTNULL(function_library())->Run("Foo", ...).
function_library()1141   FunctionLibraryRuntime* function_library() const {
1142     return params_->function_library;
1143   }
1144 
runner()1145   std::function<void(std::function<void()>)>* runner() const {
1146     return params_->runner;
1147   }
stats_collector()1148   StepStatsCollectorInterface* stats_collector() const {
1149     return params_->stats_collector;
1150   }
1151 
1152   // Shared resources accessible to this kernel.
resource_manager()1153   ResourceMgr* resource_manager() const { return params_->resource_manager; }
1154 
slice_reader_cache()1155   checkpoint::TensorSliceReaderCacheWrapper* slice_reader_cache() const {
1156     return params_->slice_reader_cache;
1157   }
1158 
1159   // Execution.
1160   //
1161   // OpKernels can use these eigen devices to carry out their
1162   // numerical computation.
eigen_cpu_device()1163   const Eigen::ThreadPoolDevice& eigen_cpu_device() const {
1164     return *device()->eigen_cpu_device();
1165   }
eigen_gpu_device()1166   const Eigen::GpuDevice& eigen_gpu_device() const {
1167     return params_->eigen_gpu_device->device();
1168   }
1169 #ifdef TENSORFLOW_USE_SYCL
eigen_sycl_device()1170   const Eigen::SyclDevice& eigen_sycl_device() const {
1171     return *device()->eigen_sycl_device();
1172   }
1173 #endif
1174   template <typename EigenDeviceType>
1175   const EigenDeviceType& eigen_device() const;
1176 
1177   // Error handling.
1178 
1179   // If expected_inputs == inputs() and expected_outputs == output_types(),
1180   // returns OK, else returns INVALID_ARGUMENT with an error message.
1181   // Recommended for Ops with dynamic signatures, where validation can only
1182   // be performed at runtime.
1183   Status MatchSignature(const DataTypeSlice expected_inputs,
1184                         const DataTypeSlice expected_outputs);
1185 
1186   // An OpKernel should call SetStatus() if Compute() encounters an
1187   // error.
1188   void SetStatus(const Status& status);
status()1189   const Status& status() const { return status_; }
1190 
1191   // Cancellation.
1192   //
1193   // EXPERIMENTAL. See the implementation in tensorflow::FIFOQueue for an
1194   // example of how to use this API.
cancellation_manager()1195   CancellationManager* cancellation_manager() const {
1196     return params_->cancellation_manager;
1197   }
1198 
1199   // Other accessors.
1200 
1201   // For control flow.
frame_iter()1202   FrameAndIter frame_iter() const { return params_->frame_iter; }
is_input_dead()1203   bool is_input_dead() const { return params_->is_input_dead; }
1204 
1205   // May be used, e.g., to get GPU handles, etc.
1206   // TODO(tucker): Add example usage.
device()1207   DeviceBase* device() const { return params_->device; }
1208 
1209   // Retrieve list of referenced tensors in out_vector. Once this is
1210   // called, it is not legal to reference any more tensors.  Should
1211   // not be called from Op kernels.
1212   void retrieve_accessed_tensors(TensorReferenceVector* out_vector);
1213 
1214   // Per-step container for use by white-listed internal ops.
step_container()1215   ScopedStepContainer* step_container() const {
1216     return params_->step_container;
1217   }
1218 
1219   // Helper routines for the OP_REQUIRES macros
1220   void CtxFailure(const Status& s);
1221   void CtxFailureWithWarning(const Status& s);
1222   void CtxFailure(const char* file, int line, const Status& s);
1223   void CtxFailureWithWarning(const char* file, int line, const Status& s);
1224 
1225   // Unrecommended functions: these are functions that have some
1226   // current uses but are not recommended for use, and may go away at
1227   // some future major version release.
1228   //
1229   // The following functions all have versions that return Status
1230   // to capture error conditions, and are strongly preferred.
1231   Tensor* mutable_output(int index);
1232   void set_output(int index, const Tensor& tensor);
1233   mutex* input_ref_mutex(int index);
1234   void set_output_ref(int index, mutex* mu, Tensor* tensor_for_ref);
1235   TensorValue release_output(int index);
1236 
track_allocations()1237   bool track_allocations() const { return params_->track_allocations; }
1238 
1239   // Records temp memory allocation. Tensor object is recorded to identify the
1240   // case where temp memory is used as output memory.
1241   void record_temp_memory_allocation(int64 size, const Tensor& t)
1242       LOCKS_EXCLUDED(tracking_state_->stats_mu);
1243 
1244   // Returns recorded size of temporary memory;
1245   int64 temp_memory_allocated() const LOCKS_EXCLUDED(tracking_state_->stats_mu);
1246 
1247   // Records persistent memory allocation, size can be negative indicating
1248   // deallocation.
1249   void record_persistent_memory_allocation(int64 size, int64 alloc_id = -1)
1250       LOCKS_EXCLUDED(tracking_state_->stats_mu);
1251 
1252   // Returns recorded size and ids of persistent memory.
1253   int64 persistent_memory_allocated() const
1254       LOCKS_EXCLUDED(tracking_state_->stats_mu);
1255 
1256   std::vector<int64> persistent_alloc_ids() const
1257       LOCKS_EXCLUDED(tracking_state_->stats_mu);
1258 
1259   // Resets counters for temp and persistent memory and recorded ids.
1260   void clear_recorded_memory() LOCKS_EXCLUDED(tracking_state_->stats_mu);
1261 
1262   bool input_is_ref(int index) const;
1263 
1264   void set_record_memory_consumption(bool v);
1265 
1266   // Used by OpKernel implementations to track actively running deferred ops.
1267   //
1268   // A deferred op is one whose Compute method returns (or whose ComputeAsync
1269   // method invokes the callback) when work is scheduled onto a device. At that
1270   // point, we don't know when the work will actually complete (or if it has
1271   // already completed) on the device. These functions allow the executor to
1272   // track the status of deferred ops and act accordingly.
1273   //
1274   // Deferred OpKernel implementations must use these methods to get two
1275   // functions. It then must call these two functions in pairs, before and after
1276   // device execution, respectively.
inc_num_deferred_ops_function()1277   TF_MUST_USE_RESULT std::function<void()> inc_num_deferred_ops_function() {
1278     DCHECK(params_->op_kernel->is_deferred());
1279     return params_->inc_num_deferred_ops_function
1280                ? params_->inc_num_deferred_ops_function
1281                : []() {};
1282   }
dec_num_deferred_ops_function()1283   TF_MUST_USE_RESULT std::function<void()> dec_num_deferred_ops_function() {
1284     DCHECK(params_->op_kernel->is_deferred());
1285     return params_->dec_num_deferred_ops_function
1286                ? params_->dec_num_deferred_ops_function
1287                : []() {};
1288   }
1289 
1290   Allocator* get_allocator(AllocatorAttributes attr);
1291 
1292  private:
1293   bool record_memory_consumption_ = false;
1294 
1295   // Internal method to add a tensor's buffer to the list of buffers
1296   // referenced during the execution of the Op, so that GPUs may
1297   // accurately track the memory that may not be reused until the Op
1298   // execution completes.
1299   void record_tensor_reference(const Tensor& tensor);
1300   void really_record_tensor_reference(const Tensor& tensor);
1301 
1302   // Internal common method used when allocating tensor memory
allocate_tensor(DataType type,const TensorShape & shape,Tensor * out_tensor,AllocatorAttributes allocator_attr)1303   Status allocate_tensor(DataType type, const TensorShape& shape,
1304                          Tensor* out_tensor,
1305                          AllocatorAttributes allocator_attr) {
1306     return allocate_tensor(type, shape, out_tensor, allocator_attr,
1307                            AllocationAttributes());
1308   }
1309 
1310   Status allocate_tensor(DataType type, const TensorShape& shape,
1311                          Tensor* out_tensor, AllocatorAttributes allocator_attr,
1312                          const AllocationAttributes& allocation_attr);
1313 
1314   // Initialize the allocated_scope_ids_ set the first time this method is
1315   // called.
1316   void maybe_initialize_scope_id_set();
1317 
1318   // This is called by PersistentTensor::AccessTensor whenever the
1319   // wrapped tensor is retrieved, to ensure the runtime knows that the
1320   // Tensor is being accessed within an Op. This is necessary for
1321   // memory safety of devices like GPUs that queue Ops for
1322   // asynchronous execution after the Compute() method completes.
1323   friend class PersistentTensor;
1324   void NotifyUseOfPersistentTensor(const Tensor& tensor);
1325 
1326   Status status_;
1327   friend class CollectiveExecutor;  // for access to params_
1328   Params* params_;                  // not owned
1329   gtl::InlinedVector<TensorValue, 4> outputs_;
1330 
1331   // Keep track of calls to ScopedAllocator.
1332   // TODO(ayushd): change to absl::flat_hash_set.
1333   std::unique_ptr<std::unordered_set<int32>> allocated_scope_ids_;
1334 
1335   // The following data members are only used when allocation tracking is
1336   // enabled, memory consumption is being recorded, or tensor access is being
1337   // recorded.
1338   struct TrackingState {
1339     mutable mutex mu;
1340     gtl::InlinedVector<WrappedAllocator, 4> wrapped_allocators GUARDED_BY(mu);
1341 
1342     UniqueTensorReferences referenced_tensors GUARDED_BY(mu);
1343 
1344     mutable mutex stats_mu;
1345     int64 temp_memory_allocated GUARDED_BY(stats_mu) = 0;
1346 
1347     int64 persistent_memory_allocated GUARDED_BY(stats_mu) = 0;
1348     gtl::InlinedVector<std::pair<const void*, int64>, 2>
1349         temp_tensor_buffer_and_size GUARDED_BY(stats_mu);
1350     gtl::InlinedVector<int64, 2> persistent_alloc_ids GUARDED_BY(stats_mu);
1351   };
1352   std::unique_ptr<TrackingState> tracking_state_;
1353 
1354   TF_DISALLOW_COPY_AND_ASSIGN(OpKernelContext);
1355 };
1356 
1357 template <>
1358 const Eigen::ThreadPoolDevice& OpKernelContext::eigen_device() const;
1359 
1360 template <>
1361 const Eigen::GpuDevice& OpKernelContext::eigen_device() const;
1362 
1363 #ifdef TENSORFLOW_USE_SYCL
1364 template <>
1365 const Eigen::SyclDevice& OpKernelContext::eigen_device() const;
1366 #endif
1367 
1368 // Register your OpKernel by specifying the Op's name, the device the
1369 // kernel runs on, any type attr constraints for this kernel, any
1370 // host-memory args, and the class to instantiate.  Examples:
1371 //
1372 //  // A kernel that supports all types.
1373 //  REGISTER_KERNEL_BUILDER(Name("Save").Device(DEVICE_CPU), SaveOp);
1374 //
1375 //  // The following are equivalent ways of specifying that the kernel only
1376 //  // works if the "T" type attr is set to DT_FLOAT.
1377 //  REGISTER_KERNEL_BUILDER(
1378 //      Name("Sub").Device(DEVICE_CPU).TypeConstraint<float>("T"),
1379 //      SubOp<float>);
1380 //  // (You would then repeat this for every type supported by "Sub".)
1381 //
1382 //  // This form allows you to specify a list of types as the constraint.
1383 //  REGISTER_KERNEL_BUILDER(Name("Sub")
1384 //                              .Device(DEVICE_CPU)
1385 //                              .TypeConstraint("T", {DT_FLOAT}),
1386 //                          SubOp<float>);
1387 //
1388 //  // A kernel that expects one of the input tensors in host memory.
1389 //  REGISTER_KERNEL_BUILDER(
1390 //      Name("Reshape").Device(DEVICE_GPU).HostMemory("shape"), ReshapeOp);
1391 //
1392 // See kernel_def_builder for details.
1393 
1394 // Instantiate an OpKernel that has been registered.  Returns nullptr
1395 // if no operation for that type of device / input signature combination
1396 // (and a NOT_FOUND *status), or there is an error in construction (and
1397 // an INVALID_ARGUMENT *status).  Otherwise, the caller takes ownership
1398 // of the returned pointer.
1399 // EXPECTED USAGE: unique_ptr<OpKernel> op = CreateOpKernel(...);
1400 // REQUIRES: def has all attrs specified (e.g. using AddDefaultsToNodeDef()).
1401 std::unique_ptr<OpKernel> CreateOpKernel(DeviceType device_type,
1402                                          DeviceBase* device,
1403                                          Allocator* allocator,
1404                                          const NodeDef& def,
1405                                          int graph_def_version, Status* status);
1406 Status CreateOpKernel(DeviceType device_type, DeviceBase* device,
1407                       Allocator* allocator, FunctionLibraryRuntime* flib,
1408                       const NodeDef& def, int graph_def_version,
1409                       OpKernel** kernel);
1410 Status CreateOpKernel(DeviceType device_type, DeviceBase* device,
1411                       Allocator* allocator, FunctionLibraryRuntime* flib,
1412                       ResourceMgr* resource_mgr, const NodeDef& def,
1413                       int graph_def_version, OpKernel** kernel);
1414 
1415 // Returns into 'device_types' the subset of prioritized_types that this
1416 // binary has registered for the given NodeDef.
1417 //
1418 // REQUIRES: * 'device_types' is not nullptr.
1419 //           * def has all attrs specified (e.g. using AddDefaultsToNodeDef()).
1420 Status SupportedDeviceTypesForNode(
1421     const std::vector<DeviceType>& prioritized_types, const NodeDef& def,
1422     PrioritizedDeviceTypeVector* device_types,
1423     const DeviceNameUtils::ParsedName* local_address_spec = nullptr);
1424 
1425 // Returns a message with a description of the kernels registered for op
1426 // `op_name`.
1427 string KernelsRegisteredForOp(StringPiece op_name);
1428 
1429 // Call once after Op registration has completed.
1430 Status ValidateKernelRegistrations(const OpRegistryInterface& op_registry);
1431 
1432 // -----------------------------------------------------------------------------
1433 // OpKernel registration implementation follows, please ignore.
1434 
1435 // Allow the REGISTER_KERNEL_BUILDER(Name("op_name").Device(...)...) syntax.
1436 namespace register_kernel {
1437 
1438 class Name : public KernelDefBuilder {
1439  public:
1440   // With selective registration, kernels whose implementation class is not used
1441   // by any kernel are disabled with the SHOULD_REGISTER_OP_KERNEL call in
1442   // REGISTER_KERNEL_BUILDER_UNIQ. However, an unused kernel that shares an
1443   // implementation class with a used kernel would get through that mechanism.
1444   //
1445   // This mechanism stops that registration by changing the name of the kernel
1446   // for the unused op to one that is ignored by
1447   // OpKernelRegistrar::InitInternal.  Note that this method alone is
1448   // not sufficient - the compiler can't evaluate the entire KernelDefBuilder at
1449   // compilation time, so this method doesn't actually reduce code size.
Name(const char * op)1450   explicit Name(const char* op)
1451       : KernelDefBuilder(SHOULD_REGISTER_OP(op) ? op : "_no_register") {}
1452 };
1453 
1454 namespace system {
1455 
1456 class Name : public KernelDefBuilder {
1457  public:
1458   // For system kernels, we ignore selective registration and
1459   // unconditionally register the kernel.
Name(const char * op)1460   explicit Name(const char* op) : KernelDefBuilder(op) {}
1461 };
1462 
1463 }  // namespace system
1464 
1465 }  // namespace register_kernel
1466 
1467 #define REGISTER_KERNEL_BUILDER(kernel_builder, ...) \
1468   REGISTER_KERNEL_BUILDER_UNIQ_HELPER(__COUNTER__, kernel_builder, __VA_ARGS__)
1469 
1470 #define REGISTER_KERNEL_BUILDER_UNIQ_HELPER(ctr, kernel_builder, ...) \
1471   REGISTER_KERNEL_BUILDER_UNIQ(ctr, kernel_builder, __VA_ARGS__)
1472 
1473 #define REGISTER_KERNEL_BUILDER_UNIQ(ctr, kernel_builder, ...)        \
1474   constexpr bool should_register_##ctr##__flag =                      \
1475       SHOULD_REGISTER_OP_KERNEL(#__VA_ARGS__);                        \
1476   static ::tensorflow::kernel_factory::OpKernelRegistrar              \
1477       registrar__body__##ctr##__object(                               \
1478           should_register_##ctr##__flag                               \
1479               ? ::tensorflow::register_kernel::kernel_builder.Build() \
1480               : nullptr,                                              \
1481           #__VA_ARGS__,                                               \
1482           [](::tensorflow::OpKernelConstruction* context)             \
1483               -> ::tensorflow::OpKernel* {                            \
1484             return new __VA_ARGS__(context);                          \
1485           });
1486 
1487 // The `REGISTER_SYSTEM_KERNEL_BUILDER()` macro acts as
1488 // `REGISTER_KERNEL_BUILDER()` except that the kernel is registered
1489 // unconditionally even when selective registration is used.
1490 #define REGISTER_SYSTEM_KERNEL_BUILDER(kernel_builder, ...)               \
1491   REGISTER_SYSTEM_KERNEL_BUILDER_UNIQ_HELPER(__COUNTER__, kernel_builder, \
1492                                              __VA_ARGS__)
1493 
1494 #define REGISTER_SYSTEM_KERNEL_BUILDER_UNIQ_HELPER(ctr, kernel_builder, ...) \
1495   REGISTER_SYSTEM_KERNEL_BUILDER_UNIQ(ctr, kernel_builder, __VA_ARGS__)
1496 
1497 #define REGISTER_SYSTEM_KERNEL_BUILDER_UNIQ(ctr, kernel_builder, ...)    \
1498   static ::tensorflow::kernel_factory::OpKernelRegistrar                 \
1499       registrar__body__##ctr##__object(                                  \
1500           ::tensorflow::register_kernel::system::kernel_builder.Build(), \
1501           #__VA_ARGS__,                                                  \
1502           [](::tensorflow::OpKernelConstruction* context)                \
1503               -> ::tensorflow::OpKernel* {                               \
1504             return new __VA_ARGS__(context);                             \
1505           });
1506 
1507 // Checks whether a given kernel is registered on device_type.
1508 bool KernelDefAvailable(const DeviceType& device_type, const NodeDef& node_def);
1509 
1510 // If node of node_name, experimental_debug_info, node_op, node_device and
1511 // node_attrs has a corresponding kernel registered on device_type, returns OK
1512 // and fill in the kernel def and kernel_class_name. <def> and
1513 // <kernel_class_name> may be null.
1514 Status FindKernelDef(
1515     const DeviceType& device_type, StringPiece node_name,
1516     bool has_experimental_debug_info,
1517     const NodeDef_ExperimentalDebugInfo& experimental_debug_info,
1518     StringPiece node_op, StringPiece node_device, AttrSlice node_attrs,
1519     const KernelDef** def, string* kernel_class_name);
1520 
1521 // If node_def has a corresponding kernel registered on device_type,
1522 // returns OK and fill in the kernel def and kernel_class_name. <def> and
1523 // <kernel_class_name> may be null.
1524 Status FindKernelDef(const DeviceType& device_type, const NodeDef& node_def,
1525                      const KernelDef** def, string* kernel_class_name);
1526 
1527 // Writes a list of all registered kernels to LOG(INFO), to help users debug
1528 // missing kernel errors.
1529 void LogAllRegisteredKernels();
1530 
1531 // Gets a list of all registered kernels.
1532 KernelList GetAllRegisteredKernels();
1533 
1534 // Gets a list of all registered kernels for which predicate returns true
1535 KernelList GetFilteredRegisteredKernels(
1536     const std::function<bool(const KernelDef&)>& predicate);
1537 
1538 // Gets a list of all registered kernels for a given op
1539 KernelList GetRegisteredKernelsForOp(StringPiece op_name);
1540 
1541 namespace kernel_factory {
1542 
1543 // OpKernelFactory is responsible for creating OpKernels when TensorFlow needs
1544 // them. You register factories with the TensorFlow core by constructing an
1545 // OpKernelRegistrar and passing the factory as a constructor parameter.
1546 class OpKernelFactory {
1547  public:
1548   virtual OpKernel* Create(OpKernelConstruction* context) = 0;
1549   virtual ~OpKernelFactory() = default;
1550 };
1551 
1552 class OpKernelRegistrar {
1553  public:
1554   // Registers the given kernel factory with TensorFlow. TF will call the
1555   // factory Create() method when it determines that a kernel matching the given
1556   // KernelDef is required.
OpKernelRegistrar(const KernelDef * kernel_def,StringPiece kernel_class_name,std::unique_ptr<OpKernelFactory> factory)1557   OpKernelRegistrar(const KernelDef* kernel_def, StringPiece kernel_class_name,
1558                     std::unique_ptr<OpKernelFactory> factory) {
1559     // Perform the check in the header to allow compile-time optimization
1560     // to a no-op, allowing the linker to remove the kernel symbols.
1561     if (kernel_def != nullptr) {
1562       InitInternal(kernel_def, kernel_class_name, std::move(factory));
1563     }
1564   }
1565 
1566   // Registers the given factory function with TensorFlow. This is equivalent
1567   // to registering a factory whose Create function invokes `create_fn`.
OpKernelRegistrar(const KernelDef * kernel_def,StringPiece kernel_class_name,OpKernel * (* create_fn)(OpKernelConstruction *))1568   OpKernelRegistrar(const KernelDef* kernel_def, StringPiece kernel_class_name,
1569                     OpKernel* (*create_fn)(OpKernelConstruction*)) {
1570     // Perform the check in the header to allow compile-time optimization
1571     // to a no-op, allowing the linker to remove the kernel symbols.
1572     if (kernel_def != nullptr) {
1573       InitInternal(kernel_def, kernel_class_name,
1574                    absl::make_unique<PtrOpKernelFactory>(create_fn));
1575     }
1576   }
1577 
1578  private:
1579   struct PtrOpKernelFactory : public OpKernelFactory {
PtrOpKernelFactoryPtrOpKernelFactory1580     explicit PtrOpKernelFactory(OpKernel* (*create_func)(OpKernelConstruction*))
1581         : create_func_(create_func) {}
1582 
1583     OpKernel* Create(OpKernelConstruction* context) override;
1584 
1585     OpKernel* (*create_func_)(OpKernelConstruction*);
1586   };
1587 
1588   void InitInternal(const KernelDef* kernel_def, StringPiece kernel_class_name,
1589                     std::unique_ptr<OpKernelFactory> factory);
1590 };
1591 
1592 }  // namespace kernel_factory
1593 
1594 // -----------------------------------------------------------------------------
1595 // Template and inline method implementations, please ignore
1596 
1597 template <class T>
GetAttr(StringPiece attr_name,T * value)1598 Status OpKernelConstruction::GetAttr(StringPiece attr_name, T* value) const {
1599   return GetNodeAttr(def(), attr_name, value);
1600 }
1601 
input_dtype(int index)1602 inline DataType OpKernelContext::input_dtype(int index) const {
1603   DCHECK_GE(index, 0);
1604   DCHECK_LT(index, num_inputs());
1605   const TensorValue& value((*params_->inputs)[index]);
1606   return value.dtype();
1607 }
1608 
input_memory_type(int index)1609 inline MemoryType OpKernelContext::input_memory_type(int index) const {
1610   DCHECK_GE(index, 0);
1611   DCHECK_LT(index, num_inputs());
1612   return op_kernel().input_memory_types()[index];
1613 }
1614 
expected_output_dtype(int index)1615 inline DataType OpKernelContext::expected_output_dtype(int index) const {
1616   DCHECK_GE(index, 0);
1617   DCHECK_LT(index, num_outputs());
1618   return params_->op_kernel->output_type(index);
1619 }
1620 
output_memory_type(int index)1621 inline MemoryType OpKernelContext::output_memory_type(int index) const {
1622   DCHECK_GE(index, 0);
1623   DCHECK_LT(index, num_outputs());
1624   return op_kernel().output_memory_types()[index];
1625 }
1626 
input_is_ref(int index)1627 inline bool OpKernelContext::input_is_ref(int index) const {
1628   const TensorValue& value((*params_->inputs)[index]);
1629   return value.is_ref();
1630 }
1631 
record_tensor_reference(const Tensor & tensor)1632 inline void OpKernelContext::record_tensor_reference(const Tensor& tensor) {
1633   DCHECK_EQ(params_->device->RequiresRecordingAccessedTensors(),
1634             params_->record_tensor_accesses);
1635   if (params_->record_tensor_accesses) {
1636     really_record_tensor_reference(tensor);
1637   }
1638 }
1639 
retrieve_accessed_tensors(TensorReferenceVector * out_vector)1640 inline void OpKernelContext::retrieve_accessed_tensors(
1641     TensorReferenceVector* out_vector) {
1642   if (params_->record_tensor_accesses) {
1643     DCHECK(tracking_state_);
1644     mutex_lock l(tracking_state_->mu);
1645     tracking_state_->referenced_tensors.FreezeAndReturnReferences(out_vector);
1646   }
1647 }
1648 
1649 // no input if tensor == nullptr.
has_input(int index)1650 inline bool OpKernelContext::has_input(int index) const {
1651   DCHECK_GE(index, 0);
1652   DCHECK_LT(index, num_inputs());
1653   return (*params_->inputs)[index].tensor != nullptr;
1654 }
1655 
input_ref_mutex(int index)1656 inline mutex* OpKernelContext::input_ref_mutex(int index) {
1657   DCHECK_GE(index, 0);
1658   DCHECK_LT(index, num_inputs());
1659   DCHECK(input_is_ref(index));
1660   return (*params_->inputs)[index].mutex_if_ref;
1661 }
1662 
NotifyUseOfPersistentTensor(const Tensor & t)1663 inline void OpKernelContext::NotifyUseOfPersistentTensor(const Tensor& t) {
1664   if (t.IsInitialized()) {
1665     record_tensor_reference(t);
1666   }
1667 }
1668 
mutable_output(int index)1669 inline Tensor* OpKernelContext::mutable_output(int index) {
1670   DCHECK_GE(index, 0);
1671   DCHECK_LT(index, num_outputs());
1672   // No need to record_tensor_reference since the output must already
1673   // have been set by a call that did so.
1674   return outputs_[index].tensor;
1675 }
1676 
release_output(int index)1677 inline TensorValue OpKernelContext::release_output(int index) {
1678   DCHECK_GE(index, 0);
1679   DCHECK_LT(index, num_outputs());
1680   TensorValue value = outputs_[index];
1681   outputs_[index] = TensorValue();
1682   return value;
1683 }
1684 
forward_input_or_allocate_output(gtl::ArraySlice<int> candidate_input_indices,int output_index,const TensorShape & output_shape,Tensor ** output)1685 inline Status OpKernelContext::forward_input_or_allocate_output(
1686     gtl::ArraySlice<int> candidate_input_indices, int output_index,
1687     const TensorShape& output_shape, Tensor** output) {
1688   for (int input_index : candidate_input_indices) {
1689     if (forward_input_to_output_with_shape(input_index, output_index,
1690                                            output_shape, output)) {
1691       return Status::OK();
1692     }
1693   }
1694   return allocate_output(output_index, output_shape, output);
1695 }
1696 
forward_input_or_allocate_output(gtl::ArraySlice<StringPiece> candidate_input_names,StringPiece output_name,const TensorShape & output_shape,Tensor ** output)1697 inline Status OpKernelContext::forward_input_or_allocate_output(
1698     gtl::ArraySlice<StringPiece> candidate_input_names, StringPiece output_name,
1699     const TensorShape& output_shape, Tensor** output) {
1700   for (const StringPiece& input_name : candidate_input_names) {
1701     if (forward_input_to_output_with_shape(input_name, output_name,
1702                                            output_shape, output)
1703             .ok()) {
1704       return Status::OK();
1705     }
1706   }
1707   return allocate_output(output_name, output_shape, output);
1708 }
1709 
1710 template <typename T>
op_device_context()1711 T* OpKernelContext::op_device_context() {
1712   static_assert(std::is_base_of<DeviceContext, T>::value,
1713                 "T is not a subclass of DeviceContext");
1714   return static_cast<T*>(op_device_context());
1715 }
1716 
1717 inline const Tensor& OpInputList::operator[](int i) const {
1718   DCHECK_GE(i, 0);
1719   DCHECK_LT(i, stop_ - start_);
1720   return ctx_->input(start_ + i);
1721 }
1722 
ref_mutex(int i)1723 inline mutex* OpMutableInputList::ref_mutex(int i) {
1724   DCHECK_GE(i, 0);
1725   DCHECK_LT(i, stop_ - start_);
1726   return ctx_->input_ref_mutex(start_ + i);
1727 }
1728 
at(int i,bool lock_held)1729 inline Tensor OpMutableInputList::at(int i, bool lock_held) {
1730   DCHECK_GE(i, 0);
1731   DCHECK_LT(i, stop_ - start_);
1732   return ctx_->mutable_input(start_ + i, lock_held);
1733 }
1734 
1735 inline Tensor* OpOutputList::operator[](int i) {
1736   DCHECK_GE(i, 0);
1737   DCHECK_LT(i, stop_ - start_);
1738   return ctx_->mutable_output(start_ + i);
1739 }
1740 
required(int i)1741 inline bool OpOutputList::required(int i) const {
1742   DCHECK_GE(i, 0);
1743   DCHECK_LT(i, stop_ - start_);
1744   return ctx_->output_required(start_ + i);
1745 }
1746 
expected_output_dtype(int i)1747 inline DataType OpOutputList::expected_output_dtype(int i) const {
1748   DCHECK_GE(i, 0);
1749   DCHECK_LT(i, stop_ - start_);
1750   return ctx_->expected_output_dtype(start_ + i);
1751 }
1752 
allocate(int i,const TensorShape & shape,Tensor ** output)1753 inline Status OpOutputList::allocate(int i, const TensorShape& shape,
1754                                      Tensor** output) {
1755   DCHECK_GE(i, 0);
1756   DCHECK_LT(i, stop_ - start_);
1757   return ctx_->allocate_output(start_ + i, shape, output);
1758 }
1759 
set(int i,const Tensor & tensor)1760 inline void OpOutputList::set(int i, const Tensor& tensor) {
1761   DCHECK_GE(i, 0);
1762   DCHECK_LT(i, stop_ - start_);
1763   ctx_->set_output(start_ + i, tensor);
1764 }
1765 
set_ref(int i,mutex * mu,Tensor * tensor_for_ref)1766 inline void OpOutputList::set_ref(int i, mutex* mu, Tensor* tensor_for_ref) {
1767   DCHECK_GE(i, 0);
1768   DCHECK_LT(i, stop_ - start_);
1769   ctx_->set_output_ref(i, mu, tensor_for_ref);
1770 }
1771 
1772 // Convenience macros for asserting and handling exceptional conditions.
1773 // Analogous to the CHECK* macros provided by logging.h.
1774 //
1775 // Example use:
1776 // void Compute(OperationContext* context) {
1777 //   OP_REQUIRES(context, context->num_inputs() == 2,
1778 //               errors::InvalidArgument("FooOp requires 2 arguments"));
1779 //   ...
1780 //   Status status = SomeUncertainMethod();
1781 //   OP_REQUIRES_OK(context, status);
1782 //   ...
1783 // }
1784 
1785 // Generate a fatal error if OP_REQUIRES or OP_REQUIRES_OK are used in
1786 // AsyncOpKernel implementations. If these macros are used and the condition
1787 // does not hold, the `done` callback will never be called and the system will
1788 // deadlock, so a crash failure is preferable. Since the OP_REQUIRES[_OK] macros
1789 // are legal to use in AsyncOpKernel constructors, we use overload resolution
1790 // to distinguish between OpKernelConstruction* and OpKernelContext* context
1791 // types.
1792 class XlaOpKernelContext;
CheckNotInComputeAsync(XlaOpKernelContext *,const char *)1793 inline void CheckNotInComputeAsync(XlaOpKernelContext*, const char*) {}
CheckNotInComputeAsync(OpKernelConstruction *,const char *)1794 inline void CheckNotInComputeAsync(OpKernelConstruction*, const char*) {}
1795 void CheckNotInComputeAsync(OpKernelContext* ctx,
1796                             const char* correct_macro_name);
1797 
1798 #define OP_REQUIRES(CTX, EXP, STATUS)                     \
1799   do {                                                    \
1800     if (!TF_PREDICT_TRUE(EXP)) {                          \
1801       CheckNotInComputeAsync((CTX), "OP_REQUIRES_ASYNC"); \
1802       (CTX)->CtxFailure(__FILE__, __LINE__, (STATUS));    \
1803       return;                                             \
1804     }                                                     \
1805   } while (0)
1806 
1807 #define OP_REQUIRES_OK(CTX, ...)                             \
1808   do {                                                       \
1809     ::tensorflow::Status _s(__VA_ARGS__);                    \
1810     if (!TF_PREDICT_TRUE(_s.ok())) {                         \
1811       CheckNotInComputeAsync((CTX), "OP_REQUIRES_OK_ASYNC"); \
1812       (CTX)->CtxFailureWithWarning(__FILE__, __LINE__, _s);  \
1813       return;                                                \
1814     }                                                        \
1815   } while (0)
1816 
1817 #define OP_REQUIRES_ASYNC(CTX, EXP, STATUS, CALLBACK)  \
1818   do {                                                 \
1819     if (!TF_PREDICT_TRUE(EXP)) {                       \
1820       (CTX)->CtxFailure(__FILE__, __LINE__, (STATUS)); \
1821       (CALLBACK)();                                    \
1822       return;                                          \
1823     }                                                  \
1824   } while (0)
1825 
1826 #define OP_REQUIRES_OK_ASYNC(CTX, STATUS, CALLBACK)         \
1827   do {                                                      \
1828     ::tensorflow::Status _s(STATUS);                        \
1829     if (!TF_PREDICT_TRUE(_s.ok())) {                        \
1830       (CTX)->CtxFailureWithWarning(__FILE__, __LINE__, _s); \
1831       (CALLBACK)();                                         \
1832       return;                                               \
1833     }                                                       \
1834   } while (0)
1835 
1836 }  // namespace tensorflow
1837 
1838 #endif  // TENSORFLOW_CORE_FRAMEWORK_OP_KERNEL_H_
1839