1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_CORE_FRAMEWORK_OP_KERNEL_H_
17 #define TENSORFLOW_CORE_FRAMEWORK_OP_KERNEL_H_
18
19 #include <atomic>
20 #include <functional>
21 #include <unordered_set>
22 #include <utility>
23 #include <vector>
24
25 #include "tensorflow/core/framework/allocator.h"
26 #include "tensorflow/core/framework/cancellation.h"
27 #include "tensorflow/core/framework/control_flow.h"
28 #include "tensorflow/core/framework/device_base.h"
29 #include "tensorflow/core/framework/graph.pb.h"
30 #include "tensorflow/core/framework/kernel_def.pb.h"
31 #include "tensorflow/core/framework/kernel_def_builder.h"
32 #include "tensorflow/core/framework/node_def.pb.h"
33 #include "tensorflow/core/framework/node_def_util.h"
34 #include "tensorflow/core/framework/op.h" // TODO(b/62899350): Remove
35 #include "tensorflow/core/framework/rendezvous.h"
36 #include "tensorflow/core/framework/selective_registration.h"
37 #include "tensorflow/core/framework/session_state.h"
38 #include "tensorflow/core/framework/tensor.h"
39 #include "tensorflow/core/framework/tensor_shape.h"
40 #include "tensorflow/core/framework/tensor_shape.pb.h" // TODO(b/62899350): Remove
41 #include "tensorflow/core/framework/tracking_allocator.h"
42 #include "tensorflow/core/framework/types.h"
43 #include "tensorflow/core/framework/types.pb.h"
44 #include "tensorflow/core/framework/unique_tensor_references.h"
45 #include "tensorflow/core/lib/core/errors.h"
46 #include "tensorflow/core/lib/core/status.h"
47 #include "tensorflow/core/lib/gtl/array_slice.h"
48 #include "tensorflow/core/lib/gtl/manual_constructor.h"
49 #include "tensorflow/core/platform/env.h"
50 #include "tensorflow/core/platform/logging.h"
51 #include "tensorflow/core/platform/macros.h"
52 #include "tensorflow/core/platform/mutex.h"
53 #include "tensorflow/core/platform/profile_utils/cpu_utils.h"
54 #include "tensorflow/core/platform/thread_annotations.h"
55 #include "tensorflow/core/platform/types.h"
56 #include "tensorflow/core/protobuf/config.pb.h"
57
58 namespace Eigen {
59 struct ThreadPoolDevice;
60 struct GpuDevice;
61 struct SyclDevice;
62 } // end namespace Eigen
63
64 namespace tensorflow {
65
66 namespace checkpoint {
67 class TensorSliceReaderCacheWrapper;
68 } // namespace checkpoint
69
70 class AsyncOpKernel;
71 class CallFrameInterface;
72 class DeviceMgr;
73 class FunctionLibraryRuntime;
74 class OpKernelConstruction; // declared below
75 class OpKernelContext; // declared below,
76 class OpRegistryInterface;
77 class ResourceMgr;
78 class ScopedStepContainer;
79 class CollectiveExecutor;
80 class StepStatsCollectorInterface;
81
82 class OpKernel {
83 public:
84 // OpKernel won't be instantiated by the scheduler, so you may perform
85 // expensive initialization in the descendant's constructor.
86 explicit OpKernel(OpKernelConstruction* context);
87
88 // Specialized constructor that enables the descendant to provide a different
89 // `NodeDef` value. For example, this constructor can be used to provide a
90 // stripped-down `NodeDef` that does not contain the full set of attrs (such
91 // as tensor values) if the descendant stores them in a different form.
92 explicit OpKernel(OpKernelConstruction* context,
93 std::unique_ptr<const NodeDef> node_def,
94 bool is_deferred = false);
95
96 // Specialized constructor that allows a kernel implementation to mark itself
97 // as a "deferred" op. If true, the executor will provide access to the
98 // `OpKernelContext::inc_num_deferred_ops_function()` and
99 // `OpKernelContext::dec_num_deferred_ops_function()` methods at run-time.
100 explicit OpKernel(OpKernelConstruction* context, bool is_deferred);
101
102 virtual ~OpKernel();
103
104 // An OpKernel's computation can be either synchronous or
105 // asynchronous. All OpKernel Compute() methods must be thread-safe as they
106 // may be called concurrently (e.g. by multiple executions of the same graph
107 // concurrently).
108 //
109 // Most OpKernels should compute synchronously. They should
110 // subclass OpKernel and override the Compute() method and have it
111 // return after completing the supplied work.
112 //
113 // A synchronous OpKernel *MUST NOT* block the calling thread on a
114 // synchronization mechanism (condition variable, Notification, etc.) that
115 // will be unblocked by the execution of another OpKernel. Execution may
116 // deadlock in that case, because the executor may use a bounded number of
117 // threads.
118 //
119 // If an OpKernel must block on the execution of another OpKernel (e.g. a
120 // RecvOp, or a DequeueOp), the implementation *MUST* subclass AsyncOpKernel,
121 // and override `AsyncOpKernel::ComputeAsync()`. In addition, because the
122 // unblocking kernel may never run (due to an error or cancellation), in most
123 // cases the AsyncOpKernel should implement cancellation support via
124 // `ctx->cancellation_manager()`.
125 //
126 // In both cases, implementations of Compute() and ComputeAsync()
127 // get inputs and write outputs through the given OpKernelContext
128 // and returns a status via context->SetStatus(). They must be
129 // thread-safe.
130
131 // Synchronous compute.
132 //
133 // "context" is guaranteed to be alive until Compute() returns.
134 virtual void Compute(OpKernelContext* context) = 0;
135
136 // Returns nullptr iff this op kernel is synchronous.
AsAsync()137 virtual AsyncOpKernel* AsAsync() { return nullptr; }
AsAsync()138 virtual const AsyncOpKernel* AsAsync() const { return nullptr; }
139
140 // Initial time (in CPU cycles) we expect an operation to take. Used to
141 // determine whether an operation should be place in a threadpool. Operations
142 // start out "expensive".
143 static const uint64 kInitialCostEstimateCycles = 100 * 1000 * 1000;
144 static const uint64 kOpIsExpensiveThresholdCycles = 5000;
145 static const uint64 kCostDecay = 10;
146
147 // Returns true iff this op kernel is considered "expensive". The
148 // runtime may use this flag to optimize graph execution for example
149 // to "inline" inexpensive kernels.
IsExpensive()150 virtual bool IsExpensive() {
151 return expensive_ && (cost_estimate_.load(std::memory_order_relaxed) >
152 kOpIsExpensiveThresholdCycles);
153 }
154
155 // Returns a pointer to the tensor stored inside constant ops.
const_tensor()156 virtual const Tensor* const_tensor() const { return nullptr; }
157
158 // Updates the dynamic cost estimate, which is used to determine whether this
159 // op is expensive. The new cost estimate is a weighted average of the old
160 // cost estimate and the latest cost.
UpdateCostEstimate(uint64 elapsed_cycles)161 void UpdateCostEstimate(uint64 elapsed_cycles) {
162 // N.B. Updates to `cost_estimate_` are atomic but unlocked. Simultaneous
163 // updates may result in one or more updates being ignored. This does not
164 // affect correctness but may slow down the update frequency.
165 cost_estimate_.store(
166 (kCostDecay - 1) * cost_estimate_.load(std::memory_order_relaxed) /
167 kCostDecay +
168 (elapsed_cycles / kCostDecay),
169 std::memory_order_relaxed);
170 }
171
172 // Accessors.
def()173 const NodeDef& def() const { return *def_; }
174 const string& name() const; // Same as def().name()
name_view()175 absl::string_view name_view() const { return name_view_; }
176 const string& type_string() const; // Same as def().op()
type_string_view()177 absl::string_view type_string_view() const { return type_string_view_; }
178 const string& requested_device() const; // Same as def().device()
179
num_inputs()180 int num_inputs() const { return input_types_.size(); }
input_type(int i)181 DataType input_type(int i) const { return input_types_[i]; }
input_types()182 const DataTypeVector& input_types() const { return input_types_; }
input_memory_types()183 const MemoryTypeVector& input_memory_types() const {
184 return input_memory_types_;
185 }
186 const string& requested_input(int i) const; // Same as def().input(i)
187
num_outputs()188 int num_outputs() const { return output_types_.size(); }
output_type(int o)189 DataType output_type(int o) const { return output_types_[o]; }
output_types()190 const DataTypeVector& output_types() const { return output_types_; }
output_memory_types()191 const MemoryTypeVector& output_memory_types() const {
192 return output_memory_types_;
193 }
194
195 Status InputRange(StringPiece input_name, int* start, int* stop) const;
196 Status OutputRange(StringPiece output_name, int* start, int* stop) const;
197
198 // Returns `true` if and only if this kernel uses deferred execution.
is_deferred()199 bool is_deferred() const { return is_deferred_; }
200
201 // Returns a trace string for current computation, op name/type and input
202 // tensor shape/dtype are encoded for profiler cost analysis. Most OpKernel
203 // should use the default implementation.
204 // Override this function to add OpKernel specific attributes that are
205 // necessary for cost analysis.
206 virtual string TraceString(OpKernelContext* ctx, bool verbose);
207
208 private:
209 const std::unique_ptr<const NodeDef> def_;
210 const DataTypeVector input_types_;
211 const MemoryTypeVector input_memory_types_;
212 const DataTypeVector output_types_;
213 const MemoryTypeVector output_memory_types_;
214 NameRangeMap input_name_map_;
215 NameRangeMap output_name_map_;
216 const absl::string_view name_view_;
217 const absl::string_view type_string_view_;
218 const int graph_def_version_;
219 const bool is_deferred_;
220 bool expensive_;
221 std::atomic_uint_fast64_t cost_estimate_;
222
223 TF_DISALLOW_COPY_AND_ASSIGN(OpKernel);
224 };
225
226 class AsyncOpKernel : public OpKernel {
227 public:
228 using OpKernel::OpKernel; // Lift OpKernel constructors.
229
230 // Asynchronous compute.
231 //
232 // Implementations of ComputeAsync() must ensure that `done` is (eventually)
233 // called exactly once to signal the completion of the computation. The
234 // implementation of ComputeAsync() must not block on the execution of another
235 // OpKernel. `done` may be called by the current thread, or by another thread.
236 // `context` is guaranteed to stay alive until the `done` callback starts.
237 //
238 // Since it is possible that the unblocking kernel may never run (due to an
239 // error or cancellation), in most cases the AsyncOpKernel should implement
240 // cancellation support via `context->cancellation_manager()`.
241 //
242 // WARNING: As soon as the `done` callback starts, `context` and `this` may be
243 // deleted. No code depending on these objects should execute after the call
244 // to `done`.
245 typedef std::function<void()> DoneCallback;
246 virtual void ComputeAsync(OpKernelContext* context, DoneCallback done) = 0;
247
AsAsync()248 AsyncOpKernel* AsAsync() override { return this; }
AsAsync()249 const AsyncOpKernel* AsAsync() const override { return this; }
250
251 void Compute(OpKernelContext* context) override;
252 };
253
254 // Wraps a tensor that is held by an Op across calls to Compute(). For memory
255 // safety when using asynchronous devices like GPUs, the system must be notified
256 // when a Tensor is used inside an Op execution. The wrapper ensures that all
257 // uses of the Tensor are tracked, because in order to retrieve the Tensor the
258 // caller must use AccessTensor which notifies the context.
259 class PersistentTensor {
260 public:
PersistentTensor()261 PersistentTensor() {}
PersistentTensor(const Tensor & tensor)262 explicit PersistentTensor(const Tensor& tensor) : tensor_(tensor) {}
263
264 // Caller does not own the returned Tensor*.
265 Tensor* AccessTensor(OpKernelConstruction* context);
266 // Caller does not own the returned Tensor*.
267 Tensor* AccessTensor(OpKernelContext* context);
268
269 // The check for initialization does not need to access the
270 // underlying tensor buffer.
IsInitialized()271 bool IsInitialized() const { return tensor_.IsInitialized(); }
272
NumElements()273 int64 NumElements() const { return tensor_.NumElements(); }
274
AllocatedBytes()275 int64 AllocatedBytes() const { return tensor_.AllocatedBytes(); }
276
277 private:
278 Tensor tensor_;
279 };
280
281 class OpKernelConstruction {
282 public:
283 OpKernelConstruction(DeviceType device_type, DeviceBase* device,
284 Allocator* allocator, const NodeDef* node_def,
285 const OpDef* op_def, FunctionLibraryRuntime* flib,
286 ResourceMgr* resource_mgr,
287 const DataTypeSlice& input_types,
288 const MemoryTypeSlice& input_memory_types,
289 const DataTypeSlice& output_types,
290 const MemoryTypeSlice& output_memory_types,
291 int graph_def_version, Status* status);
292
env()293 Env* env() const { return device_->env(); }
294
295 // Allocation of tensors during kernel construction:
296 //
297 // It is legal to temporarily allocate scratch tensor storage during
298 // Op kernel construction. Scratch tensors should be allocated using
299 // allocate_temp below. Some kernels need to keep tensors in between
300 // invocations. If such a Tensor is allocated during kernel
301 // construction this must be done using allocate_persistent, and the
302 // Op may only store the returned PersistentTensor object. When the
303 // Tensor is needed in a subsequent invocation, it can be retrieved
304 // from the PersistentTensor using the AccessTensor method. This
305 // ensures that the system is made aware of any use of the tensor's
306 // allocated memory, which is needed for correctness on asynchronous
307 // devices such as GPUs.
308
309 // Allocates a temporary Tensor of the specified type and shape. The
310 // Tensor must not be used after kernel construction is
311 // complete. See comment above.
312 Status allocate_temp(DataType type, const TensorShape& shape,
313 Tensor* out_temp);
314 Status allocate_temp(DataType type, const TensorShape& shape,
315 Tensor* out_temp, AllocatorAttributes allocator_attr);
316
317 // Allocates a Tensor of the specified type and shape which the Op
318 // plans to maintain as persistent state. out_persistent holds the
319 // PersistentTensor which is the object the caller should store. For
320 // convenience, if out_tensor is non-null then it will be filled in
321 // with a Tensor* pointing to the newly-allocated tensor which the
322 // caller can use instead of calling
323 // out_persistent->AccessTensor. The caller does not own out_tensor
324 // and should not keep a copy of it. See comment above.
325 Status allocate_persistent(DataType type, const TensorShape& shape,
326 PersistentTensor* out_persistent,
327 Tensor** out_tensor);
328
329 // User-supplied configuration of this operation.
def()330 const NodeDef& def() const { return *def_; }
331
332 // For inspecting the inputs to this operation.
num_inputs()333 int num_inputs() const { return input_types_.size(); }
input_type(int i)334 DataType input_type(int i) const { return input_types_[i]; }
input_types()335 const DataTypeSlice& input_types() const { return input_types_; }
input_memory_types()336 const MemoryTypeSlice& input_memory_types() const {
337 return input_memory_types_;
338 }
339
340 // For inspecting the outputs expected from this operation.
num_outputs()341 int num_outputs() const { return output_types_.size(); }
output_type(int i)342 DataType output_type(int i) const { return output_types_[i]; }
output_types()343 const DataTypeSlice& output_types() const { return output_types_; }
output_memory_types()344 const MemoryTypeSlice& output_memory_types() const {
345 return output_memory_types_;
346 }
347
348 // If expected_inputs == inputs() and expected_outputs == output_types(),
349 // returns OK, else returns INVALID_ARGUMENT with an error message.
350 // Recommended for Ops with dynamic signatures.
351 Status MatchSignature(const DataTypeSlice expected_inputs,
352 const DataTypeSlice expected_outputs);
353
354 // For recording configuration errors during construction.
355 void SetStatus(const Status& status);
status()356 const Status& status() const { return *status_; }
357
358 // Look up the attr with name attr_name and set *value to its value. If no
359 // attr with attr_name is found in def(), or the attr does not have
360 // a matching type, a non-ok status will be returned.
361 template <class T>
362 Status GetAttr(StringPiece attr_name, T* value) const;
363
364 // Return true if the attr_name is defined in def().
365 bool HasAttr(StringPiece attr_name) const;
366
367 // Return the device type.
device_type()368 const DeviceType& device_type() const { return device_type_; }
369
370 // If not nullptr, the kernel can instantiate functions defined in
371 // the library. E.g.,
372 // CHECK_NOTNULL(function_library())->Instantiate("Foo", ...).
function_library()373 FunctionLibraryRuntime* function_library() const { return flib_; }
374
375 // Shared resources accessible to this kernel.
resource_manager()376 ResourceMgr* resource_manager() const { return resource_mgr_; }
377
378 // The GraphDef version whose behavior we should follow.
graph_def_version()379 int graph_def_version() const { return graph_def_version_; }
380
381 // Helper routines for the OP_REQUIRES macros
382 void CtxFailure(const Status& s);
383 void CtxFailureWithWarning(const Status& s);
384 void CtxFailure(const char* file, int line, const Status& s);
385 void CtxFailureWithWarning(const char* file, int line, const Status& s);
386
387 // Unrecommended functions: these are functions that have some
388 // current uses but are not recommended for use, and may go away at
389 // some future major version release.
390
391 // May be used, e.g., to get GPU handles, etc.
392 //
393 // Currently only used to call MakeTensorFromProto() for
394 // implementing ConstantOp for every device. See comments
395 // on Device::MakeTensorFromProto for longer-term replacement
396 // ideas.
device()397 DeviceBase* device() const { return device_; }
398
399 private:
400 const DeviceType device_type_;
401 DeviceBase* const device_;
402 Allocator* allocator_;
403 const NodeDef* def_;
404 const OpDef* op_def_;
405 FunctionLibraryRuntime* flib_;
406 ResourceMgr* const resource_mgr_;
407 DataTypeSlice input_types_;
408 MemoryTypeSlice input_memory_types_;
409 DataTypeSlice output_types_;
410 MemoryTypeSlice output_memory_types_;
411 const int graph_def_version_;
412 Status* status_;
413
414 // Allow op_def_ across from OpKernel, but not from subclasses.
415 // TODO(irving): Remove protos from this header entirely.
416 friend class OpKernel;
417
418 TF_DISALLOW_COPY_AND_ASSIGN(OpKernelConstruction);
419 };
420
421 // TODO(mrry): Consider converting to a random_access_iterator, and upgrading
422 // tensorflow::gtl::iterator_range to make the below container classes
423 // unnecessary.
424 template <typename ListType, typename ElementType>
425 class OpArgIterator {
426 public:
427 using iterator_category = std::forward_iterator_tag;
428 using value_type = ElementType;
429 using pointer = ElementType*;
430 using const_pointer = const ElementType*;
431 using reference = ElementType&;
432 using const_reference = const ElementType&;
433 using difference_type = ptrdiff_t;
434
OpArgIterator(const ListType * list,int i)435 OpArgIterator(const ListType* list, int i) : list_(list), i_(i) {}
436
437 bool operator==(const OpArgIterator& rhs) {
438 DCHECK(list_ == rhs.list_);
439 return i_ == rhs.i_;
440 }
441
442 bool operator!=(const OpArgIterator& rhs) {
443 DCHECK(list_ == rhs.list_);
444 return i_ != rhs.i_;
445 }
446
447 OpArgIterator operator++() { // prefix ++it
448 ++i_;
449 return *this;
450 }
451
452 OpArgIterator operator++(int) { // postfix it++
453 OpArgIterator old_value = *this;
454 ++i_;
455 return old_value;
456 }
457
458 reference operator*() { return (*list_)[i_]; }
459 pointer operator->() { return &(*list_)[i_]; }
460
461 const_reference operator*() const { return (*list_)[i_]; }
462 const_pointer operator->() const { return &(*list_)[i_]; }
463
464 private:
465 const ListType* const list_;
466 int i_;
467 };
468
469 // Utility class for representing a list of immutable input tensors
470 // that are passed to the op as a single named argument.
471 class OpInputList {
472 public:
473 typedef OpArgIterator<OpInputList, const Tensor> Iterator;
OpInputList()474 OpInputList() : ctx_(nullptr), start_(0), stop_(0) {}
OpInputList(OpKernelContext * ctx,int start,int stop)475 OpInputList(OpKernelContext* ctx, int start, int stop)
476 : ctx_(ctx), start_(start), stop_(stop) {}
477 OpInputList& operator=(const OpInputList& other) = default;
478 const Tensor& operator[](int i) const;
size()479 int size() const { return stop_ - start_; }
begin()480 Iterator begin() const { return Iterator(this, 0); }
end()481 Iterator end() const { return Iterator(this, size()); }
482
483 private:
484 OpKernelContext* ctx_; // not owned
485 int start_;
486 int stop_;
487 };
488
489 // Utility class for representing a list of mutable ("ref") input tensors
490 // that are passed to the op as a single named argument.
491 class OpMutableInputList {
492 public:
493 typedef OpArgIterator<OpMutableInputList, Tensor*> Iterator;
OpMutableInputList(OpKernelContext * ctx,int start,int stop)494 OpMutableInputList(OpKernelContext* ctx, int start, int stop)
495 : ctx_(ctx), start_(start), stop_(stop) {}
OpMutableInputList()496 OpMutableInputList() : ctx_(nullptr), start_(0), stop_(0) {}
497 OpMutableInputList& operator=(const OpMutableInputList& other) = default;
498 Tensor at(int i, bool lock_held);
499 mutex* ref_mutex(int i);
size()500 int size() const { return stop_ - start_; }
begin()501 Iterator begin() const { return Iterator(this, 0); }
end()502 Iterator end() const { return Iterator(this, size()); }
503
504 private:
505 OpKernelContext* ctx_; // not owned
506 int start_;
507 int stop_;
508 };
509
510 // Utility class for representing a list of output tensors that are
511 // grouped as a single named output.
512 class OpOutputList {
513 public:
514 typedef OpArgIterator<OpOutputList, const Tensor*> Iterator;
OpOutputList()515 OpOutputList() : ctx_(nullptr), start_(0), stop_(0) {}
OpOutputList(OpKernelContext * ctx,int start,int stop)516 OpOutputList(OpKernelContext* ctx, int start, int stop)
517 : ctx_(ctx), start_(start), stop_(stop) {}
518 OpOutputList& operator=(const OpOutputList& other) = default;
519 Tensor* operator[](int i);
520 bool required(int i) const;
521 DataType expected_output_dtype(int i) const;
522 Status allocate(int i, const TensorShape& shape, Tensor** output);
523 void set(int i, const Tensor& tensor);
524 void set_ref(int i, mutex* mu, Tensor* tensor_for_ref);
size()525 int size() const { return stop_ - start_; }
begin()526 Iterator begin() const { return Iterator(this, 0); }
end()527 Iterator end() const { return Iterator(this, size()); }
528
529 private:
530 OpKernelContext* ctx_; // not owned
531 int start_;
532 int stop_;
533 };
534
535 // Holds a tensor or tensor reference. For tensor references, we need
536 // a mutex to prevent concurrent access to the tensor.
537 struct TensorValue {
TensorValueTensorValue538 TensorValue() : mutex_if_ref(nullptr), tensor(nullptr) {}
TensorValueTensorValue539 explicit TensorValue(Tensor* t) : mutex_if_ref(nullptr), tensor(t) {}
TensorValueTensorValue540 TensorValue(mutex* mu, Tensor* t) : mutex_if_ref(mu), tensor(t) {}
541 Tensor* operator->() const { return tensor; }
is_refTensorValue542 bool is_ref() const { return mutex_if_ref != nullptr; }
543
544 // Return the dtype of the Tensor. For references, return the underlying type.
dtypeTensorValue545 DataType dtype() const {
546 if (is_ref()) {
547 return MakeRefType(tensor->dtype());
548 } else {
549 return tensor->dtype();
550 }
551 }
552
553 // Return the dtype of the Tensor. For references, return the underlying type.
554 // This variation on the dtype() acquires the lock for references.
555 //
556 // TODO(b/133843385): Disallow dtype modifications
dtype_safeTensorValue557 DataType dtype_safe() const {
558 if (is_ref()) {
559 tf_shared_lock ml(*mutex_if_ref);
560 return MakeRefType(tensor->dtype());
561 } else {
562 return tensor->dtype();
563 }
564 }
565
566 mutex* mutex_if_ref; // nullptr if not a ref, != nullptr if a ref
567 Tensor* tensor;
568 };
569
570 // Used to store partitioned graphs from function-calling ops.
571 struct GraphCollector {
572 mutex mu;
573 std::vector<GraphDef> partitioned_graphs GUARDED_BY(mu);
574 GraphDef raw_graph GUARDED_BY(mu);
575 GraphDef optimized_graph GUARDED_BY(mu);
576
577 bool dirty GUARDED_BY(mu);
578
GraphCollectorGraphCollector579 GraphCollector() : dirty(false) {}
580
CollectRawGraphGraphCollector581 void CollectRawGraph(const GraphDef& graph) {
582 mutex_lock ml(mu);
583 raw_graph.MergeFrom(graph);
584 dirty = true;
585 }
586
CollectOptimizedGraphGraphCollector587 void CollectOptimizedGraph(const GraphDef& graph) {
588 mutex_lock ml(mu);
589 optimized_graph.MergeFrom(graph);
590 dirty = true;
591 }
592
CollectPartitionedGraphGraphCollector593 void CollectPartitionedGraph(const GraphDef& graph) {
594 mutex_lock ml(mu);
595 partitioned_graphs.push_back(graph);
596 dirty = true;
597 }
598
ClearGraphsGraphCollector599 void ClearGraphs() EXCLUSIVE_LOCKS_REQUIRED(mu) {
600 raw_graph.Clear();
601 optimized_graph.Clear();
602 partitioned_graphs.clear();
603 dirty = false;
604 }
605
HasUpdatedGraphsGraphCollector606 bool HasUpdatedGraphs() {
607 mutex_lock ml(mu);
608 return dirty;
609 }
610 };
611
612 class OpKernelContext {
613 public:
614 // The first element of a WrappedAllocator is a "base" Allocator and
615 // the second element is that Allocator wrapped by a
616 // TrackingAllocator
617 typedef std::pair<Allocator*, TrackingAllocator*> WrappedAllocator;
618
619 // TODO(zhifengc): Do some cleanup of Params.
620 // The Params struct is passed in to initialize an OpKernelContext,
621 // and must outlive the OpKernelContext.
622 struct Params {
~ParamsParams623 ~Params() { delete eigen_gpu_device; }
624
625 // The step being executed.
626 int64 step_id = 0;
627
628 // True if the op is created by eager runtime.
629 bool is_eager = false;
630
631 // The op kernel being computed.
632 OpKernel* op_kernel = nullptr;
633
634 // The device on which the kernel is running.
635 DeviceBase* device = nullptr;
636
637 // The Eigen GPU device wrapper, which may include a per-op
638 // wrapped allocator. The concrete type of this object depends on
639 // the type of this->device, so eigen_gpu_device can't be an
640 // inline member and must be heap allocated. However, we don't
641 // want to allocate a new eigen_gpu_device for every Op that is
642 // executed. Instead this member is allocated on first use using
643 // ensure_eigen_gpu_device, and then if the Params structure is
644 // re-used for subsequent Ops, the eigen_gpu_device is
645 // ReInitialized in the OpKernelContext constructor. Unlike the
646 // other pointers in Params, this one is owned by Params.
647 PerOpGpuDevice* eigen_gpu_device = nullptr;
648
ensure_eigen_gpu_deviceParams649 inline void ensure_eigen_gpu_device() {
650 DCHECK(device);
651 if (nullptr == eigen_gpu_device) {
652 // Surprisingly, MakeGpuDevice will return nullptr if the
653 // device is not a GPU device. This is ok, since those devices
654 // will never use eigen_gpu_device. It seems better to have
655 // ensure_eigen_gpu_device fall through and regenerate the
656 // nullptr every time an OpKernelContext is instantiated, than
657 // to do an unnecessary allocation of a dummy eigen GPU
658 // device for CPU device Ops.
659 eigen_gpu_device = device->MakeGpuDevice();
660 }
661 }
662
663 bool track_allocations = false;
664 bool log_memory = false;
665 bool record_tensor_accesses = false;
666
667 // Array indexed by output number for this node
668 const AllocatorAttributes* output_attr_array = nullptr;
669
670 // Shared resources accessible by this op kernel invocation.
671 ResourceMgr* resource_manager = nullptr;
672
673 // Per-step resources accessible by this op kernel invocation should be
674 // stored in this container..
675 ScopedStepContainer* step_container = nullptr;
676
677 // Mechanism used by this op kernel invocation to communicate with
678 // computations running on other devices.
679 RendezvousInterface* rendezvous = nullptr;
680 const std::function<Status(const int64, const DeviceMgr*, Rendezvous** r)>*
681 create_rendezvous;
682
683 // Mechanism for executing a collective op that needs to coordinate
684 // with parallel instances running on other devices.
685 CollectiveExecutor* collective_executor = nullptr;
686
687 // The session state for this op.
688 SessionState* session_state = nullptr;
689
690 // Unique session identifier. Can be empty.
691 string session_handle;
692
693 // Metadata about the session. Can be nullptr.
694 const SessionMetadata* session_metadata = nullptr;
695
696 // The tensor store for this op.
697 TensorStore* tensor_store = nullptr;
698
699 // Mechanism used by this op kernel invocation to register a callback
700 // for its cancellation.
701 CancellationManager* cancellation_manager = nullptr;
702
703 // Inputs to this op kernel.
704 const gtl::InlinedVector<TensorValue, 4>* inputs = nullptr;
705 bool is_input_dead = false;
706
707 const gtl::InlinedVector<AllocatorAttributes, 4>* input_alloc_attrs =
708 nullptr;
709
710 // Device context.
711 DeviceContext* op_device_context = nullptr;
712
713 // Control-flow op supports.
714 FrameAndIter frame_iter;
715
716 // Function call supports.
717 CallFrameInterface* call_frame = nullptr;
718 FunctionLibraryRuntime* function_library = nullptr;
719 std::function<void(std::function<void()>)>* runner = nullptr;
720 StepStatsCollectorInterface* stats_collector = nullptr;
721 GraphCollector* graph_collector = nullptr;
722
723 // TensorSliceReaderCache support.
724 checkpoint::TensorSliceReaderCacheWrapper* slice_reader_cache = nullptr;
725
726 // Support for forwarding reservations (used by ScopedAllocator).
727 static const int kNeverForward = -2;
728 static const int kNoReservation = -1;
729 // Values in [0,...) represent reservations for the indexed output.
730 const int* forward_from_array = nullptr;
731
732 // For tracking actively running deferred ops.
733 std::function<void()> inc_num_deferred_ops_function;
734 std::function<void()> dec_num_deferred_ops_function;
735 };
736
737 // params must outlive the OpKernelContext.
738 explicit OpKernelContext(Params* params);
739 OpKernelContext(Params* params, int num_outputs);
740 ~OpKernelContext();
741
env()742 Env* env() const { return params_->device->env(); }
743
step_id()744 int64 step_id() const { return params_->step_id; }
745
is_eager()746 bool is_eager() const { return params_->is_eager; }
747
op_kernel()748 const OpKernel& op_kernel() const { return *params_->op_kernel; }
749
750 // Input/output signature.
751
num_inputs()752 int num_inputs() const { return params_->inputs->size(); }
753 DataType input_dtype(int index) const;
754 Status input_dtype(StringPiece name, DataType* dtype) const;
755 MemoryType input_memory_type(int index) const;
756
num_outputs()757 int num_outputs() const { return outputs_.size(); }
758 DataType expected_output_dtype(int index) const;
759 MemoryType output_memory_type(int index) const;
760
761 // Input
762
763 // Returns an immutable input tensor. May only be used for non-Ref
764 // inputs. For Ref inputs use mutable_input below.
765 // REQUIRES: !IsRefType(input_dtype(index))
766 // TODO(mrry): Convert this to return Status.
767 const Tensor& input(int index);
768
769 // Returns the named immutable input tensor in "tensor", as defined
770 // in the OpDef. May only be used for non-Ref inputs. For Ref inputs
771 // use mutable_input below.
772 // REQUIRES: !IsRefType(input_dtype(index))
773 // REQUIRES: the named input must not be a list.
774 Status input(StringPiece name, const Tensor** tensor);
775
776 // Returns the named list-valued immutable input in "list", as
777 // defined in the OpDef. If the named output is not list-valued,
778 // returns a one-element list. May only be used for non-Ref
779 // inputs. For Ref inputs use mutable_input below.
780 // REQUIRES: !IsRefType(input_dtype(index))
781 Status input_list(StringPiece name, OpInputList* list);
782
783 // For mutable inputs, use the following together to make sure there
784 // is no concurrent access to mutable_input(), e.g.:
785 // {
786 // Tensor& t = context->mutable_input(index);
787 // mutex_lock lock(*context->input_ref_mutex(index));
788 // // modify the values in t
789 // }
790 // REQUIRES: IsRefType(input_dtype(index))
791 Status input_ref_mutex(StringPiece name, mutex** out_mutex);
792
793 // Returns a mutable input tensor. Must be used to access Ref
794 // inputs. REQUIRES: IsRefType(input_dtype(index)). The caller may
795 // modify the values stored in the Tensor buffer, and modifications
796 // will be visible to other Ops reading the same ref tensor. If
797 // !lock_held the input mutex will be acquired before returning the
798 // Tensor.
799 // TODO(mrry): Convert this to return Status.
800 Tensor mutable_input(int index, bool lock_held);
801
802 // Returns the named mutable input tensor in "tensor", as defined in
803 // the OpDef. Must be used to access Ref inputs. The values stored
804 // in the Tensor buffer may be modified, and modifications will be
805 // visible to other Ops reading the same ref tensor. If !lock_held
806 // the input mutex will be acquired before returning the Tensor.
807 // REQUIRES: the named input must not be a list.
808 // REQUIRES: the named input must be a ref tensor.
809 Status mutable_input(StringPiece name, Tensor* tensor, bool lock_held);
810
811 // Returns the named list-valued mutable input in "list", as defined
812 // in the OpDef. If the named input is not list-valued, returns a
813 // one-element list. Must be used to access Ref inputs. The values
814 // stored in the Tensor buffer may be modified, and modifications
815 // will be visible to other Ops reading the same ref tensor.
816 // REQUIRES: the named input must be a ref tensor.
817 Status mutable_input_list(StringPiece name, OpMutableInputList* list);
818
819 // Replace the corresponding Ref Input to use the storage buffer
820 // used by tensor. If !lock_held the input mutex will be acquired
821 // before returning the Tensor.
822 // REQUIRES: IsRefType(input_dtype(index)).
823 void replace_ref_input(int index, const Tensor& tensor, bool lock_held);
824
825 // Replace the corresponding named Ref Input to use the storage
826 // buffer used by tensor. If !lock_held the input mutex will be
827 // acquired before returning the Tensor.
828 // REQUIRES: IsRefType(input_dtype(index)).
829 Status replace_ref_input(StringPiece name, const Tensor& tensor,
830 bool lock_held);
831
832 // Deletes the Tensor object used as the Ref Input at
833 // input_index. This is not usually necessary and should be used
834 // with caution. If !lock_held the input mutex will be acquired
835 // before returning the Tensor.
836 // REQUIRES: IsRefType(input_dtype(input_index)).
837 void delete_ref_input(int input_index, bool lock_held);
838
839 // Return true if there is input at the given index. An operator has no
840 // input at index if its tensor is null. This is primarily used by the
841 // merge operator.
842 // TODO(mrry): Convert this to return Status.
843 bool has_input(int index) const;
844
845 // Returns true if all inputs are the same shape, otherwise sets the
846 // status to a non-OK value and returns false.
847 // Usage: if (!context->ValidateInputsAreSameShape(this)) return;
848 bool ValidateInputsAreSameShape(OpKernel* op);
849
850 // If non-null, kernels should populate with any partition subgraphs created.
graph_collector()851 GraphCollector* graph_collector() { return params_->graph_collector; }
852
853 // Input to output forwarding.
854
855 // Set the output Ref Tensor at output_index to be an alias of the
856 // input Ref Tensor at input_index.
857 // REQUIRES: IsRefType(input_dtype(input_index)).
858 // REQUIRES: IsRefType(output_dtype(output_index)).
859 void forward_ref_input_to_ref_output(int input_index, int output_index);
860
861 // Returns true when an alias to input[input_index], reshaped to output_shape,
862 // which is safe to use for in-place computation was written to *output.
863 // Returns false if input[input_index] has a refcount greater than one, or if
864 // its type does not match the expected output type of output[output_index],
865 // or the number of elements in input[input_index] does not equal the number
866 // of elements in output_shape.
867 bool forward_input_to_output_with_shape(int input_index, int output_index,
868 const TensorShape& output_shape,
869 Tensor** output) TF_MUST_USE_RESULT;
870 Status forward_input_to_output_with_shape(StringPiece input_name,
871 StringPiece output_name,
872 const TensorShape& output_shape,
873 Tensor** output) TF_MUST_USE_RESULT;
874
875 // Returns a pointer to a Tensor aliasing the underlying buffer backing
876 // input[input_index] iff
877 // * input[input_index] is not a ref,
878 // * the data type, shape, memory type, and allocator attributes of
879 // input[input_index] are compatible with those given in dtype, shape,
880 // memory_type, and attr,
881 // * refcount on the underlying buffer is one.
882 // * Either there is no forwarding reservation for either input_index
883 // or output_index or the specified input is reserved for the specified
884 // output. More precisely:
885 //
886 // These cases mean neither input nor output has a reservation:
887 // forward_from_array = nullptr
888 // OR (input_index is not in forward_from_array AND
889 // (output_index == kNoReservation OR
890 // forward_from_array[output_index] == kNoReservation))
891 //
892 // This case means that input_index is reserved for output_index:
893 // forward_from_array[output_index] == input_index
894 //
895 // This case means the output is reserved to always be allocated,
896 // never assigned a forwarded input:
897 // forward_from_array[output_index] == kNeverForward
898 //
899 // Otherwise returns nullptr.
900 // NOTE: For Cuda kernels that read inputs using the __ldg() intrinsic,
901 // forwarding is only safe if there are no reads via __ldg() after writes
902 // to the same address.
903 std::unique_ptr<Tensor> forward_input(
904 int input_index, int output_index, DataType output_dtype,
905 const TensorShape& output_shape, MemoryType output_memory_type,
906 const AllocatorAttributes& output_attr) TF_MUST_USE_RESULT;
907
908 // Tries to forward one of the inputs given in input_indices to
909 // output[output_index]. If none of the given inputs can be forwarded, calls
910 // allocate_output() to allocate a new output buffer.
911 Status forward_input_or_allocate_output(
912 gtl::ArraySlice<int> candidate_input_indices, int output_index,
913 const TensorShape& output_shape, Tensor** output) TF_MUST_USE_RESULT;
914 Status forward_input_or_allocate_output(
915 gtl::ArraySlice<StringPiece> candidate_input_names,
916 StringPiece output_name, const TensorShape& output_shape,
917 Tensor** output) TF_MUST_USE_RESULT;
918
919 // Tries to reuse one of the inputs given in input_indices as a temporary.
920 // If none of the given inputs can be forwarded, calls
921 // allocate_temp() to allocate a new temporary buffer.
922 Status forward_input_or_allocate_temp(
923 gtl::ArraySlice<int> candidate_input_indices, DataType type,
924 const TensorShape& shape, const AllocatorAttributes& allocator_attr,
925 Tensor* out_temp) TF_MUST_USE_RESULT;
926
forward_input_or_allocate_temp(gtl::ArraySlice<int> candidate_input_indices,DataType type,const TensorShape & shape,Tensor * out_temp)927 Status forward_input_or_allocate_temp(
928 gtl::ArraySlice<int> candidate_input_indices, DataType type,
929 const TensorShape& shape, Tensor* out_temp) TF_MUST_USE_RESULT {
930 return forward_input_or_allocate_temp(candidate_input_indices, type, shape,
931 AllocatorAttributes(), out_temp);
932 }
933
934 // Output
935
936 // Returns the named list-valued output in "list", as defined in the OpDef.
937 // If the named output is not list-valued, returns a one-element list.
938 Status output_list(StringPiece name, OpOutputList* list);
939
940 // If output_required(index) returns true, the OpKernel's Compute() method
941 // should call allocate_output(index, ...), set_output(index, ...),
942 // set_output_ref(index, ...), or set the status to a non-ok value.
943 // If it returns false, it may output, but is not required to do so.
944 // TODO(mrry): Convert this to return Status, and implement a string
945 // name version.
output_required(int index)946 bool output_required(int index) const {
947 return true; // TODO(josh11b): implement
948 }
949
950 // Allocation of tensors during kernel execution inside the Compute
951 // method:
952 //
953 // There are three methods to allocate Tensors when an Op kernel
954 // executes.
955 //
956 // 1) allocate_persistent. This is only needed for Tensors that will
957 // be stored by the Op between invocations, and it *must* be used
958 // for those Tensors. The call returns a PersistentTensor, and that
959 // is the only object the Op is allowed to hold on to between
960 // invocations. When the Tensor is needed in a subsequent
961 // invocation, it can be retrieved from the PersistentTensor using
962 // the AccessTensor method. This ensures that the system is made
963 // aware of any use of the tensor's allocated memory, which is
964 // needed for correctness on asynchronous devices such as GPUs.
965 //
966 // 2) allocate_output. This should be used to allocate any tensor
967 // that is going to be used as an output from the Op at the end of
968 // the current execution. The caller indicates which output the
969 // Tensor will be assigned to, and the call returns the
970 // newly-allocated Tensor. The Tensor can subsequently be assigned
971 // to during kernel execution, and will be used as the designated
972 // output when the kernel execution completes.
973 //
974 // 3) allocate_temp. This should be used to allocate any scratch
975 // storage that is needed while the kernel is executing, and will
976 // not be retained by the Op.
977 //
978 // In some cases a Tensor needs to be used as an output even though
979 // it was previously allocated elsewhere. The Tensor may have been
980 // passed as an input, or stored in a PersistentTensor during a
981 // previous kernel execution, or allocated earlier in the kernel
982 // execution at a time when it was not known which output it would
983 // be assigned to. In this case the kernel can use set_output or
984 // set_output_ref to indicate that the tensor should be used as the
985 // designated output. It is legal to use any previously-allocated
986 // Tensor as an argument to set_output or set_output_ref, including
987 // Tensors allocated via allocate_temp. There may be a performance
988 // penalty to using a Tensor that was not allocated using
989 // allocate_output. This is because allocate_output uses the
990 // AllocatorAttributes stored in output_attr_array for the
991 // designated output. In some cases, using the wrong attributes may
992 // cause an extra copy of the Tensor's buffer.
993
994 // Allocates output for the specified output index with shape.
995 // OpKernelContext retains ownership of the returned pointer. See
996 // comment above.
997 //
998 // If memory allocation fails, returns an error status.
999 //
1000 // REQUIRES: !IsRefType(expected_output_dtype(index))
1001 Status allocate_output(int index, const TensorShape& shape,
1002 Tensor** tensor) TF_MUST_USE_RESULT;
1003 Status allocate_output(StringPiece name, const TensorShape& shape,
1004 Tensor** tensor) TF_MUST_USE_RESULT;
1005 // The following methods use the supplied attributes instead of
1006 // those in output_attr_array. The caller is responsible for
1007 // ensuring that the attributes are "compatible" with the
1008 // output_attr_array, e.g. the tensor is allocated on the correct
1009 // device. See comment above.
1010 Status allocate_output(int index, const TensorShape& shape, Tensor** tensor,
1011 AllocatorAttributes attr) TF_MUST_USE_RESULT;
1012 Status allocate_output(StringPiece name, const TensorShape& shape,
1013 Tensor** tensor,
1014 AllocatorAttributes attr) TF_MUST_USE_RESULT;
1015
1016 // Allocates a temporary Tensor of the specified type and
1017 // shape. Devices such as GPUs that enqueue Ops for lazy execution
1018 // may retain references to the temporary tensors after the Op's
1019 // Compute method has run. See comment above.
1020 Status allocate_temp(DataType type, const TensorShape& shape,
1021 Tensor* out_temp, AllocatorAttributes allocator_attr,
1022 const AllocationAttributes& allocation_attr);
allocate_temp(DataType type,const TensorShape & shape,Tensor * out_temp,AllocatorAttributes allocator_attr)1023 Status allocate_temp(DataType type, const TensorShape& shape,
1024 Tensor* out_temp, AllocatorAttributes allocator_attr) {
1025 return allocate_temp(type, shape, out_temp, allocator_attr,
1026 AllocationAttributes());
1027 }
allocate_temp(DataType type,const TensorShape & shape,Tensor * out_temp)1028 Status allocate_temp(DataType type, const TensorShape& shape,
1029 Tensor* out_temp) {
1030 return allocate_temp(type, shape, out_temp, AllocatorAttributes());
1031 }
1032
1033 // Allocates a Tensor of the specified type and shape which the Op
1034 // plans to maintain as persistent state. out_persistent holds the
1035 // PersistentTensor which is the object the caller should store. For
1036 // convenience, if out_tensor is non-null then it will be filled in
1037 // with a Tensor* pointing to the newly-allocated tensor which the
1038 // caller can use instead of calling
1039 // out_persistent->AccessTensor. The caller does not own out_tensor
1040 // and should not keep a copy of it. See comment above.
1041 Status allocate_persistent(DataType type, const TensorShape& shape,
1042 PersistentTensor* out_persistent,
1043 Tensor** out_tensor, AllocatorAttributes attr);
allocate_persistent(DataType type,const TensorShape & shape,PersistentTensor * out_persistent,Tensor ** out_tensor)1044 Status allocate_persistent(DataType type, const TensorShape& shape,
1045 PersistentTensor* out_persistent,
1046 Tensor** out_tensor) {
1047 return allocate_persistent(type, shape, out_persistent, out_tensor,
1048 AllocatorAttributes());
1049 }
1050
1051 // Copies a tensor (allocated by the caller) to the specified output
1052 // index. REQUIRES: !IsRefType(expected_output_dtype(index))
1053 // REQUIRES: 'tensor' must have the same MemoryType as
1054 // output_memory_types[index]. See comment above.
1055 Status set_output(StringPiece name, const Tensor& tensor);
1056
1057 // To output a reference. Caller retains ownership of mu and tensor_for_ref,
1058 // and they must outlive all uses within the step. See comment above.
1059 // REQUIRES: IsRefType(expected_output_dtype(index))
1060 Status set_output_ref(StringPiece name, mutex* mu, Tensor* tensor_for_ref);
1061
1062 // Returns nullptr if allocate_output() or set_output() have not been called.
1063 Status mutable_output(StringPiece name, Tensor** tensor);
1064
1065 // Return the DeviceContext that should be used for this Op.
1066 //
1067 // If using the templated function, the type must be a subclass
1068 // of DeviceContext.
1069 //
1070 // Returns nullptr if the device did not provide one.
1071 template <typename T>
1072 T* op_device_context();
op_device_context()1073 DeviceContext* op_device_context() {
1074 DeviceContext* ret = params_->op_device_context;
1075 if (ret == nullptr) {
1076 auto* dev_info = device()->tensorflow_gpu_device_info();
1077 if (dev_info) ret = dev_info->default_context;
1078 }
1079 return ret;
1080 }
1081
input_alloc_attr(int index)1082 AllocatorAttributes input_alloc_attr(int index) const {
1083 if (params_->input_alloc_attrs == nullptr) {
1084 return AllocatorAttributes();
1085 } else {
1086 DCHECK_GE(index, 0);
1087 DCHECK_LT(index, params_->input_alloc_attrs->size());
1088 return (*params_->input_alloc_attrs)[index];
1089 }
1090 }
1091
output_alloc_attr(int index)1092 AllocatorAttributes output_alloc_attr(int index) const {
1093 return params_->output_attr_array[index];
1094 }
1095
ConsumeWrappedAllocators()1096 gtl::InlinedVector<WrappedAllocator, 4> ConsumeWrappedAllocators() {
1097 gtl::InlinedVector<WrappedAllocator, 4> retrieved;
1098 if (tracking_state_) {
1099 mutex_lock lock(tracking_state_->mu);
1100 retrieved.swap(tracking_state_->wrapped_allocators);
1101 }
1102 return retrieved;
1103 }
1104
1105 // Communication.
1106 //
1107 // An op kernel communicates with outside environment through
1108 // Rendezvous Send() and Recv().
rendezvous()1109 RendezvousInterface* rendezvous() const { return params_->rendezvous; }
create_rendezvous(const int64 step_id,const DeviceMgr * device_mgr,Rendezvous ** r)1110 Status create_rendezvous(const int64 step_id, const DeviceMgr* device_mgr,
1111 Rendezvous** r) const {
1112 return (*params_->create_rendezvous)(step_id, device_mgr, r);
1113 }
1114
collective_executor()1115 CollectiveExecutor* collective_executor() const {
1116 return params_->collective_executor;
1117 }
1118
1119 // An op kernel can access the session state it belongs to.
session_state()1120 SessionState* session_state() const { return params_->session_state; }
1121
1122 // Unique identifier of the session it belongs to. Can be empty.
session_handle()1123 string session_handle() const { return params_->session_handle; }
1124
1125 // Metadata about the session. Can be nullptr.
session_metadata()1126 const SessionMetadata* session_metadata() const {
1127 return params_->session_metadata;
1128 }
1129
1130 // An op kernel can access the tensor store of the run it belongs to.
tensor_store()1131 TensorStore* tensor_store() const { return params_->tensor_store; }
1132
1133 // Function call support.
1134 //
1135 // If this kernel invocation is within a function execution,
1136 // call_frame() returns the call frame for the function call.
call_frame()1137 CallFrameInterface* call_frame() const { return params_->call_frame; }
1138
1139 // If not nullptr, the kernel invoke functions defined in the
1140 // library. E.g., CHECK_NOTNULL(function_library())->Run("Foo", ...).
function_library()1141 FunctionLibraryRuntime* function_library() const {
1142 return params_->function_library;
1143 }
1144
runner()1145 std::function<void(std::function<void()>)>* runner() const {
1146 return params_->runner;
1147 }
stats_collector()1148 StepStatsCollectorInterface* stats_collector() const {
1149 return params_->stats_collector;
1150 }
1151
1152 // Shared resources accessible to this kernel.
resource_manager()1153 ResourceMgr* resource_manager() const { return params_->resource_manager; }
1154
slice_reader_cache()1155 checkpoint::TensorSliceReaderCacheWrapper* slice_reader_cache() const {
1156 return params_->slice_reader_cache;
1157 }
1158
1159 // Execution.
1160 //
1161 // OpKernels can use these eigen devices to carry out their
1162 // numerical computation.
eigen_cpu_device()1163 const Eigen::ThreadPoolDevice& eigen_cpu_device() const {
1164 return *device()->eigen_cpu_device();
1165 }
eigen_gpu_device()1166 const Eigen::GpuDevice& eigen_gpu_device() const {
1167 return params_->eigen_gpu_device->device();
1168 }
1169 #ifdef TENSORFLOW_USE_SYCL
eigen_sycl_device()1170 const Eigen::SyclDevice& eigen_sycl_device() const {
1171 return *device()->eigen_sycl_device();
1172 }
1173 #endif
1174 template <typename EigenDeviceType>
1175 const EigenDeviceType& eigen_device() const;
1176
1177 // Error handling.
1178
1179 // If expected_inputs == inputs() and expected_outputs == output_types(),
1180 // returns OK, else returns INVALID_ARGUMENT with an error message.
1181 // Recommended for Ops with dynamic signatures, where validation can only
1182 // be performed at runtime.
1183 Status MatchSignature(const DataTypeSlice expected_inputs,
1184 const DataTypeSlice expected_outputs);
1185
1186 // An OpKernel should call SetStatus() if Compute() encounters an
1187 // error.
1188 void SetStatus(const Status& status);
status()1189 const Status& status() const { return status_; }
1190
1191 // Cancellation.
1192 //
1193 // EXPERIMENTAL. See the implementation in tensorflow::FIFOQueue for an
1194 // example of how to use this API.
cancellation_manager()1195 CancellationManager* cancellation_manager() const {
1196 return params_->cancellation_manager;
1197 }
1198
1199 // Other accessors.
1200
1201 // For control flow.
frame_iter()1202 FrameAndIter frame_iter() const { return params_->frame_iter; }
is_input_dead()1203 bool is_input_dead() const { return params_->is_input_dead; }
1204
1205 // May be used, e.g., to get GPU handles, etc.
1206 // TODO(tucker): Add example usage.
device()1207 DeviceBase* device() const { return params_->device; }
1208
1209 // Retrieve list of referenced tensors in out_vector. Once this is
1210 // called, it is not legal to reference any more tensors. Should
1211 // not be called from Op kernels.
1212 void retrieve_accessed_tensors(TensorReferenceVector* out_vector);
1213
1214 // Per-step container for use by white-listed internal ops.
step_container()1215 ScopedStepContainer* step_container() const {
1216 return params_->step_container;
1217 }
1218
1219 // Helper routines for the OP_REQUIRES macros
1220 void CtxFailure(const Status& s);
1221 void CtxFailureWithWarning(const Status& s);
1222 void CtxFailure(const char* file, int line, const Status& s);
1223 void CtxFailureWithWarning(const char* file, int line, const Status& s);
1224
1225 // Unrecommended functions: these are functions that have some
1226 // current uses but are not recommended for use, and may go away at
1227 // some future major version release.
1228 //
1229 // The following functions all have versions that return Status
1230 // to capture error conditions, and are strongly preferred.
1231 Tensor* mutable_output(int index);
1232 void set_output(int index, const Tensor& tensor);
1233 mutex* input_ref_mutex(int index);
1234 void set_output_ref(int index, mutex* mu, Tensor* tensor_for_ref);
1235 TensorValue release_output(int index);
1236
track_allocations()1237 bool track_allocations() const { return params_->track_allocations; }
1238
1239 // Records temp memory allocation. Tensor object is recorded to identify the
1240 // case where temp memory is used as output memory.
1241 void record_temp_memory_allocation(int64 size, const Tensor& t)
1242 LOCKS_EXCLUDED(tracking_state_->stats_mu);
1243
1244 // Returns recorded size of temporary memory;
1245 int64 temp_memory_allocated() const LOCKS_EXCLUDED(tracking_state_->stats_mu);
1246
1247 // Records persistent memory allocation, size can be negative indicating
1248 // deallocation.
1249 void record_persistent_memory_allocation(int64 size, int64 alloc_id = -1)
1250 LOCKS_EXCLUDED(tracking_state_->stats_mu);
1251
1252 // Returns recorded size and ids of persistent memory.
1253 int64 persistent_memory_allocated() const
1254 LOCKS_EXCLUDED(tracking_state_->stats_mu);
1255
1256 std::vector<int64> persistent_alloc_ids() const
1257 LOCKS_EXCLUDED(tracking_state_->stats_mu);
1258
1259 // Resets counters for temp and persistent memory and recorded ids.
1260 void clear_recorded_memory() LOCKS_EXCLUDED(tracking_state_->stats_mu);
1261
1262 bool input_is_ref(int index) const;
1263
1264 void set_record_memory_consumption(bool v);
1265
1266 // Used by OpKernel implementations to track actively running deferred ops.
1267 //
1268 // A deferred op is one whose Compute method returns (or whose ComputeAsync
1269 // method invokes the callback) when work is scheduled onto a device. At that
1270 // point, we don't know when the work will actually complete (or if it has
1271 // already completed) on the device. These functions allow the executor to
1272 // track the status of deferred ops and act accordingly.
1273 //
1274 // Deferred OpKernel implementations must use these methods to get two
1275 // functions. It then must call these two functions in pairs, before and after
1276 // device execution, respectively.
inc_num_deferred_ops_function()1277 TF_MUST_USE_RESULT std::function<void()> inc_num_deferred_ops_function() {
1278 DCHECK(params_->op_kernel->is_deferred());
1279 return params_->inc_num_deferred_ops_function
1280 ? params_->inc_num_deferred_ops_function
1281 : []() {};
1282 }
dec_num_deferred_ops_function()1283 TF_MUST_USE_RESULT std::function<void()> dec_num_deferred_ops_function() {
1284 DCHECK(params_->op_kernel->is_deferred());
1285 return params_->dec_num_deferred_ops_function
1286 ? params_->dec_num_deferred_ops_function
1287 : []() {};
1288 }
1289
1290 Allocator* get_allocator(AllocatorAttributes attr);
1291
1292 private:
1293 bool record_memory_consumption_ = false;
1294
1295 // Internal method to add a tensor's buffer to the list of buffers
1296 // referenced during the execution of the Op, so that GPUs may
1297 // accurately track the memory that may not be reused until the Op
1298 // execution completes.
1299 void record_tensor_reference(const Tensor& tensor);
1300 void really_record_tensor_reference(const Tensor& tensor);
1301
1302 // Internal common method used when allocating tensor memory
allocate_tensor(DataType type,const TensorShape & shape,Tensor * out_tensor,AllocatorAttributes allocator_attr)1303 Status allocate_tensor(DataType type, const TensorShape& shape,
1304 Tensor* out_tensor,
1305 AllocatorAttributes allocator_attr) {
1306 return allocate_tensor(type, shape, out_tensor, allocator_attr,
1307 AllocationAttributes());
1308 }
1309
1310 Status allocate_tensor(DataType type, const TensorShape& shape,
1311 Tensor* out_tensor, AllocatorAttributes allocator_attr,
1312 const AllocationAttributes& allocation_attr);
1313
1314 // Initialize the allocated_scope_ids_ set the first time this method is
1315 // called.
1316 void maybe_initialize_scope_id_set();
1317
1318 // This is called by PersistentTensor::AccessTensor whenever the
1319 // wrapped tensor is retrieved, to ensure the runtime knows that the
1320 // Tensor is being accessed within an Op. This is necessary for
1321 // memory safety of devices like GPUs that queue Ops for
1322 // asynchronous execution after the Compute() method completes.
1323 friend class PersistentTensor;
1324 void NotifyUseOfPersistentTensor(const Tensor& tensor);
1325
1326 Status status_;
1327 friend class CollectiveExecutor; // for access to params_
1328 Params* params_; // not owned
1329 gtl::InlinedVector<TensorValue, 4> outputs_;
1330
1331 // Keep track of calls to ScopedAllocator.
1332 // TODO(ayushd): change to absl::flat_hash_set.
1333 std::unique_ptr<std::unordered_set<int32>> allocated_scope_ids_;
1334
1335 // The following data members are only used when allocation tracking is
1336 // enabled, memory consumption is being recorded, or tensor access is being
1337 // recorded.
1338 struct TrackingState {
1339 mutable mutex mu;
1340 gtl::InlinedVector<WrappedAllocator, 4> wrapped_allocators GUARDED_BY(mu);
1341
1342 UniqueTensorReferences referenced_tensors GUARDED_BY(mu);
1343
1344 mutable mutex stats_mu;
1345 int64 temp_memory_allocated GUARDED_BY(stats_mu) = 0;
1346
1347 int64 persistent_memory_allocated GUARDED_BY(stats_mu) = 0;
1348 gtl::InlinedVector<std::pair<const void*, int64>, 2>
1349 temp_tensor_buffer_and_size GUARDED_BY(stats_mu);
1350 gtl::InlinedVector<int64, 2> persistent_alloc_ids GUARDED_BY(stats_mu);
1351 };
1352 std::unique_ptr<TrackingState> tracking_state_;
1353
1354 TF_DISALLOW_COPY_AND_ASSIGN(OpKernelContext);
1355 };
1356
1357 template <>
1358 const Eigen::ThreadPoolDevice& OpKernelContext::eigen_device() const;
1359
1360 template <>
1361 const Eigen::GpuDevice& OpKernelContext::eigen_device() const;
1362
1363 #ifdef TENSORFLOW_USE_SYCL
1364 template <>
1365 const Eigen::SyclDevice& OpKernelContext::eigen_device() const;
1366 #endif
1367
1368 // Register your OpKernel by specifying the Op's name, the device the
1369 // kernel runs on, any type attr constraints for this kernel, any
1370 // host-memory args, and the class to instantiate. Examples:
1371 //
1372 // // A kernel that supports all types.
1373 // REGISTER_KERNEL_BUILDER(Name("Save").Device(DEVICE_CPU), SaveOp);
1374 //
1375 // // The following are equivalent ways of specifying that the kernel only
1376 // // works if the "T" type attr is set to DT_FLOAT.
1377 // REGISTER_KERNEL_BUILDER(
1378 // Name("Sub").Device(DEVICE_CPU).TypeConstraint<float>("T"),
1379 // SubOp<float>);
1380 // // (You would then repeat this for every type supported by "Sub".)
1381 //
1382 // // This form allows you to specify a list of types as the constraint.
1383 // REGISTER_KERNEL_BUILDER(Name("Sub")
1384 // .Device(DEVICE_CPU)
1385 // .TypeConstraint("T", {DT_FLOAT}),
1386 // SubOp<float>);
1387 //
1388 // // A kernel that expects one of the input tensors in host memory.
1389 // REGISTER_KERNEL_BUILDER(
1390 // Name("Reshape").Device(DEVICE_GPU).HostMemory("shape"), ReshapeOp);
1391 //
1392 // See kernel_def_builder for details.
1393
1394 // Instantiate an OpKernel that has been registered. Returns nullptr
1395 // if no operation for that type of device / input signature combination
1396 // (and a NOT_FOUND *status), or there is an error in construction (and
1397 // an INVALID_ARGUMENT *status). Otherwise, the caller takes ownership
1398 // of the returned pointer.
1399 // EXPECTED USAGE: unique_ptr<OpKernel> op = CreateOpKernel(...);
1400 // REQUIRES: def has all attrs specified (e.g. using AddDefaultsToNodeDef()).
1401 std::unique_ptr<OpKernel> CreateOpKernel(DeviceType device_type,
1402 DeviceBase* device,
1403 Allocator* allocator,
1404 const NodeDef& def,
1405 int graph_def_version, Status* status);
1406 Status CreateOpKernel(DeviceType device_type, DeviceBase* device,
1407 Allocator* allocator, FunctionLibraryRuntime* flib,
1408 const NodeDef& def, int graph_def_version,
1409 OpKernel** kernel);
1410 Status CreateOpKernel(DeviceType device_type, DeviceBase* device,
1411 Allocator* allocator, FunctionLibraryRuntime* flib,
1412 ResourceMgr* resource_mgr, const NodeDef& def,
1413 int graph_def_version, OpKernel** kernel);
1414
1415 // Returns into 'device_types' the subset of prioritized_types that this
1416 // binary has registered for the given NodeDef.
1417 //
1418 // REQUIRES: * 'device_types' is not nullptr.
1419 // * def has all attrs specified (e.g. using AddDefaultsToNodeDef()).
1420 Status SupportedDeviceTypesForNode(
1421 const std::vector<DeviceType>& prioritized_types, const NodeDef& def,
1422 PrioritizedDeviceTypeVector* device_types,
1423 const DeviceNameUtils::ParsedName* local_address_spec = nullptr);
1424
1425 // Returns a message with a description of the kernels registered for op
1426 // `op_name`.
1427 string KernelsRegisteredForOp(StringPiece op_name);
1428
1429 // Call once after Op registration has completed.
1430 Status ValidateKernelRegistrations(const OpRegistryInterface& op_registry);
1431
1432 // -----------------------------------------------------------------------------
1433 // OpKernel registration implementation follows, please ignore.
1434
1435 // Allow the REGISTER_KERNEL_BUILDER(Name("op_name").Device(...)...) syntax.
1436 namespace register_kernel {
1437
1438 class Name : public KernelDefBuilder {
1439 public:
1440 // With selective registration, kernels whose implementation class is not used
1441 // by any kernel are disabled with the SHOULD_REGISTER_OP_KERNEL call in
1442 // REGISTER_KERNEL_BUILDER_UNIQ. However, an unused kernel that shares an
1443 // implementation class with a used kernel would get through that mechanism.
1444 //
1445 // This mechanism stops that registration by changing the name of the kernel
1446 // for the unused op to one that is ignored by
1447 // OpKernelRegistrar::InitInternal. Note that this method alone is
1448 // not sufficient - the compiler can't evaluate the entire KernelDefBuilder at
1449 // compilation time, so this method doesn't actually reduce code size.
Name(const char * op)1450 explicit Name(const char* op)
1451 : KernelDefBuilder(SHOULD_REGISTER_OP(op) ? op : "_no_register") {}
1452 };
1453
1454 namespace system {
1455
1456 class Name : public KernelDefBuilder {
1457 public:
1458 // For system kernels, we ignore selective registration and
1459 // unconditionally register the kernel.
Name(const char * op)1460 explicit Name(const char* op) : KernelDefBuilder(op) {}
1461 };
1462
1463 } // namespace system
1464
1465 } // namespace register_kernel
1466
1467 #define REGISTER_KERNEL_BUILDER(kernel_builder, ...) \
1468 REGISTER_KERNEL_BUILDER_UNIQ_HELPER(__COUNTER__, kernel_builder, __VA_ARGS__)
1469
1470 #define REGISTER_KERNEL_BUILDER_UNIQ_HELPER(ctr, kernel_builder, ...) \
1471 REGISTER_KERNEL_BUILDER_UNIQ(ctr, kernel_builder, __VA_ARGS__)
1472
1473 #define REGISTER_KERNEL_BUILDER_UNIQ(ctr, kernel_builder, ...) \
1474 constexpr bool should_register_##ctr##__flag = \
1475 SHOULD_REGISTER_OP_KERNEL(#__VA_ARGS__); \
1476 static ::tensorflow::kernel_factory::OpKernelRegistrar \
1477 registrar__body__##ctr##__object( \
1478 should_register_##ctr##__flag \
1479 ? ::tensorflow::register_kernel::kernel_builder.Build() \
1480 : nullptr, \
1481 #__VA_ARGS__, \
1482 [](::tensorflow::OpKernelConstruction* context) \
1483 -> ::tensorflow::OpKernel* { \
1484 return new __VA_ARGS__(context); \
1485 });
1486
1487 // The `REGISTER_SYSTEM_KERNEL_BUILDER()` macro acts as
1488 // `REGISTER_KERNEL_BUILDER()` except that the kernel is registered
1489 // unconditionally even when selective registration is used.
1490 #define REGISTER_SYSTEM_KERNEL_BUILDER(kernel_builder, ...) \
1491 REGISTER_SYSTEM_KERNEL_BUILDER_UNIQ_HELPER(__COUNTER__, kernel_builder, \
1492 __VA_ARGS__)
1493
1494 #define REGISTER_SYSTEM_KERNEL_BUILDER_UNIQ_HELPER(ctr, kernel_builder, ...) \
1495 REGISTER_SYSTEM_KERNEL_BUILDER_UNIQ(ctr, kernel_builder, __VA_ARGS__)
1496
1497 #define REGISTER_SYSTEM_KERNEL_BUILDER_UNIQ(ctr, kernel_builder, ...) \
1498 static ::tensorflow::kernel_factory::OpKernelRegistrar \
1499 registrar__body__##ctr##__object( \
1500 ::tensorflow::register_kernel::system::kernel_builder.Build(), \
1501 #__VA_ARGS__, \
1502 [](::tensorflow::OpKernelConstruction* context) \
1503 -> ::tensorflow::OpKernel* { \
1504 return new __VA_ARGS__(context); \
1505 });
1506
1507 // Checks whether a given kernel is registered on device_type.
1508 bool KernelDefAvailable(const DeviceType& device_type, const NodeDef& node_def);
1509
1510 // If node of node_name, experimental_debug_info, node_op, node_device and
1511 // node_attrs has a corresponding kernel registered on device_type, returns OK
1512 // and fill in the kernel def and kernel_class_name. <def> and
1513 // <kernel_class_name> may be null.
1514 Status FindKernelDef(
1515 const DeviceType& device_type, StringPiece node_name,
1516 bool has_experimental_debug_info,
1517 const NodeDef_ExperimentalDebugInfo& experimental_debug_info,
1518 StringPiece node_op, StringPiece node_device, AttrSlice node_attrs,
1519 const KernelDef** def, string* kernel_class_name);
1520
1521 // If node_def has a corresponding kernel registered on device_type,
1522 // returns OK and fill in the kernel def and kernel_class_name. <def> and
1523 // <kernel_class_name> may be null.
1524 Status FindKernelDef(const DeviceType& device_type, const NodeDef& node_def,
1525 const KernelDef** def, string* kernel_class_name);
1526
1527 // Writes a list of all registered kernels to LOG(INFO), to help users debug
1528 // missing kernel errors.
1529 void LogAllRegisteredKernels();
1530
1531 // Gets a list of all registered kernels.
1532 KernelList GetAllRegisteredKernels();
1533
1534 // Gets a list of all registered kernels for which predicate returns true
1535 KernelList GetFilteredRegisteredKernels(
1536 const std::function<bool(const KernelDef&)>& predicate);
1537
1538 // Gets a list of all registered kernels for a given op
1539 KernelList GetRegisteredKernelsForOp(StringPiece op_name);
1540
1541 namespace kernel_factory {
1542
1543 // OpKernelFactory is responsible for creating OpKernels when TensorFlow needs
1544 // them. You register factories with the TensorFlow core by constructing an
1545 // OpKernelRegistrar and passing the factory as a constructor parameter.
1546 class OpKernelFactory {
1547 public:
1548 virtual OpKernel* Create(OpKernelConstruction* context) = 0;
1549 virtual ~OpKernelFactory() = default;
1550 };
1551
1552 class OpKernelRegistrar {
1553 public:
1554 // Registers the given kernel factory with TensorFlow. TF will call the
1555 // factory Create() method when it determines that a kernel matching the given
1556 // KernelDef is required.
OpKernelRegistrar(const KernelDef * kernel_def,StringPiece kernel_class_name,std::unique_ptr<OpKernelFactory> factory)1557 OpKernelRegistrar(const KernelDef* kernel_def, StringPiece kernel_class_name,
1558 std::unique_ptr<OpKernelFactory> factory) {
1559 // Perform the check in the header to allow compile-time optimization
1560 // to a no-op, allowing the linker to remove the kernel symbols.
1561 if (kernel_def != nullptr) {
1562 InitInternal(kernel_def, kernel_class_name, std::move(factory));
1563 }
1564 }
1565
1566 // Registers the given factory function with TensorFlow. This is equivalent
1567 // to registering a factory whose Create function invokes `create_fn`.
OpKernelRegistrar(const KernelDef * kernel_def,StringPiece kernel_class_name,OpKernel * (* create_fn)(OpKernelConstruction *))1568 OpKernelRegistrar(const KernelDef* kernel_def, StringPiece kernel_class_name,
1569 OpKernel* (*create_fn)(OpKernelConstruction*)) {
1570 // Perform the check in the header to allow compile-time optimization
1571 // to a no-op, allowing the linker to remove the kernel symbols.
1572 if (kernel_def != nullptr) {
1573 InitInternal(kernel_def, kernel_class_name,
1574 absl::make_unique<PtrOpKernelFactory>(create_fn));
1575 }
1576 }
1577
1578 private:
1579 struct PtrOpKernelFactory : public OpKernelFactory {
PtrOpKernelFactoryPtrOpKernelFactory1580 explicit PtrOpKernelFactory(OpKernel* (*create_func)(OpKernelConstruction*))
1581 : create_func_(create_func) {}
1582
1583 OpKernel* Create(OpKernelConstruction* context) override;
1584
1585 OpKernel* (*create_func_)(OpKernelConstruction*);
1586 };
1587
1588 void InitInternal(const KernelDef* kernel_def, StringPiece kernel_class_name,
1589 std::unique_ptr<OpKernelFactory> factory);
1590 };
1591
1592 } // namespace kernel_factory
1593
1594 // -----------------------------------------------------------------------------
1595 // Template and inline method implementations, please ignore
1596
1597 template <class T>
GetAttr(StringPiece attr_name,T * value)1598 Status OpKernelConstruction::GetAttr(StringPiece attr_name, T* value) const {
1599 return GetNodeAttr(def(), attr_name, value);
1600 }
1601
input_dtype(int index)1602 inline DataType OpKernelContext::input_dtype(int index) const {
1603 DCHECK_GE(index, 0);
1604 DCHECK_LT(index, num_inputs());
1605 const TensorValue& value((*params_->inputs)[index]);
1606 return value.dtype();
1607 }
1608
input_memory_type(int index)1609 inline MemoryType OpKernelContext::input_memory_type(int index) const {
1610 DCHECK_GE(index, 0);
1611 DCHECK_LT(index, num_inputs());
1612 return op_kernel().input_memory_types()[index];
1613 }
1614
expected_output_dtype(int index)1615 inline DataType OpKernelContext::expected_output_dtype(int index) const {
1616 DCHECK_GE(index, 0);
1617 DCHECK_LT(index, num_outputs());
1618 return params_->op_kernel->output_type(index);
1619 }
1620
output_memory_type(int index)1621 inline MemoryType OpKernelContext::output_memory_type(int index) const {
1622 DCHECK_GE(index, 0);
1623 DCHECK_LT(index, num_outputs());
1624 return op_kernel().output_memory_types()[index];
1625 }
1626
input_is_ref(int index)1627 inline bool OpKernelContext::input_is_ref(int index) const {
1628 const TensorValue& value((*params_->inputs)[index]);
1629 return value.is_ref();
1630 }
1631
record_tensor_reference(const Tensor & tensor)1632 inline void OpKernelContext::record_tensor_reference(const Tensor& tensor) {
1633 DCHECK_EQ(params_->device->RequiresRecordingAccessedTensors(),
1634 params_->record_tensor_accesses);
1635 if (params_->record_tensor_accesses) {
1636 really_record_tensor_reference(tensor);
1637 }
1638 }
1639
retrieve_accessed_tensors(TensorReferenceVector * out_vector)1640 inline void OpKernelContext::retrieve_accessed_tensors(
1641 TensorReferenceVector* out_vector) {
1642 if (params_->record_tensor_accesses) {
1643 DCHECK(tracking_state_);
1644 mutex_lock l(tracking_state_->mu);
1645 tracking_state_->referenced_tensors.FreezeAndReturnReferences(out_vector);
1646 }
1647 }
1648
1649 // no input if tensor == nullptr.
has_input(int index)1650 inline bool OpKernelContext::has_input(int index) const {
1651 DCHECK_GE(index, 0);
1652 DCHECK_LT(index, num_inputs());
1653 return (*params_->inputs)[index].tensor != nullptr;
1654 }
1655
input_ref_mutex(int index)1656 inline mutex* OpKernelContext::input_ref_mutex(int index) {
1657 DCHECK_GE(index, 0);
1658 DCHECK_LT(index, num_inputs());
1659 DCHECK(input_is_ref(index));
1660 return (*params_->inputs)[index].mutex_if_ref;
1661 }
1662
NotifyUseOfPersistentTensor(const Tensor & t)1663 inline void OpKernelContext::NotifyUseOfPersistentTensor(const Tensor& t) {
1664 if (t.IsInitialized()) {
1665 record_tensor_reference(t);
1666 }
1667 }
1668
mutable_output(int index)1669 inline Tensor* OpKernelContext::mutable_output(int index) {
1670 DCHECK_GE(index, 0);
1671 DCHECK_LT(index, num_outputs());
1672 // No need to record_tensor_reference since the output must already
1673 // have been set by a call that did so.
1674 return outputs_[index].tensor;
1675 }
1676
release_output(int index)1677 inline TensorValue OpKernelContext::release_output(int index) {
1678 DCHECK_GE(index, 0);
1679 DCHECK_LT(index, num_outputs());
1680 TensorValue value = outputs_[index];
1681 outputs_[index] = TensorValue();
1682 return value;
1683 }
1684
forward_input_or_allocate_output(gtl::ArraySlice<int> candidate_input_indices,int output_index,const TensorShape & output_shape,Tensor ** output)1685 inline Status OpKernelContext::forward_input_or_allocate_output(
1686 gtl::ArraySlice<int> candidate_input_indices, int output_index,
1687 const TensorShape& output_shape, Tensor** output) {
1688 for (int input_index : candidate_input_indices) {
1689 if (forward_input_to_output_with_shape(input_index, output_index,
1690 output_shape, output)) {
1691 return Status::OK();
1692 }
1693 }
1694 return allocate_output(output_index, output_shape, output);
1695 }
1696
forward_input_or_allocate_output(gtl::ArraySlice<StringPiece> candidate_input_names,StringPiece output_name,const TensorShape & output_shape,Tensor ** output)1697 inline Status OpKernelContext::forward_input_or_allocate_output(
1698 gtl::ArraySlice<StringPiece> candidate_input_names, StringPiece output_name,
1699 const TensorShape& output_shape, Tensor** output) {
1700 for (const StringPiece& input_name : candidate_input_names) {
1701 if (forward_input_to_output_with_shape(input_name, output_name,
1702 output_shape, output)
1703 .ok()) {
1704 return Status::OK();
1705 }
1706 }
1707 return allocate_output(output_name, output_shape, output);
1708 }
1709
1710 template <typename T>
op_device_context()1711 T* OpKernelContext::op_device_context() {
1712 static_assert(std::is_base_of<DeviceContext, T>::value,
1713 "T is not a subclass of DeviceContext");
1714 return static_cast<T*>(op_device_context());
1715 }
1716
1717 inline const Tensor& OpInputList::operator[](int i) const {
1718 DCHECK_GE(i, 0);
1719 DCHECK_LT(i, stop_ - start_);
1720 return ctx_->input(start_ + i);
1721 }
1722
ref_mutex(int i)1723 inline mutex* OpMutableInputList::ref_mutex(int i) {
1724 DCHECK_GE(i, 0);
1725 DCHECK_LT(i, stop_ - start_);
1726 return ctx_->input_ref_mutex(start_ + i);
1727 }
1728
at(int i,bool lock_held)1729 inline Tensor OpMutableInputList::at(int i, bool lock_held) {
1730 DCHECK_GE(i, 0);
1731 DCHECK_LT(i, stop_ - start_);
1732 return ctx_->mutable_input(start_ + i, lock_held);
1733 }
1734
1735 inline Tensor* OpOutputList::operator[](int i) {
1736 DCHECK_GE(i, 0);
1737 DCHECK_LT(i, stop_ - start_);
1738 return ctx_->mutable_output(start_ + i);
1739 }
1740
required(int i)1741 inline bool OpOutputList::required(int i) const {
1742 DCHECK_GE(i, 0);
1743 DCHECK_LT(i, stop_ - start_);
1744 return ctx_->output_required(start_ + i);
1745 }
1746
expected_output_dtype(int i)1747 inline DataType OpOutputList::expected_output_dtype(int i) const {
1748 DCHECK_GE(i, 0);
1749 DCHECK_LT(i, stop_ - start_);
1750 return ctx_->expected_output_dtype(start_ + i);
1751 }
1752
allocate(int i,const TensorShape & shape,Tensor ** output)1753 inline Status OpOutputList::allocate(int i, const TensorShape& shape,
1754 Tensor** output) {
1755 DCHECK_GE(i, 0);
1756 DCHECK_LT(i, stop_ - start_);
1757 return ctx_->allocate_output(start_ + i, shape, output);
1758 }
1759
set(int i,const Tensor & tensor)1760 inline void OpOutputList::set(int i, const Tensor& tensor) {
1761 DCHECK_GE(i, 0);
1762 DCHECK_LT(i, stop_ - start_);
1763 ctx_->set_output(start_ + i, tensor);
1764 }
1765
set_ref(int i,mutex * mu,Tensor * tensor_for_ref)1766 inline void OpOutputList::set_ref(int i, mutex* mu, Tensor* tensor_for_ref) {
1767 DCHECK_GE(i, 0);
1768 DCHECK_LT(i, stop_ - start_);
1769 ctx_->set_output_ref(i, mu, tensor_for_ref);
1770 }
1771
1772 // Convenience macros for asserting and handling exceptional conditions.
1773 // Analogous to the CHECK* macros provided by logging.h.
1774 //
1775 // Example use:
1776 // void Compute(OperationContext* context) {
1777 // OP_REQUIRES(context, context->num_inputs() == 2,
1778 // errors::InvalidArgument("FooOp requires 2 arguments"));
1779 // ...
1780 // Status status = SomeUncertainMethod();
1781 // OP_REQUIRES_OK(context, status);
1782 // ...
1783 // }
1784
1785 // Generate a fatal error if OP_REQUIRES or OP_REQUIRES_OK are used in
1786 // AsyncOpKernel implementations. If these macros are used and the condition
1787 // does not hold, the `done` callback will never be called and the system will
1788 // deadlock, so a crash failure is preferable. Since the OP_REQUIRES[_OK] macros
1789 // are legal to use in AsyncOpKernel constructors, we use overload resolution
1790 // to distinguish between OpKernelConstruction* and OpKernelContext* context
1791 // types.
1792 class XlaOpKernelContext;
CheckNotInComputeAsync(XlaOpKernelContext *,const char *)1793 inline void CheckNotInComputeAsync(XlaOpKernelContext*, const char*) {}
CheckNotInComputeAsync(OpKernelConstruction *,const char *)1794 inline void CheckNotInComputeAsync(OpKernelConstruction*, const char*) {}
1795 void CheckNotInComputeAsync(OpKernelContext* ctx,
1796 const char* correct_macro_name);
1797
1798 #define OP_REQUIRES(CTX, EXP, STATUS) \
1799 do { \
1800 if (!TF_PREDICT_TRUE(EXP)) { \
1801 CheckNotInComputeAsync((CTX), "OP_REQUIRES_ASYNC"); \
1802 (CTX)->CtxFailure(__FILE__, __LINE__, (STATUS)); \
1803 return; \
1804 } \
1805 } while (0)
1806
1807 #define OP_REQUIRES_OK(CTX, ...) \
1808 do { \
1809 ::tensorflow::Status _s(__VA_ARGS__); \
1810 if (!TF_PREDICT_TRUE(_s.ok())) { \
1811 CheckNotInComputeAsync((CTX), "OP_REQUIRES_OK_ASYNC"); \
1812 (CTX)->CtxFailureWithWarning(__FILE__, __LINE__, _s); \
1813 return; \
1814 } \
1815 } while (0)
1816
1817 #define OP_REQUIRES_ASYNC(CTX, EXP, STATUS, CALLBACK) \
1818 do { \
1819 if (!TF_PREDICT_TRUE(EXP)) { \
1820 (CTX)->CtxFailure(__FILE__, __LINE__, (STATUS)); \
1821 (CALLBACK)(); \
1822 return; \
1823 } \
1824 } while (0)
1825
1826 #define OP_REQUIRES_OK_ASYNC(CTX, STATUS, CALLBACK) \
1827 do { \
1828 ::tensorflow::Status _s(STATUS); \
1829 if (!TF_PREDICT_TRUE(_s.ok())) { \
1830 (CTX)->CtxFailureWithWarning(__FILE__, __LINE__, _s); \
1831 (CALLBACK)(); \
1832 return; \
1833 } \
1834 } while (0)
1835
1836 } // namespace tensorflow
1837
1838 #endif // TENSORFLOW_CORE_FRAMEWORK_OP_KERNEL_H_
1839