1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_CORE_FRAMEWORK_OP_KERNEL_H_
17 #define TENSORFLOW_CORE_FRAMEWORK_OP_KERNEL_H_
18
19 #include <functional>
20 #include <unordered_set>
21 #include <utility>
22 #include <vector>
23
24 #include "tensorflow/core/framework/allocator.h"
25 #include "tensorflow/core/framework/cancellation.h"
26 #include "tensorflow/core/framework/control_flow.h"
27 #include "tensorflow/core/framework/device_base.h"
28 #include "tensorflow/core/framework/graph.pb.h"
29 #include "tensorflow/core/framework/kernel_def.pb.h"
30 #include "tensorflow/core/framework/kernel_def_builder.h"
31 #include "tensorflow/core/framework/node_def.pb.h"
32 #include "tensorflow/core/framework/node_def_util.h"
33 #include "tensorflow/core/framework/node_properties.h"
34 #include "tensorflow/core/framework/op.h" // TODO(b/62899350): Remove
35 #include "tensorflow/core/framework/op_requires.h"
36 #include "tensorflow/core/framework/registration/registration.h"
37 #include "tensorflow/core/framework/rendezvous.h"
38 #include "tensorflow/core/framework/session_state.h"
39 #include "tensorflow/core/framework/tensor.h"
40 #include "tensorflow/core/framework/tensor_shape.h"
41 #include "tensorflow/core/framework/tensor_shape.pb.h" // TODO(b/62899350): Remove
42 #include "tensorflow/core/framework/tracking_allocator.h"
43 #include "tensorflow/core/framework/types.h"
44 #include "tensorflow/core/framework/types.pb.h"
45 #include "tensorflow/core/lib/core/errors.h"
46 #include "tensorflow/core/lib/core/status.h"
47 #include "tensorflow/core/lib/gtl/array_slice.h"
48 #include "tensorflow/core/lib/gtl/manual_constructor.h"
49 #include "tensorflow/core/platform/env.h"
50 #include "tensorflow/core/platform/logging.h"
51 #include "tensorflow/core/platform/macros.h"
52 #include "tensorflow/core/platform/mutex.h"
53 #include "tensorflow/core/platform/profile_utils/cpu_utils.h"
54 #include "tensorflow/core/platform/thread_annotations.h"
55 #include "tensorflow/core/platform/types.h"
56 #include "tensorflow/core/protobuf/config.pb.h"
57 #include "tensorflow/core/util/managed_stack_trace.h"
58
59 namespace Eigen {
60 struct ThreadPoolDevice;
61 struct GpuDevice;
62 } // end namespace Eigen
63
64 namespace tensorflow {
65
66 namespace checkpoint {
67 class TensorSliceReaderCacheWrapper;
68 } // namespace checkpoint
69
70 class AsyncOpKernel;
71 class CallFrameInterface;
72 class DeviceMgr;
73 class FunctionLibraryRuntime;
74 class OpKernelConstruction; // declared below
75 class OpKernelContext; // declared below,
76 class OpRegistryInterface;
77 class ResourceMgr;
78 class ScopedStepContainer;
79 class CollectiveExecutor;
80 class StepStatsCollectorInterface;
81 class CoordinationServiceAgent;
82
83 class OpKernel {
84 public:
85 // OpKernel won't be instantiated by the scheduler, so you may perform
86 // expensive initialization in the descendant's constructor.
87 explicit OpKernel(OpKernelConstruction* context);
88
89 // Specialized constructor that allows a kernel implementation to mark itself
90 // as a "deferred" op. If true, the executor will provide access to the
91 // `OpKernelContext::inc_num_deferred_ops_function()` and
92 // `OpKernelContext::dec_num_deferred_ops_function()` methods at run-time.
93 OpKernel(OpKernelConstruction* context, bool is_deferred);
94
95 // Specialized constructor that enables the descendant to provide a custom
96 // `NodeDef` value. For example, this constructor can be used to provide a
97 // stripped-down `NodeDef` that does not contain the full set of attrs (such
98 // as tensor values) if the descendant stores them in a different form.
99 OpKernel(OpKernelConstruction* context, NodeDef&& custom_def,
100 bool is_deferred);
101
102 virtual ~OpKernel();
103
104 // An OpKernel's computation can be either synchronous or
105 // asynchronous. All OpKernel Compute() methods must be thread-safe as they
106 // may be called concurrently (e.g. by multiple executions of the same graph
107 // concurrently).
108 //
109 // Most OpKernels should compute synchronously. They should
110 // subclass OpKernel and override the Compute() method and have it
111 // return after completing the supplied work.
112 //
113 // A synchronous OpKernel *MUST NOT* block the calling thread on a
114 // synchronization mechanism (condition variable, Notification, etc.) that
115 // will be unblocked by the execution of another OpKernel. Execution may
116 // deadlock in that case, because the executor may use a bounded number of
117 // threads.
118 //
119 // If an OpKernel must block on the execution of another OpKernel (e.g. a
120 // RecvOp, or a DequeueOp), the implementation *MUST* subclass AsyncOpKernel,
121 // and override `AsyncOpKernel::ComputeAsync()`. In addition, because the
122 // unblocking kernel may never run (due to an error or cancellation), in most
123 // cases the AsyncOpKernel should implement cancellation support via
124 // `ctx->cancellation_manager()`.
125 //
126 // In both cases, implementations of Compute() and ComputeAsync()
127 // get inputs and write outputs through the given OpKernelContext
128 // and returns a status via context->SetStatus(). They must be
129 // thread-safe.
130
131 // Synchronous compute.
132 //
133 // "context" is guaranteed to be alive until Compute() returns.
134 virtual void Compute(OpKernelContext* context) = 0;
135
136 // Returns nullptr iff this op kernel is synchronous.
AsAsync()137 virtual AsyncOpKernel* AsAsync() { return nullptr; }
138
139 // Returns true iff this op kernel is considered "expensive". The
140 // runtime may use this flag to optimize graph execution for example
141 // to "inline" inexpensive kernels.
IsExpensive()142 virtual bool IsExpensive() { return expensive_; }
143
144 // Returns a pointer to the tensor stored inside constant ops.
const_tensor()145 virtual const Tensor* const_tensor() const { return nullptr; }
146
147 // Accessors.
def()148 const NodeDef& def() const { return props_->node_def; }
name()149 const std::string& name() const { return props_->node_def.name(); }
name_view()150 absl::string_view name_view() const { return name_view_; }
type_string()151 const std::string& type_string() const { return props_->node_def.op(); }
type_string_view()152 absl::string_view type_string_view() const { return type_string_view_; }
requested_input(int i)153 const std::string& requested_input(int i) const {
154 return props_->node_def.input(i);
155 }
requested_device()156 const std::string& requested_device() const {
157 return props_->node_def.device();
158 }
159
num_inputs()160 int num_inputs() const { return props_->input_types.size(); }
input_type(int i)161 DataType input_type(int i) const { return props_->input_types[i]; }
input_types()162 const DataTypeVector& input_types() const { return props_->input_types; }
input_memory_types()163 const MemoryTypeVector& input_memory_types() const {
164 return input_memory_types_;
165 }
166
num_outputs()167 int num_outputs() const { return props_->output_types.size(); }
output_type(int o)168 DataType output_type(int o) const { return props_->output_types[o]; }
output_types()169 const DataTypeVector& output_types() const { return props_->output_types; }
output_memory_types()170 const MemoryTypeVector& output_memory_types() const {
171 return output_memory_types_;
172 }
173
174 Status InputRange(StringPiece input_name, int* start, int* stop) const;
175 Status OutputRange(StringPiece output_name, int* start, int* stop) const;
176
177 // Returns `true` if and only if this kernel uses deferred execution.
is_deferred()178 bool is_deferred() const { return is_deferred_; }
179
180 // Returns a trace string for current computation, op name/type and input
181 // tensor shape/dtype are encoded for profiler cost analysis. Most OpKernel
182 // should use the default implementation.
183 virtual std::string TraceString(const OpKernelContext& ctx,
184 bool verbose) const;
185
186 protected:
187 std::string ShapeTraceString(const OpKernelContext& ctx) const;
188
189 private:
190 const std::shared_ptr<const NodeProperties> props_;
191 const MemoryTypeVector input_memory_types_;
192 const MemoryTypeVector output_memory_types_;
193 NameRangeMap input_name_map_;
194 NameRangeMap output_name_map_;
195 const absl::string_view name_view_;
196 const absl::string_view type_string_view_;
197 const int graph_def_version_;
198 const bool is_deferred_;
199 bool expensive_;
200
201 TF_DISALLOW_COPY_AND_ASSIGN(OpKernel);
202 };
203
204 class AsyncOpKernel : public OpKernel {
205 public:
206 using OpKernel::OpKernel; // Lift OpKernel constructors.
207
208 // Asynchronous compute.
209 //
210 // Implementations of ComputeAsync() must ensure that `done` is (eventually)
211 // called exactly once to signal the completion of the computation. The
212 // implementation of ComputeAsync() must not block on the execution of another
213 // OpKernel. `done` may be called by the current thread, or by another thread.
214 // `context` is guaranteed to stay alive until the `done` callback starts.
215 //
216 // Since it is possible that the unblocking kernel may never run (due to an
217 // error or cancellation), in most cases the AsyncOpKernel should implement
218 // cancellation support via `context->cancellation_manager()`.
219 //
220 // WARNING: As soon as the `done` callback starts, `context` and `this` may be
221 // deleted. No code depending on these objects should execute after the call
222 // to `done`.
223 typedef std::function<void()> DoneCallback;
224 virtual void ComputeAsync(OpKernelContext* context, DoneCallback done) = 0;
225
AsAsync()226 AsyncOpKernel* AsAsync() override { return this; }
227
228 void Compute(OpKernelContext* context) override;
229 };
230
231 class OpKernelConstruction {
232 public:
233 OpKernelConstruction(DeviceType device_type, DeviceBase* device,
234 Allocator* allocator, FunctionLibraryRuntime* flib,
235 ResourceMgr* resource_mgr,
236 const std::shared_ptr<const NodeProperties>& props,
237 const MemoryTypeSlice& input_memory_types,
238 const MemoryTypeSlice& output_memory_types,
239 int graph_def_version, Status* status);
240
env()241 Env* env() const { return device_->env(); }
242
243 // Allocation of tensors during kernel construction:
244 //
245 // It is legal to temporarily allocate scratch tensor storage during
246 // Op kernel construction. Scratch tensors should be allocated using
247 // allocate_temp below. Some kernels need to keep tensors in between
248 // invocations. If such a Tensor is allocated during kernel
249 // construction this also must be done using allocate_temp, and the
250 // Op may only store the returned Tensor object.
251
252 // Allocates a temporary Tensor of the specified type and shape. The
253 // Tensor must not be used after kernel construction is
254 // complete. See comment above.
255 Status allocate_temp(DataType type, const TensorShape& shape,
256 Tensor* out_temp);
257 Status allocate_temp(DataType type, const TensorShape& shape,
258 Tensor* out_temp, AllocatorAttributes allocator_attr);
259
260 // User-supplied configuration of this operation.
def()261 const NodeDef& def() const { return props_->node_def; }
262
263 // For inspecting the inputs to this operation.
num_inputs()264 int num_inputs() const { return props_->input_types.size(); }
input_type(int i)265 DataType input_type(int i) const { return props_->input_types[i]; }
input_types()266 const DataTypeSlice& input_types() const { return props_->input_types_slice; }
input_memory_types()267 const MemoryTypeSlice& input_memory_types() const {
268 return input_memory_types_;
269 }
270
271 // For inspecting the outputs expected from this operation.
num_outputs()272 int num_outputs() const { return props_->output_types.size(); }
output_type(int i)273 DataType output_type(int i) const { return props_->output_types[i]; }
output_types()274 const DataTypeSlice& output_types() const {
275 return props_->output_types_slice;
276 }
output_memory_types()277 const MemoryTypeSlice& output_memory_types() const {
278 return output_memory_types_;
279 }
280
281 // If expected_inputs == inputs() and expected_outputs == output_types(),
282 // returns OK, else returns INVALID_ARGUMENT with an error message.
283 // Recommended for Ops with dynamic signatures.
284 Status MatchSignature(const DataTypeSlice expected_inputs,
285 const DataTypeSlice expected_outputs);
286
287 // For recording configuration errors during construction.
288 void SetStatus(const Status& status);
status()289 const Status& status() const { return *status_; }
290
291 // Look up the attr with name attr_name and set *value to its value. If no
292 // attr with attr_name is found in def(), or the attr does not have
293 // a matching type, a non-ok status will be returned.
294 template <class T>
295 Status GetAttr(StringPiece attr_name, T* value) const;
296
297 // Return true if the attr_name is defined in def().
298 bool HasAttr(StringPiece attr_name) const;
299
300 // Return the device type.
device_type()301 const DeviceType& device_type() const { return device_type_; }
302
303 // If not nullptr, the kernel can instantiate functions defined in
304 // the library. E.g.,
305 // CHECK_NOTNULL(function_library())->Instantiate("Foo", ...).
function_library()306 FunctionLibraryRuntime* function_library() const { return flib_; }
307
308 // Shared resources accessible to this kernel.
resource_manager()309 ResourceMgr* resource_manager() const { return resource_mgr_; }
310
311 // The GraphDef version whose behavior we should follow.
graph_def_version()312 int graph_def_version() const { return graph_def_version_; }
313
314 // Helper routines for the OP_REQUIRES macros
315 void CtxFailure(const Status& s);
316 void CtxFailureWithWarning(const Status& s);
317 void CtxFailure(const char* file, int line, const Status& s);
318 void CtxFailureWithWarning(const char* file, int line, const Status& s);
319
320 // Unrecommended functions: these are functions that have some
321 // current uses but are not recommended for use, and may go away at
322 // some future major version release.
323
324 // May be used, e.g., to get GPU handles, etc.
325 //
326 // Currently only used to call MakeTensorFromProto() for
327 // implementing ConstantOp for every device. See comments
328 // on Device::MakeTensorFromProto for longer-term replacement
329 // ideas.
device()330 DeviceBase* device() const { return device_; }
331
332 private:
333 const DeviceType device_type_;
334 DeviceBase* const device_;
335 Allocator* allocator_;
336 FunctionLibraryRuntime* flib_;
337 ResourceMgr* const resource_mgr_;
338 std::shared_ptr<const NodeProperties> props_;
339 MemoryTypeSlice input_memory_types_;
340 MemoryTypeSlice output_memory_types_;
341 const int graph_def_version_;
342 Status* status_;
343
344 // Allow access from OpKernel ctor.
345 friend class OpKernel;
346
347 TF_DISALLOW_COPY_AND_ASSIGN(OpKernelConstruction);
348 };
349
350 // TODO(mrry): Consider converting to a random_access_iterator, and upgrading
351 // tensorflow::gtl::iterator_range to make the below container classes
352 // unnecessary.
353 template <typename ListType, typename ElementType>
354 class OpArgIterator {
355 public:
356 using iterator_category = std::forward_iterator_tag;
357 using value_type = ElementType;
358 using pointer = ElementType*;
359 using const_pointer = const ElementType*;
360 using reference = ElementType&;
361 using const_reference = const ElementType&;
362 using difference_type = ptrdiff_t;
363
OpArgIterator(const ListType * list,int i)364 OpArgIterator(const ListType* list, int i) : list_(list), i_(i) {}
365
366 bool operator==(const OpArgIterator& rhs) {
367 DCHECK(list_ == rhs.list_);
368 return i_ == rhs.i_;
369 }
370
371 bool operator!=(const OpArgIterator& rhs) {
372 DCHECK(list_ == rhs.list_);
373 return i_ != rhs.i_;
374 }
375
376 OpArgIterator operator++() { // prefix ++it
377 ++i_;
378 return *this;
379 }
380
381 OpArgIterator operator++(int) { // postfix it++
382 OpArgIterator old_value = *this;
383 ++i_;
384 return old_value;
385 }
386
387 reference operator*() { return (*list_)[i_]; }
388 pointer operator->() { return &(*list_)[i_]; }
389
390 const_reference operator*() const { return (*list_)[i_]; }
391 const_pointer operator->() const { return &(*list_)[i_]; }
392
393 private:
394 const ListType* const list_;
395 int i_;
396 };
397
398 // Utility class for representing a list of immutable input tensors
399 // that are passed to the op as a single named argument.
400 class OpInputList {
401 public:
402 typedef OpArgIterator<OpInputList, const Tensor> Iterator;
OpInputList()403 OpInputList() : ctx_(nullptr), start_(0), stop_(0) {}
OpInputList(OpKernelContext * ctx,int start,int stop)404 OpInputList(OpKernelContext* ctx, int start, int stop)
405 : ctx_(ctx), start_(start), stop_(stop) {}
406 OpInputList& operator=(const OpInputList& other) = default;
407 const Tensor& operator[](int i) const;
size()408 int size() const { return stop_ - start_; }
begin()409 Iterator begin() const { return Iterator(this, 0); }
end()410 Iterator end() const { return Iterator(this, size()); }
411
412 private:
413 OpKernelContext* ctx_; // not owned
414 int start_;
415 int stop_;
416 };
417
418 // Utility class for representing a list of mutable ("ref") input tensors
419 // that are passed to the op as a single named argument.
420 class OpMutableInputList {
421 public:
422 typedef OpArgIterator<OpMutableInputList, Tensor*> Iterator;
OpMutableInputList(OpKernelContext * ctx,int start,int stop)423 OpMutableInputList(OpKernelContext* ctx, int start, int stop)
424 : ctx_(ctx), start_(start), stop_(stop) {}
OpMutableInputList()425 OpMutableInputList() : ctx_(nullptr), start_(0), stop_(0) {}
426 OpMutableInputList& operator=(const OpMutableInputList& other) = default;
427 Tensor at(int i, bool lock_held);
428 mutex* ref_mutex(int i);
size()429 int size() const { return stop_ - start_; }
begin()430 Iterator begin() const { return Iterator(this, 0); }
end()431 Iterator end() const { return Iterator(this, size()); }
432
433 private:
434 OpKernelContext* ctx_; // not owned
435 int start_;
436 int stop_;
437 };
438
439 // Utility class for representing a list of output tensors that are
440 // grouped as a single named output.
441 class OpOutputList {
442 public:
443 typedef OpArgIterator<OpOutputList, const Tensor*> Iterator;
OpOutputList()444 OpOutputList() : ctx_(nullptr), start_(0), stop_(0) {}
OpOutputList(OpKernelContext * ctx,int start,int stop)445 OpOutputList(OpKernelContext* ctx, int start, int stop)
446 : ctx_(ctx), start_(start), stop_(stop) {}
447 OpOutputList& operator=(const OpOutputList& other) = default;
448 Tensor* operator[](int i);
449 bool required(int i) const;
450 DataType expected_output_dtype(int i) const;
451 Status allocate(int i, const TensorShape& shape, Tensor** output);
452 void set(int i, const Tensor& tensor);
453 void set(int i, Tensor&& tensor);
454 void set_ref(int i, mutex* mu, Tensor* tensor_for_ref);
size()455 int size() const { return stop_ - start_; }
begin()456 Iterator begin() const { return Iterator(this, 0); }
end()457 Iterator end() const { return Iterator(this, size()); }
458
459 private:
460 OpKernelContext* ctx_; // not owned
461 int start_;
462 int stop_;
463 };
464
465 // Holds a tensor or tensor reference. For tensor references, we need
466 // a mutex to prevent concurrent access to the tensor.
467 struct TensorValue {
TensorValueTensorValue468 TensorValue() : mutex_if_ref(nullptr), tensor(nullptr) {}
TensorValueTensorValue469 explicit TensorValue(Tensor* t) : mutex_if_ref(nullptr), tensor(t) {}
TensorValueTensorValue470 TensorValue(mutex* mu, Tensor* t) : mutex_if_ref(mu), tensor(t) {}
471 Tensor* operator->() const { return tensor; }
is_refTensorValue472 bool is_ref() const { return mutex_if_ref != nullptr; }
473
474 // Return the dtype of the Tensor. For references, return the underlying type.
dtypeTensorValue475 DataType dtype() const {
476 if (is_ref()) {
477 return MakeRefType(tensor->dtype());
478 } else {
479 return tensor->dtype();
480 }
481 }
482
483 // Return the dtype of the Tensor. For references, return the underlying type.
484 // This variation on the dtype() acquires the lock for references.
485 //
486 // TODO(b/133843385): Disallow dtype modifications
dtype_safeTensorValue487 DataType dtype_safe() const {
488 if (is_ref()) {
489 tf_shared_lock ml(*mutex_if_ref);
490 return MakeRefType(tensor->dtype());
491 } else {
492 return tensor->dtype();
493 }
494 }
495
496 mutex* mutex_if_ref; // nullptr if not a ref, != nullptr if a ref
497 Tensor* tensor;
498 };
499
500 // Used to store partitioned graphs from function-calling ops.
501 struct GraphCollector {
502 mutex mu;
503 std::vector<GraphDef> partitioned_graphs TF_GUARDED_BY(mu);
504 GraphDef raw_graph TF_GUARDED_BY(mu);
505 GraphDef optimized_graph TF_GUARDED_BY(mu);
506
507 bool dirty TF_GUARDED_BY(mu);
508
GraphCollectorGraphCollector509 GraphCollector() : dirty(false) {}
510
CollectRawGraphGraphCollector511 void CollectRawGraph(const GraphDef& graph) {
512 mutex_lock ml(mu);
513 raw_graph.MergeFrom(graph);
514 dirty = true;
515 }
516
CollectOptimizedGraphGraphCollector517 void CollectOptimizedGraph(const GraphDef& graph) {
518 mutex_lock ml(mu);
519 optimized_graph.MergeFrom(graph);
520 dirty = true;
521 }
522
CollectPartitionedGraphGraphCollector523 void CollectPartitionedGraph(const GraphDef& graph) {
524 mutex_lock ml(mu);
525 partitioned_graphs.push_back(graph);
526 dirty = true;
527 }
528
ClearGraphsGraphCollector529 void ClearGraphs() TF_EXCLUSIVE_LOCKS_REQUIRED(mu) {
530 raw_graph.Clear();
531 optimized_graph.Clear();
532 partitioned_graphs.clear();
533 dirty = false;
534 }
535
HasUpdatedGraphsGraphCollector536 bool HasUpdatedGraphs() {
537 mutex_lock ml(mu);
538 return dirty;
539 }
540 };
541
542 class OpKernelContext {
543 public:
544 // The first element of a WrappedAllocator is a "base" Allocator and
545 // the second element is that Allocator wrapped by a
546 // TrackingAllocator
547 typedef std::pair<Allocator*, TrackingAllocator*> WrappedAllocator;
548
549 // TODO(zhifengc): Do some cleanup of Params.
550 // The Params struct is passed in to initialize an OpKernelContext,
551 // and must outlive the OpKernelContext.
552 struct Params {
~ParamsParams553 ~Params() { delete eigen_gpu_device; }
554
555 // The step being executed.
556 int64 step_id = 0;
557
558 // Timestamp for the start of graph execution. Used for latency metrics.
559 int64 start_time_usecs = 0;
560
561 // The op kernel being computed.
562 OpKernel* op_kernel = nullptr;
563
564 // The device on which the kernel is running.
565 DeviceBase* device = nullptr;
566
567 // The Eigen GPU device wrapper, which may include a per-op
568 // wrapped allocator. The concrete type of this object depends on
569 // the type of this->device, so eigen_gpu_device can't be an
570 // inline member and must be heap allocated. However, we don't
571 // want to allocate a new eigen_gpu_device for every Op that is
572 // executed. Instead this member is allocated on first use using
573 // ensure_eigen_gpu_device, and then if the Params structure is
574 // re-used for subsequent Ops, the eigen_gpu_device is
575 // ReInitialized in the OpKernelContext constructor. Unlike the
576 // other pointers in Params, this one is owned by Params.
577 PerOpGpuDevice* eigen_gpu_device = nullptr;
578
ensure_eigen_gpu_deviceParams579 inline void ensure_eigen_gpu_device() {
580 DCHECK(device);
581 if (nullptr == eigen_gpu_device) {
582 // Surprisingly, MakeGpuDevice will return nullptr if the
583 // device is not a GPU device. This is ok, since those devices
584 // will never use eigen_gpu_device. It seems better to have
585 // ensure_eigen_gpu_device fall through and regenerate the
586 // nullptr every time an OpKernelContext is instantiated, than
587 // to do an unnecessary allocation of a dummy eigen GPU
588 // device for CPU device Ops.
589 eigen_gpu_device = device->MakeGpuDevice();
590 }
591 }
592
593 bool track_allocations = false;
594 bool log_memory = false;
595
596 // Array indexed by output number for this node
597 const AllocatorAttributes* output_attr_array = nullptr;
598
599 // Shared resources accessible by this op kernel invocation.
600 ResourceMgr* resource_manager = nullptr;
601
602 // Per-step resources accessible by this op kernel invocation should be
603 // stored in this container..
604 ScopedStepContainer* step_container = nullptr;
605
606 // Mechanism used by this op kernel invocation to communicate with
607 // computations running on other devices.
608 RendezvousInterface* rendezvous = nullptr;
609
610 // Mechanism for executing a collective op that needs to coordinate
611 // with parallel instances running on other devices.
612 CollectiveExecutor* collective_executor = nullptr;
613
614 // The session state for this op.
615 SessionState* session_state = nullptr;
616
617 // Unique session identifier. Can be empty.
618 std::string session_handle;
619
620 // Metadata about the session. Can be nullptr.
621 const SessionMetadata* session_metadata = nullptr;
622
623 // The tensor store for this op.
624 TensorStore* tensor_store = nullptr;
625
626 // Mechanism used by this op kernel invocation to register a callback
627 // for its cancellation.
628 CancellationManager* cancellation_manager = nullptr;
629
630 // Inputs to this op kernel.
631 const gtl::InlinedVector<TensorValue, 4>* inputs = nullptr;
632 bool is_input_dead = false;
633
634 const gtl::InlinedVector<AllocatorAttributes, 4>* input_alloc_attrs =
635 nullptr;
636
637 // Device context.
638 DeviceContext* op_device_context = nullptr;
639
640 // Control-flow op supports.
641 FrameAndIter frame_iter;
642
643 // Function call supports.
644 CallFrameInterface* call_frame = nullptr;
645 FunctionLibraryRuntime* function_library = nullptr;
646 std::function<void(std::function<void()>)>* runner = nullptr;
647 StepStatsCollectorInterface* stats_collector = nullptr;
648 GraphCollector* graph_collector = nullptr;
649 bool run_all_kernels_inline = false;
650 const std::string* executor_type = nullptr;
651
652 // TensorSliceReaderCache support.
653 checkpoint::TensorSliceReaderCacheWrapper* slice_reader_cache = nullptr;
654
655 // Support for forwarding reservations (used by ScopedAllocator).
656 static constexpr int kNeverForward = -2;
657 static constexpr int kNoReservation = -1;
658 // Values in [0,...) represent reservations for the indexed output.
659 const int* forward_from_array = nullptr;
660
661 // For tracking actively running deferred ops.
662 std::function<void()> inc_num_deferred_ops_function;
663 std::function<void()> dec_num_deferred_ops_function;
664
665 absl::optional<ManagedStackTrace> stack_trace = {};
666
667 // For implementing `OpKernelContext::output_required()`. If null, all
668 // outputs are required.
669 bool* outputs_required_array = nullptr;
670
671 // For access to distributed coordination service.
672 CoordinationServiceAgent* coordination_service_agent = nullptr;
673 };
674
675 // params must outlive the OpKernelContext.
676 explicit OpKernelContext(Params* params);
677 OpKernelContext(Params* params, int num_outputs);
678 ~OpKernelContext();
679
env()680 Env* env() const { return params_->device->env(); }
681
step_id()682 int64 step_id() const { return params_->step_id; }
683
start_time_usecs()684 int64 start_time_usecs() const { return params_->start_time_usecs; }
685
op_kernel()686 const OpKernel& op_kernel() const { return *params_->op_kernel; }
687
688 // Stack trace of where the op was defined (if defined in eager mode).
stack_trace()689 const absl::optional<ManagedStackTrace>& stack_trace() const {
690 return params_->stack_trace;
691 }
692
693 // Input/output signature.
694
num_inputs()695 int num_inputs() const { return params_->inputs->size(); }
696 DataType input_dtype(int index) const;
697 Status input_dtype(StringPiece name, DataType* dtype) const;
698 MemoryType input_memory_type(int index) const;
699
num_outputs()700 int num_outputs() const { return outputs_.size(); }
701 DataType expected_output_dtype(int index) const;
702 MemoryType output_memory_type(int index) const;
703
704 // Input
705
706 // Returns an immutable input tensor. May only be used for non-Ref
707 // inputs. For Ref inputs use mutable_input below.
708 // REQUIRES: !IsRefType(input_dtype(index))
709 // TODO(mrry): Convert this to return Status.
710 const Tensor& input(int index) const;
711
712 // Returns the named immutable input tensor in "tensor", as defined
713 // in the OpDef. May only be used for non-Ref inputs. For Ref inputs
714 // use mutable_input below.
715 // REQUIRES: !IsRefType(input_dtype(index))
716 // REQUIRES: the named input must not be a list.
717 Status input(StringPiece name, const Tensor** tensor);
718
719 // Returns the named list-valued immutable input in "list", as
720 // defined in the OpDef. If the named output is not list-valued,
721 // returns a one-element list. May only be used for non-Ref
722 // inputs. For Ref inputs use mutable_input below.
723 // REQUIRES: !IsRefType(input_dtype(index))
724 Status input_list(StringPiece name, OpInputList* list);
725
726 // For mutable inputs, use the following together to make sure there
727 // is no concurrent access to mutable_input(), e.g.:
728 // {
729 // Tensor& t = context->mutable_input(index);
730 // mutex_lock lock(*context->input_ref_mutex(index));
731 // // modify the values in t
732 // }
733 // REQUIRES: IsRefType(input_dtype(index))
734 Status input_ref_mutex(StringPiece name, mutex** out_mutex);
735
736 // Returns a mutable input tensor. Must be used to access Ref
737 // inputs. REQUIRES: IsRefType(input_dtype(index)). The caller may
738 // modify the values stored in the Tensor buffer, and modifications
739 // will be visible to other Ops reading the same ref tensor. If
740 // !lock_held the input mutex will be acquired before returning the
741 // Tensor.
742 // TODO(mrry): Convert this to return Status.
743 Tensor mutable_input(int index, bool lock_held);
744
745 // Returns the named mutable input tensor in "tensor", as defined in
746 // the OpDef. Must be used to access Ref inputs. The values stored
747 // in the Tensor buffer may be modified, and modifications will be
748 // visible to other Ops reading the same ref tensor. If !lock_held
749 // the input mutex will be acquired before returning the Tensor.
750 // REQUIRES: the named input must not be a list.
751 // REQUIRES: the named input must be a ref tensor.
752 Status mutable_input(StringPiece name, Tensor* tensor, bool lock_held);
753
754 // Returns the named list-valued mutable input in "list", as defined
755 // in the OpDef. If the named input is not list-valued, returns a
756 // one-element list. Must be used to access Ref inputs. The values
757 // stored in the Tensor buffer may be modified, and modifications
758 // will be visible to other Ops reading the same ref tensor.
759 // REQUIRES: the named input must be a ref tensor.
760 Status mutable_input_list(StringPiece name, OpMutableInputList* list);
761
762 // Replace the corresponding Ref Input to use the storage buffer
763 // used by tensor. If !lock_held the input mutex will be acquired
764 // before returning the Tensor.
765 // REQUIRES: IsRefType(input_dtype(index)).
766 void replace_ref_input(int index, const Tensor& tensor, bool lock_held);
767
768 // Replace the corresponding named Ref Input to use the storage
769 // buffer used by tensor. If !lock_held the input mutex will be
770 // acquired before returning the Tensor.
771 // REQUIRES: IsRefType(input_dtype(index)).
772 Status replace_ref_input(StringPiece name, const Tensor& tensor,
773 bool lock_held);
774
775 // Deletes the Tensor object used as the Ref Input at
776 // input_index. This is not usually necessary and should be used
777 // with caution. If !lock_held the input mutex will be acquired
778 // before returning the Tensor.
779 // REQUIRES: IsRefType(input_dtype(input_index)).
780 void delete_ref_input(int input_index, bool lock_held);
781
782 // Return true if there is input at the given index. An operator has no
783 // input at index if its tensor is null. This is primarily used by the
784 // merge operator.
785 // TODO(mrry): Convert this to return Status.
786 bool has_input(int index) const;
787
788 // Returns true if all inputs are the same shape, otherwise sets the
789 // status to a non-OK value and returns false.
790 // Usage: if (!context->ValidateInputsAreSameShape(this)) return;
791 bool ValidateInputsAreSameShape(OpKernel* op);
792
793 // If non-null, kernels should populate with any partition subgraphs created.
graph_collector()794 GraphCollector* graph_collector() { return params_->graph_collector; }
795
796 // If True, hint that all kernels in functions called by this kernel, should
797 // be treated as "inexpensive", and hence executed on the scheduling thread.
run_all_kernels_inline()798 bool run_all_kernels_inline() const {
799 return params_->run_all_kernels_inline;
800 }
801
802 // Returns the registered name for the executor type that is executing the
803 // current kernel. If empty, the default executor is used.
804 const std::string& executor_type() const;
805
806 // Input to output forwarding.
807
808 // Set the output Ref Tensor at output_index to be an alias of the
809 // input Ref Tensor at input_index.
810 // REQUIRES: IsRefType(input_dtype(input_index)).
811 // REQUIRES: IsRefType(output_dtype(output_index)).
812 void forward_ref_input_to_ref_output(int input_index, int output_index);
813
814 // Returns true when an alias to input[input_index], reshaped to output_shape,
815 // which is safe to use for in-place computation was written to *output.
816 // Returns false if input[input_index] has a refcount greater than one, or if
817 // its type does not match the expected output type of output[output_index],
818 // or the number of elements in input[input_index] does not equal the number
819 // of elements in output_shape.
820 bool forward_input_to_output_with_shape(int input_index, int output_index,
821 const TensorShape& output_shape,
822 Tensor** output) TF_MUST_USE_RESULT;
823 Status forward_input_to_output_with_shape(StringPiece input_name,
824 StringPiece output_name,
825 const TensorShape& output_shape,
826 Tensor** output) TF_MUST_USE_RESULT;
827
828 // Returns a pointer to a Tensor aliasing the underlying buffer backing
829 // input[input_index] iff
830 // * input[input_index] is not a ref,
831 // * the data type, shape, memory type, and allocator attributes of
832 // input[input_index] are compatible with those given in dtype, shape,
833 // memory_type, and attr,
834 // * refcount on the underlying buffer is one.
835 // * Either there is no forwarding reservation for either input_index
836 // or output_index or the specified input is reserved for the specified
837 // output. More precisely:
838 //
839 // These cases mean neither input nor output has a reservation:
840 // forward_from_array = nullptr
841 // OR (input_index is not in forward_from_array AND
842 // (output_index == kNoReservation OR
843 // forward_from_array[output_index] == kNoReservation))
844 //
845 // This case means that input_index is reserved for output_index:
846 // forward_from_array[output_index] == input_index
847 //
848 // This case means the output is reserved to always be allocated,
849 // never assigned a forwarded input:
850 // forward_from_array[output_index] == kNeverForward
851 //
852 // Otherwise returns nullptr.
853 // NOTE: For Cuda kernels that read inputs using the __ldg() intrinsic,
854 // forwarding is only safe if there are no reads via __ldg() after writes
855 // to the same address.
856 std::unique_ptr<Tensor> forward_input(
857 int input_index, int output_index, DataType output_dtype,
858 const TensorShape& output_shape, MemoryType output_memory_type,
859 const AllocatorAttributes& output_attr) TF_MUST_USE_RESULT;
860
861 // Tries to forward one of the inputs given in input_indices to
862 // output[output_index]. If none of the given inputs can be forwarded, calls
863 // allocate_output() to allocate a new output buffer. The index of the
864 // forwarded input will be assign to output argument forwarded_input (if it's
865 // not nullptr). If no inputs are forwarded, forwarded_input will be assigned
866 // -1.
867 Status forward_input_or_allocate_output(
868 gtl::ArraySlice<int> candidate_input_indices, int output_index,
869 const TensorShape& output_shape, Tensor** output,
870 int* forwarded_input = nullptr) TF_MUST_USE_RESULT;
871 Status forward_input_or_allocate_output(
872 gtl::ArraySlice<StringPiece> candidate_input_names,
873 StringPiece output_name, const TensorShape& output_shape,
874 Tensor** output) TF_MUST_USE_RESULT;
875
876 // Tries to reuse one of the inputs given in input_indices as a temporary.
877 // If none of the given inputs can be forwarded, calls
878 // allocate_temp() to allocate a new temporary buffer.
879 Status forward_input_or_allocate_temp(
880 gtl::ArraySlice<int> candidate_input_indices, DataType type,
881 const TensorShape& shape, const AllocatorAttributes& allocator_attr,
882 Tensor* out_temp) TF_MUST_USE_RESULT;
883
forward_input_or_allocate_temp(gtl::ArraySlice<int> candidate_input_indices,DataType type,const TensorShape & shape,Tensor * out_temp)884 Status forward_input_or_allocate_temp(
885 gtl::ArraySlice<int> candidate_input_indices, DataType type,
886 const TensorShape& shape, Tensor* out_temp) TF_MUST_USE_RESULT {
887 return forward_input_or_allocate_temp(candidate_input_indices, type, shape,
888 AllocatorAttributes(), out_temp);
889 }
890
891 // Output
892
893 // Returns the named list-valued output in "list", as defined in the OpDef.
894 // If the named output is not list-valued, returns a one-element list.
895 Status output_list(StringPiece name, OpOutputList* list);
896
897 // If output_required(index) returns true, the OpKernel's Compute() method
898 // should call allocate_output(index, ...), set_output(index, ...),
899 // set_output_ref(index, ...), or set the status to a non-ok value.
900 // If it returns false, it may output, but is not required to do so.
output_required(int index)901 bool output_required(int index) const {
902 return !params_->outputs_required_array ||
903 params_->outputs_required_array[index];
904 }
905
906 // If output_expects_forwarding returns true, the OpKernel's Compute() method
907 // should not allocate the output with allocate_output but instead needs to
908 // use forward_input.
output_expects_forwarding(int index)909 bool output_expects_forwarding(int index) const {
910 return params_->forward_from_array != nullptr &&
911 params_->forward_from_array[index] >= 0;
912 }
913
914 // Allocation of tensors during kernel execution inside the Compute
915 // method:
916 //
917 // There are two methods to allocate Tensors when an Op kernel
918 // executes.
919 //
920 // 1) allocate_output. This should be used to allocate any tensor
921 // that is going to be used as an output from the Op at the end of
922 // the current execution. The caller indicates which output the
923 // Tensor will be assigned to, and the call returns the
924 // newly-allocated Tensor. The Tensor can subsequently be assigned
925 // to during kernel execution, and will be used as the designated
926 // output when the kernel execution completes.
927 //
928 // 2) allocate_temp. This should be used to allocate any scratch
929 // storage that is needed while the kernel is executing, and will
930 // not be retained by the Op.
931 //
932 // In some cases a Tensor needs to be used as an output even though
933 // it was previously allocated elsewhere. The Tensor may have been
934 // passed as an input, or stored in a Tensor during a
935 // previous kernel execution, or allocated earlier in the kernel
936 // execution at a time when it was not known which output it would
937 // be assigned to. In this case the kernel can use set_output or
938 // set_output_ref to indicate that the tensor should be used as the
939 // designated output. It is legal to use any previously-allocated
940 // Tensor as an argument to set_output or set_output_ref, including
941 // Tensors allocated via allocate_temp. There may be a performance
942 // penalty to using a Tensor that was not allocated using
943 // allocate_output. This is because allocate_output uses the
944 // AllocatorAttributes stored in output_attr_array for the
945 // designated output. In some cases, using the wrong attributes may
946 // cause an extra copy of the Tensor's buffer.
947
948 // Allocates output for the specified output index with shape.
949 // OpKernelContext retains ownership of the returned pointer. See
950 // comment above.
951 //
952 // If memory allocation fails, returns an error status.
953 //
954 // REQUIRES: !IsRefType(expected_output_dtype(index))
955 Status allocate_output(int index, const TensorShape& shape,
956 Tensor** tensor) TF_MUST_USE_RESULT;
957 Status allocate_output(StringPiece name, const TensorShape& shape,
958 Tensor** tensor) TF_MUST_USE_RESULT;
959 // The following methods use the supplied attributes instead of
960 // those in output_attr_array. The caller is responsible for
961 // ensuring that the attributes are "compatible" with the
962 // output_attr_array, e.g. the tensor is allocated on the correct
963 // device. See comment above.
964 Status allocate_output(int index, const TensorShape& shape, Tensor** tensor,
965 AllocatorAttributes attr) TF_MUST_USE_RESULT;
966 Status allocate_output(StringPiece name, const TensorShape& shape,
967 Tensor** tensor,
968 AllocatorAttributes attr) TF_MUST_USE_RESULT;
969
970 // Allocates a temporary Tensor of the specified type and
971 // shape. Devices such as GPUs that enqueue Ops for lazy execution
972 // may retain references to the temporary tensors after the Op's
973 // Compute method has run. See comment above.
974 Status allocate_temp(DataType type, const TensorShape& shape,
975 Tensor* out_temp, AllocatorAttributes allocator_attr,
976 const AllocationAttributes& allocation_attr);
allocate_temp(DataType type,const TensorShape & shape,Tensor * out_temp,AllocatorAttributes allocator_attr)977 Status allocate_temp(DataType type, const TensorShape& shape,
978 Tensor* out_temp, AllocatorAttributes allocator_attr) {
979 return allocate_temp(type, shape, out_temp, allocator_attr,
980 AllocationAttributes());
981 }
allocate_temp(DataType type,const TensorShape & shape,Tensor * out_temp)982 Status allocate_temp(DataType type, const TensorShape& shape,
983 Tensor* out_temp) {
984 return allocate_temp(type, shape, out_temp, AllocatorAttributes());
985 }
986
987 // Copies a tensor (allocated by the caller) to the specified output
988 // index. REQUIRES: !IsRefType(expected_output_dtype(index))
989 // REQUIRES: 'tensor' must have the same MemoryType as
990 // output_memory_types[index]. See comment above.
991 Status set_output(StringPiece name, const Tensor& tensor);
992 Status set_output(StringPiece name, Tensor&& tensor);
993 void set_output(int index, const Tensor& tensor);
994 void set_output(int index, Tensor&& tensor);
995
996 // To output a reference. Caller retains ownership of mu and tensor_for_ref,
997 // and they must outlive all uses within the step. See comment above.
998 // REQUIRES: IsRefType(expected_output_dtype(index))
999 Status set_output_ref(StringPiece name, mutex* mu, Tensor* tensor_for_ref);
1000
1001 // Returns nullptr if allocate_output() or set_output() have not been called.
1002 Status mutable_output(StringPiece name, Tensor** tensor);
1003
1004 // Return the DeviceContext that should be used for this Op.
1005 //
1006 // If using the templated function, the type must be a subclass
1007 // of DeviceContext.
1008 //
1009 // Returns nullptr if the device did not provide one.
1010 template <typename T>
1011 T* op_device_context();
op_device_context()1012 DeviceContext* op_device_context() {
1013 DeviceContext* ret = params_->op_device_context;
1014 if (ret == nullptr) {
1015 auto* dev_info = device()->tensorflow_gpu_device_info();
1016 if (dev_info) ret = dev_info->default_context;
1017 }
1018 return ret;
1019 }
1020
input_alloc_attr(int index)1021 AllocatorAttributes input_alloc_attr(int index) const {
1022 if (params_->input_alloc_attrs == nullptr) {
1023 return AllocatorAttributes();
1024 } else {
1025 DCHECK_GE(index, 0);
1026 DCHECK_LT(index, params_->input_alloc_attrs->size());
1027 return (*params_->input_alloc_attrs)[index];
1028 }
1029 }
1030
output_alloc_attr(int index)1031 AllocatorAttributes output_alloc_attr(int index) const {
1032 return params_->output_attr_array[index];
1033 }
1034
ConsumeWrappedAllocators()1035 gtl::InlinedVector<WrappedAllocator, 4> ConsumeWrappedAllocators() {
1036 gtl::InlinedVector<WrappedAllocator, 4> retrieved;
1037 if (tracking_state_) {
1038 mutex_lock lock(tracking_state_->mu);
1039 retrieved.swap(tracking_state_->wrapped_allocators);
1040 }
1041 return retrieved;
1042 }
1043
1044 // Communication.
1045 //
1046 // An op kernel communicates with outside environment through
1047 // Rendezvous Send() and Recv().
rendezvous()1048 RendezvousInterface* rendezvous() const { return params_->rendezvous; }
1049
collective_executor()1050 CollectiveExecutor* collective_executor() const {
1051 return params_->collective_executor;
1052 }
1053
1054 // An op kernel can access the session state it belongs to.
session_state()1055 SessionState* session_state() const { return params_->session_state; }
1056
1057 // Unique identifier of the session it belongs to. Can be empty.
session_handle()1058 std::string session_handle() const { return params_->session_handle; }
1059
1060 // Metadata about the session. Can be nullptr.
session_metadata()1061 const SessionMetadata* session_metadata() const {
1062 return params_->session_metadata;
1063 }
1064
1065 // An op kernel can access the tensor store of the run it belongs to.
tensor_store()1066 TensorStore* tensor_store() const { return params_->tensor_store; }
1067
1068 // Function call support.
1069 //
1070 // If this kernel invocation is within a function execution,
1071 // call_frame() returns the call frame for the function call.
call_frame()1072 CallFrameInterface* call_frame() const { return params_->call_frame; }
1073
1074 // If not nullptr, the kernel invoke functions defined in the
1075 // library. E.g., CHECK_NOTNULL(function_library())->Run("Foo", ...).
function_library()1076 FunctionLibraryRuntime* function_library() const {
1077 return params_->function_library;
1078 }
1079
runner()1080 std::function<void(std::function<void()>)>* runner() const {
1081 return params_->runner;
1082 }
stats_collector()1083 StepStatsCollectorInterface* stats_collector() const {
1084 return params_->stats_collector;
1085 }
1086
1087 // Shared resources accessible to this kernel.
resource_manager()1088 ResourceMgr* resource_manager() const { return params_->resource_manager; }
1089
slice_reader_cache()1090 checkpoint::TensorSliceReaderCacheWrapper* slice_reader_cache() const {
1091 return params_->slice_reader_cache;
1092 }
1093
1094 // Execution.
1095 //
1096 // OpKernels can use these eigen devices to carry out their
1097 // numerical computation.
eigen_cpu_device()1098 const Eigen::ThreadPoolDevice& eigen_cpu_device() const {
1099 return *device()->eigen_cpu_device();
1100 }
eigen_gpu_device()1101 const Eigen::GpuDevice& eigen_gpu_device() const {
1102 return params_->eigen_gpu_device->device();
1103 }
1104 template <typename EigenDeviceType>
1105 const EigenDeviceType& eigen_device() const;
1106
1107 // Error handling.
1108
1109 // If expected_inputs == inputs() and expected_outputs == output_types(),
1110 // returns OK, else returns INVALID_ARGUMENT with an error message.
1111 // Recommended for Ops with dynamic signatures, where validation can only
1112 // be performed at runtime.
1113 Status MatchSignature(const DataTypeSlice expected_inputs,
1114 const DataTypeSlice expected_outputs);
1115
1116 // An OpKernel should call SetStatus() if Compute() encounters an
1117 // error.
1118 void SetStatus(const Status& status);
status()1119 const Status& status() const { return status_; }
1120
1121 // Cancellation.
1122 //
1123 // EXPERIMENTAL. See the implementation in tensorflow::FIFOQueue for an
1124 // example of how to use this API.
cancellation_manager()1125 CancellationManager* cancellation_manager() const {
1126 return params_->cancellation_manager;
1127 }
1128
1129 // Other accessors.
1130
1131 // For control flow.
frame_iter()1132 FrameAndIter frame_iter() const { return params_->frame_iter; }
is_input_dead()1133 bool is_input_dead() const { return params_->is_input_dead; }
1134
1135 // May be used, e.g., to get GPU handles, etc.
1136 // TODO(tucker): Add example usage.
device()1137 DeviceBase* device() const { return params_->device; }
1138
1139 // Per-step container for use by white-listed internal ops.
step_container()1140 ScopedStepContainer* step_container() const {
1141 return params_->step_container;
1142 }
1143
1144 // Access to distributed coordination service.
coordination_service_agent()1145 CoordinationServiceAgent* coordination_service_agent() const {
1146 return params_->coordination_service_agent;
1147 }
1148
1149 // Helper routines for the OP_REQUIRES macros
1150 void CtxFailure(const Status& s);
1151 void CtxFailureWithWarning(const Status& s);
1152 void CtxFailure(const char* file, int line, const Status& s);
1153 void CtxFailureWithWarning(const char* file, int line, const Status& s);
1154
1155 // Unrecommended functions: these are functions that have some
1156 // current uses but are not recommended for use, and may go away at
1157 // some future major version release.
1158 //
1159 // The following functions all have versions that return Status
1160 // to capture error conditions, and are strongly preferred.
1161 Tensor* mutable_output(int index);
1162 mutex* input_ref_mutex(int index);
1163 void set_output_ref(int index, mutex* mu, Tensor* tensor_for_ref);
1164 TensorValue release_output(int index);
1165
track_allocations()1166 bool track_allocations() const { return params_->track_allocations; }
1167
1168 // Records temp memory allocation. Tensor object is recorded to identify the
1169 // case where temp memory is used as output memory.
1170 void record_temp_memory_allocation(int64_t size, const Tensor& t)
1171 TF_LOCKS_EXCLUDED(tracking_state_->stats_mu);
1172
1173 // Returns recorded size of temporary memory;
1174 int64 temp_memory_allocated() const
1175 TF_LOCKS_EXCLUDED(tracking_state_->stats_mu);
1176
1177 // Records persistent memory allocation, size can be negative indicating
1178 // deallocation.
1179 void record_persistent_memory_allocation(int64_t size, int64_t alloc_id = -1)
1180 TF_LOCKS_EXCLUDED(tracking_state_->stats_mu);
1181
1182 // Returns recorded size and ids of persistent memory.
1183 int64 persistent_memory_allocated() const
1184 TF_LOCKS_EXCLUDED(tracking_state_->stats_mu);
1185
1186 std::vector<int64> persistent_alloc_ids() const
1187 TF_LOCKS_EXCLUDED(tracking_state_->stats_mu);
1188
1189 // Resets counters for temp and persistent memory and recorded ids.
1190 void clear_recorded_memory() TF_LOCKS_EXCLUDED(tracking_state_->stats_mu);
1191
1192 bool input_is_ref(int index) const;
1193
1194 void set_record_memory_consumption(bool v);
1195
1196 // Used by OpKernel implementations to track actively running deferred ops.
1197 //
1198 // A deferred op is one whose Compute method returns (or whose ComputeAsync
1199 // method invokes the callback) when work is scheduled onto a device. At that
1200 // point, we don't know when the work will actually complete (or if it has
1201 // already completed) on the device. These functions allow the executor to
1202 // track the status of deferred ops and act accordingly.
1203 //
1204 // Deferred OpKernel implementations must use these methods to get two
1205 // functions. It then must call these two functions in pairs, before and after
1206 // device execution, respectively.
inc_num_deferred_ops_function()1207 TF_MUST_USE_RESULT std::function<void()> inc_num_deferred_ops_function() {
1208 DCHECK(params_->op_kernel->is_deferred());
1209 return params_->inc_num_deferred_ops_function
1210 ? params_->inc_num_deferred_ops_function
1211 : []() {};
1212 }
dec_num_deferred_ops_function()1213 TF_MUST_USE_RESULT std::function<void()> dec_num_deferred_ops_function() {
1214 DCHECK(params_->op_kernel->is_deferred());
1215 return params_->dec_num_deferred_ops_function
1216 ? params_->dec_num_deferred_ops_function
1217 : []() {};
1218 }
1219
1220 Allocator* get_allocator(AllocatorAttributes attr);
1221
1222 private:
1223 bool record_memory_consumption_ = false;
1224
1225 // Internal common method used when allocating tensor memory
allocate_tensor(DataType type,const TensorShape & shape,Tensor * out_tensor,AllocatorAttributes allocator_attr)1226 Status allocate_tensor(DataType type, const TensorShape& shape,
1227 Tensor* out_tensor,
1228 AllocatorAttributes allocator_attr) {
1229 return allocate_tensor(type, shape, out_tensor, allocator_attr,
1230 AllocationAttributes());
1231 }
1232
1233 Status allocate_tensor(DataType type, const TensorShape& shape,
1234 Tensor* out_tensor, AllocatorAttributes allocator_attr,
1235 const AllocationAttributes& allocation_attr);
1236
1237 // Helpers for `set_output()`.
1238
1239 // Returns `true` if the tensor was copied into an allocated output.
1240 bool maybe_set_output_by_allocate_and_copy(int index, const Tensor& tensor);
1241
1242 void maybe_track_allocations_for_set_output(const Tensor& tensor);
1243
1244 Status get_input_index(StringPiece name, int* out_index) const;
1245 Status get_output_index(StringPiece name, int* out_index) const;
1246
1247 // Initialize the allocated_scope_ids_ set the first time this method is
1248 // called.
1249 void maybe_initialize_scope_id_set();
1250
1251 Status status_;
1252 friend class CollectiveExecutor; // for access to params_
1253 Params* params_; // not owned
1254 gtl::InlinedVector<TensorValue, 4> outputs_;
1255
1256 // Keep track of calls to ScopedAllocator.
1257 // TODO(ayushd): change to absl::flat_hash_set.
1258 std::unique_ptr<std::unordered_set<int32>> allocated_scope_ids_;
1259
1260 // The following data members are only used when allocation tracking is
1261 // enabled, memory consumption is being recorded, or tensor access is being
1262 // recorded.
1263 struct TrackingState {
1264 mutable mutex mu;
1265 gtl::InlinedVector<WrappedAllocator, 4> wrapped_allocators
1266 TF_GUARDED_BY(mu);
1267
1268 mutable mutex stats_mu;
1269 int64 temp_memory_allocated TF_GUARDED_BY(stats_mu) = 0;
1270
1271 int64 persistent_memory_allocated TF_GUARDED_BY(stats_mu) = 0;
1272 gtl::InlinedVector<std::pair<const void*, int64>, 2>
1273 temp_tensor_buffer_and_size TF_GUARDED_BY(stats_mu);
1274 gtl::InlinedVector<int64, 2> persistent_alloc_ids TF_GUARDED_BY(stats_mu);
1275 };
1276 std::unique_ptr<TrackingState> tracking_state_;
1277
1278 // For access to `params_->op_kernel`.
1279 friend void CheckNotInComputeAsync(OpKernelContext* ctx,
1280 const char* correct_macro_name);
1281
1282 TF_DISALLOW_COPY_AND_ASSIGN(OpKernelContext);
1283 };
1284
1285 template <>
1286 const Eigen::ThreadPoolDevice& OpKernelContext::eigen_device() const;
1287
1288 template <>
1289 const Eigen::GpuDevice& OpKernelContext::eigen_device() const;
1290
1291 // Register your OpKernel by specifying the Op's name, the device the
1292 // kernel runs on, any type attr constraints for this kernel, any
1293 // host-memory args, and the class to instantiate. Examples:
1294 //
1295 // // A kernel that supports all types.
1296 // REGISTER_KERNEL_BUILDER(Name("Save").Device(DEVICE_CPU), SaveOp);
1297 //
1298 // // The following are equivalent ways of specifying that the kernel only
1299 // // works if the "T" type attr is set to DT_FLOAT.
1300 // REGISTER_KERNEL_BUILDER(
1301 // Name("Sub").Device(DEVICE_CPU).TypeConstraint<float>("T"),
1302 // SubOp<float>);
1303 // // (You would then repeat this for every type supported by "Sub".)
1304 //
1305 // // This form allows you to specify a list of types as the constraint.
1306 // REGISTER_KERNEL_BUILDER(Name("Sub")
1307 // .Device(DEVICE_CPU)
1308 // .TypeConstraint("T", {DT_FLOAT}),
1309 // SubOp<float>);
1310 //
1311 // // A kernel that expects one of the input tensors in host memory.
1312 // REGISTER_KERNEL_BUILDER(
1313 // Name("Reshape").Device(DEVICE_GPU).HostMemory("shape"), ReshapeOp);
1314 //
1315 // See kernel_def_builder for details.
1316
1317 // Instantiate an OpKernel that has been registered. Returns nullptr
1318 // if no operation for that type of device / input signature combination
1319 // (and a NOT_FOUND *status), or there is an error in construction (and
1320 // an INVALID_ARGUMENT *status). Otherwise, the caller takes ownership
1321 // of the returned pointer.
1322 // EXPECTED USAGE: unique_ptr<OpKernel> op = CreateOpKernel(...);
1323 // REQUIRES: def has all attrs specified (e.g. using AddDefaultsToNodeDef()).
1324 std::unique_ptr<OpKernel> CreateOpKernel(DeviceType device_type,
1325 DeviceBase* device,
1326 Allocator* allocator,
1327 const NodeDef& node_def,
1328 int graph_def_version, Status* status);
1329
1330 std::unique_ptr<OpKernel> CreateOpKernel(
1331 DeviceType device_type, DeviceBase* device, Allocator* allocator,
1332 const std::shared_ptr<const NodeProperties>& props, int graph_def_version,
1333 Status* status);
1334
1335 Status CreateOpKernel(DeviceType device_type, DeviceBase* device,
1336 Allocator* allocator, FunctionLibraryRuntime* flib,
1337 const std::shared_ptr<const NodeProperties>& props,
1338 int graph_def_version, OpKernel** kernel);
1339
1340 Status CreateOpKernel(DeviceType device_type, DeviceBase* device,
1341 Allocator* allocator, FunctionLibraryRuntime* flib,
1342 ResourceMgr* resource_mgr,
1343 const std::shared_ptr<const NodeProperties>& props,
1344 int graph_def_version, OpKernel** kernel);
1345
1346 // Returns into 'device_types' the subset of prioritized_types that this
1347 // binary has registered for the given NodeDef.
1348 //
1349 // REQUIRES: * 'device_types' is not nullptr.
1350 // * def has all attrs specified (e.g. using AddDefaultsToNodeDef()).
1351 Status SupportedDeviceTypesForNode(
1352 const std::vector<DeviceType>& prioritized_types, const NodeDef& def,
1353 PrioritizedDeviceTypeVector* device_types,
1354 const DeviceNameUtils::ParsedName* local_address_spec = nullptr);
1355
1356 // Returns a message with a description of the kernels registered for op
1357 // `op_name`.
1358 std::string KernelsRegisteredForOp(StringPiece op_name);
1359
1360 // Call once after Op registration has completed.
1361 Status ValidateKernelRegistrations(const OpRegistryInterface& op_registry);
1362
1363 // -----------------------------------------------------------------------------
1364 // OpKernel registration implementation follows, please ignore.
1365
1366 // Allow the REGISTER_KERNEL_BUILDER(Name("op_name").Device(...)...) syntax.
1367 namespace register_kernel {
1368
1369 class Name : public KernelDefBuilder {
1370 public:
Name(const char * op)1371 explicit Name(const char* op) : KernelDefBuilder(op) {}
1372 };
1373
1374 } // namespace register_kernel
1375
1376 // Kernel registration appears as:
1377 // REGISTER_KERNEL_BUILDER(Name("OpName").Device(DEVICE_CPU)..., OpImpl)
1378 // We'd like to have "OpName" as a constant-expression, without requiring that
1379 // of the overall KernelDefBuilder expression (beginning with the
1380 // register_kernel::Name constructor above).
1381 //
1382 // So, we pull the "OpName" part to a separate macro-level argument. This
1383 // involves treating Name("OpName") as a macro call, via token-pasting (e.g.
1384 // M_## => M_Name("OpName")), and having it expand to '"OpName",
1385 // Name("OpName")' which is then usable as two arguments.
1386 #define TF_EXTRACT_KERNEL_NAME_Name(name_str) \
1387 name_str, ::tensorflow::register_kernel::Name(name_str)
1388 #define TF_EXTRACT_KERNEL_NAME_IMPL(m, ...) m(__VA_ARGS__)
1389 #define TF_EXTRACT_KERNEL_NAME(m, kernel_builder, ...) \
1390 TF_EXTRACT_KERNEL_NAME_IMPL(m, TF_EXTRACT_KERNEL_NAME_##kernel_builder, \
1391 __VA_ARGS__)
1392
1393 // REGISTER_KERNEL_BUILDER_IMPL_2, with a unique 'ctr' as the first argument.
1394 // TODO(dodgen): There are some uses of this macro inside functions, where
1395 // kernel_builder refers to (non-const) locals (they should be fixed). To
1396 // accommodate those, kernel_builder.Build() appears as an argument to an
1397 // immediately-called lambda (not in the lambda itself).
1398 #define REGISTER_KERNEL_BUILDER_IMPL_3(ctr, op_name, kernel_builder_expr, \
1399 is_system_kernel, ...) \
1400 static ::tensorflow::InitOnStartupMarker const register_kernel_##ctr \
1401 TF_ATTRIBUTE_UNUSED = \
1402 TF_INIT_ON_STARTUP_IF(is_system_kernel || \
1403 (SHOULD_REGISTER_OP_KERNEL(#__VA_ARGS__) && \
1404 SHOULD_REGISTER_OP(op_name))) \
1405 << ([](::tensorflow::KernelDef const* kernel_def) { \
1406 ::tensorflow::kernel_factory::OpKernelRegistrar registrar( \
1407 kernel_def, #__VA_ARGS__, \
1408 [](::tensorflow::OpKernelConstruction* context) \
1409 -> ::tensorflow::OpKernel* { \
1410 return new __VA_ARGS__(context); \
1411 }); \
1412 (void)registrar; \
1413 return ::tensorflow::InitOnStartupMarker{}; \
1414 })(kernel_builder_expr.Build());
1415
1416 // REGISTER_KERNEL_BUILDER_IMPL, but with kernel_builder split to op_name,
1417 // kernel_builder_expr.
1418 #define REGISTER_KERNEL_BUILDER_IMPL_2(op_name, kernel_builder_expr, \
1419 is_system_kernel, ...) \
1420 TF_NEW_ID_FOR_INIT(REGISTER_KERNEL_BUILDER_IMPL_3, op_name, \
1421 kernel_builder_expr, is_system_kernel, __VA_ARGS__)
1422
1423 // REGISTER_KERNEL_BUILDER, but with is_system_kernel bound.
1424 #define REGISTER_KERNEL_BUILDER_IMPL(kernel_builder, is_system_kernel, ...) \
1425 TF_EXTRACT_KERNEL_NAME(REGISTER_KERNEL_BUILDER_IMPL_2, kernel_builder, \
1426 is_system_kernel, __VA_ARGS__)
1427
1428 #define REGISTER_KERNEL_BUILDER(kernel_builder, ...) \
1429 TF_ATTRIBUTE_ANNOTATE("tf:kernel") \
1430 REGISTER_KERNEL_BUILDER_IMPL(kernel_builder, false, __VA_ARGS__)
1431
1432 // The `REGISTER_SYSTEM_KERNEL_BUILDER()` macro acts as
1433 // `REGISTER_KERNEL_BUILDER()` except that the kernel is registered
1434 // unconditionally even when selective registration is used.
1435 #define REGISTER_SYSTEM_KERNEL_BUILDER(kernel_builder, ...) \
1436 TF_ATTRIBUTE_ANNOTATE("tf:kernel") \
1437 TF_ATTRIBUTE_ANNOTATE("tf:kernel:system") \
1438 REGISTER_KERNEL_BUILDER_IMPL(kernel_builder, true, __VA_ARGS__)
1439
1440 // Checks whether a given kernel is registered on device_type.
1441 bool KernelDefAvailable(const DeviceType& device_type, const NodeDef& node_def);
1442
1443 // If node of node_name, experimental_debug_info, node_op, node_device and
1444 // node_attrs has a corresponding kernel registered on device_type, returns OK
1445 // and fill in the kernel def and kernel_class_name. <def> and
1446 // <kernel_class_name> may be null.
1447 Status FindKernelDef(
1448 const DeviceType& device_type, StringPiece node_name,
1449 bool has_experimental_debug_info,
1450 const NodeDef_ExperimentalDebugInfo& experimental_debug_info,
1451 StringPiece node_op, StringPiece node_device, AttrSlice node_attrs,
1452 const KernelDef** def, std::string* kernel_class_name);
1453
1454 // If node_def has a corresponding kernel registered on device_type,
1455 // returns OK and fill in the kernel def and kernel_class_name. <def> and
1456 // <kernel_class_name> may be null.
1457 Status FindKernelDef(const DeviceType& device_type, const NodeDef& node_def,
1458 const KernelDef** def, std::string* kernel_class_name);
1459
1460 // Writes a list of all registered kernels to LOG(INFO), to help users debug
1461 // missing kernel errors.
1462 void LogAllRegisteredKernels();
1463
1464 // Gets a list of all registered kernels.
1465 KernelList GetAllRegisteredKernels();
1466
1467 // Gets a list of all registered kernels for which predicate returns true
1468 KernelList GetFilteredRegisteredKernels(
1469 const std::function<bool(const KernelDef&)>& predicate);
1470
1471 // Gets a list of all registered kernels for a given op
1472 KernelList GetRegisteredKernelsForOp(StringPiece op_name);
1473
1474 namespace kernel_factory {
1475
1476 // OpKernelFactory is responsible for creating OpKernels when TensorFlow needs
1477 // them. You register factories with the TensorFlow core by constructing an
1478 // OpKernelRegistrar and passing the factory as a constructor parameter.
1479 class OpKernelFactory {
1480 public:
1481 virtual OpKernel* Create(OpKernelConstruction* context) = 0;
1482 virtual ~OpKernelFactory() = default;
1483 };
1484
1485 class OpKernelRegistrar {
1486 public:
1487 // Registers the given kernel factory with TensorFlow. TF will call the
1488 // factory Create() method when it determines that a kernel matching the given
1489 // KernelDef is required.
OpKernelRegistrar(const KernelDef * kernel_def,StringPiece kernel_class_name,std::unique_ptr<OpKernelFactory> factory)1490 OpKernelRegistrar(const KernelDef* kernel_def, StringPiece kernel_class_name,
1491 std::unique_ptr<OpKernelFactory> factory) {
1492 InitInternal(kernel_def, kernel_class_name, std::move(factory));
1493 }
1494
1495 // Registers the given factory function with TensorFlow. This is equivalent
1496 // to registering a factory whose Create function invokes `create_fn`.
OpKernelRegistrar(const KernelDef * kernel_def,StringPiece kernel_class_name,OpKernel * (* create_fn)(OpKernelConstruction *))1497 OpKernelRegistrar(const KernelDef* kernel_def, StringPiece kernel_class_name,
1498 OpKernel* (*create_fn)(OpKernelConstruction*)) {
1499 InitInternal(kernel_def, kernel_class_name,
1500 absl::make_unique<PtrOpKernelFactory>(create_fn));
1501 }
1502
1503 private:
1504 struct PtrOpKernelFactory : public OpKernelFactory {
PtrOpKernelFactoryPtrOpKernelFactory1505 explicit PtrOpKernelFactory(OpKernel* (*create_func)(OpKernelConstruction*))
1506 : create_func_(create_func) {}
1507
1508 OpKernel* Create(OpKernelConstruction* context) override;
1509
1510 OpKernel* (*create_func_)(OpKernelConstruction*);
1511 };
1512
1513 void InitInternal(const KernelDef* kernel_def, StringPiece kernel_class_name,
1514 std::unique_ptr<OpKernelFactory> factory);
1515 };
1516
1517 } // namespace kernel_factory
1518
1519 // -----------------------------------------------------------------------------
1520 // Template and inline method implementations, please ignore
1521
1522 template <class T>
GetAttr(StringPiece attr_name,T * value)1523 Status OpKernelConstruction::GetAttr(StringPiece attr_name, T* value) const {
1524 return GetNodeAttr(def(), attr_name, value);
1525 }
1526
input_dtype(int index)1527 inline DataType OpKernelContext::input_dtype(int index) const {
1528 DCHECK_GE(index, 0);
1529 DCHECK_LT(index, num_inputs());
1530 const TensorValue& value((*params_->inputs)[index]);
1531 return value.dtype();
1532 }
1533
input_memory_type(int index)1534 inline MemoryType OpKernelContext::input_memory_type(int index) const {
1535 DCHECK_GE(index, 0);
1536 DCHECK_LT(index, num_inputs());
1537 return op_kernel().input_memory_types()[index];
1538 }
1539
expected_output_dtype(int index)1540 inline DataType OpKernelContext::expected_output_dtype(int index) const {
1541 DCHECK_GE(index, 0);
1542 DCHECK_LT(index, num_outputs());
1543 return params_->op_kernel->output_type(index);
1544 }
1545
output_memory_type(int index)1546 inline MemoryType OpKernelContext::output_memory_type(int index) const {
1547 DCHECK_GE(index, 0);
1548 DCHECK_LT(index, num_outputs());
1549 return op_kernel().output_memory_types()[index];
1550 }
1551
input_is_ref(int index)1552 inline bool OpKernelContext::input_is_ref(int index) const {
1553 const TensorValue& value((*params_->inputs)[index]);
1554 return value.is_ref();
1555 }
1556
1557 // no input if tensor == nullptr.
has_input(int index)1558 inline bool OpKernelContext::has_input(int index) const {
1559 DCHECK_GE(index, 0);
1560 DCHECK_LT(index, num_inputs());
1561 return (*params_->inputs)[index].tensor != nullptr;
1562 }
1563
input_ref_mutex(int index)1564 inline mutex* OpKernelContext::input_ref_mutex(int index) {
1565 DCHECK_GE(index, 0);
1566 DCHECK_LT(index, num_inputs());
1567 DCHECK(input_is_ref(index));
1568 return (*params_->inputs)[index].mutex_if_ref;
1569 }
1570
mutable_output(int index)1571 inline Tensor* OpKernelContext::mutable_output(int index) {
1572 DCHECK_GE(index, 0);
1573 DCHECK_LT(index, num_outputs());
1574 return outputs_[index].tensor;
1575 }
1576
release_output(int index)1577 inline TensorValue OpKernelContext::release_output(int index) {
1578 DCHECK_GE(index, 0);
1579 DCHECK_LT(index, num_outputs());
1580 TensorValue value = outputs_[index];
1581 outputs_[index] = TensorValue();
1582 return value;
1583 }
1584
forward_input_or_allocate_output(gtl::ArraySlice<int> candidate_input_indices,int output_index,const TensorShape & output_shape,Tensor ** output,int * forwarded_input)1585 inline Status OpKernelContext::forward_input_or_allocate_output(
1586 gtl::ArraySlice<int> candidate_input_indices, int output_index,
1587 const TensorShape& output_shape, Tensor** output, int* forwarded_input) {
1588 for (int input_index : candidate_input_indices) {
1589 if (forward_input_to_output_with_shape(input_index, output_index,
1590 output_shape, output)) {
1591 if (forwarded_input != nullptr) {
1592 *forwarded_input = input_index;
1593 }
1594 return Status::OK();
1595 }
1596 }
1597 if (forwarded_input != nullptr) {
1598 *forwarded_input = -1;
1599 }
1600 return allocate_output(output_index, output_shape, output);
1601 }
1602
forward_input_or_allocate_output(gtl::ArraySlice<StringPiece> candidate_input_names,StringPiece output_name,const TensorShape & output_shape,Tensor ** output)1603 inline Status OpKernelContext::forward_input_or_allocate_output(
1604 gtl::ArraySlice<StringPiece> candidate_input_names, StringPiece output_name,
1605 const TensorShape& output_shape, Tensor** output) {
1606 for (const StringPiece& input_name : candidate_input_names) {
1607 if (forward_input_to_output_with_shape(input_name, output_name,
1608 output_shape, output)
1609 .ok()) {
1610 return Status::OK();
1611 }
1612 }
1613 return allocate_output(output_name, output_shape, output);
1614 }
1615
1616 template <typename T>
op_device_context()1617 T* OpKernelContext::op_device_context() {
1618 static_assert(std::is_base_of<DeviceContext, T>::value,
1619 "T is not a subclass of DeviceContext");
1620 return static_cast<T*>(op_device_context());
1621 }
1622
1623 inline const Tensor& OpInputList::operator[](int i) const {
1624 DCHECK_GE(i, 0);
1625 DCHECK_LT(i, stop_ - start_);
1626 return ctx_->input(start_ + i);
1627 }
1628
ref_mutex(int i)1629 inline mutex* OpMutableInputList::ref_mutex(int i) {
1630 DCHECK_GE(i, 0);
1631 DCHECK_LT(i, stop_ - start_);
1632 return ctx_->input_ref_mutex(start_ + i);
1633 }
1634
at(int i,bool lock_held)1635 inline Tensor OpMutableInputList::at(int i, bool lock_held) {
1636 DCHECK_GE(i, 0);
1637 DCHECK_LT(i, stop_ - start_);
1638 return ctx_->mutable_input(start_ + i, lock_held);
1639 }
1640
1641 inline Tensor* OpOutputList::operator[](int i) {
1642 DCHECK_GE(i, 0);
1643 DCHECK_LT(i, stop_ - start_);
1644 return ctx_->mutable_output(start_ + i);
1645 }
1646
required(int i)1647 inline bool OpOutputList::required(int i) const {
1648 DCHECK_GE(i, 0);
1649 DCHECK_LT(i, stop_ - start_);
1650 return ctx_->output_required(start_ + i);
1651 }
1652
expected_output_dtype(int i)1653 inline DataType OpOutputList::expected_output_dtype(int i) const {
1654 DCHECK_GE(i, 0);
1655 DCHECK_LT(i, stop_ - start_);
1656 return ctx_->expected_output_dtype(start_ + i);
1657 }
1658
allocate(int i,const TensorShape & shape,Tensor ** output)1659 inline Status OpOutputList::allocate(int i, const TensorShape& shape,
1660 Tensor** output) {
1661 DCHECK_GE(i, 0);
1662 DCHECK_LT(i, stop_ - start_);
1663 return ctx_->allocate_output(start_ + i, shape, output);
1664 }
1665
set(int i,const Tensor & tensor)1666 inline void OpOutputList::set(int i, const Tensor& tensor) {
1667 DCHECK_GE(i, 0);
1668 DCHECK_LT(i, stop_ - start_);
1669 ctx_->set_output(start_ + i, tensor);
1670 }
1671
set(int i,Tensor && tensor)1672 inline void OpOutputList::set(int i, Tensor&& tensor) {
1673 DCHECK_GE(i, 0);
1674 DCHECK_LT(i, stop_ - start_);
1675 ctx_->set_output(start_ + i, std::move(tensor));
1676 }
1677
set_ref(int i,mutex * mu,Tensor * tensor_for_ref)1678 inline void OpOutputList::set_ref(int i, mutex* mu, Tensor* tensor_for_ref) {
1679 DCHECK_GE(i, 0);
1680 DCHECK_LT(i, stop_ - start_);
1681 ctx_->set_output_ref(i, mu, tensor_for_ref);
1682 }
1683
1684 // Generate a fatal error if OP_REQUIRES or OP_REQUIRES_OK are used in
1685 // AsyncOpKernel implementations. If these macros are used and the condition
1686 // does not hold, the `done` callback will never be called and the system will
1687 // deadlock, so a crash failure is preferable. Since the OP_REQUIRES[_OK] macros
1688 // are legal to use in AsyncOpKernel constructors, we use overload resolution
1689 // to distinguish between OpKernelConstruction* and OpKernelContext* context
1690 // types.
1691 class XlaOpKernelContext;
CheckNotInComputeAsync(XlaOpKernelContext *,const char *)1692 inline void CheckNotInComputeAsync(XlaOpKernelContext*, const char*) {}
CheckNotInComputeAsync(OpKernelConstruction *,const char *)1693 inline void CheckNotInComputeAsync(OpKernelConstruction*, const char*) {}
1694 void CheckNotInComputeAsync(OpKernelContext* ctx,
1695 const char* correct_macro_name);
1696
1697 } // namespace tensorflow
1698
1699 #endif // TENSORFLOW_CORE_FRAMEWORK_OP_KERNEL_H_
1700