1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_CORE_FRAMEWORK_OP_KERNEL_H_
17 #define TENSORFLOW_CORE_FRAMEWORK_OP_KERNEL_H_
18
19 #include <atomic>
20 #include <functional>
21
22 #include <utility>
23 #include <vector>
24 #include "tensorflow/core/framework/allocator.h"
25 #include "tensorflow/core/framework/cancellation.h"
26 #include "tensorflow/core/framework/control_flow.h"
27 #include "tensorflow/core/framework/device_base.h"
28 #include "tensorflow/core/framework/graph.pb.h"
29 #include "tensorflow/core/framework/kernel_def.pb.h"
30 #include "tensorflow/core/framework/kernel_def_builder.h"
31 #include "tensorflow/core/framework/node_def_util.h"
32 #include "tensorflow/core/framework/op.h" // TODO(b/62899350): Remove
33 #include "tensorflow/core/framework/rendezvous.h"
34 #include "tensorflow/core/framework/selective_registration.h"
35 #include "tensorflow/core/framework/session_state.h"
36 #include "tensorflow/core/framework/tensor.h"
37 #include "tensorflow/core/framework/tensor_shape.h"
38 #include "tensorflow/core/framework/tensor_shape.pb.h" // TODO(b/62899350): Remove
39 #include "tensorflow/core/framework/tracking_allocator.h"
40 #include "tensorflow/core/framework/types.h"
41 #include "tensorflow/core/framework/types.pb.h"
42 #include "tensorflow/core/framework/unique_tensor_references.h"
43 #include "tensorflow/core/lib/core/errors.h"
44 #include "tensorflow/core/lib/core/status.h"
45 #include "tensorflow/core/lib/gtl/array_slice.h"
46 #include "tensorflow/core/lib/gtl/manual_constructor.h"
47 #include "tensorflow/core/platform/env.h"
48 #include "tensorflow/core/platform/logging.h"
49 #include "tensorflow/core/platform/macros.h"
50 #include "tensorflow/core/platform/mutex.h"
51 #include "tensorflow/core/platform/profile_utils/cpu_utils.h"
52 #include "tensorflow/core/platform/thread_annotations.h"
53 #include "tensorflow/core/platform/types.h"
54
55 namespace Eigen {
56 struct ThreadPoolDevice;
57 struct GpuDevice;
58 struct SyclDevice;
59 } // end namespace Eigen
60
61 namespace tensorflow {
62
63 namespace checkpoint {
64 class TensorSliceReaderCacheWrapper;
65 } // namespace checkpoint
66
67 class AsyncOpKernel;
68 class CallFrameInterface;
69 class FunctionLibraryRuntime;
70 class OpKernelConstruction; // declared below
71 class OpKernelContext; // declared below,
72 class OpRegistryInterface;
73 class ResourceMgr;
74 class ScopedStepContainer;
75 class CollectiveExecutor;
76 class StepStatsCollectorInterface;
77
78 class OpKernel {
79 public:
80 // OpKernel won't be instantiated by the scheduler, so you may perform
81 // expensive initialization in the descendant's constructor.
82 explicit OpKernel(OpKernelConstruction* context);
83
84 // Specialized constructor that enables the descendant to provide a different
85 // `NodeDef` value. For example, this constructor can be used to provide a
86 // stripped-down `NodeDef` that does not contain the full set of attrs (such
87 // as tensor values) if the descendant stores them in a different form.
88 explicit OpKernel(OpKernelConstruction* context,
89 std::unique_ptr<const NodeDef> node_def);
90
91 virtual ~OpKernel();
92
93 // An OpKernel's computation can be either synchronous or
94 // asynchronous. All OpKernel Compute() methods must be thread-safe as they
95 // may be called concurrently (e.g. by multiple executions of the same graph
96 // concurrently).
97 //
98 // Most OpKernels should compute synchronously. They should
99 // subclass OpKernel and override the Compute() method and have it
100 // return after completing the supplied work.
101 //
102 // A few special kernels might need to be asynchronous to bound the
103 // number of threads (e.g., network receive operations). These
104 // kernels must subclass AsyncOpKernel and override
105 // AsyncOpKernel::ComputeAsync().
106 //
107 // In both cases, implementations of Compute() and ComputeAsync()
108 // get inputs and write outputs through the given OpKernelContext
109 // and returns a status via context->SetStatus(). They must be
110 // thread-safe.
111
112 // Synchronous compute.
113 //
114 // "context" is guaranteed to be alive until Compute() returns.
115 virtual void Compute(OpKernelContext* context) = 0;
116
117 // Returns nullptr iff this op kernel is synchronous.
AsAsync()118 virtual AsyncOpKernel* AsAsync() { return nullptr; }
AsAsync()119 virtual const AsyncOpKernel* AsAsync() const { return nullptr; }
120
121 // Initial time (in CPU cycles) we expect an operation to take. Used to
122 // determine whether an operation should be place in a threadpool. Operations
123 // start out "expensive".
124 static const uint64 kInitialCostEstimateCycles = 100 * 1000 * 1000;
125 static const uint64 kOpIsExpensiveThresholdCycles = 5000;
126 static const uint64 kCostDecay = 10;
127
128 // Returns true iff this op kernel is considered "expensive". The
129 // runtime may use this flag to optimize graph execution for example
130 // to "inline" inexpensive kernels.
IsExpensive()131 virtual bool IsExpensive() {
132 return expensive_ && (cost_estimate_.load(std::memory_order_relaxed) >
133 kOpIsExpensiveThresholdCycles);
134 }
135
136 // Updates the dynamic cost estimate, which is used to determine whether this
137 // op is expensive. The new cost estimate is a weighted average of the old
138 // cost estimate and the latest cost.
UpdateCostEstimate(uint64 elapsed_cycles)139 void UpdateCostEstimate(uint64 elapsed_cycles) {
140 // N.B. Updates to `cost_estimate_` are atomic but unlocked. Simulataneous
141 // updates may result in one or more updates being ignored. This does not
142 // affect correctness but may slow down the update frequency.
143 cost_estimate_.store(
144 (kCostDecay - 1) * cost_estimate_.load(std::memory_order_relaxed) /
145 kCostDecay +
146 (elapsed_cycles / kCostDecay),
147 std::memory_order_relaxed);
148 }
149
150 // Accessors.
def()151 const NodeDef& def() const { return *def_; }
152 const string& name() const; // Same as def().name()
153 const string& type_string() const; // Same as def().op()
154 const string& requested_device() const; // Same as def().device()
is_internal()155 bool is_internal() const { return is_internal_; }
156
num_inputs()157 int num_inputs() const { return input_types_.size(); }
input_type(int i)158 DataType input_type(int i) const { return input_types_[i]; }
input_types()159 const DataTypeVector& input_types() const { return input_types_; }
input_memory_types()160 const MemoryTypeVector& input_memory_types() const {
161 return input_memory_types_;
162 }
163 const string& requested_input(int i) const; // Same as def().input(i)
164
num_outputs()165 int num_outputs() const { return output_types_.size(); }
output_type(int o)166 DataType output_type(int o) const { return output_types_[o]; }
output_types()167 const DataTypeVector& output_types() const { return output_types_; }
output_memory_types()168 const MemoryTypeVector& output_memory_types() const {
169 return output_memory_types_;
170 }
171
172 Status InputRange(StringPiece input_name, int* start, int* stop) const;
173 Status OutputRange(StringPiece output_name, int* start, int* stop) const;
174
175 // We allow legacy scalars within Google up until GraphDef version 6.
176 // TODO(irving): Remove when we can drop support for GraphDef version 5.
allow_legacy_scalars()177 bool allow_legacy_scalars() const {
178 #if defined(PLATFORM_GOOGLE) || defined(PLATFORM_GOOGLE_ANDROID)
179 return graph_def_version_ < 6;
180 #else
181 return false;
182 #endif
183 }
184
185 // Allow either scalars or (if allowing legacy scalars) shape (1,).
IsLegacyScalar(const TensorShape & shape)186 bool IsLegacyScalar(const TensorShape& shape) const {
187 return shape.dims() == 0 || (allow_legacy_scalars() && shape.dims() == 1 &&
188 shape.dim_size(0) == 1);
189 }
190
191 // Allow rank 1 or (if allowing legacy scalars) rank 0.
IsLegacyVector(const TensorShape & shape)192 bool IsLegacyVector(const TensorShape& shape) const {
193 return shape.dims() == 1 || (allow_legacy_scalars() && shape.dims() == 0);
194 }
195
196 // Turn a shape Tensor into a TensorShape
197 // TODO(irving): Move to TensorShapeUtils once !allow_legacy_scalars
198 Status MakeShape(const Tensor& shape, TensorShape* out) const;
199
200 static int DeviceNumaNode(const DeviceBase* device);
201
202 private:
203 const std::unique_ptr<const NodeDef> def_;
204 const DataTypeVector input_types_;
205 const MemoryTypeVector input_memory_types_;
206 const DataTypeVector output_types_;
207 const MemoryTypeVector output_memory_types_;
208 const int graph_def_version_;
209 const bool is_internal_; // True if this is an internal operation
210 NameRangeMap input_name_map_;
211 NameRangeMap output_name_map_;
212 bool expensive_;
213 std::atomic_uint_fast64_t cost_estimate_;
214
215 TF_DISALLOW_COPY_AND_ASSIGN(OpKernel);
216 };
217
218 class AsyncOpKernel : public OpKernel {
219 public:
220 using OpKernel::OpKernel; // Lift OpKernel constructors.
221
222 // Asynchronous compute.
223 //
224 // Implementations of ComputeAsync() must run "done" to signal the
225 // completion of the computation. "context" is guaranteed to be
226 // alive until the "done" callback starts.
227 typedef std::function<void()> DoneCallback;
228 virtual void ComputeAsync(OpKernelContext* context, DoneCallback done) = 0;
229
AsAsync()230 AsyncOpKernel* AsAsync() final { return this; }
AsAsync()231 const AsyncOpKernel* AsAsync() const final { return this; }
232
233 void Compute(OpKernelContext* context) final;
234 };
235
236 // Wraps a tensor that is held by an Op across calls to Compute(). For
237 // memory safety when using asynchronous devices like GPUs, the system
238 // must be notified when a Tensor is used inside an Op execution. The
239 // wrapper ensures that all uses of the Tensor are tracked, because in
240 // order to retrieve the Tensor the caller must use AccessTensor which
241 // notifies the context.
242 class PersistentTensor {
243 public:
PersistentTensor()244 PersistentTensor() {}
PersistentTensor(const Tensor & tensor)245 explicit PersistentTensor(const Tensor& tensor) : tensor_(tensor) {}
246
247 // Caller does not own the returned Tensor*.
248 Tensor* AccessTensor(OpKernelConstruction* context);
249 // Caller does not own the returned Tensor*.
250 Tensor* AccessTensor(OpKernelContext* context);
251
252 // The check for initialization does not need to access the
253 // underlying tensor buffer.
IsInitialized()254 bool IsInitialized() const { return tensor_.IsInitialized(); }
255
NumElements()256 int64 NumElements() const { return tensor_.NumElements(); }
257
AllocatedBytes()258 int64 AllocatedBytes() const { return tensor_.AllocatedBytes(); }
259
260 private:
261 Tensor tensor_;
262 };
263
264 class OpKernelConstruction {
265 public:
266 OpKernelConstruction(DeviceType device_type, DeviceBase* device,
267 Allocator* allocator, const NodeDef* node_def,
268 const OpDef* op_def, FunctionLibraryRuntime* flib,
269 const DataTypeSlice& input_types,
270 const MemoryTypeSlice& input_memory_types,
271 const DataTypeSlice& output_types,
272 const MemoryTypeSlice& output_memory_types,
273 int graph_def_version, Status* status);
274
env()275 Env* env() const { return device_->env(); }
276
277 // Allocation of tensors during kernel construction:
278 //
279 // It is legal to temporarily allocate scratch tensor storage during
280 // Op kernel construction. Scratch tensors should be allocated using
281 // allocate_temp below. Some kernels need to keep tensors in between
282 // invocations. If such a Tensor is allocated during kernel
283 // construction this must be done using allocate_persistent, and the
284 // Op may only store the returned PersistentTensor object. When the
285 // Tensor is needed in a subsequent invocation, it can be retrieved
286 // from the PersistentTensor using the AccessTensor method. This
287 // ensures that the system is made aware of any use of the tensor's
288 // allocated memory, which is needed for correctness on asynchronous
289 // devices such as GPUs.
290
291 // Allocates a temporary Tensor of the specified type and shape. The
292 // Tensor must not be used after kernel construction is
293 // complete. See comment above.
294 Status allocate_temp(DataType type, const TensorShape& shape,
295 Tensor* out_temp);
296
297 // Allocates a Tensor of the specified type and shape which the Op
298 // plans to maintain as persistent state. out_persistent holds the
299 // PersistentTensor which is the object the caller should store. For
300 // convenience, if out_tensor is non-null then it will be filled in
301 // with a Tensor* pointing to the newly-allocated tensor which the
302 // caller can use instead of calling
303 // out_persistent->AccessTensor. The caller does not own out_tensor
304 // and should not keep a copy of it. See comment above.
305 Status allocate_persistent(DataType type, const TensorShape& shape,
306 PersistentTensor* out_persistent,
307 Tensor** out_tensor);
308
309 // User-supplied configuration of this operation.
def()310 const NodeDef& def() const { return *def_; }
311
312 // For inspecting the inputs to this operation.
num_inputs()313 int num_inputs() const { return input_types_.size(); }
input_type(int i)314 DataType input_type(int i) const { return input_types_[i]; }
input_types()315 const DataTypeSlice& input_types() const { return input_types_; }
input_memory_types()316 const MemoryTypeSlice& input_memory_types() const {
317 return input_memory_types_;
318 }
319
320 // For inspecting the outputs expected from this operation.
num_outputs()321 int num_outputs() const { return output_types_.size(); }
output_type(int i)322 DataType output_type(int i) const { return output_types_[i]; }
output_types()323 const DataTypeSlice& output_types() const { return output_types_; }
output_memory_types()324 const MemoryTypeSlice& output_memory_types() const {
325 return output_memory_types_;
326 }
327
328 // If expected_inputs == inputs() and expected_outputs == output_types(),
329 // returns OK, else returns INVALID_ARGUMENT with an error message.
330 // Recommended for Ops with dynamic signatures.
331 Status MatchSignature(const DataTypeSlice expected_inputs,
332 const DataTypeSlice expected_outputs);
333
334 // For recording configuration errors during construction.
335 void SetStatus(const Status& status);
status()336 const Status& status() const { return *status_; }
337
338 // Look up the attr with name attr_name and set *value to its value. If no
339 // attr with attr_name is found in def(), or the attr does not have
340 // a matching type, a non-ok status will be returned.
341 template <class T>
342 Status GetAttr(StringPiece attr_name, T* value) const;
343
344 // Return true if the attr_name is defined in def().
345 bool HasAttr(StringPiece attr_name) const;
346
347 // Return the device type.
device_type()348 const DeviceType& device_type() const { return device_type_; }
349
350 // If not nullptr, the kernel can instantiate functions defined in
351 // the library. E.g.,
352 // CHECK_NOTNULL(function_library())->Instantiate("Foo", ...).
function_library()353 FunctionLibraryRuntime* function_library() const { return flib_; }
354
355 // The GraphDef version whose behavior we should follow.
graph_def_version()356 int graph_def_version() const { return graph_def_version_; }
357
358 // Helper routines for the OP_REQUIRES macros
359 void CtxFailure(const Status& s);
360 void CtxFailureWithWarning(const Status& s);
361 void CtxFailure(const char* file, int line, const Status& s);
362 void CtxFailureWithWarning(const char* file, int line, const Status& s);
363
364 // Unrecommended functions: these are functions that have some
365 // current uses but are not recommended for use, and may go away at
366 // some future major version release.
367
368 // May be used, e.g., to get GPU handles, etc.
369 //
370 // Currently only used to call MakeTensorFromProto() for
371 // implementing ConstantOp for every device. See comments
372 // on Device::MakeTensorFromProto for longer-term replacement
373 // ideas.
device()374 DeviceBase* device() const { return device_; }
375
376 private:
377 const DeviceType device_type_;
378 DeviceBase* const device_;
379 Allocator* allocator_;
380 const NodeDef* def_;
381 const OpDef* op_def_;
382 FunctionLibraryRuntime* flib_;
383 DataTypeSlice input_types_;
384 MemoryTypeSlice input_memory_types_;
385 DataTypeSlice output_types_;
386 MemoryTypeSlice output_memory_types_;
387 const int graph_def_version_;
388 Status* status_;
389
390 // Allow op_def_ across from OpKernel, but not from subclasses.
391 // TODO(irving): Remove protos from this header entirely.
392 friend class OpKernel;
393
394 TF_DISALLOW_COPY_AND_ASSIGN(OpKernelConstruction);
395 };
396
397 // TODO(mrry): Consider converting to a random_access_iterator, and upgrading
398 // tensorflow::gtl::iterator_range to make the below container classes
399 // unnecessary.
400 template <typename ListType, typename ElementType>
401 class OpArgIterator {
402 public:
403 using iterator_category = std::forward_iterator_tag;
404 using value_type = ElementType;
405 using pointer = ElementType*;
406 using const_pointer = const ElementType*;
407 using reference = ElementType&;
408 using const_reference = const ElementType&;
409 using difference_type = ptrdiff_t;
410
OpArgIterator(const ListType * list,int i)411 OpArgIterator(const ListType* list, int i) : list_(list), i_(i) {}
412
413 bool operator==(const OpArgIterator& rhs) {
414 DCHECK(list_ == rhs.list_);
415 return i_ == rhs.i_;
416 }
417
418 bool operator!=(const OpArgIterator& rhs) {
419 DCHECK(list_ == rhs.list_);
420 return i_ != rhs.i_;
421 }
422
423 OpArgIterator operator++() { // prefix ++it
424 ++i_;
425 return *this;
426 }
427
428 OpArgIterator operator++(int) { // postfix it++
429 OpArgIterator old_value = *this;
430 ++i_;
431 return old_value;
432 }
433
434 reference operator*() { return (*list_)[i_]; }
435 pointer operator->() { return &(*list_)[i_]; }
436
437 const_reference operator*() const { return (*list_)[i_]; }
438 const_pointer operator->() const { return &(*list_)[i_]; }
439
440 private:
441 const ListType* const list_;
442 int i_;
443 };
444
445 // Utility class for representing a list of immutable input tensors
446 // that are passed to the op as a single named argument.
447 class OpInputList {
448 public:
449 typedef OpArgIterator<OpInputList, const Tensor> Iterator;
OpInputList()450 OpInputList() : ctx_(nullptr), start_(0), stop_(0) {}
OpInputList(OpKernelContext * ctx,int start,int stop)451 OpInputList(OpKernelContext* ctx, int start, int stop)
452 : ctx_(ctx), start_(start), stop_(stop) {}
453 OpInputList& operator=(const OpInputList& other) = default;
454 const Tensor& operator[](int i) const;
size()455 int size() const { return stop_ - start_; }
begin()456 Iterator begin() const { return Iterator(this, 0); }
end()457 Iterator end() const { return Iterator(this, size()); }
458
459 private:
460 OpKernelContext* ctx_; // not owned
461 int start_;
462 int stop_;
463 };
464
465 // Utility class for representing a list of mutable ("ref") input tensors
466 // that are passed to the op as a single named argument.
467 class OpMutableInputList {
468 public:
469 typedef OpArgIterator<OpMutableInputList, Tensor*> Iterator;
OpMutableInputList(OpKernelContext * ctx,int start,int stop)470 OpMutableInputList(OpKernelContext* ctx, int start, int stop)
471 : ctx_(ctx), start_(start), stop_(stop) {}
OpMutableInputList()472 OpMutableInputList() : ctx_(nullptr), start_(0), stop_(0) {}
473 OpMutableInputList& operator=(const OpMutableInputList& other) = default;
474 Tensor at(int i, bool lock_held);
475 mutex* ref_mutex(int i);
size()476 int size() const { return stop_ - start_; }
begin()477 Iterator begin() const { return Iterator(this, 0); }
end()478 Iterator end() const { return Iterator(this, size()); }
479
480 private:
481 OpKernelContext* ctx_; // not owned
482 int start_;
483 int stop_;
484 };
485
486 // Utility class for representing a list of output tensors that are
487 // grouped as a single named output.
488 class OpOutputList {
489 public:
490 typedef OpArgIterator<OpOutputList, const Tensor*> Iterator;
OpOutputList()491 OpOutputList() : ctx_(nullptr), start_(0), stop_(0) {}
OpOutputList(OpKernelContext * ctx,int start,int stop)492 OpOutputList(OpKernelContext* ctx, int start, int stop)
493 : ctx_(ctx), start_(start), stop_(stop) {}
494 OpOutputList& operator=(const OpOutputList& other) = default;
495 Tensor* operator[](int i);
496 bool required(int i) const;
497 DataType expected_output_dtype(int i) const;
498 Status allocate(int i, const TensorShape& shape, Tensor** output);
499 void set(int i, const Tensor& tensor);
500 void set_ref(int i, mutex* mu, Tensor* tensor_for_ref);
size()501 int size() const { return stop_ - start_; }
begin()502 Iterator begin() const { return Iterator(this, 0); }
end()503 Iterator end() const { return Iterator(this, size()); }
504
505 private:
506 OpKernelContext* ctx_; // not owned
507 int start_;
508 int stop_;
509 };
510
511 // Holds a tensor or tensor reference. For tensor references, we need
512 // a mutex to prevent concurrent access to the tensor.
513 struct TensorValue {
TensorValueTensorValue514 TensorValue() : mutex_if_ref(nullptr), tensor(nullptr) {}
TensorValueTensorValue515 TensorValue(Tensor* t) // NOLINT(runtime/explicit)
516 : mutex_if_ref(nullptr), tensor(t) {}
TensorValueTensorValue517 TensorValue(mutex* mu, Tensor* t) : mutex_if_ref(mu), tensor(t) {}
518 Tensor* operator->() const { return tensor; }
is_refTensorValue519 bool is_ref() const { return mutex_if_ref != nullptr; }
520
521 mutex* mutex_if_ref; // nullptr if not a ref, != nullptr if a ref
522 Tensor* tensor;
523 };
524
525 // Used to store partitioned graphs from function-calling ops.
526 struct GraphCollector {
527 mutex mu;
528 std::vector<GraphDef> partitioned_graphs GUARDED_BY(mu);
529 GraphDef raw_graph GUARDED_BY(mu);
530 GraphDef optimized_graph GUARDED_BY(mu);
531
532 bool dirty GUARDED_BY(mu);
533
GraphCollectorGraphCollector534 GraphCollector() : dirty(false) {}
535
CollectRawGraphGraphCollector536 void CollectRawGraph(const GraphDef& graph) {
537 mutex_lock ml(mu);
538 raw_graph.MergeFrom(graph);
539 dirty = true;
540 }
541
CollectOptimizedGraphGraphCollector542 void CollectOptimizedGraph(const GraphDef& graph) {
543 mutex_lock ml(mu);
544 optimized_graph.MergeFrom(graph);
545 dirty = true;
546 }
547
CollectPartitionedGraphGraphCollector548 void CollectPartitionedGraph(const GraphDef& graph) {
549 mutex_lock ml(mu);
550 partitioned_graphs.push_back(graph);
551 dirty = true;
552 }
553
ClearGraphsGraphCollector554 void ClearGraphs() EXCLUSIVE_LOCKS_REQUIRED(mu) {
555 raw_graph.Clear();
556 optimized_graph.Clear();
557 partitioned_graphs.clear();
558 dirty = false;
559 }
560
HasUpdatedGraphsGraphCollector561 bool HasUpdatedGraphs() {
562 mutex_lock ml(mu);
563 return dirty;
564 }
565 };
566
567 class OpKernelContext {
568 public:
569 // The first element of a WrappedAllocator is a "base" Allocator and
570 // the second element is that Allocator wrapped by a
571 // TrackingAllocator
572 typedef std::pair<Allocator*, TrackingAllocator*> WrappedAllocator;
573
574 // TODO(zhifengc): Do some cleanup of Params.
575 // The Params struct is passed in to initialize an OpKernelContext,
576 // and must outlive the OpKernelContext.
577 struct Params {
~ParamsParams578 ~Params() { delete eigen_gpu_device; }
579
580 // The step being executed.
581 int64 step_id = 0;
582
583 // The op kernel being computed.
584 OpKernel* op_kernel = nullptr;
585
586 // The device on which the kernel is running.
587 DeviceBase* device = nullptr;
588
589 // The Eigen GPU device wrapper, which may include a per-op
590 // wrapped allocator. The concrete type of this object depends on
591 // the type of this->device, so eigen_gpu_device can't be an
592 // inline member and must be heap allocated. However, we don't
593 // want to allocate a new eigen_gpu_device for every Op that is
594 // executed. Instead this member is allocated on first use using
595 // ensure_eigen_gpu_device, and then if the Params structure is
596 // re-used for subsequent Ops, the eigen_gpu_device is
597 // ReInitialized in the OpKernelContext constructor. Unlike the
598 // other pointers in Params, this one is owned by Params.
599 PerOpGpuDevice* eigen_gpu_device = nullptr;
600
ensure_eigen_gpu_deviceParams601 inline void ensure_eigen_gpu_device() {
602 DCHECK(device);
603 if (nullptr == eigen_gpu_device) {
604 // Surprisingly, MakeGpuDevice will return nullptr if the
605 // device is not a GPU device. This is ok, since those devices
606 // will never use eigen_gpu_device. It seems better to have
607 // ensure_eigen_gpu_device fall through and regenerate the
608 // nullptr every time an OpKernelContext is instantiated, than
609 // to do an unnecessary allocation of a dummy eigen GPU
610 // device for CPU device Ops.
611 eigen_gpu_device = device->MakeGpuDevice();
612 }
613 }
614
615 bool track_allocations = false;
616 bool log_memory = false;
617 bool record_tensor_accesses = false;
618
619 // Array indexed by output number for this node
620 const AllocatorAttributes* output_attr_array = nullptr;
621
622 // Shared resources accessible by this op kernel invocation.
623 ResourceMgr* resource_manager = nullptr;
624
625 // Per-step resources accessible by this op kernel invocation should be
626 // stored in this container..
627 ScopedStepContainer* step_container = nullptr;
628
629 // Mechanism used by this op kernel invocation to communicate with
630 // computations running on other devices.
631 Rendezvous* rendezvous = nullptr;
632
633 // Mechanism for executing a collective op that needs to coordinate
634 // with parallel instances running on other devices.
635 CollectiveExecutor* collective_executor = nullptr;
636
637 // The session state for this op.
638 SessionState* session_state = nullptr;
639
640 // Unique session identifier. Can be empty.
641 string session_handle;
642
643 // The tensor store for this op.
644 TensorStore* tensor_store = nullptr;
645
646 // Mechanism used by this op kernel invocation to register a callback
647 // for its cancellation.
648 CancellationManager* cancellation_manager = nullptr;
649
650 // Inputs to this op kernel.
651 const gtl::InlinedVector<TensorValue, 4>* inputs = nullptr;
652 bool is_input_dead = false;
653
654 const gtl::InlinedVector<AllocatorAttributes, 4>* input_alloc_attrs =
655 nullptr;
656
657 // Device contexts.
658 const gtl::InlinedVector<DeviceContext*, 4>* input_device_contexts =
659 nullptr;
660 DeviceContext* op_device_context = nullptr;
661
662 // Control-flow op supports.
663 FrameAndIter frame_iter;
664
665 // Function call supports.
666 CallFrameInterface* call_frame = nullptr;
667 FunctionLibraryRuntime* function_library = nullptr;
668 std::function<void(std::function<void()>)>* runner = nullptr;
669 StepStatsCollectorInterface* stats_collector = nullptr;
670 GraphCollector* graph_collector = nullptr;
671
672 // TensorSliceReaderCache support.
673 checkpoint::TensorSliceReaderCacheWrapper* slice_reader_cache = nullptr;
674
675 // Support for forwarding reservations (used by ScopedAllocator).
676 static const int kNeverForward = -2;
677 static const int kNoReservation = -1;
678 // Values in [0,...) represent reservations for the indexed output.
679 const int* forward_from_array = nullptr;
680
681 // For tracking actively running deferred ops.
682 std::function<void()> inc_num_deferred_ops_function = []() {};
683 std::function<void()> dec_num_deferred_ops_function = []() {};
684 };
685
686 // params must outlive the OpKernelContext.
687 explicit OpKernelContext(Params* params);
688 OpKernelContext(Params* params, int noutputs);
689 ~OpKernelContext();
690
env()691 Env* env() const { return params_->device->env(); }
692
step_id()693 int64 step_id() const { return params_->step_id; }
694
op_kernel()695 const OpKernel& op_kernel() const { return *params_->op_kernel; }
696
697 // Input/output signature.
698
num_inputs()699 int num_inputs() const { return params_->inputs->size(); }
700 DataType input_dtype(int index) const;
701 Status input_dtype(StringPiece name, DataType* dtype) const;
702 MemoryType input_memory_type(int index) const;
703
num_outputs()704 int num_outputs() const { return outputs_.size(); }
705 DataType expected_output_dtype(int index) const;
706 MemoryType output_memory_type(int index) const;
707
708 // Input
709
710 // Returns an immutable input tensor. May only be used for non-Ref
711 // inputs. For Ref inputs use mutable_input below.
712 // REQUIRES: !IsRefType(input_dtype(index))
713 // TODO(mrry): Convert this to return Status.
714 const Tensor& input(int index);
715
716 // Returns the named immutable input tensor in "tensor", as defined
717 // in the OpDef. May only be used for non-Ref inputs. For Ref inputs
718 // use mutable_input below.
719 // REQUIRES: !IsRefType(input_dtype(index))
720 // REQUIRES: the named input must not be a list.
721 Status input(StringPiece name, const Tensor** tensor);
722
723 // Returns the named list-valued immutable input in "list", as
724 // defined in the OpDef. If the named output is not list-valued,
725 // returns a one-element list. May only be used for non-Ref
726 // inputs. For Ref inputs use mutable_input below.
727 // REQUIRES: !IsRefType(input_dtype(index))
728 Status input_list(StringPiece name, OpInputList* list);
729
730 // For mutable inputs, use the following together to make sure there
731 // is no concurrent access to mutable_input(), e.g.:
732 // {
733 // Tensor& t = context->mutable_input(index);
734 // mutex_lock lock(*context->input_ref_mutex(index));
735 // // modify the values in t
736 // }
737 // REQUIRES: IsRefType(input_dtype(index))
738 Status input_ref_mutex(StringPiece name, mutex** out_mutex);
739
740 // Returns a mutable input tensor. Must be used to access Ref
741 // inputs. REQUIRES: IsRefType(input_dtype(index)). The caller may
742 // modify the values stored in the Tensor buffer, and modifications
743 // will be visible to other Ops reading the same ref tensor. If
744 // !lock_held the input mutex will be acquired before returning the
745 // Tensor.
746 // TODO(mrry): Convert this to return Status.
747 Tensor mutable_input(int index, bool lock_held);
748
749 // Returns the named mutable input tensor in "tensor", as defined in
750 // the OpDef. Must be used to access Ref inputs. The values stored
751 // in the Tensor buffer may be modified, and modifications will be
752 // visible to other Ops reading the same ref tensor. If !lock_held
753 // the input mutex will be acquired before returning the Tensor.
754 // REQUIRES: the named input must not be a list.
755 // REQUIRES: the named input must be a ref tensor.
756 Status mutable_input(StringPiece name, Tensor* tensor, bool lock_held);
757
758 // Returns the named list-valued mutable input in "list", as defined
759 // in the OpDef. If the named input is not list-valued, returns a
760 // one-element list. Must be used to access Ref inputs. The values
761 // stored in the Tensor buffer may be modified, and modifications
762 // will be visible to other Ops reading the same ref tensor.
763 // REQUIRES: the named input must be a ref tensor.
764 Status mutable_input_list(StringPiece name, OpMutableInputList* list);
765
766 // Replace the corresponding Ref Input to use the storage buffer
767 // used by tensor. If !lock_held the input mutex will be acquired
768 // before returning the Tensor.
769 // REQUIRES: IsRefType(input_dtype(index)).
770 void replace_ref_input(int index, const Tensor& tensor, bool lock_held);
771
772 // Replace the corresponding named Ref Input to use the storage
773 // buffer used by tensor. If !lock_held the input mutex will be
774 // acquired before returning the Tensor.
775 // REQUIRES: IsRefType(input_dtype(index)).
776 Status replace_ref_input(StringPiece name, const Tensor& tensor,
777 bool lock_held);
778
779 // Deletes the Tensor object used as the Ref Input at
780 // input_index. This is not usually necessary and should be used
781 // with caution. If !lock_held the input mutex will be acquired
782 // before returning the Tensor.
783 // REQUIRES: IsRefType(input_dtype(input_index)).
784 void delete_ref_input(int input_index, bool lock_held);
785
786 // Return true if there is input at the given index. An operator has no
787 // input at index if its tensor is null. This is primarily used by the
788 // merge operator.
789 // TODO(mrry): Convert this to return Status.
790 bool has_input(int index) const;
791
792 // Returns true if all inputs are the same shape, otherwise sets the
793 // status to a non-OK value and returns false.
794 // Usage: if (!context->ValidateInputsAreSameShape(this)) return;
795 bool ValidateInputsAreSameShape(OpKernel* op);
796
797 // If non-null, kernels should populate with any partition subgraphs created.
graph_collector()798 GraphCollector* graph_collector() { return params_->graph_collector; }
799
800 // Input to output forwarding.
801
802 // Set the output Ref Tensor at output_index to be an alias of the
803 // input Ref Tensor at input_index.
804 // REQUIRES: IsRefType(input_dtype(input_index)).
805 // REQUIRES: IsRefType(output_dtype(output_index)).
806 void forward_ref_input_to_ref_output(int input_index, int output_index);
807
808 // Returns true when an alias to input[input_index], reshaped to output_shape,
809 // which is safe to use for in-place computation was written to *output.
810 // Returns false if input[input_index] has a refcount greater than one, or if
811 // its type does not match the expected output type of output[output_index],
812 // or the number of elements in input[input_index] does not equal the number
813 // of elements in output_shape.
814 bool forward_input_to_output_with_shape(int input_index, int output_index,
815 const TensorShape& output_shape,
816 Tensor** output) TF_MUST_USE_RESULT;
817 Status forward_input_to_output_with_shape(StringPiece input_name,
818 StringPiece output_name,
819 const TensorShape& output_shape,
820 Tensor** output) TF_MUST_USE_RESULT;
821
822 // Returns a pointer to a Tensor aliasing the underlying buffer backing
823 // input[input_index] iff
824 // * input[input_index] is not a ref,
825 // * the data type, shape, memory type, and allocator attributes of
826 // input[input_index] are compatible with those given in dtype, shape,
827 // memory_type, and attr,
828 // * refcount on the underlying buffer is one.
829 // * Either there is no forwarding reservation for either input_index
830 // or output_index or the specified input is reserved for the specified
831 // output. More precisely:
832 //
833 // These cases mean neither input nor output has a reservation:
834 // forward_from_array = nullptr
835 // OR (input_index is not in forward_from_array AND
836 // (output_index == kNoReservation OR
837 // forward_from_array[output_index] == kNoReservation))
838 //
839 // This case means that input_index is reserved for output_index:
840 // forward_from_array[output_index] == input_index
841 //
842 // This case means the output is reserved to always be allocated,
843 // never assigned a forwarded input:
844 // forward_from_array[output_index] == kNeverForward
845 //
846 // Otherwise returns nullptr.
847 // NOTE: For Cuda kernels that read inputs using the __ldg() intrinsic,
848 // forwarding is only safe if there are no reads via __ldg() after writes
849 // to the same address.
850 std::unique_ptr<Tensor> forward_input(
851 int input_index, int output_index, DataType output_dtype,
852 const TensorShape& output_shape, MemoryType output_memory_type,
853 const AllocatorAttributes& output_attr) TF_MUST_USE_RESULT;
854
855 // Tries to forward one of the inputs given in input_indices to
856 // output[output_index]. If none of the given inputs can be forwarded, calls
857 // allocate_output() to allocate a new output buffer.
858 Status forward_input_or_allocate_output(
859 gtl::ArraySlice<int> candidate_input_indices, int output_index,
860 const TensorShape& output_shape, Tensor** output) TF_MUST_USE_RESULT;
861 Status forward_input_or_allocate_output(
862 gtl::ArraySlice<StringPiece> candidate_input_names,
863 StringPiece output_name, const TensorShape& output_shape,
864 Tensor** output) TF_MUST_USE_RESULT;
865
866 // Tries to reuse one of the inputs given in input_indices as a temporary.
867 // If none of the given inputs can be forwarded, calls
868 // allocate_temp() to allocate a new temporary buffer.
869 Status forward_input_or_allocate_temp(
870 gtl::ArraySlice<int> candidate_input_indices, DataType type,
871 const TensorShape& shape, const AllocatorAttributes& allocator_attr,
872 Tensor* out_temp) TF_MUST_USE_RESULT;
873
forward_input_or_allocate_temp(gtl::ArraySlice<int> candidate_input_indices,DataType type,const TensorShape & shape,Tensor * out_temp)874 Status forward_input_or_allocate_temp(
875 gtl::ArraySlice<int> candidate_input_indices, DataType type,
876 const TensorShape& shape, Tensor* out_temp) TF_MUST_USE_RESULT {
877 return forward_input_or_allocate_temp(candidate_input_indices, type, shape,
878 AllocatorAttributes(), out_temp);
879 }
880
881 // Output
882
883 // Returns the named list-valued output in "list", as defined in the OpDef.
884 // If the named output is not list-valued, returns a one-element list.
885 Status output_list(StringPiece name, OpOutputList* list);
886
887 // If output_required(index) returns true, the OpKernel's Compute() method
888 // should call allocate_output(index, ...), set_output(index, ...),
889 // set_output_ref(index, ...), or set the status to a non-ok value.
890 // If it returns false, it may output, but is not required to do so.
891 // TODO(mrry): Convert this to return Status, and implement a string
892 // name version.
output_required(int index)893 bool output_required(int index) const {
894 return true; // TODO(josh11b): implement
895 }
896
897 // Allocation of tensors during kernel execution inside the Compute
898 // method:
899 //
900 // There are three methods to allocate Tensors when an Op kernel
901 // executes.
902 //
903 // 1) allocate_persistent. This is only needed for Tensors that will
904 // be stored by the Op between invocations, and it *must* be used
905 // for those Tensors. The call returns a PersistentTensor, and that
906 // is the only object the Op is allowed to hold on to between
907 // invocations. When the Tensor is needed in a subsequent
908 // invocation, it can be retrieved from the PersistentTensor using
909 // the AccessTensor method. This ensures that the system is made
910 // aware of any use of the tensor's allocated memory, which is
911 // needed for correctness on asynchronous devices such as GPUs.
912 //
913 // 2) allocate_output. This should be used to allocate any tensor
914 // that is going to be used as an output from the Op at the end of
915 // the current execution. The caller indicates which output the
916 // Tensor will be assigned to, and the call returns the
917 // newly-allocated Tensor. The Tensor can subsequently be assigned
918 // to during kernel execution, and will be used as the designated
919 // output when the kernel execution completes.
920 //
921 // 3) allocate_temp. This should be used to allocate any scratch
922 // storage that is needed while the kernel is executing, and will
923 // not be retained by the Op.
924 //
925 // In some cases a Tensor needs to be used as an output even though
926 // it was previously allocated elsewhere. The Tensor may have been
927 // passed as an input, or stored in a PersistentTensor during a
928 // previous kernel execution, or allocated earlier in the kernel
929 // execution at a time when it was not known which output it would
930 // be assigned to. In this case the kernel can use set_output or
931 // set_output_ref to indicate that the tensor should be used as the
932 // designated output. It is legal to use any previously-allocated
933 // Tensor as an argument to set_output or set_output_ref, including
934 // Tensors allocated via allocate_temp. There may be a performance
935 // penalty to using a Tensor that was not allocated using
936 // allocate_output. This is because allocate_output uses the
937 // AllocatorAttributes stored in output_attr_array for the
938 // designated output. In some cases, using the wrong attributes may
939 // cause an extra copy of the Tensor's buffer.
940
941 // Allocates output for the specified output index with shape.
942 // OpKernelContext retains ownership of the returned pointer. See
943 // comment above.
944 //
945 // If memory allocation fails, returns an error status.
946 //
947 // REQUIRES: !IsRefType(expected_output_dtype(index))
948 Status allocate_output(int index, const TensorShape& shape,
949 Tensor** tensor) TF_MUST_USE_RESULT;
950 Status allocate_output(StringPiece name, const TensorShape& shape,
951 Tensor** tensor) TF_MUST_USE_RESULT;
952 // The following methods use the supplied attributes instead of
953 // those in output_attr_array. The caller is responsible for
954 // ensuring that the attributes are "compatible" with the
955 // output_attr_array, e.g. the tensor is allocated on the correct
956 // device. See comment above.
957 Status allocate_output(int index, const TensorShape& shape, Tensor** tensor,
958 AllocatorAttributes attr) TF_MUST_USE_RESULT;
959 Status allocate_output(StringPiece name, const TensorShape& shape,
960 Tensor** tensor,
961 AllocatorAttributes attr) TF_MUST_USE_RESULT;
962
963 // Allocates a temporary Tensor of the specified type and
964 // shape. Devices such as GPUs that enqueue Ops for lazy execution
965 // may retain references to the temporary tensors after the Op's
966 // Compute method has run. See comment above.
967 Status allocate_temp(DataType type, const TensorShape& shape,
968 Tensor* out_temp, AllocatorAttributes allocator_attr,
969 const AllocationAttributes& allocation_attr);
allocate_temp(DataType type,const TensorShape & shape,Tensor * out_temp,AllocatorAttributes allocator_attr)970 Status allocate_temp(DataType type, const TensorShape& shape,
971 Tensor* out_temp, AllocatorAttributes allocator_attr) {
972 return allocate_temp(type, shape, out_temp, allocator_attr,
973 AllocationAttributes());
974 }
allocate_temp(DataType type,const TensorShape & shape,Tensor * out_temp)975 Status allocate_temp(DataType type, const TensorShape& shape,
976 Tensor* out_temp) {
977 return allocate_temp(type, shape, out_temp, AllocatorAttributes());
978 }
979
980 // Allocates a Tensor of the specified type and shape which the Op
981 // plans to maintain as persistent state. out_persistent holds the
982 // PersistentTensor which is the object the caller should store. For
983 // convenience, if out_tensor is non-null then it will be filled in
984 // with a Tensor* pointing to the newly-allocated tensor which the
985 // caller can use instead of calling
986 // out_persistent->AccessTensor. The caller does not own out_tensor
987 // and should not keep a copy of it. See comment above.
988 Status allocate_persistent(DataType type, const TensorShape& shape,
989 PersistentTensor* out_persistent,
990 Tensor** out_tensor, AllocatorAttributes attr);
allocate_persistent(DataType type,const TensorShape & shape,PersistentTensor * out_persistent,Tensor ** out_tensor)991 Status allocate_persistent(DataType type, const TensorShape& shape,
992 PersistentTensor* out_persistent,
993 Tensor** out_tensor) {
994 return allocate_persistent(type, shape, out_persistent, out_tensor,
995 AllocatorAttributes());
996 }
997
998 // Copies a tensor (allocated by the caller) to the specified output
999 // index. REQUIRES: !IsRefType(expected_output_dtype(index))
1000 // REQUIRES: 'tensor' must have the same MemoryType as
1001 // output_memory_types[index]. See comment above.
1002 Status set_output(StringPiece name, const Tensor& tensor);
1003
1004 // To output a reference. Caller retains ownership of mu and tensor_for_ref,
1005 // and they must outlive all uses within the step. See comment above.
1006 // REQUIRES: IsRefType(expected_output_dtype(index))
1007 Status set_output_ref(StringPiece name, mutex* mu, Tensor* tensor_for_ref);
1008
1009 // Returns nullptr if allocate_output() or set_output() have not been called.
1010 Status mutable_output(StringPiece name, Tensor** tensor);
1011
1012 // Records device specific state about how the input tensors were
1013 // computed.
1014 //
1015 // If using the templated function, the type must be a subclass
1016 // of DeviceContext.
1017 //
1018 // Get the DeviceContext used for the index input. Returns nullptr
1019 // if no DeviceContext was provided.
1020 template <typename T>
1021 T* input_device_context(int index);
1022 DeviceContext* input_device_context(int index);
1023
1024 // Return the DeviceContext that should be used for this Op.
1025 //
1026 // If using the templated function, the type must be a subclass
1027 // of DeviceContext.
1028 //
1029 // Returns nullptr if the device did not provide one.
1030 template <typename T>
1031 T* op_device_context();
op_device_context()1032 DeviceContext* op_device_context() {
1033 DeviceContext* ret = params_->op_device_context;
1034 if (ret == nullptr) {
1035 auto* dev_info = device()->tensorflow_gpu_device_info();
1036 if (dev_info) ret = dev_info->default_context;
1037 }
1038 return ret;
1039 }
1040
input_alloc_attr(int index)1041 AllocatorAttributes input_alloc_attr(int index) const {
1042 if (params_->input_alloc_attrs == nullptr) {
1043 return AllocatorAttributes();
1044 } else {
1045 DCHECK_GE(index, 0);
1046 DCHECK_LT(index, params_->input_alloc_attrs->size());
1047 return (*params_->input_alloc_attrs)[index];
1048 }
1049 }
1050
output_alloc_attr(int index)1051 AllocatorAttributes output_alloc_attr(int index) const {
1052 return params_->output_attr_array[index];
1053 }
1054
ConsumeWrappedAllocators()1055 gtl::InlinedVector<WrappedAllocator, 4> ConsumeWrappedAllocators() {
1056 mutex_lock lock(mu_);
1057 gtl::InlinedVector<WrappedAllocator, 4> retrieved;
1058 retrieved.swap(wrapped_allocators_);
1059 return retrieved;
1060 }
1061
1062 // Communication.
1063 //
1064 // An op kernel communicates with outside environment through
1065 // Rendezvous Send() and Recv().
rendezvous()1066 Rendezvous* rendezvous() const { return params_->rendezvous; }
1067
collective_executor()1068 CollectiveExecutor* collective_executor() const {
1069 return params_->collective_executor;
1070 }
1071
1072 // An op kernel can access the session state it belongs to.
session_state()1073 SessionState* session_state() const { return params_->session_state; }
1074
1075 // Unique identifier of the session it belongs to. Can be empty.
session_handle()1076 string session_handle() const { return params_->session_handle; }
1077
1078 // An op kernel can access the tensor store of the run it belongs to.
tensor_store()1079 TensorStore* tensor_store() const { return params_->tensor_store; }
1080
1081 // Function call support.
1082 //
1083 // If this kernel invocation is within a function execution,
1084 // call_frame() returns the call frame for the function call.
call_frame()1085 CallFrameInterface* call_frame() const { return params_->call_frame; }
1086
1087 // If not nullptr, the kernel invoke functions defined in the
1088 // library. E.g., CHECK_NOTNULL(function_library())->Run("Foo", ...).
function_library()1089 FunctionLibraryRuntime* function_library() const {
1090 return params_->function_library;
1091 }
1092
runner()1093 std::function<void(std::function<void()>)>* runner() const {
1094 return params_->runner;
1095 }
stats_collector()1096 StepStatsCollectorInterface* stats_collector() const {
1097 return params_->stats_collector;
1098 }
1099
1100 // Shared resources accessible to this kernel.
resource_manager()1101 ResourceMgr* resource_manager() const { return params_->resource_manager; }
1102
slice_reader_cache()1103 checkpoint::TensorSliceReaderCacheWrapper* slice_reader_cache() const {
1104 return params_->slice_reader_cache;
1105 }
1106
1107 // Execution.
1108 //
1109 // OpKernels can use these eigen devices to carry out their
1110 // numerical computation.
eigen_cpu_device()1111 const Eigen::ThreadPoolDevice& eigen_cpu_device() const {
1112 return *device()->eigen_cpu_device();
1113 }
eigen_gpu_device()1114 const Eigen::GpuDevice& eigen_gpu_device() const {
1115 return params_->eigen_gpu_device->device();
1116 }
1117 #ifdef TENSORFLOW_USE_SYCL
eigen_sycl_device()1118 const Eigen::SyclDevice& eigen_sycl_device() const {
1119 return *device()->eigen_sycl_device();
1120 }
1121 #endif
1122 template <typename EigenDeviceType>
1123 const EigenDeviceType& eigen_device() const;
1124
1125 // Error handling.
1126
1127 // If expected_inputs == inputs() and expected_outputs == output_types(),
1128 // returns OK, else returns INVALID_ARGUMENT with an error message.
1129 // Recommended for Ops with dynamic signatures, where validation can only
1130 // be performed at runtime.
1131 Status MatchSignature(const DataTypeSlice expected_inputs,
1132 const DataTypeSlice expected_outputs);
1133
1134 // An OpKernel should call SetStatus() if Compute() encounters an
1135 // error.
1136 void SetStatus(const Status& status);
status()1137 const Status& status() const { return status_; }
1138
1139 // Cancellation.
1140 //
1141 // EXPERIMENTAL. See the implementation in tensorflow::FIFOQueue for an
1142 // example of how to use this API.
cancellation_manager()1143 CancellationManager* cancellation_manager() const {
1144 return params_->cancellation_manager;
1145 }
1146
1147 // Other accessors.
1148
1149 // For control flow.
frame_iter()1150 FrameAndIter frame_iter() const { return params_->frame_iter; }
is_input_dead()1151 bool is_input_dead() const { return params_->is_input_dead; }
1152
1153 // May be used, e.g., to get GPU handles, etc.
1154 // TODO(tucker): Add example usage.
device()1155 DeviceBase* device() const { return params_->device; }
1156
1157 // Retrieve list of referenced tensors in out_vector. Once this is
1158 // called, it is not legal to reference any more tensors. Should
1159 // not be called from Op kernels.
1160 void retrieve_accessed_tensors(TensorReferenceVector* out_vector);
1161
1162 // Per-step container for use by white-listed internal ops.
step_container()1163 ScopedStepContainer* step_container() const {
1164 return params_->step_container;
1165 }
1166
1167 // Helper routines for the OP_REQUIRES macros
1168 void CtxFailure(const Status& s);
1169 void CtxFailureWithWarning(const Status& s);
1170 void CtxFailure(const char* file, int line, const Status& s);
1171 void CtxFailureWithWarning(const char* file, int line, const Status& s);
1172
1173 // Unrecommended functions: these are functions that have some
1174 // current uses but are not recommended for use, and may go away at
1175 // some future major version release.
1176 //
1177 // The following functions all have versions that return Status
1178 // to capture error conditions, and are strongly preferred.
1179 Tensor* mutable_output(int index);
1180 void set_output(int index, const Tensor& tensor);
1181 mutex* input_ref_mutex(int index);
1182 void set_output_ref(int index, mutex* mu, Tensor* tensor_for_ref);
1183 TensorValue release_output(int index);
1184
track_allocations()1185 bool track_allocations() const { return params_->track_allocations; }
1186
1187 // Records temp memory allocation. Tensor object is recorded to identify the
1188 // case where temp memory is used as output memory.
1189 void record_temp_memory_allocation(int64 size, const Tensor& t)
1190 LOCKS_EXCLUDED(stats_mu_);
1191
1192 // Returns recorded size of temporary memory;
1193 int64 temp_memory_allocated() const LOCKS_EXCLUDED(stats_mu_);
1194
1195 // Records persistent memory allocation, size can be negative indicating
1196 // deallocation.
1197 void record_persistent_memory_allocation(int64 size, int64 alloc_id = -1)
1198 LOCKS_EXCLUDED(stats_mu_);
1199
1200 // Returns recorded size and ids of persistent memory.
1201 int64 persistent_memory_allocated() const LOCKS_EXCLUDED(stats_mu_);
1202
1203 std::vector<int64> persistent_alloc_ids() const LOCKS_EXCLUDED(stats_mu_);
1204
1205 // Resets counters for temp and persistent memory and recorded ids.
1206 void clear_recorded_memory() LOCKS_EXCLUDED(stats_mu_);
1207
1208 bool input_is_ref(int index) const;
1209
1210 // Used by OpKernel implementations to track actively running deferred ops.
1211 //
1212 // A deferred op is one whose Compute method returns (or whose ComputeAsync
1213 // method invokes the callback) when work is scheduled onto a device. At that
1214 // point, we don't know when the work will actually complete (or if it has
1215 // already completed) on the device. These functions allow the executor to
1216 // track the status of deferred ops and act accordingly.
1217 //
1218 // Deferred OpKernel implementations must use these methods to get two
1219 // functions. It then must call these two functions in pairs, before and after
1220 // device execution, respectively.
inc_num_deferred_ops_function()1221 TF_MUST_USE_RESULT std::function<void()> inc_num_deferred_ops_function() {
1222 return params_->inc_num_deferred_ops_function;
1223 }
dec_num_deferred_ops_function()1224 TF_MUST_USE_RESULT std::function<void()> dec_num_deferred_ops_function() {
1225 return params_->dec_num_deferred_ops_function;
1226 }
1227
1228 private:
1229 Allocator* get_allocator(AllocatorAttributes attr);
1230
1231 // Internal method to add a tensor's buffer to the list of buffers
1232 // referenced during the execution of the Op, so that GPUs may
1233 // accurately track the memory that may not be reused until the Op
1234 // execution completes.
1235 void record_tensor_reference(const Tensor& tensor);
1236 void really_record_tensor_reference(const Tensor& tensor);
1237
1238 // Internal common method used when allocating tensor memory
allocate_tensor(DataType type,const TensorShape & shape,Tensor * out_tensor,AllocatorAttributes allocator_attr)1239 Status allocate_tensor(DataType type, const TensorShape& shape,
1240 Tensor* out_tensor,
1241 AllocatorAttributes allocator_attr) {
1242 return allocate_tensor(type, shape, out_tensor, allocator_attr,
1243 AllocationAttributes());
1244 }
1245
1246 Status allocate_tensor(DataType type, const TensorShape& shape,
1247 Tensor* out_tensor, AllocatorAttributes allocator_attr,
1248 const AllocationAttributes& allocation_attr);
1249
1250 // This is called by PersistentTensor::AccessTensor whenever the
1251 // wrapped tensor is retrieved, to ensure the runtime knows that the
1252 // Tensor is being accessed within an Op. This is necessary for
1253 // memory safety of devices like GPUs that queue Ops for
1254 // asynchronous execution after the Compute() method completes.
1255 friend class PersistentTensor;
1256 void NotifyUseOfPersistentTensor(const Tensor& tensor);
1257
1258 Status status_;
1259 friend class CollectiveExecutor; // for access to params_
1260 Params* params_; // not owned
1261 mutable mutex mu_; // mutable so const accessors can acquire the lock
1262 gtl::InlinedVector<WrappedAllocator, 4> wrapped_allocators_ GUARDED_BY(mu_);
1263 gtl::InlinedVector<TensorValue, 4> outputs_;
1264
1265 // Constructed only if <params->record_tensor_accesses>.
1266 ManualConstructor<UniqueTensorReferences> referenced_tensors_ GUARDED_BY(mu_);
1267
1268 // The following data members are only used when allocation tracking is
1269 // enabled.
1270 mutable mutex stats_mu_;
1271 int64 temp_memory_allocated_ GUARDED_BY(stats_mu_);
1272 int64 persistent_memory_allocated_ GUARDED_BY(stats_mu_);
1273 std::unique_ptr<gtl::InlinedVector<std::pair<const void*, int64>, 2>>
1274 temp_tensor_buffer_and_size_ GUARDED_BY(stats_mu_);
1275 std::unique_ptr<gtl::InlinedVector<int64, 2>> persistent_alloc_ids_
1276 GUARDED_BY(stats_mu_);
1277
1278 TF_DISALLOW_COPY_AND_ASSIGN(OpKernelContext);
1279 };
1280
1281 // Register your OpKernel by specifying the Op's name, the device the
1282 // kernel runs on, any type attr constraints for this kernel, any
1283 // host-memory args, and the class to instantiate. Examples:
1284 //
1285 // // A kernel that supports all types.
1286 // REGISTER_KERNEL_BUILDER(Name("Save").Device(DEVICE_CPU), SaveOp);
1287 //
1288 // // The following are equivalent ways of specifying that the kernel only
1289 // // works if the "T" type attr is set to DT_FLOAT.
1290 // REGISTER_KERNEL_BUILDER(
1291 // Name("Sub").Device(DEVICE_CPU).TypeConstraint<float>("T"),
1292 // SubOp<float>);
1293 // // (You would then repeat this for every type supported by "Sub".)
1294 //
1295 // // This form allows you to specify a list of types as the constraint.
1296 // REGISTER_KERNEL_BUILDER(Name("Sub")
1297 // .Device(DEVICE_CPU)
1298 // .TypeConstraint("T", {DT_FLOAT}),
1299 // SubOp<float>);
1300 //
1301 // // A kernel that expects one of the input tensors in host memory.
1302 // REGISTER_KERNEL_BUILDER(
1303 // Name("Reshape").Device(DEVICE_GPU).HostMemory("shape"), ReshapeOp);
1304 //
1305 // See kernel_def_builder for details.
1306
1307 // Instantiate an OpKernel that has been registered. Returns nullptr
1308 // if no operation for that type of device / input signature combination
1309 // (and a NOT_FOUND *status), or there is an error in construction (and
1310 // an INVALID_ARGUMENT *status). Otherwise, the caller takes ownership
1311 // of the returned pointer.
1312 // EXPECTED USAGE: unique_ptr<OpKernel> op = CreateOpKernel(...);
1313 // REQUIRES: def has all attrs specified (e.g. using AddDefaultsToNodeDef()).
1314 std::unique_ptr<OpKernel> CreateOpKernel(DeviceType device_type,
1315 DeviceBase* device,
1316 Allocator* allocator,
1317 const NodeDef& def,
1318 int graph_def_version, Status* status);
1319 Status CreateOpKernel(DeviceType device_type, DeviceBase* device,
1320 Allocator* allocator, FunctionLibraryRuntime* flib,
1321 const NodeDef& def, int graph_def_version,
1322 OpKernel** kernel);
1323
1324 // Returns into 'device_types' the subset of prioritized_types that this
1325 // binary has registered for the given NodeDef.
1326 //
1327 // REQUIRES: * 'device_types' is not nullptr.
1328 // * def has all attrs specified (e.g. using AddDefaultsToNodeDef()).
1329 Status SupportedDeviceTypesForNode(
1330 const std::vector<DeviceType>& prioritized_types, const NodeDef& def,
1331 PrioritizedDeviceTypeVector* device_types);
1332
1333 // Returns a message with a description of the kernels registered for op
1334 // `op_name`.
1335 string KernelsRegisteredForOp(StringPiece op_name);
1336
1337 // Call once after Op registration has completed.
1338 Status ValidateKernelRegistrations(const OpRegistryInterface& op_registry);
1339
1340 // -----------------------------------------------------------------------------
1341 // OpKernel registration implementation follows, please ignore.
1342
1343 // Allow the REGISTER_KERNEL_BUILDER(Name("op_name").Device(...)...) syntax.
1344 namespace register_kernel {
1345
1346 class Name : public KernelDefBuilder {
1347 public:
1348 // With selective registration, kernels whose implementation class is not used
1349 // by any kernel are disabled with the SHOULD_REGISTER_OP_KERNEL call in
1350 // REGISTER_KERNEL_BUILDER_UNIQ. However, an unused kernel that shares an
1351 // implementation class with a used kernel would get through that mechanism.
1352 //
1353 // This mechanism stops that registration by changing the name of the kernel
1354 // for the unused op to one that is ignored by
1355 // OpKernelRegistrar::InitInternal. Note that this method alone is
1356 // not sufficient - the compiler can't evaluate the entire KernelDefBuilder at
1357 // compilation time, so this method doesn't actually reduce code size.
Name(const char * op)1358 explicit Name(const char* op)
1359 : KernelDefBuilder(SHOULD_REGISTER_OP(op) ? op : "_no_register") {}
1360 };
1361
1362 namespace system {
1363
1364 class Name : public KernelDefBuilder {
1365 public:
1366 // For system kernels, we ignore selective registration and
1367 // unconditionally register the kernel.
Name(const char * op)1368 explicit Name(const char* op) : KernelDefBuilder(op) {}
1369 };
1370
1371 } // namespace system
1372
1373 } // namespace register_kernel
1374
1375 #define REGISTER_KERNEL_BUILDER(kernel_builder, ...) \
1376 REGISTER_KERNEL_BUILDER_UNIQ_HELPER(__COUNTER__, kernel_builder, __VA_ARGS__)
1377
1378 #define REGISTER_KERNEL_BUILDER_UNIQ_HELPER(ctr, kernel_builder, ...) \
1379 REGISTER_KERNEL_BUILDER_UNIQ(ctr, kernel_builder, __VA_ARGS__)
1380
1381 #define REGISTER_KERNEL_BUILDER_UNIQ(ctr, kernel_builder, ...) \
1382 constexpr bool should_register_##ctr##__flag = \
1383 SHOULD_REGISTER_OP_KERNEL(#__VA_ARGS__); \
1384 static ::tensorflow::kernel_factory::OpKernelRegistrar \
1385 registrar__body__##ctr##__object( \
1386 should_register_##ctr##__flag \
1387 ? ::tensorflow::register_kernel::kernel_builder.Build() \
1388 : nullptr, \
1389 #__VA_ARGS__, \
1390 [](::tensorflow::OpKernelConstruction* context) \
1391 -> ::tensorflow::OpKernel* { \
1392 return new __VA_ARGS__(context); \
1393 });
1394
1395 // The `REGISTER_SYSTEM_KERNEL_BUILDER()` macro acts as
1396 // `REGISTER_KERNEL_BUILDER()` except that the kernel is registered
1397 // unconditionally even when selective registration is used.
1398 #define REGISTER_SYSTEM_KERNEL_BUILDER(kernel_builder, ...) \
1399 REGISTER_SYSTEM_KERNEL_BUILDER_UNIQ_HELPER(__COUNTER__, kernel_builder, \
1400 __VA_ARGS__)
1401
1402 #define REGISTER_SYSTEM_KERNEL_BUILDER_UNIQ_HELPER(ctr, kernel_builder, ...) \
1403 REGISTER_SYSTEM_KERNEL_BUILDER_UNIQ(ctr, kernel_builder, __VA_ARGS__)
1404
1405 #define REGISTER_SYSTEM_KERNEL_BUILDER_UNIQ(ctr, kernel_builder, ...) \
1406 static ::tensorflow::kernel_factory::OpKernelRegistrar \
1407 registrar__body__##ctr##__object( \
1408 ::tensorflow::register_kernel::system::kernel_builder.Build(), \
1409 #__VA_ARGS__, \
1410 [](::tensorflow::OpKernelConstruction* context) \
1411 -> ::tensorflow::OpKernel* { \
1412 return new __VA_ARGS__(context); \
1413 });
1414
1415 // Checks whether a given kernel is registered on device_type.
1416 bool KernelDefAvailable(const DeviceType& device_type, const NodeDef& node_def);
1417
1418 // If node_def has a corresponding kernel registered on device_type,
1419 // returns OK and fill in the kernel def and kernel_class_name. <def> and
1420 // <kernel_class_name> may be null.
1421 Status FindKernelDef(const DeviceType& device_type, const NodeDef& node_def,
1422 const KernelDef** def, string* kernel_class_name);
1423
1424 // Writes a list of all registered kernels to LOG(INFO), to help users debug
1425 // missing kernel errors.
1426 void LogAllRegisteredKernels();
1427
1428 // Gets a list of all registered kernels.
1429 KernelList GetAllRegisteredKernels();
1430
1431 // Gets a list of all registered kernels for which predicate returns true
1432 KernelList GetFilteredRegisteredKernels(
1433 const std::function<bool(const KernelDef&)>& predicate);
1434
1435 // Gets a list of all registered kernels for a given op
1436 KernelList GetRegisteredKernelsForOp(StringPiece op_name);
1437
1438 namespace kernel_factory {
1439
1440 // OpKernelFactory is responsible for creating OpKernels when TensorFlow needs
1441 // them. You register factories with the TensorFlow core by constructing an
1442 // OpKernelRegistrar and passing the factory as a constructor parameter.
1443 class OpKernelFactory {
1444 public:
1445 virtual OpKernel* Create(OpKernelConstruction* context) = 0;
1446 virtual ~OpKernelFactory() = default;
1447 };
1448
1449 class OpKernelRegistrar {
1450 public:
1451 // Registers the given kernel factory with TensorFlow. TF will call the
1452 // factory Create() method when it determines that a kernel matching the given
1453 // KernelDef is required.
OpKernelRegistrar(const KernelDef * kernel_def,StringPiece kernel_class_name,std::unique_ptr<OpKernelFactory> factory)1454 OpKernelRegistrar(const KernelDef* kernel_def, StringPiece kernel_class_name,
1455 std::unique_ptr<OpKernelFactory> factory) {
1456 // Perform the check in the header to allow compile-time optimization
1457 // to a no-op, allowing the linker to remove the kernel symbols.
1458 if (kernel_def != nullptr) {
1459 InitInternal(kernel_def, kernel_class_name, std::move(factory));
1460 }
1461 }
1462
1463 // Registers the given factory function with TensorFlow. This is equivalent
1464 // to registering a factory whose Create function invokes `create_fn`.
OpKernelRegistrar(const KernelDef * kernel_def,StringPiece kernel_class_name,OpKernel * (* create_fn)(OpKernelConstruction *))1465 OpKernelRegistrar(const KernelDef* kernel_def, StringPiece kernel_class_name,
1466 OpKernel* (*create_fn)(OpKernelConstruction*)) {
1467 // Perform the check in the header to allow compile-time optimization
1468 // to a no-op, allowing the linker to remove the kernel symbols.
1469 if (kernel_def != nullptr) {
1470 InitInternal(kernel_def, kernel_class_name,
1471 absl::make_unique<PtrOpKernelFactory>(create_fn));
1472 }
1473 }
1474
1475 private:
1476 struct PtrOpKernelFactory : public OpKernelFactory {
PtrOpKernelFactoryPtrOpKernelFactory1477 explicit PtrOpKernelFactory(OpKernel* (*create_func)(OpKernelConstruction*))
1478 : create_func_(create_func) {}
1479
1480 OpKernel* Create(OpKernelConstruction* context) override;
1481
1482 OpKernel* (*create_func_)(OpKernelConstruction*);
1483 };
1484
1485 void InitInternal(const KernelDef* kernel_def, StringPiece kernel_class_name,
1486 std::unique_ptr<OpKernelFactory> factory);
1487 };
1488
1489 } // namespace kernel_factory
1490
1491 // -----------------------------------------------------------------------------
1492 // Template and inline method implementations, please ignore
1493
1494 template <class T>
GetAttr(StringPiece attr_name,T * value)1495 Status OpKernelConstruction::GetAttr(StringPiece attr_name, T* value) const {
1496 return GetNodeAttr(def(), attr_name, value);
1497 }
1498
input_dtype(int index)1499 inline DataType OpKernelContext::input_dtype(int index) const {
1500 DCHECK_GE(index, 0);
1501 DCHECK_LT(index, num_inputs());
1502 const TensorValue& value((*params_->inputs)[index]);
1503 if (value.is_ref()) {
1504 return MakeRefType(value->dtype());
1505 } else {
1506 return value->dtype();
1507 }
1508 }
1509
input_memory_type(int index)1510 inline MemoryType OpKernelContext::input_memory_type(int index) const {
1511 DCHECK_GE(index, 0);
1512 DCHECK_LT(index, num_inputs());
1513 return op_kernel().input_memory_types()[index];
1514 }
1515
expected_output_dtype(int index)1516 inline DataType OpKernelContext::expected_output_dtype(int index) const {
1517 DCHECK_GE(index, 0);
1518 DCHECK_LT(index, num_outputs());
1519 return params_->op_kernel->output_type(index);
1520 }
1521
output_memory_type(int index)1522 inline MemoryType OpKernelContext::output_memory_type(int index) const {
1523 DCHECK_GE(index, 0);
1524 DCHECK_LT(index, num_outputs());
1525 return op_kernel().output_memory_types()[index];
1526 }
1527
input_is_ref(int index)1528 inline bool OpKernelContext::input_is_ref(int index) const {
1529 const TensorValue& value((*params_->inputs)[index]);
1530 return value.is_ref();
1531 }
1532
record_tensor_reference(const Tensor & tensor)1533 inline void OpKernelContext::record_tensor_reference(const Tensor& tensor) {
1534 DCHECK_EQ(params_->device->RequiresRecordingAccessedTensors(),
1535 params_->record_tensor_accesses);
1536 if (params_->record_tensor_accesses) {
1537 really_record_tensor_reference(tensor);
1538 }
1539 }
1540
retrieve_accessed_tensors(TensorReferenceVector * out_vector)1541 inline void OpKernelContext::retrieve_accessed_tensors(
1542 TensorReferenceVector* out_vector) {
1543 if (params_->record_tensor_accesses) {
1544 mutex_lock l(mu_);
1545 referenced_tensors_->FreezeAndReturnReferences(out_vector);
1546 }
1547 }
1548
1549 // no input if tensor == nullptr.
has_input(int index)1550 inline bool OpKernelContext::has_input(int index) const {
1551 DCHECK_GE(index, 0);
1552 DCHECK_LT(index, num_inputs());
1553 return (*params_->inputs)[index].tensor != nullptr;
1554 }
1555
input_ref_mutex(int index)1556 inline mutex* OpKernelContext::input_ref_mutex(int index) {
1557 DCHECK_GE(index, 0);
1558 DCHECK_LT(index, num_inputs());
1559 DCHECK(input_is_ref(index));
1560 return (*params_->inputs)[index].mutex_if_ref;
1561 }
1562
NotifyUseOfPersistentTensor(const Tensor & t)1563 inline void OpKernelContext::NotifyUseOfPersistentTensor(const Tensor& t) {
1564 if (t.IsInitialized()) {
1565 record_tensor_reference(t);
1566 }
1567 }
1568
mutable_output(int index)1569 inline Tensor* OpKernelContext::mutable_output(int index) {
1570 DCHECK_GE(index, 0);
1571 DCHECK_LT(index, num_outputs());
1572 // No need to record_tensor_reference since the output must already
1573 // have been set by a call that did so.
1574 return outputs_[index].tensor;
1575 }
1576
release_output(int index)1577 inline TensorValue OpKernelContext::release_output(int index) {
1578 DCHECK_GE(index, 0);
1579 DCHECK_LT(index, num_outputs());
1580 TensorValue value = outputs_[index];
1581 outputs_[index] = TensorValue();
1582 return value;
1583 }
1584
forward_input_or_allocate_output(gtl::ArraySlice<int> candidate_input_indices,int output_index,const TensorShape & output_shape,Tensor ** output)1585 inline Status OpKernelContext::forward_input_or_allocate_output(
1586 gtl::ArraySlice<int> candidate_input_indices, int output_index,
1587 const TensorShape& output_shape, Tensor** output) {
1588 for (int input_index : candidate_input_indices) {
1589 if (forward_input_to_output_with_shape(input_index, output_index,
1590 output_shape, output)) {
1591 return Status::OK();
1592 }
1593 }
1594 return allocate_output(output_index, output_shape, output);
1595 }
1596
forward_input_or_allocate_output(gtl::ArraySlice<StringPiece> candidate_input_names,StringPiece output_name,const TensorShape & output_shape,Tensor ** output)1597 inline Status OpKernelContext::forward_input_or_allocate_output(
1598 gtl::ArraySlice<StringPiece> candidate_input_names, StringPiece output_name,
1599 const TensorShape& output_shape, Tensor** output) {
1600 for (const StringPiece& input_name : candidate_input_names) {
1601 if (forward_input_to_output_with_shape(input_name, output_name,
1602 output_shape, output)
1603 .ok()) {
1604 return Status::OK();
1605 }
1606 }
1607 return allocate_output(output_name, output_shape, output);
1608 }
1609
1610 template <typename T>
op_device_context()1611 T* OpKernelContext::op_device_context() {
1612 static_assert(std::is_base_of<DeviceContext, T>::value,
1613 "T is not a subclass of DeviceContext");
1614 return static_cast<T*>(op_device_context());
1615 }
1616
1617 template <typename T>
input_device_context(int index)1618 T* OpKernelContext::input_device_context(int index) {
1619 DCHECK_NE(params_->input_device_contexts, nullptr);
1620 DCHECK_GE(index, 0);
1621 DCHECK_LT(index, params_->input_device_contexts->size());
1622 static_assert(std::is_base_of<DeviceContext, T>::value,
1623 "T is not a subclass of DeviceContext");
1624 return static_cast<T*>((*params_->input_device_contexts)[index]);
1625 }
1626
input_device_context(int index)1627 inline DeviceContext* OpKernelContext::input_device_context(int index) {
1628 DCHECK_NE(params_->input_device_contexts, nullptr);
1629 DCHECK_GE(index, 0);
1630 DCHECK_LT(index, params_->input_device_contexts->size());
1631 return (*params_->input_device_contexts)[index];
1632 }
1633
1634 inline const Tensor& OpInputList::operator[](int i) const {
1635 DCHECK_GE(i, 0);
1636 DCHECK_LT(i, stop_ - start_);
1637 return ctx_->input(start_ + i);
1638 }
1639
ref_mutex(int i)1640 inline mutex* OpMutableInputList::ref_mutex(int i) {
1641 DCHECK_GE(i, 0);
1642 DCHECK_LT(i, stop_ - start_);
1643 return ctx_->input_ref_mutex(start_ + i);
1644 }
1645
at(int i,bool lock_held)1646 inline Tensor OpMutableInputList::at(int i, bool lock_held) {
1647 DCHECK_GE(i, 0);
1648 DCHECK_LT(i, stop_ - start_);
1649 return ctx_->mutable_input(start_ + i, lock_held);
1650 }
1651
1652 inline Tensor* OpOutputList::operator[](int i) {
1653 DCHECK_GE(i, 0);
1654 DCHECK_LT(i, stop_ - start_);
1655 return ctx_->mutable_output(start_ + i);
1656 }
1657
required(int i)1658 inline bool OpOutputList::required(int i) const {
1659 DCHECK_GE(i, 0);
1660 DCHECK_LT(i, stop_ - start_);
1661 return ctx_->output_required(start_ + i);
1662 }
1663
expected_output_dtype(int i)1664 inline DataType OpOutputList::expected_output_dtype(int i) const {
1665 DCHECK_GE(i, 0);
1666 DCHECK_LT(i, stop_ - start_);
1667 return ctx_->expected_output_dtype(start_ + i);
1668 }
1669
allocate(int i,const TensorShape & shape,Tensor ** output)1670 inline Status OpOutputList::allocate(int i, const TensorShape& shape,
1671 Tensor** output) {
1672 DCHECK_GE(i, 0);
1673 DCHECK_LT(i, stop_ - start_);
1674 return ctx_->allocate_output(start_ + i, shape, output);
1675 }
1676
set(int i,const Tensor & tensor)1677 inline void OpOutputList::set(int i, const Tensor& tensor) {
1678 DCHECK_GE(i, 0);
1679 DCHECK_LT(i, stop_ - start_);
1680 ctx_->set_output(start_ + i, tensor);
1681 }
1682
set_ref(int i,mutex * mu,Tensor * tensor_for_ref)1683 inline void OpOutputList::set_ref(int i, mutex* mu, Tensor* tensor_for_ref) {
1684 DCHECK_GE(i, 0);
1685 DCHECK_LT(i, stop_ - start_);
1686 ctx_->set_output_ref(i, mu, tensor_for_ref);
1687 }
1688
1689 // Convenience macros for asserting and handling exceptional conditions.
1690 // Analogous to the CHECK* macros provided by logging.h.
1691 //
1692 // Example use:
1693 // void Compute(OperationContext* context) {
1694 // OP_REQUIRES(context, context->num_inputs() == 2,
1695 // errors::InvalidArgument("FooOp requires 2 arguments"));
1696 // ...
1697 // Status status = SomeUncertainMethod();
1698 // OP_REQUIRES_OK(context, status);
1699 // ...
1700 // }
1701
1702 // Generate a fatal error if OP_REQUIRES or OP_REQUIRES_OK are used in
1703 // AsyncOpKernel implementations. If these macros are used and the condition
1704 // does not hold, the `done` callback will never be called and the system will
1705 // deadlock, so a crash failure is preferable. Since the OP_REQUIRES[_OK] macros
1706 // are legal to use in AsyncOpKernel constructors, we use overload resolution
1707 // to distinguish between OpKernelConstruction* and OpKernelContext* context
1708 // types.
1709 class XlaOpKernelContext;
CheckNotInComputeAsync(XlaOpKernelContext *,const char *)1710 inline void CheckNotInComputeAsync(XlaOpKernelContext*, const char*) {}
CheckNotInComputeAsync(OpKernelConstruction *,const char *)1711 inline void CheckNotInComputeAsync(OpKernelConstruction*, const char*) {}
1712 void CheckNotInComputeAsync(OpKernelContext* ctx,
1713 const char* correct_macro_name);
1714
1715 #define OP_REQUIRES(CTX, EXP, STATUS) \
1716 do { \
1717 if (!TF_PREDICT_TRUE(EXP)) { \
1718 CheckNotInComputeAsync((CTX), "OP_REQUIRES_ASYNC"); \
1719 (CTX)->CtxFailure(__FILE__, __LINE__, (STATUS)); \
1720 return; \
1721 } \
1722 } while (0)
1723
1724 #define OP_REQUIRES_OK(CTX, ...) \
1725 do { \
1726 ::tensorflow::Status _s(__VA_ARGS__); \
1727 if (!TF_PREDICT_TRUE(_s.ok())) { \
1728 CheckNotInComputeAsync((CTX), "OP_REQUIRES_OK_ASYNC"); \
1729 (CTX)->CtxFailureWithWarning(__FILE__, __LINE__, _s); \
1730 return; \
1731 } \
1732 } while (0)
1733
1734 #define OP_REQUIRES_ASYNC(CTX, EXP, STATUS, CALLBACK) \
1735 do { \
1736 if (!TF_PREDICT_TRUE(EXP)) { \
1737 (CTX)->CtxFailure(__FILE__, __LINE__, (STATUS)); \
1738 (CALLBACK)(); \
1739 return; \
1740 } \
1741 } while (0)
1742
1743 #define OP_REQUIRES_OK_ASYNC(CTX, STATUS, CALLBACK) \
1744 do { \
1745 ::tensorflow::Status _s(STATUS); \
1746 if (!TF_PREDICT_TRUE(_s.ok())) { \
1747 (CTX)->CtxFailureWithWarning(__FILE__, __LINE__, _s); \
1748 (CALLBACK)(); \
1749 return; \
1750 } \
1751 } while (0)
1752
1753 } // namespace tensorflow
1754
1755 #endif // TENSORFLOW_CORE_FRAMEWORK_OP_KERNEL_H_
1756