1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_CORE_FRAMEWORK_OP_KERNEL_H_
17 #define TENSORFLOW_CORE_FRAMEWORK_OP_KERNEL_H_
18
19 #include <functional>
20 #include <unordered_set>
21 #include <utility>
22 #include <vector>
23
24 #include "absl/time/time.h"
25 #include "absl/types/optional.h"
26 #include "absl/types/span.h"
27 #include "tensorflow/core/framework/allocator.h"
28 #include "tensorflow/core/framework/cancellation.h"
29 #include "tensorflow/core/framework/control_flow.h"
30 #include "tensorflow/core/framework/device_base.h"
31 #include "tensorflow/core/framework/graph.pb.h"
32 #include "tensorflow/core/framework/kernel_def.pb.h"
33 #include "tensorflow/core/framework/kernel_def_builder.h"
34 #include "tensorflow/core/framework/node_def.pb.h"
35 #include "tensorflow/core/framework/node_def_util.h"
36 #include "tensorflow/core/framework/node_properties.h"
37 #include "tensorflow/core/framework/op.h" // TODO(b/62899350): Remove
38 #include "tensorflow/core/framework/op_requires.h"
39 #include "tensorflow/core/framework/registration/registration.h"
40 #include "tensorflow/core/framework/rendezvous.h"
41 #include "tensorflow/core/framework/session_state.h"
42 #include "tensorflow/core/framework/tensor.h"
43 #include "tensorflow/core/framework/tensor_shape.h"
44 #include "tensorflow/core/framework/tensor_shape.pb.h" // TODO(b/62899350): Remove
45 #include "tensorflow/core/framework/tracking_allocator.h"
46 #include "tensorflow/core/framework/types.h"
47 #include "tensorflow/core/framework/types.pb.h"
48 #include "tensorflow/core/lib/core/errors.h"
49 #include "tensorflow/core/lib/core/status.h"
50 #include "tensorflow/core/lib/gtl/array_slice.h"
51 #include "tensorflow/core/lib/gtl/manual_constructor.h"
52 #include "tensorflow/core/platform/env.h"
53 #include "tensorflow/core/platform/logging.h"
54 #include "tensorflow/core/platform/macros.h"
55 #include "tensorflow/core/platform/mutex.h"
56 #include "tensorflow/core/platform/profile_utils/cpu_utils.h"
57 #include "tensorflow/core/platform/thread_annotations.h"
58 #include "tensorflow/core/platform/types.h"
59 #include "tensorflow/core/protobuf/config.pb.h"
60 #include "tensorflow/core/util/managed_stack_trace.h"
61
62 namespace Eigen {
63 struct ThreadPoolDevice;
64 struct GpuDevice;
65 } // end namespace Eigen
66
67 namespace tensorflow {
68
69 namespace checkpoint {
70 class TensorSliceReaderCacheWrapper;
71 } // namespace checkpoint
72
73 class AsyncOpKernel;
74 class CallFrameInterface;
75 class DeviceMgr;
76 class FunctionLibraryRuntime;
77 class OpKernelConstruction; // declared below
78 class OpKernelContext; // declared below,
79 class OpRegistryInterface;
80 class ResourceMgr;
81 class ScopedStepContainer;
82 class CollectiveExecutor;
83 class StepStatsCollectorInterface;
84 class CoordinationServiceAgent;
85
86 // A label that is added to kernels that are JIT compiled. These labels will be
87 // removed before kernels are looked up, so they can be used without specifying
88 // the label. This label is a temporary measure to allow JIT kernels to be
89 // disabled if needed.
90 extern const char* kJitKernelLabel;
91 extern const char* kDisableJitKernelsEnvVar;
92
93 class OpKernel {
94 public:
95 // OpKernel won't be instantiated by the scheduler, so you may perform
96 // expensive initialization in the descendant's constructor.
97 explicit OpKernel(OpKernelConstruction* context);
98
99 // Specialized constructor that allows a kernel implementation to mark itself
100 // as a "deferred" op. If true, the executor will provide access to the
101 // `OpKernelContext::inc_num_deferred_ops_function()` and
102 // `OpKernelContext::dec_num_deferred_ops_function()` methods at run-time.
103 OpKernel(OpKernelConstruction* context, bool is_deferred);
104
105 // Specialized constructor that enables the descendant to provide a custom
106 // `NodeDef` value. For example, this constructor can be used to provide a
107 // stripped-down `NodeDef` that does not contain the full set of attrs (such
108 // as tensor values) if the descendant stores them in a different form.
109 OpKernel(OpKernelConstruction* context, NodeDef&& custom_def,
110 bool is_deferred);
111
112 virtual ~OpKernel();
113
114 // An OpKernel's computation can be either synchronous or
115 // asynchronous. All OpKernel Compute() methods must be thread-safe as they
116 // may be called concurrently (e.g. by multiple executions of the same graph
117 // concurrently).
118 //
119 // Most OpKernels should compute synchronously. They should
120 // subclass OpKernel and override the Compute() method and have it
121 // return after completing the supplied work.
122 //
123 // A synchronous OpKernel *MUST NOT* block the calling thread on a
124 // synchronization mechanism (condition variable, Notification, etc.) that
125 // will be unblocked by the execution of another OpKernel. Execution may
126 // deadlock in that case, because the executor may use a bounded number of
127 // threads.
128 //
129 // If an OpKernel must block on the execution of another OpKernel (e.g. a
130 // RecvOp, or a DequeueOp), the implementation *MUST* subclass AsyncOpKernel,
131 // and override `AsyncOpKernel::ComputeAsync()`. In addition, because the
132 // unblocking kernel may never run (due to an error or cancellation), in most
133 // cases the AsyncOpKernel should implement cancellation support via
134 // `ctx->cancellation_manager()`.
135 //
136 // In both cases, implementations of Compute() and ComputeAsync()
137 // get inputs and write outputs through the given OpKernelContext
138 // and returns a status via context->SetStatus(). They must be
139 // thread-safe.
140
141 // Synchronous compute.
142 //
143 // "context" is guaranteed to be alive until Compute() returns.
144 virtual void Compute(OpKernelContext* context) = 0;
145
146 // Returns nullptr iff this op kernel is synchronous.
AsAsync()147 virtual AsyncOpKernel* AsAsync() { return nullptr; }
148
149 // Returns true iff this op kernel is considered "expensive". The
150 // runtime may use this flag to optimize graph execution for example
151 // to "inline" inexpensive kernels.
IsExpensive()152 virtual bool IsExpensive() { return expensive_; }
153
154 // Returns a pointer to the tensor stored inside constant ops.
const_tensor()155 virtual const Tensor* const_tensor() const { return nullptr; }
156
157 // Accessors.
def()158 const NodeDef& def() const { return props_->node_def; }
name()159 const std::string& name() const { return props_->node_def.name(); }
name_view()160 absl::string_view name_view() const { return name_view_; }
type_string()161 const std::string& type_string() const { return props_->node_def.op(); }
type_string_view()162 absl::string_view type_string_view() const { return type_string_view_; }
requested_input(int i)163 const std::string& requested_input(int i) const {
164 return props_->node_def.input(i);
165 }
requested_device()166 const std::string& requested_device() const {
167 return props_->node_def.device();
168 }
169
num_inputs()170 int num_inputs() const { return props_->input_types.size(); }
input_type(int i)171 DataType input_type(int i) const { return props_->input_types[i]; }
input_types()172 const DataTypeVector& input_types() const { return props_->input_types; }
input_memory_types()173 const MemoryTypeVector& input_memory_types() const {
174 return input_memory_types_;
175 }
176
num_outputs()177 int num_outputs() const { return props_->output_types.size(); }
output_type(int o)178 DataType output_type(int o) const { return props_->output_types[o]; }
output_types()179 const DataTypeVector& output_types() const { return props_->output_types; }
output_memory_types()180 const MemoryTypeVector& output_memory_types() const {
181 return output_memory_types_;
182 }
183
184 Status InputRange(StringPiece input_name, int* start, int* stop) const;
185 Status OutputRange(StringPiece output_name, int* start, int* stop) const;
186
187 // Returns `true` if and only if this kernel uses deferred execution.
is_deferred()188 bool is_deferred() const { return is_deferred_; }
189
190 // Returns a trace string for current computation, op name/type and input
191 // tensor shape/dtype are encoded for profiler cost analysis. Most OpKernel
192 // should use the default implementation.
193 virtual std::string TraceString(const OpKernelContext& ctx,
194 bool verbose) const;
195
196 protected:
197 std::string ShapeTraceString(const OpKernelContext& ctx) const;
198
199 private:
200 const std::shared_ptr<const NodeProperties> props_;
201 const MemoryTypeVector input_memory_types_;
202 const MemoryTypeVector output_memory_types_;
203 NameRangeMap input_name_map_;
204 NameRangeMap output_name_map_;
205 const absl::string_view name_view_;
206 const absl::string_view type_string_view_;
207 const int graph_def_version_;
208 const bool is_deferred_;
209 bool expensive_;
210
211 TF_DISALLOW_COPY_AND_ASSIGN(OpKernel);
212 };
213
214 class AsyncOpKernel : public OpKernel {
215 public:
216 using OpKernel::OpKernel; // Lift OpKernel constructors.
217
218 // Asynchronous compute.
219 //
220 // Implementations of ComputeAsync() must ensure that `done` is (eventually)
221 // called exactly once to signal the completion of the computation. The
222 // implementation of ComputeAsync() must not block on the execution of another
223 // OpKernel. `done` may be called by the current thread, or by another thread.
224 // `context` is guaranteed to stay alive until the `done` callback starts.
225 //
226 // Since it is possible that the unblocking kernel may never run (due to an
227 // error or cancellation), in most cases the AsyncOpKernel should implement
228 // cancellation support via `context->cancellation_manager()`.
229 //
230 // WARNING: As soon as the `done` callback starts, `context` and `this` may be
231 // deleted. No code depending on these objects should execute after the call
232 // to `done`.
233 typedef std::function<void()> DoneCallback;
234 virtual void ComputeAsync(OpKernelContext* context, DoneCallback done) = 0;
235
AsAsync()236 AsyncOpKernel* AsAsync() override { return this; }
237
238 void Compute(OpKernelContext* context) override;
239 };
240
241 class OpKernelConstruction {
242 public:
243 OpKernelConstruction(DeviceType device_type, DeviceBase* device,
244 Allocator* allocator, FunctionLibraryRuntime* flib,
245 ResourceMgr* resource_mgr,
246 const std::shared_ptr<const NodeProperties>& props,
247 const MemoryTypeSlice& input_memory_types,
248 const MemoryTypeSlice& output_memory_types,
249 int graph_def_version, Status* status);
250
env()251 Env* env() const { return device_->env(); }
252
253 // Allocation of tensors during kernel construction:
254 //
255 // It is legal to temporarily allocate scratch tensor storage during
256 // Op kernel construction. Scratch tensors should be allocated using
257 // allocate_temp below. Some kernels need to keep tensors in between
258 // invocations. If such a Tensor is allocated during kernel
259 // construction this also must be done using allocate_temp, and the
260 // Op may only store the returned Tensor object.
261
262 // Allocates a temporary Tensor of the specified type and shape. The
263 // Tensor must not be used after kernel construction is
264 // complete. See comment above.
265 Status allocate_temp(DataType type, const TensorShape& shape,
266 Tensor* out_temp);
267 Status allocate_temp(DataType type, const TensorShape& shape,
268 Tensor* out_temp, AllocatorAttributes allocator_attr);
269
270 // User-supplied configuration of this operation.
def()271 const NodeDef& def() const { return props_->node_def; }
272
273 // For inspecting the inputs to this operation.
num_inputs()274 int num_inputs() const { return props_->input_types.size(); }
input_type(int i)275 DataType input_type(int i) const { return props_->input_types[i]; }
input_types()276 const DataTypeSlice& input_types() const { return props_->input_types_slice; }
input_memory_types()277 const MemoryTypeSlice& input_memory_types() const {
278 return input_memory_types_;
279 }
280
281 // For inspecting the outputs expected from this operation.
num_outputs()282 int num_outputs() const { return props_->output_types.size(); }
output_type(int i)283 DataType output_type(int i) const { return props_->output_types[i]; }
output_types()284 const DataTypeSlice& output_types() const {
285 return props_->output_types_slice;
286 }
output_memory_types()287 const MemoryTypeSlice& output_memory_types() const {
288 return output_memory_types_;
289 }
290
291 // If expected_inputs == inputs() and expected_outputs == output_types(),
292 // returns OK, else returns INVALID_ARGUMENT with an error message.
293 // Recommended for Ops with dynamic signatures.
294 Status MatchSignature(const DataTypeSlice expected_inputs,
295 const DataTypeSlice expected_outputs);
296
297 // For recording configuration errors during construction.
298 void SetStatus(const Status& status);
status()299 const Status& status() const { return *status_; }
300
301 // Look up the attr with name attr_name and set *value to its value. If no
302 // attr with attr_name is found in def(), or the attr does not have
303 // a matching type, a non-ok status will be returned.
304 template <class T>
305 Status GetAttr(StringPiece attr_name, T* value) const TF_ATTRIBUTE_NOINLINE;
306
307 // Return true if the attr_name is defined in def().
308 bool HasAttr(StringPiece attr_name) const;
309
310 // Return the device type.
device_type()311 const DeviceType& device_type() const { return device_type_; }
312
313 // If not nullptr, the kernel can instantiate functions defined in
314 // the library. E.g.,
315 // CHECK_NOTNULL(function_library())->Instantiate("Foo", ...).
function_library()316 FunctionLibraryRuntime* function_library() const { return flib_; }
317
318 // Shared resources accessible to this kernel.
resource_manager()319 ResourceMgr* resource_manager() const { return resource_mgr_; }
320
321 // The GraphDef version whose behavior we should follow.
graph_def_version()322 int graph_def_version() const { return graph_def_version_; }
323
324 // Helper routines for the OP_REQUIRES macros
325 void CtxFailure(const Status& s);
326 void CtxFailureWithWarning(const Status& s);
327 void CtxFailure(const char* file, int line, const Status& s);
328 void CtxFailureWithWarning(const char* file, int line, const Status& s);
329
330 // Unrecommended functions: these are functions that have some
331 // current uses but are not recommended for use, and may go away at
332 // some future major version release.
333
334 // May be used, e.g., to get GPU handles, etc.
335 //
336 // Currently only used to call MakeTensorFromProto() for
337 // implementing ConstantOp for every device. See comments
338 // on Device::MakeTensorFromProto for longer-term replacement
339 // ideas.
device()340 DeviceBase* device() const { return device_; }
341
342 private:
343 const DeviceType device_type_;
344 DeviceBase* const device_;
345 Allocator* allocator_;
346 FunctionLibraryRuntime* flib_;
347 ResourceMgr* const resource_mgr_;
348 std::shared_ptr<const NodeProperties> props_;
349 MemoryTypeSlice input_memory_types_;
350 MemoryTypeSlice output_memory_types_;
351 const int graph_def_version_;
352 Status* status_;
353
354 // Allow access from OpKernel ctor.
355 friend class OpKernel;
356
357 TF_DISALLOW_COPY_AND_ASSIGN(OpKernelConstruction);
358 };
359
360 // TODO(mrry): Consider converting to a random_access_iterator, and upgrading
361 // tensorflow::gtl::iterator_range to make the below container classes
362 // unnecessary.
363 template <typename ListType, typename ElementType>
364 class OpArgIterator {
365 public:
366 using iterator_category = std::forward_iterator_tag;
367 using value_type = ElementType;
368 using pointer = ElementType*;
369 using const_pointer = const ElementType*;
370 using reference = ElementType&;
371 using const_reference = const ElementType&;
372 using difference_type = ptrdiff_t;
373
OpArgIterator(const ListType * list,int i)374 OpArgIterator(const ListType* list, int i) : list_(list), i_(i) {}
375
376 bool operator==(const OpArgIterator& rhs) {
377 DCHECK(list_ == rhs.list_);
378 return i_ == rhs.i_;
379 }
380
381 bool operator!=(const OpArgIterator& rhs) {
382 DCHECK(list_ == rhs.list_);
383 return i_ != rhs.i_;
384 }
385
386 OpArgIterator operator++() { // prefix ++it
387 ++i_;
388 return *this;
389 }
390
391 OpArgIterator operator++(int) { // postfix it++
392 OpArgIterator old_value = *this;
393 ++i_;
394 return old_value;
395 }
396
397 reference operator*() { return (*list_)[i_]; }
398 pointer operator->() { return &(*list_)[i_]; }
399
400 const_reference operator*() const { return (*list_)[i_]; }
401 const_pointer operator->() const { return &(*list_)[i_]; }
402
403 private:
404 const ListType* const list_;
405 int i_;
406 };
407
408 // Utility class for representing a list of immutable input tensors
409 // that are passed to the op as a single named argument.
410 class OpInputList {
411 public:
412 typedef OpArgIterator<OpInputList, const Tensor> Iterator;
OpInputList()413 OpInputList() : ctx_(nullptr), start_(0), stop_(0) {}
OpInputList(OpKernelContext * ctx,int start,int stop)414 OpInputList(OpKernelContext* ctx, int start, int stop)
415 : ctx_(ctx), start_(start), stop_(stop) {}
416 OpInputList& operator=(const OpInputList& other) = default;
417 const Tensor& operator[](int i) const;
size()418 int size() const { return stop_ - start_; }
begin()419 Iterator begin() const { return Iterator(this, 0); }
end()420 Iterator end() const { return Iterator(this, size()); }
421
422 private:
423 OpKernelContext* ctx_; // not owned
424 int start_;
425 int stop_;
426 };
427
428 // Utility class for representing a list of mutable ("ref") input tensors
429 // that are passed to the op as a single named argument.
430 class OpMutableInputList {
431 public:
432 typedef OpArgIterator<OpMutableInputList, Tensor*> Iterator;
OpMutableInputList(OpKernelContext * ctx,int start,int stop)433 OpMutableInputList(OpKernelContext* ctx, int start, int stop)
434 : ctx_(ctx), start_(start), stop_(stop) {}
OpMutableInputList()435 OpMutableInputList() : ctx_(nullptr), start_(0), stop_(0) {}
436 OpMutableInputList& operator=(const OpMutableInputList& other) = default;
437 Tensor at(int i, bool lock_held);
438 mutex* ref_mutex(int i);
size()439 int size() const { return stop_ - start_; }
begin()440 Iterator begin() const { return Iterator(this, 0); }
end()441 Iterator end() const { return Iterator(this, size()); }
442
443 private:
444 OpKernelContext* ctx_; // not owned
445 int start_;
446 int stop_;
447 };
448
449 // Utility class for representing a list of output tensors that are
450 // grouped as a single named output.
451 class OpOutputList {
452 public:
453 typedef OpArgIterator<OpOutputList, const Tensor*> Iterator;
OpOutputList()454 OpOutputList() : ctx_(nullptr), start_(0), stop_(0) {}
OpOutputList(OpKernelContext * ctx,int start,int stop)455 OpOutputList(OpKernelContext* ctx, int start, int stop)
456 : ctx_(ctx), start_(start), stop_(stop) {}
457 OpOutputList& operator=(const OpOutputList& other) = default;
458 Tensor* operator[](int i);
459 bool required(int i) const;
460 DataType expected_output_dtype(int i) const;
461 Status allocate(int i, const TensorShape& shape, Tensor** output);
462 void set(int i, const Tensor& tensor);
463 void set(int i, Tensor&& tensor);
464 void set_ref(int i, mutex* mu, Tensor* tensor_for_ref);
size()465 int size() const { return stop_ - start_; }
begin()466 Iterator begin() const { return Iterator(this, 0); }
end()467 Iterator end() const { return Iterator(this, size()); }
468
469 private:
470 OpKernelContext* ctx_; // not owned
471 int start_;
472 int stop_;
473 };
474
475 // Holds a tensor or tensor reference. For tensor references, we need
476 // a mutex to prevent concurrent access to the tensor.
477 struct TensorValue {
TensorValueTensorValue478 TensorValue() : mutex_if_ref(nullptr), tensor(nullptr) {}
TensorValueTensorValue479 explicit TensorValue(Tensor* t) : mutex_if_ref(nullptr), tensor(t) {}
TensorValueTensorValue480 TensorValue(mutex* mu, Tensor* t) : mutex_if_ref(mu), tensor(t) {}
481 Tensor* operator->() const { return tensor; }
is_refTensorValue482 bool is_ref() const { return mutex_if_ref != nullptr; }
483
484 // Return the dtype of the Tensor. For references, return the underlying type.
dtypeTensorValue485 DataType dtype() const {
486 if (is_ref()) {
487 return MakeRefType(tensor->dtype());
488 } else {
489 return tensor->dtype();
490 }
491 }
492
493 // Return the dtype of the Tensor. For references, return the underlying type.
494 // This variation on the dtype() acquires the lock for references.
495 //
496 // TODO(b/133843385): Disallow dtype modifications
dtype_safeTensorValue497 DataType dtype_safe() const {
498 if (is_ref()) {
499 tf_shared_lock ml(*mutex_if_ref);
500 return MakeRefType(tensor->dtype());
501 } else {
502 return tensor->dtype();
503 }
504 }
505
506 mutex* mutex_if_ref; // nullptr if not a ref, != nullptr if a ref
507 Tensor* tensor;
508 };
509
510 // Used to store partitioned graphs from function-calling ops.
511 struct GraphCollector {
512 mutex mu;
513 std::vector<GraphDef> partitioned_graphs TF_GUARDED_BY(mu);
514 GraphDef raw_graph TF_GUARDED_BY(mu);
515 GraphDef optimized_graph TF_GUARDED_BY(mu);
516
517 bool dirty TF_GUARDED_BY(mu);
518
GraphCollectorGraphCollector519 GraphCollector() : dirty(false) {}
520
CollectRawGraphGraphCollector521 void CollectRawGraph(const GraphDef& graph) {
522 mutex_lock ml(mu);
523 raw_graph.MergeFrom(graph);
524 dirty = true;
525 }
526
CollectOptimizedGraphGraphCollector527 void CollectOptimizedGraph(const GraphDef& graph) {
528 mutex_lock ml(mu);
529 optimized_graph.MergeFrom(graph);
530 dirty = true;
531 }
532
CollectPartitionedGraphGraphCollector533 void CollectPartitionedGraph(const GraphDef& graph) {
534 mutex_lock ml(mu);
535 partitioned_graphs.push_back(graph);
536 dirty = true;
537 }
538
ClearGraphsGraphCollector539 void ClearGraphs() TF_EXCLUSIVE_LOCKS_REQUIRED(mu) {
540 raw_graph.Clear();
541 optimized_graph.Clear();
542 partitioned_graphs.clear();
543 dirty = false;
544 }
545
HasUpdatedGraphsGraphCollector546 bool HasUpdatedGraphs() {
547 mutex_lock ml(mu);
548 return dirty;
549 }
550 };
551
552 class OpKernelContext {
553 public:
554 // The first element of a WrappedAllocator is a "base" Allocator and
555 // the second element is that Allocator wrapped by a
556 // TrackingAllocator
557 typedef std::pair<Allocator*, TrackingAllocator*> WrappedAllocator;
558
559 // TODO(zhifengc): Do some cleanup of Params.
560 // The Params struct is passed in to initialize an OpKernelContext,
561 // and must outlive the OpKernelContext.
562 struct Params {
~ParamsParams563 ~Params() { delete eigen_gpu_device; }
564
565 // The step being executed.
566 int64_t step_id = 0;
567
568 // Timestamp for the start of graph execution. Used for latency metrics.
569 int64_t start_time_usecs = 0;
570
571 // The deadline for the session to complete by. Empty if unspecified.
572 absl::optional<absl::Time> deadline;
573
574 // The op kernel being computed.
575 OpKernel* op_kernel = nullptr;
576
577 // The device on which the kernel is running.
578 DeviceBase* device = nullptr;
579
580 // The Eigen GPU device wrapper, which may include a per-op
581 // wrapped allocator. The concrete type of this object depends on
582 // the type of this->device, so eigen_gpu_device can't be an
583 // inline member and must be heap allocated. However, we don't
584 // want to allocate a new eigen_gpu_device for every Op that is
585 // executed. Instead this member is allocated on first use using
586 // ensure_eigen_gpu_device, and then if the Params structure is
587 // re-used for subsequent Ops, the eigen_gpu_device is
588 // ReInitialized in the OpKernelContext constructor. Unlike the
589 // other pointers in Params, this one is owned by Params.
590 PerOpGpuDevice* eigen_gpu_device = nullptr;
591
ensure_eigen_gpu_deviceParams592 inline void ensure_eigen_gpu_device() {
593 DCHECK(device);
594 if (nullptr == eigen_gpu_device) {
595 // Surprisingly, MakeGpuDevice will return nullptr if the
596 // device is not a GPU device. This is ok, since those devices
597 // will never use eigen_gpu_device. It seems better to have
598 // ensure_eigen_gpu_device fall through and regenerate the
599 // nullptr every time an OpKernelContext is instantiated, than
600 // to do an unnecessary allocation of a dummy eigen GPU
601 // device for CPU device Ops.
602 eigen_gpu_device = device->MakeGpuDevice();
603 }
604 }
605
606 bool track_allocations = false;
607 bool log_memory = false;
608
609 // Array indexed by output number for this node
610 const AllocatorAttributes* output_attr_array = nullptr;
611
612 // Shared resources accessible by this op kernel invocation.
613 ResourceMgr* resource_manager = nullptr;
614
615 // Per-step resources accessible by this op kernel invocation should be
616 // stored in this container..
617 ScopedStepContainer* step_container = nullptr;
618
619 // Mechanism used by this op kernel invocation to communicate with
620 // computations running on other devices.
621 RendezvousInterface* rendezvous = nullptr;
622
623 // Mechanism for executing a collective op that needs to coordinate
624 // with parallel instances running on other devices.
625 CollectiveExecutor* collective_executor = nullptr;
626
627 // The session state for this op.
628 SessionState* session_state = nullptr;
629
630 // Unique session identifier. Can be empty.
631 std::string session_handle;
632
633 // Metadata about the session. Can be nullptr.
634 const SessionMetadata* session_metadata = nullptr;
635
636 // The tensor store for this op.
637 TensorStore* tensor_store = nullptr;
638
639 // Mechanism used by this op kernel invocation to register a callback
640 // for its cancellation.
641 CancellationManager* cancellation_manager = nullptr;
642
643 // Inputs to this op kernel.
644 absl::Span<const TensorValue> inputs;
645 bool is_input_dead = false;
646
647 absl::Span<const AllocatorAttributes> input_alloc_attrs;
648
649 // Device context.
650 DeviceContext* op_device_context = nullptr;
651
652 // Control-flow op supports.
653 FrameAndIter frame_iter;
654
655 // Function call supports.
656 CallFrameInterface* call_frame = nullptr;
657 FunctionLibraryRuntime* function_library = nullptr;
658 std::function<void(std::function<void()>)>* runner = nullptr;
659 StepStatsCollectorInterface* stats_collector = nullptr;
660 GraphCollector* graph_collector = nullptr;
661 bool run_all_kernels_inline = false;
662 const std::string* executor_type = nullptr;
663
664 // TensorSliceReaderCache support.
665 checkpoint::TensorSliceReaderCacheWrapper* slice_reader_cache = nullptr;
666
667 // Support for forwarding reservations (used by ScopedAllocator).
668 static constexpr int kNeverForward = -2;
669 static constexpr int kNoReservation = -1;
670 // Values in [0,...) represent reservations for the indexed output.
671 const int* forward_from_array = nullptr;
672
673 // For tracking actively running deferred ops.
674 std::function<void()> inc_num_deferred_ops_function;
675 std::function<void()> dec_num_deferred_ops_function;
676
677 absl::optional<ManagedStackTrace> stack_trace = {};
678
679 // For implementing `OpKernelContext::output_required()`. If null, all
680 // outputs are required.
681 bool* outputs_required_array = nullptr;
682
683 // For access to distributed coordination service.
684 CoordinationServiceAgent* coordination_service_agent = nullptr;
685 };
686
687 // params must outlive the OpKernelContext.
688 explicit OpKernelContext(Params* params);
689 OpKernelContext(Params* params, int num_outputs);
690 ~OpKernelContext();
691
env()692 Env* env() const { return params_->device->env(); }
693
step_id()694 int64_t step_id() const { return params_->step_id; }
695
start_time_usecs()696 int64_t start_time_usecs() const { return params_->start_time_usecs; }
697
698 // The deadline for the session to complete by. Empty if unspecified in
699 // RunOptions.
deadline()700 absl::optional<absl::Time> deadline() const { return params_->deadline; }
701
op_kernel()702 const OpKernel& op_kernel() const { return *params_->op_kernel; }
703
704 // Stack trace of where the op was defined (if defined in eager mode).
stack_trace()705 const absl::optional<ManagedStackTrace>& stack_trace() const {
706 return params_->stack_trace;
707 }
708
709 // Input/output signature.
710
num_inputs()711 int num_inputs() const { return params_->inputs.size(); }
712 DataType input_dtype(int index) const;
713 Status input_dtype(StringPiece name, DataType* dtype) const;
714 MemoryType input_memory_type(int index) const;
715
num_outputs()716 int num_outputs() const { return outputs_.size(); }
717 DataType expected_output_dtype(int index) const;
718 MemoryType output_memory_type(int index) const;
719
720 // Input
721
722 // Returns an immutable input tensor. May only be used for non-Ref
723 // inputs. For Ref inputs use mutable_input below.
724 // REQUIRES: !IsRefType(input_dtype(index))
725 // TODO(mrry): Convert this to return Status.
726 const Tensor& input(int index) const;
727
728 // Returns the named immutable input tensor in "tensor", as defined
729 // in the OpDef. May only be used for non-Ref inputs. For Ref inputs
730 // use mutable_input below.
731 // REQUIRES: !IsRefType(input_dtype(index))
732 // REQUIRES: the named input must not be a list.
733 Status input(StringPiece name, const Tensor** tensor);
734
735 // Returns the named list-valued immutable input in "list", as
736 // defined in the OpDef. If the named output is not list-valued,
737 // returns a one-element list. May only be used for non-Ref
738 // inputs. For Ref inputs use mutable_input below.
739 // REQUIRES: !IsRefType(input_dtype(index))
740 Status input_list(StringPiece name, OpInputList* list);
741
742 // For mutable inputs, use the following together to make sure there
743 // is no concurrent access to mutable_input(), e.g.:
744 // {
745 // Tensor& t = context->mutable_input(index);
746 // mutex_lock lock(*context->input_ref_mutex(index));
747 // // modify the values in t
748 // }
749 // REQUIRES: IsRefType(input_dtype(index))
750 Status input_ref_mutex(StringPiece name, mutex** out_mutex);
751
752 // Returns a mutable input tensor. Must be used to access Ref
753 // inputs. REQUIRES: IsRefType(input_dtype(index)). The caller may
754 // modify the values stored in the Tensor buffer, and modifications
755 // will be visible to other Ops reading the same ref tensor. If
756 // !lock_held the input mutex will be acquired before returning the
757 // Tensor.
758 // TODO(mrry): Convert this to return Status.
759 Tensor mutable_input(int index, bool lock_held);
760
761 // Returns the named mutable input tensor in "tensor", as defined in
762 // the OpDef. Must be used to access Ref inputs. The values stored
763 // in the Tensor buffer may be modified, and modifications will be
764 // visible to other Ops reading the same ref tensor. If !lock_held
765 // the input mutex will be acquired before returning the Tensor.
766 // REQUIRES: the named input must not be a list.
767 // REQUIRES: the named input must be a ref tensor.
768 Status mutable_input(StringPiece name, Tensor* tensor, bool lock_held);
769
770 // Returns the named list-valued mutable input in "list", as defined
771 // in the OpDef. If the named input is not list-valued, returns a
772 // one-element list. Must be used to access Ref inputs. The values
773 // stored in the Tensor buffer may be modified, and modifications
774 // will be visible to other Ops reading the same ref tensor.
775 // REQUIRES: the named input must be a ref tensor.
776 Status mutable_input_list(StringPiece name, OpMutableInputList* list);
777
778 // Replace the corresponding Ref Input to use the storage buffer
779 // used by tensor. If !lock_held the input mutex will be acquired
780 // before returning the Tensor.
781 // REQUIRES: IsRefType(input_dtype(index)).
782 void replace_ref_input(int index, const Tensor& tensor, bool lock_held);
783
784 // Replace the corresponding named Ref Input to use the storage
785 // buffer used by tensor. If !lock_held the input mutex will be
786 // acquired before returning the Tensor.
787 // REQUIRES: IsRefType(input_dtype(index)).
788 Status replace_ref_input(StringPiece name, const Tensor& tensor,
789 bool lock_held);
790
791 // Deletes the Tensor object used as the Ref Input at
792 // input_index. This is not usually necessary and should be used
793 // with caution. If !lock_held the input mutex will be acquired
794 // before returning the Tensor.
795 // REQUIRES: IsRefType(input_dtype(input_index)).
796 void delete_ref_input(int input_index, bool lock_held);
797
798 // Return true if there is input at the given index. An operator has no
799 // input at index if its tensor is null. This is primarily used by the
800 // merge operator.
801 // TODO(mrry): Convert this to return Status.
802 bool has_input(int index) const;
803
804 // Returns true if all inputs are the same shape, otherwise sets the
805 // status to a non-OK value and returns false.
806 // Usage: if (!context->ValidateInputsAreSameShape(this)) return;
807 bool ValidateInputsAreSameShape(OpKernel* op);
808
809 // If non-null, kernels should populate with any partition subgraphs created.
graph_collector()810 GraphCollector* graph_collector() { return params_->graph_collector; }
811
812 // If True, hint that all kernels in functions called by this kernel, should
813 // be treated as "inexpensive", and hence executed on the scheduling thread.
run_all_kernels_inline()814 bool run_all_kernels_inline() const {
815 return params_->run_all_kernels_inline;
816 }
817
818 // Returns the registered name for the executor type that is executing the
819 // current kernel. If empty, the default executor is used.
820 const std::string& executor_type() const;
821
822 // Input to output forwarding.
823
824 // Set the output Ref Tensor at output_index to be an alias of the
825 // input Ref Tensor at input_index.
826 // REQUIRES: IsRefType(input_dtype(input_index)).
827 // REQUIRES: IsRefType(output_dtype(output_index)).
828 void forward_ref_input_to_ref_output(int input_index, int output_index);
829
830 // Returns true when an alias to input[input_index], reshaped to output_shape,
831 // which is safe to use for in-place computation was written to *output.
832 // Returns false if input[input_index] has a refcount greater than one, or if
833 // its type does not match the expected output type of output[output_index],
834 // or the number of elements in input[input_index] does not equal the number
835 // of elements in output_shape.
836 bool forward_input_to_output_with_shape(int input_index, int output_index,
837 const TensorShape& output_shape,
838 Tensor** output) TF_MUST_USE_RESULT;
839 Status forward_input_to_output_with_shape(StringPiece input_name,
840 StringPiece output_name,
841 const TensorShape& output_shape,
842 Tensor** output) TF_MUST_USE_RESULT;
843
844 // Returns a pointer to a Tensor aliasing the underlying buffer backing
845 // input[input_index] iff
846 // * input[input_index] is not a ref,
847 // * the data type, shape, memory type, and allocator attributes of
848 // input[input_index] are compatible with those given in dtype, shape,
849 // memory_type, and attr,
850 // * refcount on the underlying buffer is one.
851 // * Either there is no forwarding reservation for either input_index
852 // or output_index or the specified input is reserved for the specified
853 // output. More precisely:
854 //
855 // These cases mean neither input nor output has a reservation:
856 // forward_from_array = nullptr
857 // OR (input_index is not in forward_from_array AND
858 // (output_index == kNoReservation OR
859 // forward_from_array[output_index] == kNoReservation))
860 //
861 // This case means that input_index is reserved for output_index:
862 // forward_from_array[output_index] == input_index
863 //
864 // This case means the output is reserved to always be allocated,
865 // never assigned a forwarded input:
866 // forward_from_array[output_index] == kNeverForward
867 //
868 // Otherwise returns nullptr.
869 // NOTE: For Cuda kernels that read inputs using the __ldg() intrinsic,
870 // forwarding is only safe if there are no reads via __ldg() after writes
871 // to the same address.
872 std::unique_ptr<Tensor> forward_input(
873 int input_index, int output_index, DataType output_dtype,
874 const TensorShape& output_shape, MemoryType output_memory_type,
875 const AllocatorAttributes& output_attr) TF_MUST_USE_RESULT;
876
877 // Tries to forward one of the inputs given in input_indices to
878 // output[output_index]. If none of the given inputs can be forwarded, calls
879 // allocate_output() to allocate a new output buffer. The index of the
880 // forwarded input will be assign to output argument forwarded_input (if it's
881 // not nullptr). If no inputs are forwarded, forwarded_input will be assigned
882 // -1.
883 Status forward_input_or_allocate_output(
884 gtl::ArraySlice<int> candidate_input_indices, int output_index,
885 const TensorShape& output_shape, Tensor** output,
886 int* forwarded_input = nullptr) TF_MUST_USE_RESULT;
887 Status forward_input_or_allocate_output(
888 gtl::ArraySlice<StringPiece> candidate_input_names,
889 StringPiece output_name, const TensorShape& output_shape,
890 Tensor** output) TF_MUST_USE_RESULT;
891
892 // Tries to reuse one of the inputs given in input_indices as a temporary.
893 // If none of the given inputs can be forwarded, calls
894 // allocate_temp() to allocate a new temporary buffer.
895 Status forward_input_or_allocate_temp(
896 gtl::ArraySlice<int> candidate_input_indices, DataType type,
897 const TensorShape& shape, const AllocatorAttributes& allocator_attr,
898 Tensor* out_temp) TF_MUST_USE_RESULT;
899
forward_input_or_allocate_temp(gtl::ArraySlice<int> candidate_input_indices,DataType type,const TensorShape & shape,Tensor * out_temp)900 Status forward_input_or_allocate_temp(
901 gtl::ArraySlice<int> candidate_input_indices, DataType type,
902 const TensorShape& shape, Tensor* out_temp) TF_MUST_USE_RESULT {
903 return forward_input_or_allocate_temp(candidate_input_indices, type, shape,
904 AllocatorAttributes(), out_temp);
905 }
906
907 // Output
908
909 // Returns the named list-valued output in "list", as defined in the OpDef.
910 // If the named output is not list-valued, returns a one-element list.
911 Status output_list(StringPiece name, OpOutputList* list);
912
913 // If output_required(index) returns true, the OpKernel's Compute() method
914 // should call allocate_output(index, ...), set_output(index, ...),
915 // set_output_ref(index, ...), or set the status to a non-ok value.
916 // If it returns false, it may output, but is not required to do so.
output_required(int index)917 bool output_required(int index) const {
918 return !params_->outputs_required_array ||
919 params_->outputs_required_array[index];
920 }
921
922 // If output_expects_forwarding returns true, the OpKernel's Compute() method
923 // should not allocate the output with allocate_output but instead needs to
924 // use forward_input.
output_expects_forwarding(int index)925 bool output_expects_forwarding(int index) const {
926 return params_->forward_from_array != nullptr &&
927 params_->forward_from_array[index] >= 0;
928 }
929
930 // Allocation of tensors during kernel execution inside the Compute
931 // method:
932 //
933 // There are two methods to allocate Tensors when an Op kernel
934 // executes.
935 //
936 // 1) allocate_output. This should be used to allocate any tensor
937 // that is going to be used as an output from the Op at the end of
938 // the current execution. The caller indicates which output the
939 // Tensor will be assigned to, and the call returns the
940 // newly-allocated Tensor. The Tensor can subsequently be assigned
941 // to during kernel execution, and will be used as the designated
942 // output when the kernel execution completes.
943 //
944 // 2) allocate_temp. This should be used to allocate any scratch
945 // storage that is needed while the kernel is executing, and will
946 // not be retained by the Op.
947 //
948 // In some cases a Tensor needs to be used as an output even though
949 // it was previously allocated elsewhere. The Tensor may have been
950 // passed as an input, or stored in a Tensor during a
951 // previous kernel execution, or allocated earlier in the kernel
952 // execution at a time when it was not known which output it would
953 // be assigned to. In this case the kernel can use set_output or
954 // set_output_ref to indicate that the tensor should be used as the
955 // designated output. It is legal to use any previously-allocated
956 // Tensor as an argument to set_output or set_output_ref, including
957 // Tensors allocated via allocate_temp. There may be a performance
958 // penalty to using a Tensor that was not allocated using
959 // allocate_output. This is because allocate_output uses the
960 // AllocatorAttributes stored in output_attr_array for the
961 // designated output. In some cases, using the wrong attributes may
962 // cause an extra copy of the Tensor's buffer.
963
964 // Allocates output for the specified output index with shape.
965 // OpKernelContext retains ownership of the returned pointer. See
966 // comment above.
967 //
968 // If memory allocation fails, returns an error status.
969 //
970 // REQUIRES: !IsRefType(expected_output_dtype(index))
971 Status allocate_output(int index, const TensorShape& shape,
972 Tensor** tensor) TF_MUST_USE_RESULT;
973 Status allocate_output(StringPiece name, const TensorShape& shape,
974 Tensor** tensor) TF_MUST_USE_RESULT;
975 // The following methods use the supplied attributes instead of
976 // those in output_attr_array. The caller is responsible for
977 // ensuring that the attributes are "compatible" with the
978 // output_attr_array, e.g. the tensor is allocated on the correct
979 // device. See comment above.
980 Status allocate_output(int index, const TensorShape& shape, Tensor** tensor,
981 AllocatorAttributes attr) TF_MUST_USE_RESULT;
982 Status allocate_output(StringPiece name, const TensorShape& shape,
983 Tensor** tensor,
984 AllocatorAttributes attr) TF_MUST_USE_RESULT;
985
986 // Allocates a temporary Tensor of the specified type and
987 // shape. Devices such as GPUs that enqueue Ops for lazy execution
988 // may retain references to the temporary tensors after the Op's
989 // Compute method has run. See comment above.
990 Status allocate_temp(DataType type, const TensorShape& shape,
991 Tensor* out_temp, AllocatorAttributes allocator_attr,
992 const AllocationAttributes& allocation_attr);
993 Status allocate_temp(DataType type, const TensorShape& shape,
994 Tensor* out_temp, AllocatorAttributes allocator_attr);
995 Status allocate_temp(DataType type, const TensorShape& shape,
996 Tensor* out_temp);
997
998 // Copies a tensor (allocated by the caller) to the specified output
999 // index. REQUIRES: !IsRefType(expected_output_dtype(index))
1000 // REQUIRES: 'tensor' must have the same MemoryType as
1001 // output_memory_types[index]. See comment above.
1002 Status set_output(StringPiece name, const Tensor& tensor);
1003 Status set_output(StringPiece name, Tensor&& tensor);
1004 void set_output(int index, const Tensor& tensor);
1005 void set_output(int index, Tensor&& tensor);
1006
1007 // To output a reference. Caller retains ownership of mu and tensor_for_ref,
1008 // and they must outlive all uses within the step. See comment above.
1009 // REQUIRES: IsRefType(expected_output_dtype(index))
1010 Status set_output_ref(StringPiece name, mutex* mu, Tensor* tensor_for_ref);
1011
1012 // Returns nullptr if allocate_output() or set_output() have not been called.
1013 Status mutable_output(StringPiece name, Tensor** tensor);
1014
1015 // Return the DeviceContext that should be used for this Op.
1016 //
1017 // If using the templated function, the type must be a subclass
1018 // of DeviceContext.
1019 //
1020 // Returns nullptr if the device did not provide one.
1021 template <typename T>
1022 T* op_device_context();
op_device_context()1023 DeviceContext* op_device_context() {
1024 DeviceContext* ret = params_->op_device_context;
1025 if (ret == nullptr) {
1026 auto* dev_info = device()->tensorflow_accelerator_device_info();
1027 if (dev_info) ret = dev_info->default_context;
1028 }
1029 return ret;
1030 }
1031
input_alloc_attr(int index)1032 AllocatorAttributes input_alloc_attr(int index) const {
1033 if (params_->input_alloc_attrs.empty()) {
1034 return AllocatorAttributes();
1035 } else {
1036 DCHECK_GE(index, 0);
1037 DCHECK_LT(index, params_->input_alloc_attrs.size());
1038 return params_->input_alloc_attrs[index];
1039 }
1040 }
1041
output_alloc_attr(int index)1042 AllocatorAttributes output_alloc_attr(int index) const {
1043 return params_->output_attr_array[index];
1044 }
1045
ConsumeWrappedAllocators()1046 gtl::InlinedVector<WrappedAllocator, 4> ConsumeWrappedAllocators() {
1047 gtl::InlinedVector<WrappedAllocator, 4> retrieved;
1048 if (tracking_state_) {
1049 mutex_lock lock(tracking_state_->mu);
1050 retrieved.swap(tracking_state_->wrapped_allocators);
1051 }
1052 return retrieved;
1053 }
1054
1055 // Communication.
1056 //
1057 // An op kernel communicates with outside environment through
1058 // Rendezvous Send() and Recv().
rendezvous()1059 RendezvousInterface* rendezvous() const { return params_->rendezvous; }
1060
collective_executor()1061 CollectiveExecutor* collective_executor() const {
1062 return params_->collective_executor;
1063 }
1064
1065 // An op kernel can access the session state it belongs to.
session_state()1066 SessionState* session_state() const { return params_->session_state; }
1067
1068 // Unique identifier of the session it belongs to. Can be empty.
session_handle()1069 std::string session_handle() const { return params_->session_handle; }
1070
1071 // Metadata about the session. Can be nullptr.
session_metadata()1072 const SessionMetadata* session_metadata() const {
1073 return params_->session_metadata;
1074 }
1075
1076 // An op kernel can access the tensor store of the run it belongs to.
tensor_store()1077 TensorStore* tensor_store() const { return params_->tensor_store; }
1078
1079 // Function call support.
1080 //
1081 // If this kernel invocation is within a function execution,
1082 // call_frame() returns the call frame for the function call.
call_frame()1083 CallFrameInterface* call_frame() const { return params_->call_frame; }
1084
1085 // If not nullptr, the kernel invoke functions defined in the
1086 // library. E.g., CHECK_NOTNULL(function_library())->Run("Foo", ...).
function_library()1087 FunctionLibraryRuntime* function_library() const {
1088 return params_->function_library;
1089 }
1090
runner()1091 std::function<void(std::function<void()>)>* runner() const {
1092 return params_->runner;
1093 }
stats_collector()1094 StepStatsCollectorInterface* stats_collector() const {
1095 return params_->stats_collector;
1096 }
1097
1098 // Shared resources accessible to this kernel.
resource_manager()1099 ResourceMgr* resource_manager() const { return params_->resource_manager; }
1100
slice_reader_cache()1101 checkpoint::TensorSliceReaderCacheWrapper* slice_reader_cache() const {
1102 return params_->slice_reader_cache;
1103 }
1104
1105 // Execution.
1106 //
1107 // OpKernels can use these eigen devices to carry out their
1108 // numerical computation.
eigen_cpu_device()1109 const Eigen::ThreadPoolDevice& eigen_cpu_device() const {
1110 return *device()->eigen_cpu_device();
1111 }
eigen_gpu_device()1112 const Eigen::GpuDevice& eigen_gpu_device() const {
1113 return params_->eigen_gpu_device->device();
1114 }
1115 template <typename EigenDeviceType>
1116 const EigenDeviceType& eigen_device() const;
1117
1118 // Error handling.
1119
1120 // If expected_inputs == inputs() and expected_outputs == output_types(),
1121 // returns OK, else returns INVALID_ARGUMENT with an error message.
1122 // Recommended for Ops with dynamic signatures, where validation can only
1123 // be performed at runtime.
1124 Status MatchSignature(const DataTypeSlice expected_inputs,
1125 const DataTypeSlice expected_outputs);
1126
1127 // An OpKernel should call SetStatus() if Compute() encounters an
1128 // error.
1129 void SetStatus(const Status& status);
status()1130 const Status& status() const { return status_; }
1131
1132 // Cancellation.
1133 //
1134 // EXPERIMENTAL. See the implementation in tensorflow::FIFOQueue for an
1135 // example of how to use this API.
cancellation_manager()1136 CancellationManager* cancellation_manager() const {
1137 return params_->cancellation_manager;
1138 }
1139
1140 // Other accessors.
1141
1142 // For control flow.
frame_iter()1143 FrameAndIter frame_iter() const { return params_->frame_iter; }
is_input_dead()1144 bool is_input_dead() const { return params_->is_input_dead; }
1145
1146 // May be used, e.g., to get GPU handles, etc.
1147 // TODO(tucker): Add example usage.
device()1148 DeviceBase* device() const { return params_->device; }
1149
1150 // Per-step container for use by white-listed internal ops.
step_container()1151 ScopedStepContainer* step_container() const {
1152 return params_->step_container;
1153 }
1154
1155 // Access to distributed coordination service.
coordination_service_agent()1156 CoordinationServiceAgent* coordination_service_agent() const {
1157 return params_->coordination_service_agent;
1158 }
1159
1160 // Helper routines for the OP_REQUIRES macros
1161 void CtxFailure(const Status& s);
1162 void CtxFailureWithWarning(const Status& s);
1163 void CtxFailure(const char* file, int line, const Status& s);
1164 void CtxFailureWithWarning(const char* file, int line, const Status& s);
1165
1166 // Unrecommended functions: these are functions that have some
1167 // current uses but are not recommended for use, and may go away at
1168 // some future major version release.
1169 //
1170 // The following functions all have versions that return Status
1171 // to capture error conditions, and are strongly preferred.
1172 Tensor* mutable_output(int index);
1173 mutex* input_ref_mutex(int index);
1174 void set_output_ref(int index, mutex* mu, Tensor* tensor_for_ref);
1175 TensorValue release_output(int index);
1176
track_allocations()1177 bool track_allocations() const { return params_->track_allocations; }
1178
1179 // Records temp memory allocation. Tensor object is recorded to identify the
1180 // case where temp memory is used as output memory.
1181 void record_temp_memory_allocation(int64_t size, const Tensor& t)
1182 TF_LOCKS_EXCLUDED(tracking_state_->stats_mu);
1183
1184 // Returns recorded size of temporary memory;
1185 int64_t temp_memory_allocated() const
1186 TF_LOCKS_EXCLUDED(tracking_state_->stats_mu);
1187
1188 // Records persistent memory allocation, size can be negative indicating
1189 // deallocation.
1190 void record_persistent_memory_allocation(int64_t size, int64_t alloc_id = -1)
1191 TF_LOCKS_EXCLUDED(tracking_state_->stats_mu);
1192
1193 // Returns recorded size and ids of persistent memory.
1194 int64_t persistent_memory_allocated() const
1195 TF_LOCKS_EXCLUDED(tracking_state_->stats_mu);
1196
1197 std::vector<int64_t> persistent_alloc_ids() const
1198 TF_LOCKS_EXCLUDED(tracking_state_->stats_mu);
1199
1200 // Resets counters for temp and persistent memory and recorded ids.
1201 void clear_recorded_memory() TF_LOCKS_EXCLUDED(tracking_state_->stats_mu);
1202
1203 bool input_is_ref(int index) const;
1204
1205 void set_record_memory_consumption(bool v);
1206
1207 // Used by OpKernel implementations to track actively running deferred ops.
1208 //
1209 // A deferred op is one whose Compute method returns (or whose ComputeAsync
1210 // method invokes the callback) when work is scheduled onto a device. At that
1211 // point, we don't know when the work will actually complete (or if it has
1212 // already completed) on the device. These functions allow the executor to
1213 // track the status of deferred ops and act accordingly.
1214 //
1215 // Deferred OpKernel implementations must use these methods to get two
1216 // functions. It then must call these two functions in pairs, before and after
1217 // device execution, respectively.
inc_num_deferred_ops_function()1218 TF_MUST_USE_RESULT std::function<void()> inc_num_deferred_ops_function() {
1219 DCHECK(params_->op_kernel->is_deferred());
1220 return params_->inc_num_deferred_ops_function
1221 ? params_->inc_num_deferred_ops_function
1222 : []() {};
1223 }
dec_num_deferred_ops_function()1224 TF_MUST_USE_RESULT std::function<void()> dec_num_deferred_ops_function() {
1225 DCHECK(params_->op_kernel->is_deferred());
1226 return params_->dec_num_deferred_ops_function
1227 ? params_->dec_num_deferred_ops_function
1228 : []() {};
1229 }
1230
1231 Allocator* get_allocator(AllocatorAttributes attr);
1232
1233 private:
1234 bool record_memory_consumption_ = false;
1235
1236 // Internal common method used when allocating tensor memory
allocate_tensor(DataType type,const TensorShape & shape,Tensor * out_tensor,AllocatorAttributes allocator_attr)1237 Status allocate_tensor(DataType type, const TensorShape& shape,
1238 Tensor* out_tensor,
1239 AllocatorAttributes allocator_attr) {
1240 return allocate_tensor(type, shape, out_tensor, allocator_attr,
1241 AllocationAttributes());
1242 }
1243
1244 Status allocate_tensor(DataType type, const TensorShape& shape,
1245 Tensor* out_tensor, AllocatorAttributes allocator_attr,
1246 const AllocationAttributes& allocation_attr);
1247
1248 // Helpers for `set_output()`.
1249
1250 // Returns `true` if the tensor was copied into an allocated output.
1251 bool maybe_set_output_by_allocate_and_copy(int index, const Tensor& tensor);
1252
1253 void maybe_track_allocations_for_set_output(const Tensor& tensor);
1254
1255 Status get_input_index(StringPiece name, int* out_index) const;
1256 Status get_output_index(StringPiece name, int* out_index) const;
1257
1258 // Initialize the allocated_scope_ids_ set the first time this method is
1259 // called.
1260 void maybe_initialize_scope_id_set();
1261
1262 Status status_;
1263 friend class CollectiveExecutor; // for access to params_
1264 Params* params_; // not owned
1265 gtl::InlinedVector<TensorValue, 4> outputs_;
1266
1267 // Keep track of calls to ScopedAllocator.
1268 // TODO(ayushd): change to absl::flat_hash_set.
1269 std::unique_ptr<std::unordered_set<int32>> allocated_scope_ids_;
1270
1271 // The following data members are only used when allocation tracking is
1272 // enabled, memory consumption is being recorded, or tensor access is being
1273 // recorded.
1274 struct TrackingState {
1275 mutable mutex mu;
1276 gtl::InlinedVector<WrappedAllocator, 4> wrapped_allocators
1277 TF_GUARDED_BY(mu);
1278
1279 mutable mutex stats_mu;
1280 int64_t temp_memory_allocated TF_GUARDED_BY(stats_mu) = 0;
1281
1282 int64_t persistent_memory_allocated TF_GUARDED_BY(stats_mu) = 0;
1283 gtl::InlinedVector<std::pair<const void*, int64_t>, 2>
1284 temp_tensor_buffer_and_size TF_GUARDED_BY(stats_mu);
1285 gtl::InlinedVector<int64_t, 2> persistent_alloc_ids TF_GUARDED_BY(stats_mu);
1286 };
1287 std::unique_ptr<TrackingState> tracking_state_;
1288
1289 // For access to `params_->op_kernel`.
1290 friend void CheckNotInComputeAsync(OpKernelContext* ctx,
1291 const char* correct_macro_name);
1292
1293 TF_DISALLOW_COPY_AND_ASSIGN(OpKernelContext);
1294 };
1295
1296 template <>
1297 const Eigen::ThreadPoolDevice& OpKernelContext::eigen_device() const;
1298
1299 template <>
1300 const Eigen::GpuDevice& OpKernelContext::eigen_device() const;
1301
1302 // Register your OpKernel by specifying the Op's name, the device the
1303 // kernel runs on, any type attr constraints for this kernel, any
1304 // host-memory args, and the class to instantiate. Examples:
1305 //
1306 // // A kernel that supports all types.
1307 // REGISTER_KERNEL_BUILDER(Name("Save").Device(DEVICE_CPU), SaveOp);
1308 //
1309 // // The following are equivalent ways of specifying that the kernel only
1310 // // works if the "T" type attr is set to DT_FLOAT.
1311 // REGISTER_KERNEL_BUILDER(
1312 // Name("Sub").Device(DEVICE_CPU).TypeConstraint<float>("T"),
1313 // SubOp<float>);
1314 // // (You would then repeat this for every type supported by "Sub".)
1315 //
1316 // // This form allows you to specify a list of types as the constraint.
1317 // REGISTER_KERNEL_BUILDER(Name("Sub")
1318 // .Device(DEVICE_CPU)
1319 // .TypeConstraint("T", {DT_FLOAT}),
1320 // SubOp<float>);
1321 //
1322 // // A kernel that expects one of the input tensors in host memory.
1323 // REGISTER_KERNEL_BUILDER(
1324 // Name("Reshape").Device(DEVICE_GPU).HostMemory("shape"), ReshapeOp);
1325 //
1326 // // A kernel that works on any device. Kernels using DEVICE_DEFAULT
1327 // // must aways run on host and all inputs and outputs must use `HostMemory`.
1328 // // Kernels for data management, control-flow primitives or working with
1329 // // tensor shapes for various devices (including `PluggableDevices`) are
1330 // // typical uses.
1331 // REGISTER_KERNEL_BUILDER(
1332 // Name("TensorListLength").Device(DEVICE_DEFAULT).HostMemory("length"),
1333 // TensorListLength);
1334 //
1335 // See kernel_def_builder for details.
1336
1337 // Instantiate an OpKernel that has been registered. Returns nullptr
1338 // if no operation for that type of device / input signature combination
1339 // (and a NOT_FOUND *status), or there is an error in construction (and
1340 // an INVALID_ARGUMENT *status). Otherwise, the caller takes ownership
1341 // of the returned pointer.
1342 // EXPECTED USAGE: unique_ptr<OpKernel> op = CreateOpKernel(...);
1343 // REQUIRES: def has all attrs specified (e.g. using AddDefaultsToNodeDef()).
1344 std::unique_ptr<OpKernel> CreateOpKernel(DeviceType device_type,
1345 DeviceBase* device,
1346 Allocator* allocator,
1347 const NodeDef& node_def,
1348 int graph_def_version, Status* status);
1349
1350 std::unique_ptr<OpKernel> CreateOpKernel(
1351 DeviceType device_type, DeviceBase* device, Allocator* allocator,
1352 const std::shared_ptr<const NodeProperties>& props, int graph_def_version,
1353 Status* status);
1354
1355 Status CreateOpKernel(DeviceType device_type, DeviceBase* device,
1356 Allocator* allocator, FunctionLibraryRuntime* flib,
1357 const std::shared_ptr<const NodeProperties>& props,
1358 int graph_def_version, OpKernel** kernel);
1359
1360 Status CreateOpKernel(DeviceType device_type, DeviceBase* device,
1361 Allocator* allocator, FunctionLibraryRuntime* flib,
1362 ResourceMgr* resource_mgr,
1363 const std::shared_ptr<const NodeProperties>& props,
1364 int graph_def_version, OpKernel** kernel);
1365
1366 // Returns into 'device_types' the subset of prioritized_types that this
1367 // binary has registered for the given NodeDef.
1368 //
1369 // REQUIRES: * 'device_types' is not nullptr.
1370 // * def has all attrs specified (e.g. using AddDefaultsToNodeDef()).
1371 Status SupportedDeviceTypesForNode(
1372 const std::vector<DeviceType>& prioritized_types, const NodeDef& def,
1373 PrioritizedDeviceTypeVector* device_types,
1374 const DeviceNameUtils::ParsedName* local_address_spec = nullptr);
1375
1376 // Returns a message with a description of the kernels registered for op
1377 // `op_name`.
1378 std::string KernelsRegisteredForOp(StringPiece op_name);
1379
1380 // Call once after Op registration has completed.
1381 Status ValidateKernelRegistrations(const OpRegistryInterface& op_registry);
1382
1383 // -----------------------------------------------------------------------------
1384 // OpKernel registration implementation follows, please ignore.
1385
1386 // Allow the REGISTER_KERNEL_BUILDER(Name("op_name").Device(...)...) syntax.
1387 namespace register_kernel {
1388
1389 class Name : public KernelDefBuilder {
1390 public:
1391 explicit Name(const char* op);
1392 };
1393
1394 } // namespace register_kernel
1395
1396 // Kernel registration appears as:
1397 // REGISTER_KERNEL_BUILDER(Name("OpName").Device(DEVICE_CPU)..., OpImpl)
1398 // We'd like to have "OpName" as a constant-expression, without requiring that
1399 // of the overall KernelDefBuilder expression (beginning with the
1400 // register_kernel::Name constructor above).
1401 //
1402 // So, we pull the "OpName" part to a separate macro-level argument. This
1403 // involves treating Name("OpName") as a macro call, via token-pasting (e.g.
1404 // M_## => M_Name("OpName")), and having it expand to '"OpName",
1405 // Name("OpName")' which is then usable as two arguments.
1406 #define TF_EXTRACT_KERNEL_NAME_Name(name_str) \
1407 name_str, ::tensorflow::register_kernel::Name(name_str)
1408 #define TF_EXTRACT_KERNEL_NAME_IMPL(m, ...) m(__VA_ARGS__)
1409 #define TF_EXTRACT_KERNEL_NAME(m, kernel_builder, ...) \
1410 TF_EXTRACT_KERNEL_NAME_IMPL(m, TF_EXTRACT_KERNEL_NAME_##kernel_builder, \
1411 __VA_ARGS__)
1412
1413 // REGISTER_KERNEL_BUILDER_IMPL_2, with a unique 'ctr' as the first argument.
1414 // TODO(dodgen): There are some uses of this macro inside functions, where
1415 // kernel_builder refers to (non-const) locals (they should be fixed). To
1416 // accommodate those, kernel_builder.Build() appears as an argument to an
1417 // immediately-called lambda (not in the lambda itself).
1418 #define REGISTER_KERNEL_BUILDER_IMPL_3(ctr, op_name, kernel_builder_expr, \
1419 is_system_kernel, ...) \
1420 static ::tensorflow::InitOnStartupMarker const register_kernel_##ctr \
1421 TF_ATTRIBUTE_UNUSED = \
1422 TF_INIT_ON_STARTUP_IF(is_system_kernel || \
1423 (SHOULD_REGISTER_OP_KERNEL(#__VA_ARGS__) && \
1424 SHOULD_REGISTER_OP(op_name))) \
1425 << ([](::tensorflow::KernelDef const* kernel_def) { \
1426 ::tensorflow::kernel_factory::OpKernelRegistrar registrar( \
1427 kernel_def, #__VA_ARGS__, \
1428 [](::tensorflow::OpKernelConstruction* context) \
1429 -> ::tensorflow::OpKernel* { \
1430 return new __VA_ARGS__(context); \
1431 }); \
1432 (void)registrar; \
1433 return ::tensorflow::InitOnStartupMarker{}; \
1434 })(kernel_builder_expr.Build());
1435
1436 // REGISTER_KERNEL_BUILDER_IMPL, but with kernel_builder split to op_name,
1437 // kernel_builder_expr.
1438 #define REGISTER_KERNEL_BUILDER_IMPL_2(op_name, kernel_builder_expr, \
1439 is_system_kernel, ...) \
1440 TF_NEW_ID_FOR_INIT(REGISTER_KERNEL_BUILDER_IMPL_3, op_name, \
1441 kernel_builder_expr, is_system_kernel, __VA_ARGS__)
1442
1443 // REGISTER_KERNEL_BUILDER, but with is_system_kernel bound.
1444 #define REGISTER_KERNEL_BUILDER_IMPL(kernel_builder, is_system_kernel, ...) \
1445 TF_EXTRACT_KERNEL_NAME(REGISTER_KERNEL_BUILDER_IMPL_2, kernel_builder, \
1446 is_system_kernel, __VA_ARGS__)
1447
1448 #define REGISTER_KERNEL_BUILDER(kernel_builder, ...) \
1449 TF_ATTRIBUTE_ANNOTATE("tf:kernel") \
1450 REGISTER_KERNEL_BUILDER_IMPL(kernel_builder, false, __VA_ARGS__)
1451
1452 // The `REGISTER_SYSTEM_KERNEL_BUILDER()` macro acts as
1453 // `REGISTER_KERNEL_BUILDER()` except that the kernel is registered
1454 // unconditionally even when selective registration is used.
1455 #define REGISTER_SYSTEM_KERNEL_BUILDER(kernel_builder, ...) \
1456 TF_ATTRIBUTE_ANNOTATE("tf:kernel") \
1457 TF_ATTRIBUTE_ANNOTATE("tf:kernel:system") \
1458 REGISTER_KERNEL_BUILDER_IMPL(kernel_builder, true, __VA_ARGS__)
1459
1460 // Checks whether a given kernel is registered on device_type.
1461 bool KernelDefAvailable(const DeviceType& device_type, const NodeDef& node_def);
1462
1463 // If node of node_name, experimental_debug_info, node_op, node_device and
1464 // node_attrs has a corresponding kernel registered on device_type, returns OK
1465 // and fill in the kernel def and kernel_class_name. <def> and
1466 // <kernel_class_name> may be null.
1467 Status FindKernelDef(
1468 const DeviceType& device_type, StringPiece node_name,
1469 bool has_experimental_debug_info,
1470 const NodeDef_ExperimentalDebugInfo& experimental_debug_info,
1471 StringPiece node_op, StringPiece node_device, AttrSlice node_attrs,
1472 const KernelDef** def, std::string* kernel_class_name);
1473
1474 // If node_def has a corresponding kernel registered on device_type,
1475 // returns OK and fill in the kernel def and kernel_class_name. <def> and
1476 // <kernel_class_name> may be null.
1477 Status FindKernelDef(const DeviceType& device_type, const NodeDef& node_def,
1478 const KernelDef** def, std::string* kernel_class_name);
1479
1480 // Writes a list of all registered kernels to LOG(INFO), to help users debug
1481 // missing kernel errors.
1482 void LogAllRegisteredKernels();
1483
1484 // Gets a list of all registered kernels.
1485 KernelList GetAllRegisteredKernels();
1486
1487 // Gets a list of all registered kernels for which predicate returns true
1488 KernelList GetFilteredRegisteredKernels(
1489 const std::function<bool(const KernelDef&)>& predicate);
1490
1491 // Gets a list of all registered kernels for a given op
1492 KernelList GetRegisteredKernelsForOp(StringPiece op_name);
1493
1494 namespace kernel_factory {
1495
1496 // OpKernelFactory is responsible for creating OpKernels when TensorFlow needs
1497 // them. You register factories with the TensorFlow core by constructing an
1498 // OpKernelRegistrar and passing the factory as a constructor parameter.
1499 class OpKernelFactory {
1500 public:
1501 virtual OpKernel* Create(OpKernelConstruction* context) = 0;
1502 virtual ~OpKernelFactory() = default;
1503 };
1504
1505 class OpKernelRegistrar {
1506 public:
1507 // Registers the given kernel factory with TensorFlow. TF will call the
1508 // factory Create() method when it determines that a kernel matching the given
1509 // KernelDef is required.
OpKernelRegistrar(const KernelDef * kernel_def,StringPiece kernel_class_name,std::unique_ptr<OpKernelFactory> factory)1510 OpKernelRegistrar(const KernelDef* kernel_def, StringPiece kernel_class_name,
1511 std::unique_ptr<OpKernelFactory> factory)
1512 TF_ATTRIBUTE_NOINLINE {
1513 InitInternal(kernel_def, kernel_class_name, std::move(factory));
1514 }
1515
1516 // Registers the given factory function with TensorFlow. This is equivalent
1517 // to registering a factory whose Create function invokes `create_fn`.
OpKernelRegistrar(const KernelDef * kernel_def,StringPiece kernel_class_name,OpKernel * (* create_fn)(OpKernelConstruction *))1518 OpKernelRegistrar(const KernelDef* kernel_def, StringPiece kernel_class_name,
1519 OpKernel* (*create_fn)(OpKernelConstruction*))
1520 TF_ATTRIBUTE_NOINLINE {
1521 InitInternal(kernel_def, kernel_class_name,
1522 absl::make_unique<PtrOpKernelFactory>(create_fn));
1523 }
1524
1525 private:
1526 struct PtrOpKernelFactory : public OpKernelFactory {
PtrOpKernelFactoryPtrOpKernelFactory1527 explicit PtrOpKernelFactory(OpKernel* (*create_func)(OpKernelConstruction*))
1528 : create_func_(create_func) {}
1529
1530 OpKernel* Create(OpKernelConstruction* context) override;
1531
1532 OpKernel* (*create_func_)(OpKernelConstruction*);
1533 };
1534
1535 void InitInternal(const KernelDef* kernel_def, StringPiece kernel_class_name,
1536 std::unique_ptr<OpKernelFactory> factory);
1537 };
1538
1539 } // namespace kernel_factory
1540
1541 // -----------------------------------------------------------------------------
1542 // Template and inline method implementations, please ignore
1543
1544 template <class T>
GetAttr(StringPiece attr_name,T * value)1545 Status OpKernelConstruction::GetAttr(StringPiece attr_name, T* value) const {
1546 return GetNodeAttr(def(), attr_name, value);
1547 }
1548
input_dtype(int index)1549 inline DataType OpKernelContext::input_dtype(int index) const {
1550 DCHECK_GE(index, 0);
1551 DCHECK_LT(index, num_inputs());
1552 const TensorValue& value(params_->inputs[index]);
1553 return value.dtype();
1554 }
1555
input_memory_type(int index)1556 inline MemoryType OpKernelContext::input_memory_type(int index) const {
1557 DCHECK_GE(index, 0);
1558 DCHECK_LT(index, num_inputs());
1559 return op_kernel().input_memory_types()[index];
1560 }
1561
expected_output_dtype(int index)1562 inline DataType OpKernelContext::expected_output_dtype(int index) const {
1563 DCHECK_GE(index, 0);
1564 DCHECK_LT(index, num_outputs());
1565 return params_->op_kernel->output_type(index);
1566 }
1567
output_memory_type(int index)1568 inline MemoryType OpKernelContext::output_memory_type(int index) const {
1569 DCHECK_GE(index, 0);
1570 DCHECK_LT(index, num_outputs());
1571 return op_kernel().output_memory_types()[index];
1572 }
1573
input_is_ref(int index)1574 inline bool OpKernelContext::input_is_ref(int index) const {
1575 const TensorValue& value(params_->inputs[index]);
1576 return value.is_ref();
1577 }
1578
1579 // no input if tensor == nullptr.
has_input(int index)1580 inline bool OpKernelContext::has_input(int index) const {
1581 DCHECK_GE(index, 0);
1582 DCHECK_LT(index, num_inputs());
1583 return params_->inputs[index].tensor != nullptr;
1584 }
1585
input_ref_mutex(int index)1586 inline mutex* OpKernelContext::input_ref_mutex(int index) {
1587 DCHECK_GE(index, 0);
1588 DCHECK_LT(index, num_inputs());
1589 DCHECK(input_is_ref(index));
1590 return params_->inputs[index].mutex_if_ref;
1591 }
1592
mutable_output(int index)1593 inline Tensor* OpKernelContext::mutable_output(int index) {
1594 DCHECK_GE(index, 0);
1595 DCHECK_LT(index, num_outputs());
1596 return outputs_[index].tensor;
1597 }
1598
release_output(int index)1599 inline TensorValue OpKernelContext::release_output(int index) {
1600 DCHECK_GE(index, 0);
1601 DCHECK_LT(index, num_outputs());
1602 TensorValue value = outputs_[index];
1603 outputs_[index] = TensorValue();
1604 return value;
1605 }
1606
1607 template <typename T>
op_device_context()1608 T* OpKernelContext::op_device_context() {
1609 static_assert(std::is_base_of<DeviceContext, T>::value,
1610 "T is not a subclass of DeviceContext");
1611 return static_cast<T*>(op_device_context());
1612 }
1613
1614 inline const Tensor& OpInputList::operator[](int i) const {
1615 DCHECK_GE(i, 0);
1616 DCHECK_LT(i, stop_ - start_);
1617 return ctx_->input(start_ + i);
1618 }
1619
ref_mutex(int i)1620 inline mutex* OpMutableInputList::ref_mutex(int i) {
1621 DCHECK_GE(i, 0);
1622 DCHECK_LT(i, stop_ - start_);
1623 return ctx_->input_ref_mutex(start_ + i);
1624 }
1625
at(int i,bool lock_held)1626 inline Tensor OpMutableInputList::at(int i, bool lock_held) {
1627 DCHECK_GE(i, 0);
1628 DCHECK_LT(i, stop_ - start_);
1629 return ctx_->mutable_input(start_ + i, lock_held);
1630 }
1631
1632 inline Tensor* OpOutputList::operator[](int i) {
1633 DCHECK_GE(i, 0);
1634 DCHECK_LT(i, stop_ - start_);
1635 return ctx_->mutable_output(start_ + i);
1636 }
1637
required(int i)1638 inline bool OpOutputList::required(int i) const {
1639 DCHECK_GE(i, 0);
1640 DCHECK_LT(i, stop_ - start_);
1641 return ctx_->output_required(start_ + i);
1642 }
1643
expected_output_dtype(int i)1644 inline DataType OpOutputList::expected_output_dtype(int i) const {
1645 DCHECK_GE(i, 0);
1646 DCHECK_LT(i, stop_ - start_);
1647 return ctx_->expected_output_dtype(start_ + i);
1648 }
1649
allocate(int i,const TensorShape & shape,Tensor ** output)1650 inline Status OpOutputList::allocate(int i, const TensorShape& shape,
1651 Tensor** output) {
1652 DCHECK_GE(i, 0);
1653 DCHECK_LT(i, stop_ - start_);
1654 return ctx_->allocate_output(start_ + i, shape, output);
1655 }
1656
set(int i,const Tensor & tensor)1657 inline void OpOutputList::set(int i, const Tensor& tensor) {
1658 DCHECK_GE(i, 0);
1659 DCHECK_LT(i, stop_ - start_);
1660 ctx_->set_output(start_ + i, tensor);
1661 }
1662
set(int i,Tensor && tensor)1663 inline void OpOutputList::set(int i, Tensor&& tensor) {
1664 DCHECK_GE(i, 0);
1665 DCHECK_LT(i, stop_ - start_);
1666 ctx_->set_output(start_ + i, std::move(tensor));
1667 }
1668
set_ref(int i,mutex * mu,Tensor * tensor_for_ref)1669 inline void OpOutputList::set_ref(int i, mutex* mu, Tensor* tensor_for_ref) {
1670 DCHECK_GE(i, 0);
1671 DCHECK_LT(i, stop_ - start_);
1672 ctx_->set_output_ref(i, mu, tensor_for_ref);
1673 }
1674
1675 // Generate a fatal error if OP_REQUIRES or OP_REQUIRES_OK are used in
1676 // AsyncOpKernel implementations. If these macros are used and the condition
1677 // does not hold, the `done` callback will never be called and the system will
1678 // deadlock, so a crash failure is preferable. Since the OP_REQUIRES[_OK] macros
1679 // are legal to use in AsyncOpKernel constructors, we use overload resolution
1680 // to distinguish between OpKernelConstruction* and OpKernelContext* context
1681 // types.
1682 class XlaOpKernelContext;
CheckNotInComputeAsync(XlaOpKernelContext *,const char *)1683 inline void CheckNotInComputeAsync(XlaOpKernelContext*, const char*) {}
CheckNotInComputeAsync(OpKernelConstruction *,const char *)1684 inline void CheckNotInComputeAsync(OpKernelConstruction*, const char*) {}
1685 void CheckNotInComputeAsync(OpKernelContext* ctx,
1686 const char* correct_macro_name);
1687
1688 } // namespace tensorflow
1689
1690 #endif // TENSORFLOW_CORE_FRAMEWORK_OP_KERNEL_H_
1691