1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_CORE_FRAMEWORK_OP_KERNEL_H_
17 #define TENSORFLOW_CORE_FRAMEWORK_OP_KERNEL_H_
18
19 #include <functional>
20 #include <unordered_set>
21 #include <utility>
22 #include <vector>
23
24 #include "tensorflow/core/framework/allocator.h"
25 #include "tensorflow/core/framework/cancellation.h"
26 #include "tensorflow/core/framework/control_flow.h"
27 #include "tensorflow/core/framework/device_base.h"
28 #include "tensorflow/core/framework/graph.pb.h"
29 #include "tensorflow/core/framework/kernel_def.pb.h"
30 #include "tensorflow/core/framework/kernel_def_builder.h"
31 #include "tensorflow/core/framework/node_def.pb.h"
32 #include "tensorflow/core/framework/node_def_util.h"
33 #include "tensorflow/core/framework/node_properties.h"
34 #include "tensorflow/core/framework/op.h" // TODO(b/62899350): Remove
35 #include "tensorflow/core/framework/op_requires.h"
36 #include "tensorflow/core/framework/rendezvous.h"
37 #include "tensorflow/core/framework/selective_registration.h"
38 #include "tensorflow/core/framework/session_state.h"
39 #include "tensorflow/core/framework/tensor.h"
40 #include "tensorflow/core/framework/tensor_shape.h"
41 #include "tensorflow/core/framework/tensor_shape.pb.h" // TODO(b/62899350): Remove
42 #include "tensorflow/core/framework/tracking_allocator.h"
43 #include "tensorflow/core/framework/types.h"
44 #include "tensorflow/core/framework/types.pb.h"
45 #include "tensorflow/core/lib/core/errors.h"
46 #include "tensorflow/core/lib/core/status.h"
47 #include "tensorflow/core/lib/gtl/array_slice.h"
48 #include "tensorflow/core/lib/gtl/manual_constructor.h"
49 #include "tensorflow/core/platform/env.h"
50 #include "tensorflow/core/platform/logging.h"
51 #include "tensorflow/core/platform/macros.h"
52 #include "tensorflow/core/platform/mutex.h"
53 #include "tensorflow/core/platform/profile_utils/cpu_utils.h"
54 #include "tensorflow/core/platform/thread_annotations.h"
55 #include "tensorflow/core/platform/types.h"
56 #include "tensorflow/core/protobuf/config.pb.h"
57 #include "tensorflow/core/util/managed_stack_trace.h"
58
59 namespace Eigen {
60 struct ThreadPoolDevice;
61 struct GpuDevice;
62 } // end namespace Eigen
63
64 namespace tensorflow {
65
66 namespace checkpoint {
67 class TensorSliceReaderCacheWrapper;
68 } // namespace checkpoint
69
70 class AsyncOpKernel;
71 class CallFrameInterface;
72 class DeviceMgr;
73 class FunctionLibraryRuntime;
74 class OpKernelConstruction; // declared below
75 class OpKernelContext; // declared below,
76 class OpRegistryInterface;
77 class ResourceMgr;
78 class ScopedStepContainer;
79 class CollectiveExecutor;
80 class StepStatsCollectorInterface;
81
82 class OpKernel {
83 public:
84 // OpKernel won't be instantiated by the scheduler, so you may perform
85 // expensive initialization in the descendant's constructor.
86 explicit OpKernel(OpKernelConstruction* context);
87
88 // Specialized constructor that allows a kernel implementation to mark itself
89 // as a "deferred" op. If true, the executor will provide access to the
90 // `OpKernelContext::inc_num_deferred_ops_function()` and
91 // `OpKernelContext::dec_num_deferred_ops_function()` methods at run-time.
92 OpKernel(OpKernelConstruction* context, bool is_deferred);
93
94 // Specialized constructor that enables the descendant to provide a custom
95 // `NodeDef` value. For example, this constructor can be used to provide a
96 // stripped-down `NodeDef` that does not contain the full set of attrs (such
97 // as tensor values) if the descendant stores them in a different form.
98 OpKernel(OpKernelConstruction* context, NodeDef&& custom_def,
99 bool is_deferred);
100
101 virtual ~OpKernel();
102
103 // An OpKernel's computation can be either synchronous or
104 // asynchronous. All OpKernel Compute() methods must be thread-safe as they
105 // may be called concurrently (e.g. by multiple executions of the same graph
106 // concurrently).
107 //
108 // Most OpKernels should compute synchronously. They should
109 // subclass OpKernel and override the Compute() method and have it
110 // return after completing the supplied work.
111 //
112 // A synchronous OpKernel *MUST NOT* block the calling thread on a
113 // synchronization mechanism (condition variable, Notification, etc.) that
114 // will be unblocked by the execution of another OpKernel. Execution may
115 // deadlock in that case, because the executor may use a bounded number of
116 // threads.
117 //
118 // If an OpKernel must block on the execution of another OpKernel (e.g. a
119 // RecvOp, or a DequeueOp), the implementation *MUST* subclass AsyncOpKernel,
120 // and override `AsyncOpKernel::ComputeAsync()`. In addition, because the
121 // unblocking kernel may never run (due to an error or cancellation), in most
122 // cases the AsyncOpKernel should implement cancellation support via
123 // `ctx->cancellation_manager()`.
124 //
125 // In both cases, implementations of Compute() and ComputeAsync()
126 // get inputs and write outputs through the given OpKernelContext
127 // and returns a status via context->SetStatus(). They must be
128 // thread-safe.
129
130 // Synchronous compute.
131 //
132 // "context" is guaranteed to be alive until Compute() returns.
133 virtual void Compute(OpKernelContext* context) = 0;
134
135 // Returns nullptr iff this op kernel is synchronous.
AsAsync()136 virtual AsyncOpKernel* AsAsync() { return nullptr; }
137
138 // Returns true iff this op kernel is considered "expensive". The
139 // runtime may use this flag to optimize graph execution for example
140 // to "inline" inexpensive kernels.
IsExpensive()141 virtual bool IsExpensive() { return expensive_; }
142
143 // Returns a pointer to the tensor stored inside constant ops.
const_tensor()144 virtual const Tensor* const_tensor() const { return nullptr; }
145
146 // Accessors.
def()147 const NodeDef& def() const { return props_->node_def; }
name()148 const std::string& name() const { return props_->node_def.name(); }
name_view()149 absl::string_view name_view() const { return name_view_; }
type_string()150 const std::string& type_string() const { return props_->node_def.op(); }
type_string_view()151 absl::string_view type_string_view() const { return type_string_view_; }
requested_input(int i)152 const std::string& requested_input(int i) const {
153 return props_->node_def.input(i);
154 }
requested_device()155 const std::string& requested_device() const {
156 return props_->node_def.device();
157 }
158
num_inputs()159 int num_inputs() const { return props_->input_types.size(); }
input_type(int i)160 DataType input_type(int i) const { return props_->input_types[i]; }
input_types()161 const DataTypeVector& input_types() const { return props_->input_types; }
input_memory_types()162 const MemoryTypeVector& input_memory_types() const {
163 return input_memory_types_;
164 }
165
num_outputs()166 int num_outputs() const { return props_->output_types.size(); }
output_type(int o)167 DataType output_type(int o) const { return props_->output_types[o]; }
output_types()168 const DataTypeVector& output_types() const { return props_->output_types; }
output_memory_types()169 const MemoryTypeVector& output_memory_types() const {
170 return output_memory_types_;
171 }
172
173 Status InputRange(StringPiece input_name, int* start, int* stop) const;
174 Status OutputRange(StringPiece output_name, int* start, int* stop) const;
175
176 // Returns `true` if and only if this kernel uses deferred execution.
is_deferred()177 bool is_deferred() const { return is_deferred_; }
178
179 // Returns a trace string for current computation, op name/type and input
180 // tensor shape/dtype are encoded for profiler cost analysis. Most OpKernel
181 // should use the default implementation.
182 virtual std::string TraceString(const OpKernelContext& ctx,
183 bool verbose) const;
184
185 protected:
186 std::string ShapeTraceString(const OpKernelContext& ctx) const;
187
188 private:
189 const std::shared_ptr<const NodeProperties> props_;
190 const MemoryTypeVector input_memory_types_;
191 const MemoryTypeVector output_memory_types_;
192 NameRangeMap input_name_map_;
193 NameRangeMap output_name_map_;
194 const absl::string_view name_view_;
195 const absl::string_view type_string_view_;
196 const int graph_def_version_;
197 const bool is_deferred_;
198 bool expensive_;
199
200 TF_DISALLOW_COPY_AND_ASSIGN(OpKernel);
201 };
202
203 class AsyncOpKernel : public OpKernel {
204 public:
205 using OpKernel::OpKernel; // Lift OpKernel constructors.
206
207 // Asynchronous compute.
208 //
209 // Implementations of ComputeAsync() must ensure that `done` is (eventually)
210 // called exactly once to signal the completion of the computation. The
211 // implementation of ComputeAsync() must not block on the execution of another
212 // OpKernel. `done` may be called by the current thread, or by another thread.
213 // `context` is guaranteed to stay alive until the `done` callback starts.
214 //
215 // Since it is possible that the unblocking kernel may never run (due to an
216 // error or cancellation), in most cases the AsyncOpKernel should implement
217 // cancellation support via `context->cancellation_manager()`.
218 //
219 // WARNING: As soon as the `done` callback starts, `context` and `this` may be
220 // deleted. No code depending on these objects should execute after the call
221 // to `done`.
222 typedef std::function<void()> DoneCallback;
223 virtual void ComputeAsync(OpKernelContext* context, DoneCallback done) = 0;
224
AsAsync()225 AsyncOpKernel* AsAsync() override { return this; }
226
227 void Compute(OpKernelContext* context) override;
228 };
229
230 // Wraps a tensor that is held by an Op across calls to Compute(). For memory
231 // safety when using asynchronous devices like GPUs, the system must be notified
232 // when a Tensor is used inside an Op execution. The wrapper ensures that all
233 // uses of the Tensor are tracked, because in order to retrieve the Tensor the
234 // caller must use AccessTensor which notifies the context.
235 class PersistentTensor {
236 public:
PersistentTensor()237 PersistentTensor() {}
PersistentTensor(const Tensor & tensor)238 explicit PersistentTensor(const Tensor& tensor) : tensor_(tensor) {}
239
240 // Caller does not own the returned Tensor*.
241 Tensor* AccessTensor(OpKernelConstruction* context);
242 // Caller does not own the returned Tensor*.
243 Tensor* AccessTensor(OpKernelContext* context);
244
245 // The check for initialization does not need to access the
246 // underlying tensor buffer.
IsInitialized()247 bool IsInitialized() const { return tensor_.IsInitialized(); }
248
NumElements()249 int64 NumElements() const { return tensor_.NumElements(); }
250
AllocatedBytes()251 int64 AllocatedBytes() const { return tensor_.AllocatedBytes(); }
252
253 private:
254 Tensor tensor_;
255 };
256
257 class OpKernelConstruction {
258 public:
259 OpKernelConstruction(DeviceType device_type, DeviceBase* device,
260 Allocator* allocator, FunctionLibraryRuntime* flib,
261 ResourceMgr* resource_mgr,
262 const std::shared_ptr<const NodeProperties>& props,
263 const MemoryTypeSlice& input_memory_types,
264 const MemoryTypeSlice& output_memory_types,
265 int graph_def_version, Status* status);
266
env()267 Env* env() const { return device_->env(); }
268
269 // Allocation of tensors during kernel construction:
270 //
271 // It is legal to temporarily allocate scratch tensor storage during
272 // Op kernel construction. Scratch tensors should be allocated using
273 // allocate_temp below. Some kernels need to keep tensors in between
274 // invocations. If such a Tensor is allocated during kernel
275 // construction this must be done using allocate_persistent, and the
276 // Op may only store the returned PersistentTensor object. When the
277 // Tensor is needed in a subsequent invocation, it can be retrieved
278 // from the PersistentTensor using the AccessTensor method. This
279 // ensures that the system is made aware of any use of the tensor's
280 // allocated memory, which is needed for correctness on asynchronous
281 // devices such as GPUs.
282
283 // Allocates a temporary Tensor of the specified type and shape. The
284 // Tensor must not be used after kernel construction is
285 // complete. See comment above.
286 Status allocate_temp(DataType type, const TensorShape& shape,
287 Tensor* out_temp);
288 Status allocate_temp(DataType type, const TensorShape& shape,
289 Tensor* out_temp, AllocatorAttributes allocator_attr);
290
291 // Allocates a Tensor of the specified type and shape which the Op
292 // plans to maintain as persistent state. out_persistent holds the
293 // PersistentTensor which is the object the caller should store. For
294 // convenience, if out_tensor is non-null then it will be filled in
295 // with a Tensor* pointing to the newly-allocated tensor which the
296 // caller can use instead of calling
297 // out_persistent->AccessTensor. The caller does not own out_tensor
298 // and should not keep a copy of it. See comment above.
299 Status allocate_persistent(DataType type, const TensorShape& shape,
300 PersistentTensor* out_persistent,
301 Tensor** out_tensor);
302
303 // User-supplied configuration of this operation.
def()304 const NodeDef& def() const { return props_->node_def; }
305
306 // For inspecting the inputs to this operation.
num_inputs()307 int num_inputs() const { return props_->input_types.size(); }
input_type(int i)308 DataType input_type(int i) const { return props_->input_types[i]; }
input_types()309 const DataTypeSlice& input_types() const { return props_->input_types_slice; }
input_memory_types()310 const MemoryTypeSlice& input_memory_types() const {
311 return input_memory_types_;
312 }
313
314 // For inspecting the outputs expected from this operation.
num_outputs()315 int num_outputs() const { return props_->output_types.size(); }
output_type(int i)316 DataType output_type(int i) const { return props_->output_types[i]; }
output_types()317 const DataTypeSlice& output_types() const {
318 return props_->output_types_slice;
319 }
output_memory_types()320 const MemoryTypeSlice& output_memory_types() const {
321 return output_memory_types_;
322 }
323
324 // If expected_inputs == inputs() and expected_outputs == output_types(),
325 // returns OK, else returns INVALID_ARGUMENT with an error message.
326 // Recommended for Ops with dynamic signatures.
327 Status MatchSignature(const DataTypeSlice expected_inputs,
328 const DataTypeSlice expected_outputs);
329
330 // For recording configuration errors during construction.
331 void SetStatus(const Status& status);
status()332 const Status& status() const { return *status_; }
333
334 // Look up the attr with name attr_name and set *value to its value. If no
335 // attr with attr_name is found in def(), or the attr does not have
336 // a matching type, a non-ok status will be returned.
337 template <class T>
338 Status GetAttr(StringPiece attr_name, T* value) const;
339
340 // Return true if the attr_name is defined in def().
341 bool HasAttr(StringPiece attr_name) const;
342
343 // Return the device type.
device_type()344 const DeviceType& device_type() const { return device_type_; }
345
346 // If not nullptr, the kernel can instantiate functions defined in
347 // the library. E.g.,
348 // CHECK_NOTNULL(function_library())->Instantiate("Foo", ...).
function_library()349 FunctionLibraryRuntime* function_library() const { return flib_; }
350
351 // Shared resources accessible to this kernel.
resource_manager()352 ResourceMgr* resource_manager() const { return resource_mgr_; }
353
354 // The GraphDef version whose behavior we should follow.
graph_def_version()355 int graph_def_version() const { return graph_def_version_; }
356
357 // Helper routines for the OP_REQUIRES macros
358 void CtxFailure(const Status& s);
359 void CtxFailureWithWarning(const Status& s);
360 void CtxFailure(const char* file, int line, const Status& s);
361 void CtxFailureWithWarning(const char* file, int line, const Status& s);
362
363 // Unrecommended functions: these are functions that have some
364 // current uses but are not recommended for use, and may go away at
365 // some future major version release.
366
367 // May be used, e.g., to get GPU handles, etc.
368 //
369 // Currently only used to call MakeTensorFromProto() for
370 // implementing ConstantOp for every device. See comments
371 // on Device::MakeTensorFromProto for longer-term replacement
372 // ideas.
device()373 DeviceBase* device() const { return device_; }
374
375 private:
376 const DeviceType device_type_;
377 DeviceBase* const device_;
378 Allocator* allocator_;
379 FunctionLibraryRuntime* flib_;
380 ResourceMgr* const resource_mgr_;
381 std::shared_ptr<const NodeProperties> props_;
382 MemoryTypeSlice input_memory_types_;
383 MemoryTypeSlice output_memory_types_;
384 const int graph_def_version_;
385 Status* status_;
386
387 // Allow access from OpKernel ctor.
388 friend class OpKernel;
389
390 TF_DISALLOW_COPY_AND_ASSIGN(OpKernelConstruction);
391 };
392
393 // TODO(mrry): Consider converting to a random_access_iterator, and upgrading
394 // tensorflow::gtl::iterator_range to make the below container classes
395 // unnecessary.
396 template <typename ListType, typename ElementType>
397 class OpArgIterator {
398 public:
399 using iterator_category = std::forward_iterator_tag;
400 using value_type = ElementType;
401 using pointer = ElementType*;
402 using const_pointer = const ElementType*;
403 using reference = ElementType&;
404 using const_reference = const ElementType&;
405 using difference_type = ptrdiff_t;
406
OpArgIterator(const ListType * list,int i)407 OpArgIterator(const ListType* list, int i) : list_(list), i_(i) {}
408
409 bool operator==(const OpArgIterator& rhs) {
410 DCHECK(list_ == rhs.list_);
411 return i_ == rhs.i_;
412 }
413
414 bool operator!=(const OpArgIterator& rhs) {
415 DCHECK(list_ == rhs.list_);
416 return i_ != rhs.i_;
417 }
418
419 OpArgIterator operator++() { // prefix ++it
420 ++i_;
421 return *this;
422 }
423
424 OpArgIterator operator++(int) { // postfix it++
425 OpArgIterator old_value = *this;
426 ++i_;
427 return old_value;
428 }
429
430 reference operator*() { return (*list_)[i_]; }
431 pointer operator->() { return &(*list_)[i_]; }
432
433 const_reference operator*() const { return (*list_)[i_]; }
434 const_pointer operator->() const { return &(*list_)[i_]; }
435
436 private:
437 const ListType* const list_;
438 int i_;
439 };
440
441 // Utility class for representing a list of immutable input tensors
442 // that are passed to the op as a single named argument.
443 class OpInputList {
444 public:
445 typedef OpArgIterator<OpInputList, const Tensor> Iterator;
OpInputList()446 OpInputList() : ctx_(nullptr), start_(0), stop_(0) {}
OpInputList(OpKernelContext * ctx,int start,int stop)447 OpInputList(OpKernelContext* ctx, int start, int stop)
448 : ctx_(ctx), start_(start), stop_(stop) {}
449 OpInputList& operator=(const OpInputList& other) = default;
450 const Tensor& operator[](int i) const;
size()451 int size() const { return stop_ - start_; }
begin()452 Iterator begin() const { return Iterator(this, 0); }
end()453 Iterator end() const { return Iterator(this, size()); }
454
455 private:
456 OpKernelContext* ctx_; // not owned
457 int start_;
458 int stop_;
459 };
460
461 // Utility class for representing a list of mutable ("ref") input tensors
462 // that are passed to the op as a single named argument.
463 class OpMutableInputList {
464 public:
465 typedef OpArgIterator<OpMutableInputList, Tensor*> Iterator;
OpMutableInputList(OpKernelContext * ctx,int start,int stop)466 OpMutableInputList(OpKernelContext* ctx, int start, int stop)
467 : ctx_(ctx), start_(start), stop_(stop) {}
OpMutableInputList()468 OpMutableInputList() : ctx_(nullptr), start_(0), stop_(0) {}
469 OpMutableInputList& operator=(const OpMutableInputList& other) = default;
470 Tensor at(int i, bool lock_held);
471 mutex* ref_mutex(int i);
size()472 int size() const { return stop_ - start_; }
begin()473 Iterator begin() const { return Iterator(this, 0); }
end()474 Iterator end() const { return Iterator(this, size()); }
475
476 private:
477 OpKernelContext* ctx_; // not owned
478 int start_;
479 int stop_;
480 };
481
482 // Utility class for representing a list of output tensors that are
483 // grouped as a single named output.
484 class OpOutputList {
485 public:
486 typedef OpArgIterator<OpOutputList, const Tensor*> Iterator;
OpOutputList()487 OpOutputList() : ctx_(nullptr), start_(0), stop_(0) {}
OpOutputList(OpKernelContext * ctx,int start,int stop)488 OpOutputList(OpKernelContext* ctx, int start, int stop)
489 : ctx_(ctx), start_(start), stop_(stop) {}
490 OpOutputList& operator=(const OpOutputList& other) = default;
491 Tensor* operator[](int i);
492 bool required(int i) const;
493 DataType expected_output_dtype(int i) const;
494 Status allocate(int i, const TensorShape& shape, Tensor** output);
495 void set(int i, const Tensor& tensor);
496 void set(int i, Tensor&& tensor);
497 void set_ref(int i, mutex* mu, Tensor* tensor_for_ref);
size()498 int size() const { return stop_ - start_; }
begin()499 Iterator begin() const { return Iterator(this, 0); }
end()500 Iterator end() const { return Iterator(this, size()); }
501
502 private:
503 OpKernelContext* ctx_; // not owned
504 int start_;
505 int stop_;
506 };
507
508 // Holds a tensor or tensor reference. For tensor references, we need
509 // a mutex to prevent concurrent access to the tensor.
510 struct TensorValue {
TensorValueTensorValue511 TensorValue() : mutex_if_ref(nullptr), tensor(nullptr) {}
TensorValueTensorValue512 explicit TensorValue(Tensor* t) : mutex_if_ref(nullptr), tensor(t) {}
TensorValueTensorValue513 TensorValue(mutex* mu, Tensor* t) : mutex_if_ref(mu), tensor(t) {}
514 Tensor* operator->() const { return tensor; }
is_refTensorValue515 bool is_ref() const { return mutex_if_ref != nullptr; }
516
517 // Return the dtype of the Tensor. For references, return the underlying type.
dtypeTensorValue518 DataType dtype() const {
519 if (is_ref()) {
520 return MakeRefType(tensor->dtype());
521 } else {
522 return tensor->dtype();
523 }
524 }
525
526 // Return the dtype of the Tensor. For references, return the underlying type.
527 // This variation on the dtype() acquires the lock for references.
528 //
529 // TODO(b/133843385): Disallow dtype modifications
dtype_safeTensorValue530 DataType dtype_safe() const {
531 if (is_ref()) {
532 tf_shared_lock ml(*mutex_if_ref);
533 return MakeRefType(tensor->dtype());
534 } else {
535 return tensor->dtype();
536 }
537 }
538
539 mutex* mutex_if_ref; // nullptr if not a ref, != nullptr if a ref
540 Tensor* tensor;
541 };
542
543 // Used to store partitioned graphs from function-calling ops.
544 struct GraphCollector {
545 mutex mu;
546 std::vector<GraphDef> partitioned_graphs TF_GUARDED_BY(mu);
547 GraphDef raw_graph TF_GUARDED_BY(mu);
548 GraphDef optimized_graph TF_GUARDED_BY(mu);
549
550 bool dirty TF_GUARDED_BY(mu);
551
GraphCollectorGraphCollector552 GraphCollector() : dirty(false) {}
553
CollectRawGraphGraphCollector554 void CollectRawGraph(const GraphDef& graph) {
555 mutex_lock ml(mu);
556 raw_graph.MergeFrom(graph);
557 dirty = true;
558 }
559
CollectOptimizedGraphGraphCollector560 void CollectOptimizedGraph(const GraphDef& graph) {
561 mutex_lock ml(mu);
562 optimized_graph.MergeFrom(graph);
563 dirty = true;
564 }
565
CollectPartitionedGraphGraphCollector566 void CollectPartitionedGraph(const GraphDef& graph) {
567 mutex_lock ml(mu);
568 partitioned_graphs.push_back(graph);
569 dirty = true;
570 }
571
ClearGraphsGraphCollector572 void ClearGraphs() TF_EXCLUSIVE_LOCKS_REQUIRED(mu) {
573 raw_graph.Clear();
574 optimized_graph.Clear();
575 partitioned_graphs.clear();
576 dirty = false;
577 }
578
HasUpdatedGraphsGraphCollector579 bool HasUpdatedGraphs() {
580 mutex_lock ml(mu);
581 return dirty;
582 }
583 };
584
585 class OpKernelContext {
586 public:
587 // The first element of a WrappedAllocator is a "base" Allocator and
588 // the second element is that Allocator wrapped by a
589 // TrackingAllocator
590 typedef std::pair<Allocator*, TrackingAllocator*> WrappedAllocator;
591
592 // TODO(zhifengc): Do some cleanup of Params.
593 // The Params struct is passed in to initialize an OpKernelContext,
594 // and must outlive the OpKernelContext.
595 struct Params {
~ParamsParams596 ~Params() { delete eigen_gpu_device; }
597
598 // The step being executed.
599 int64 step_id = 0;
600
601 // The op kernel being computed.
602 OpKernel* op_kernel = nullptr;
603
604 // The device on which the kernel is running.
605 DeviceBase* device = nullptr;
606
607 // The Eigen GPU device wrapper, which may include a per-op
608 // wrapped allocator. The concrete type of this object depends on
609 // the type of this->device, so eigen_gpu_device can't be an
610 // inline member and must be heap allocated. However, we don't
611 // want to allocate a new eigen_gpu_device for every Op that is
612 // executed. Instead this member is allocated on first use using
613 // ensure_eigen_gpu_device, and then if the Params structure is
614 // re-used for subsequent Ops, the eigen_gpu_device is
615 // ReInitialized in the OpKernelContext constructor. Unlike the
616 // other pointers in Params, this one is owned by Params.
617 PerOpGpuDevice* eigen_gpu_device = nullptr;
618
ensure_eigen_gpu_deviceParams619 inline void ensure_eigen_gpu_device() {
620 DCHECK(device);
621 if (nullptr == eigen_gpu_device) {
622 // Surprisingly, MakeGpuDevice will return nullptr if the
623 // device is not a GPU device. This is ok, since those devices
624 // will never use eigen_gpu_device. It seems better to have
625 // ensure_eigen_gpu_device fall through and regenerate the
626 // nullptr every time an OpKernelContext is instantiated, than
627 // to do an unnecessary allocation of a dummy eigen GPU
628 // device for CPU device Ops.
629 eigen_gpu_device = device->MakeGpuDevice();
630 }
631 }
632
633 bool track_allocations = false;
634 bool log_memory = false;
635
636 // Array indexed by output number for this node
637 const AllocatorAttributes* output_attr_array = nullptr;
638
639 // Shared resources accessible by this op kernel invocation.
640 ResourceMgr* resource_manager = nullptr;
641
642 // Per-step resources accessible by this op kernel invocation should be
643 // stored in this container..
644 ScopedStepContainer* step_container = nullptr;
645
646 // Mechanism used by this op kernel invocation to communicate with
647 // computations running on other devices.
648 RendezvousInterface* rendezvous = nullptr;
649
650 // Mechanism for executing a collective op that needs to coordinate
651 // with parallel instances running on other devices.
652 CollectiveExecutor* collective_executor = nullptr;
653
654 // The session state for this op.
655 SessionState* session_state = nullptr;
656
657 // Unique session identifier. Can be empty.
658 std::string session_handle;
659
660 // Metadata about the session. Can be nullptr.
661 const SessionMetadata* session_metadata = nullptr;
662
663 // The tensor store for this op.
664 TensorStore* tensor_store = nullptr;
665
666 // Mechanism used by this op kernel invocation to register a callback
667 // for its cancellation.
668 CancellationManager* cancellation_manager = nullptr;
669
670 // Inputs to this op kernel.
671 const gtl::InlinedVector<TensorValue, 4>* inputs = nullptr;
672 bool is_input_dead = false;
673
674 const gtl::InlinedVector<AllocatorAttributes, 4>* input_alloc_attrs =
675 nullptr;
676
677 // Device context.
678 DeviceContext* op_device_context = nullptr;
679
680 // Control-flow op supports.
681 FrameAndIter frame_iter;
682
683 // Function call supports.
684 CallFrameInterface* call_frame = nullptr;
685 FunctionLibraryRuntime* function_library = nullptr;
686 std::function<void(std::function<void()>)>* runner = nullptr;
687 StepStatsCollectorInterface* stats_collector = nullptr;
688 GraphCollector* graph_collector = nullptr;
689 bool run_all_kernels_inline = false;
690 const std::string* executor_type = nullptr;
691
692 // TensorSliceReaderCache support.
693 checkpoint::TensorSliceReaderCacheWrapper* slice_reader_cache = nullptr;
694
695 // Support for forwarding reservations (used by ScopedAllocator).
696 static constexpr int kNeverForward = -2;
697 static constexpr int kNoReservation = -1;
698 // Values in [0,...) represent reservations for the indexed output.
699 const int* forward_from_array = nullptr;
700
701 // For tracking actively running deferred ops.
702 std::function<void()> inc_num_deferred_ops_function;
703 std::function<void()> dec_num_deferred_ops_function;
704
705 absl::optional<ManagedStackTrace> stack_trace = {};
706
707 // For implementing `OpKernelContext::output_required()`. If null, all
708 // outputs are required.
709 bool* outputs_required_array = nullptr;
710 };
711
712 // params must outlive the OpKernelContext.
713 explicit OpKernelContext(Params* params);
714 OpKernelContext(Params* params, int num_outputs);
715 ~OpKernelContext();
716
env()717 Env* env() const { return params_->device->env(); }
718
step_id()719 int64 step_id() const { return params_->step_id; }
720
op_kernel()721 const OpKernel& op_kernel() const { return *params_->op_kernel; }
722
723 // Stack trace of where the op was defined (if defined in eager mode).
stack_trace()724 const absl::optional<ManagedStackTrace>& stack_trace() const {
725 return params_->stack_trace;
726 }
727
728 // Input/output signature.
729
num_inputs()730 int num_inputs() const { return params_->inputs->size(); }
731 DataType input_dtype(int index) const;
732 Status input_dtype(StringPiece name, DataType* dtype) const;
733 MemoryType input_memory_type(int index) const;
734
num_outputs()735 int num_outputs() const { return outputs_.size(); }
736 DataType expected_output_dtype(int index) const;
737 MemoryType output_memory_type(int index) const;
738
739 // Input
740
741 // Returns an immutable input tensor. May only be used for non-Ref
742 // inputs. For Ref inputs use mutable_input below.
743 // REQUIRES: !IsRefType(input_dtype(index))
744 // TODO(mrry): Convert this to return Status.
745 const Tensor& input(int index) const;
746
747 // Returns the named immutable input tensor in "tensor", as defined
748 // in the OpDef. May only be used for non-Ref inputs. For Ref inputs
749 // use mutable_input below.
750 // REQUIRES: !IsRefType(input_dtype(index))
751 // REQUIRES: the named input must not be a list.
752 Status input(StringPiece name, const Tensor** tensor);
753
754 // Returns the named list-valued immutable input in "list", as
755 // defined in the OpDef. If the named output is not list-valued,
756 // returns a one-element list. May only be used for non-Ref
757 // inputs. For Ref inputs use mutable_input below.
758 // REQUIRES: !IsRefType(input_dtype(index))
759 Status input_list(StringPiece name, OpInputList* list);
760
761 // For mutable inputs, use the following together to make sure there
762 // is no concurrent access to mutable_input(), e.g.:
763 // {
764 // Tensor& t = context->mutable_input(index);
765 // mutex_lock lock(*context->input_ref_mutex(index));
766 // // modify the values in t
767 // }
768 // REQUIRES: IsRefType(input_dtype(index))
769 Status input_ref_mutex(StringPiece name, mutex** out_mutex);
770
771 // Returns a mutable input tensor. Must be used to access Ref
772 // inputs. REQUIRES: IsRefType(input_dtype(index)). The caller may
773 // modify the values stored in the Tensor buffer, and modifications
774 // will be visible to other Ops reading the same ref tensor. If
775 // !lock_held the input mutex will be acquired before returning the
776 // Tensor.
777 // TODO(mrry): Convert this to return Status.
778 Tensor mutable_input(int index, bool lock_held);
779
780 // Returns the named mutable input tensor in "tensor", as defined in
781 // the OpDef. Must be used to access Ref inputs. The values stored
782 // in the Tensor buffer may be modified, and modifications will be
783 // visible to other Ops reading the same ref tensor. If !lock_held
784 // the input mutex will be acquired before returning the Tensor.
785 // REQUIRES: the named input must not be a list.
786 // REQUIRES: the named input must be a ref tensor.
787 Status mutable_input(StringPiece name, Tensor* tensor, bool lock_held);
788
789 // Returns the named list-valued mutable input in "list", as defined
790 // in the OpDef. If the named input is not list-valued, returns a
791 // one-element list. Must be used to access Ref inputs. The values
792 // stored in the Tensor buffer may be modified, and modifications
793 // will be visible to other Ops reading the same ref tensor.
794 // REQUIRES: the named input must be a ref tensor.
795 Status mutable_input_list(StringPiece name, OpMutableInputList* list);
796
797 // Replace the corresponding Ref Input to use the storage buffer
798 // used by tensor. If !lock_held the input mutex will be acquired
799 // before returning the Tensor.
800 // REQUIRES: IsRefType(input_dtype(index)).
801 void replace_ref_input(int index, const Tensor& tensor, bool lock_held);
802
803 // Replace the corresponding named Ref Input to use the storage
804 // buffer used by tensor. If !lock_held the input mutex will be
805 // acquired before returning the Tensor.
806 // REQUIRES: IsRefType(input_dtype(index)).
807 Status replace_ref_input(StringPiece name, const Tensor& tensor,
808 bool lock_held);
809
810 // Deletes the Tensor object used as the Ref Input at
811 // input_index. This is not usually necessary and should be used
812 // with caution. If !lock_held the input mutex will be acquired
813 // before returning the Tensor.
814 // REQUIRES: IsRefType(input_dtype(input_index)).
815 void delete_ref_input(int input_index, bool lock_held);
816
817 // Return true if there is input at the given index. An operator has no
818 // input at index if its tensor is null. This is primarily used by the
819 // merge operator.
820 // TODO(mrry): Convert this to return Status.
821 bool has_input(int index) const;
822
823 // Returns true if all inputs are the same shape, otherwise sets the
824 // status to a non-OK value and returns false.
825 // Usage: if (!context->ValidateInputsAreSameShape(this)) return;
826 bool ValidateInputsAreSameShape(OpKernel* op);
827
828 // If non-null, kernels should populate with any partition subgraphs created.
graph_collector()829 GraphCollector* graph_collector() { return params_->graph_collector; }
830
831 // If True, hint that all kernels in functions called by this kernel, should
832 // be treated as "inexpensive", and hence executed on the scheduling thread.
run_all_kernels_inline()833 bool run_all_kernels_inline() const {
834 return params_->run_all_kernels_inline;
835 }
836
837 // Returns the registered name for the executor type that is executing the
838 // current kernel. If empty, the default executor is used.
839 const std::string& executor_type() const;
840
841 // Input to output forwarding.
842
843 // Set the output Ref Tensor at output_index to be an alias of the
844 // input Ref Tensor at input_index.
845 // REQUIRES: IsRefType(input_dtype(input_index)).
846 // REQUIRES: IsRefType(output_dtype(output_index)).
847 void forward_ref_input_to_ref_output(int input_index, int output_index);
848
849 // Returns true when an alias to input[input_index], reshaped to output_shape,
850 // which is safe to use for in-place computation was written to *output.
851 // Returns false if input[input_index] has a refcount greater than one, or if
852 // its type does not match the expected output type of output[output_index],
853 // or the number of elements in input[input_index] does not equal the number
854 // of elements in output_shape.
855 bool forward_input_to_output_with_shape(int input_index, int output_index,
856 const TensorShape& output_shape,
857 Tensor** output) TF_MUST_USE_RESULT;
858 Status forward_input_to_output_with_shape(StringPiece input_name,
859 StringPiece output_name,
860 const TensorShape& output_shape,
861 Tensor** output) TF_MUST_USE_RESULT;
862
863 // Returns a pointer to a Tensor aliasing the underlying buffer backing
864 // input[input_index] iff
865 // * input[input_index] is not a ref,
866 // * the data type, shape, memory type, and allocator attributes of
867 // input[input_index] are compatible with those given in dtype, shape,
868 // memory_type, and attr,
869 // * refcount on the underlying buffer is one.
870 // * Either there is no forwarding reservation for either input_index
871 // or output_index or the specified input is reserved for the specified
872 // output. More precisely:
873 //
874 // These cases mean neither input nor output has a reservation:
875 // forward_from_array = nullptr
876 // OR (input_index is not in forward_from_array AND
877 // (output_index == kNoReservation OR
878 // forward_from_array[output_index] == kNoReservation))
879 //
880 // This case means that input_index is reserved for output_index:
881 // forward_from_array[output_index] == input_index
882 //
883 // This case means the output is reserved to always be allocated,
884 // never assigned a forwarded input:
885 // forward_from_array[output_index] == kNeverForward
886 //
887 // Otherwise returns nullptr.
888 // NOTE: For Cuda kernels that read inputs using the __ldg() intrinsic,
889 // forwarding is only safe if there are no reads via __ldg() after writes
890 // to the same address.
891 std::unique_ptr<Tensor> forward_input(
892 int input_index, int output_index, DataType output_dtype,
893 const TensorShape& output_shape, MemoryType output_memory_type,
894 const AllocatorAttributes& output_attr) TF_MUST_USE_RESULT;
895
896 // Tries to forward one of the inputs given in input_indices to
897 // output[output_index]. If none of the given inputs can be forwarded, calls
898 // allocate_output() to allocate a new output buffer. The index of the
899 // forwarded input will be assign to output argument forwarded_input (if it's
900 // not nullptr). If no inputs are forwarded, forwarded_input will be assigned
901 // -1.
902 Status forward_input_or_allocate_output(
903 gtl::ArraySlice<int> candidate_input_indices, int output_index,
904 const TensorShape& output_shape, Tensor** output,
905 int* forwarded_input = nullptr) TF_MUST_USE_RESULT;
906 Status forward_input_or_allocate_output(
907 gtl::ArraySlice<StringPiece> candidate_input_names,
908 StringPiece output_name, const TensorShape& output_shape,
909 Tensor** output) TF_MUST_USE_RESULT;
910
911 // Tries to reuse one of the inputs given in input_indices as a temporary.
912 // If none of the given inputs can be forwarded, calls
913 // allocate_temp() to allocate a new temporary buffer.
914 Status forward_input_or_allocate_temp(
915 gtl::ArraySlice<int> candidate_input_indices, DataType type,
916 const TensorShape& shape, const AllocatorAttributes& allocator_attr,
917 Tensor* out_temp) TF_MUST_USE_RESULT;
918
forward_input_or_allocate_temp(gtl::ArraySlice<int> candidate_input_indices,DataType type,const TensorShape & shape,Tensor * out_temp)919 Status forward_input_or_allocate_temp(
920 gtl::ArraySlice<int> candidate_input_indices, DataType type,
921 const TensorShape& shape, Tensor* out_temp) TF_MUST_USE_RESULT {
922 return forward_input_or_allocate_temp(candidate_input_indices, type, shape,
923 AllocatorAttributes(), out_temp);
924 }
925
926 // Output
927
928 // Returns the named list-valued output in "list", as defined in the OpDef.
929 // If the named output is not list-valued, returns a one-element list.
930 Status output_list(StringPiece name, OpOutputList* list);
931
932 // If output_required(index) returns true, the OpKernel's Compute() method
933 // should call allocate_output(index, ...), set_output(index, ...),
934 // set_output_ref(index, ...), or set the status to a non-ok value.
935 // If it returns false, it may output, but is not required to do so.
output_required(int index)936 bool output_required(int index) const {
937 return !params_->outputs_required_array ||
938 params_->outputs_required_array[index];
939 }
940
941 // Allocation of tensors during kernel execution inside the Compute
942 // method:
943 //
944 // There are three methods to allocate Tensors when an Op kernel
945 // executes.
946 //
947 // 1) allocate_persistent. This is only needed for Tensors that will
948 // be stored by the Op between invocations, and it *must* be used
949 // for those Tensors. The call returns a PersistentTensor, and that
950 // is the only object the Op is allowed to hold on to between
951 // invocations. When the Tensor is needed in a subsequent
952 // invocation, it can be retrieved from the PersistentTensor using
953 // the AccessTensor method. This ensures that the system is made
954 // aware of any use of the tensor's allocated memory, which is
955 // needed for correctness on asynchronous devices such as GPUs.
956 //
957 // 2) allocate_output. This should be used to allocate any tensor
958 // that is going to be used as an output from the Op at the end of
959 // the current execution. The caller indicates which output the
960 // Tensor will be assigned to, and the call returns the
961 // newly-allocated Tensor. The Tensor can subsequently be assigned
962 // to during kernel execution, and will be used as the designated
963 // output when the kernel execution completes.
964 //
965 // 3) allocate_temp. This should be used to allocate any scratch
966 // storage that is needed while the kernel is executing, and will
967 // not be retained by the Op.
968 //
969 // In some cases a Tensor needs to be used as an output even though
970 // it was previously allocated elsewhere. The Tensor may have been
971 // passed as an input, or stored in a PersistentTensor during a
972 // previous kernel execution, or allocated earlier in the kernel
973 // execution at a time when it was not known which output it would
974 // be assigned to. In this case the kernel can use set_output or
975 // set_output_ref to indicate that the tensor should be used as the
976 // designated output. It is legal to use any previously-allocated
977 // Tensor as an argument to set_output or set_output_ref, including
978 // Tensors allocated via allocate_temp. There may be a performance
979 // penalty to using a Tensor that was not allocated using
980 // allocate_output. This is because allocate_output uses the
981 // AllocatorAttributes stored in output_attr_array for the
982 // designated output. In some cases, using the wrong attributes may
983 // cause an extra copy of the Tensor's buffer.
984
985 // Allocates output for the specified output index with shape.
986 // OpKernelContext retains ownership of the returned pointer. See
987 // comment above.
988 //
989 // If memory allocation fails, returns an error status.
990 //
991 // REQUIRES: !IsRefType(expected_output_dtype(index))
992 Status allocate_output(int index, const TensorShape& shape,
993 Tensor** tensor) TF_MUST_USE_RESULT;
994 Status allocate_output(StringPiece name, const TensorShape& shape,
995 Tensor** tensor) TF_MUST_USE_RESULT;
996 // The following methods use the supplied attributes instead of
997 // those in output_attr_array. The caller is responsible for
998 // ensuring that the attributes are "compatible" with the
999 // output_attr_array, e.g. the tensor is allocated on the correct
1000 // device. See comment above.
1001 Status allocate_output(int index, const TensorShape& shape, Tensor** tensor,
1002 AllocatorAttributes attr) TF_MUST_USE_RESULT;
1003 Status allocate_output(StringPiece name, const TensorShape& shape,
1004 Tensor** tensor,
1005 AllocatorAttributes attr) TF_MUST_USE_RESULT;
1006
1007 // Allocates a temporary Tensor of the specified type and
1008 // shape. Devices such as GPUs that enqueue Ops for lazy execution
1009 // may retain references to the temporary tensors after the Op's
1010 // Compute method has run. See comment above.
1011 Status allocate_temp(DataType type, const TensorShape& shape,
1012 Tensor* out_temp, AllocatorAttributes allocator_attr,
1013 const AllocationAttributes& allocation_attr);
allocate_temp(DataType type,const TensorShape & shape,Tensor * out_temp,AllocatorAttributes allocator_attr)1014 Status allocate_temp(DataType type, const TensorShape& shape,
1015 Tensor* out_temp, AllocatorAttributes allocator_attr) {
1016 return allocate_temp(type, shape, out_temp, allocator_attr,
1017 AllocationAttributes());
1018 }
allocate_temp(DataType type,const TensorShape & shape,Tensor * out_temp)1019 Status allocate_temp(DataType type, const TensorShape& shape,
1020 Tensor* out_temp) {
1021 return allocate_temp(type, shape, out_temp, AllocatorAttributes());
1022 }
1023
1024 // Allocates a Tensor of the specified type and shape which the Op
1025 // plans to maintain as persistent state. out_persistent holds the
1026 // PersistentTensor which is the object the caller should store. For
1027 // convenience, if out_tensor is non-null then it will be filled in
1028 // with a Tensor* pointing to the newly-allocated tensor which the
1029 // caller can use instead of calling
1030 // out_persistent->AccessTensor. The caller does not own out_tensor
1031 // and should not keep a copy of it. See comment above.
1032 Status allocate_persistent(DataType type, const TensorShape& shape,
1033 PersistentTensor* out_persistent,
1034 Tensor** out_tensor, AllocatorAttributes attr);
allocate_persistent(DataType type,const TensorShape & shape,PersistentTensor * out_persistent,Tensor ** out_tensor)1035 Status allocate_persistent(DataType type, const TensorShape& shape,
1036 PersistentTensor* out_persistent,
1037 Tensor** out_tensor) {
1038 return allocate_persistent(type, shape, out_persistent, out_tensor,
1039 AllocatorAttributes());
1040 }
1041
1042 // Copies a tensor (allocated by the caller) to the specified output
1043 // index. REQUIRES: !IsRefType(expected_output_dtype(index))
1044 // REQUIRES: 'tensor' must have the same MemoryType as
1045 // output_memory_types[index]. See comment above.
1046 Status set_output(StringPiece name, const Tensor& tensor);
1047 Status set_output(StringPiece name, Tensor&& tensor);
1048 void set_output(int index, const Tensor& tensor);
1049 void set_output(int index, Tensor&& tensor);
1050
1051 // To output a reference. Caller retains ownership of mu and tensor_for_ref,
1052 // and they must outlive all uses within the step. See comment above.
1053 // REQUIRES: IsRefType(expected_output_dtype(index))
1054 Status set_output_ref(StringPiece name, mutex* mu, Tensor* tensor_for_ref);
1055
1056 // Returns nullptr if allocate_output() or set_output() have not been called.
1057 Status mutable_output(StringPiece name, Tensor** tensor);
1058
1059 // Return the DeviceContext that should be used for this Op.
1060 //
1061 // If using the templated function, the type must be a subclass
1062 // of DeviceContext.
1063 //
1064 // Returns nullptr if the device did not provide one.
1065 template <typename T>
1066 T* op_device_context();
op_device_context()1067 DeviceContext* op_device_context() {
1068 DeviceContext* ret = params_->op_device_context;
1069 if (ret == nullptr) {
1070 auto* dev_info = device()->tensorflow_gpu_device_info();
1071 if (dev_info) ret = dev_info->default_context;
1072 }
1073 return ret;
1074 }
1075
input_alloc_attr(int index)1076 AllocatorAttributes input_alloc_attr(int index) const {
1077 if (params_->input_alloc_attrs == nullptr) {
1078 return AllocatorAttributes();
1079 } else {
1080 DCHECK_GE(index, 0);
1081 DCHECK_LT(index, params_->input_alloc_attrs->size());
1082 return (*params_->input_alloc_attrs)[index];
1083 }
1084 }
1085
output_alloc_attr(int index)1086 AllocatorAttributes output_alloc_attr(int index) const {
1087 return params_->output_attr_array[index];
1088 }
1089
ConsumeWrappedAllocators()1090 gtl::InlinedVector<WrappedAllocator, 4> ConsumeWrappedAllocators() {
1091 gtl::InlinedVector<WrappedAllocator, 4> retrieved;
1092 if (tracking_state_) {
1093 mutex_lock lock(tracking_state_->mu);
1094 retrieved.swap(tracking_state_->wrapped_allocators);
1095 }
1096 return retrieved;
1097 }
1098
1099 // Communication.
1100 //
1101 // An op kernel communicates with outside environment through
1102 // Rendezvous Send() and Recv().
rendezvous()1103 RendezvousInterface* rendezvous() const { return params_->rendezvous; }
1104
collective_executor()1105 CollectiveExecutor* collective_executor() const {
1106 return params_->collective_executor;
1107 }
1108
1109 // An op kernel can access the session state it belongs to.
session_state()1110 SessionState* session_state() const { return params_->session_state; }
1111
1112 // Unique identifier of the session it belongs to. Can be empty.
session_handle()1113 std::string session_handle() const { return params_->session_handle; }
1114
1115 // Metadata about the session. Can be nullptr.
session_metadata()1116 const SessionMetadata* session_metadata() const {
1117 return params_->session_metadata;
1118 }
1119
1120 // An op kernel can access the tensor store of the run it belongs to.
tensor_store()1121 TensorStore* tensor_store() const { return params_->tensor_store; }
1122
1123 // Function call support.
1124 //
1125 // If this kernel invocation is within a function execution,
1126 // call_frame() returns the call frame for the function call.
call_frame()1127 CallFrameInterface* call_frame() const { return params_->call_frame; }
1128
1129 // If not nullptr, the kernel invoke functions defined in the
1130 // library. E.g., CHECK_NOTNULL(function_library())->Run("Foo", ...).
function_library()1131 FunctionLibraryRuntime* function_library() const {
1132 return params_->function_library;
1133 }
1134
runner()1135 std::function<void(std::function<void()>)>* runner() const {
1136 return params_->runner;
1137 }
stats_collector()1138 StepStatsCollectorInterface* stats_collector() const {
1139 return params_->stats_collector;
1140 }
1141
1142 // Shared resources accessible to this kernel.
resource_manager()1143 ResourceMgr* resource_manager() const { return params_->resource_manager; }
1144
slice_reader_cache()1145 checkpoint::TensorSliceReaderCacheWrapper* slice_reader_cache() const {
1146 return params_->slice_reader_cache;
1147 }
1148
1149 // Execution.
1150 //
1151 // OpKernels can use these eigen devices to carry out their
1152 // numerical computation.
eigen_cpu_device()1153 const Eigen::ThreadPoolDevice& eigen_cpu_device() const {
1154 return *device()->eigen_cpu_device();
1155 }
eigen_gpu_device()1156 const Eigen::GpuDevice& eigen_gpu_device() const {
1157 return params_->eigen_gpu_device->device();
1158 }
1159 template <typename EigenDeviceType>
1160 const EigenDeviceType& eigen_device() const;
1161
1162 // Error handling.
1163
1164 // If expected_inputs == inputs() and expected_outputs == output_types(),
1165 // returns OK, else returns INVALID_ARGUMENT with an error message.
1166 // Recommended for Ops with dynamic signatures, where validation can only
1167 // be performed at runtime.
1168 Status MatchSignature(const DataTypeSlice expected_inputs,
1169 const DataTypeSlice expected_outputs);
1170
1171 // An OpKernel should call SetStatus() if Compute() encounters an
1172 // error.
1173 void SetStatus(const Status& status);
status()1174 const Status& status() const { return status_; }
1175
1176 // Cancellation.
1177 //
1178 // EXPERIMENTAL. See the implementation in tensorflow::FIFOQueue for an
1179 // example of how to use this API.
cancellation_manager()1180 CancellationManager* cancellation_manager() const {
1181 return params_->cancellation_manager;
1182 }
1183
1184 // Other accessors.
1185
1186 // For control flow.
frame_iter()1187 FrameAndIter frame_iter() const { return params_->frame_iter; }
is_input_dead()1188 bool is_input_dead() const { return params_->is_input_dead; }
1189
1190 // May be used, e.g., to get GPU handles, etc.
1191 // TODO(tucker): Add example usage.
device()1192 DeviceBase* device() const { return params_->device; }
1193
1194 // Per-step container for use by white-listed internal ops.
step_container()1195 ScopedStepContainer* step_container() const {
1196 return params_->step_container;
1197 }
1198
1199 // Helper routines for the OP_REQUIRES macros
1200 void CtxFailure(const Status& s);
1201 void CtxFailureWithWarning(const Status& s);
1202 void CtxFailure(const char* file, int line, const Status& s);
1203 void CtxFailureWithWarning(const char* file, int line, const Status& s);
1204
1205 // Unrecommended functions: these are functions that have some
1206 // current uses but are not recommended for use, and may go away at
1207 // some future major version release.
1208 //
1209 // The following functions all have versions that return Status
1210 // to capture error conditions, and are strongly preferred.
1211 Tensor* mutable_output(int index);
1212 mutex* input_ref_mutex(int index);
1213 void set_output_ref(int index, mutex* mu, Tensor* tensor_for_ref);
1214 TensorValue release_output(int index);
1215
track_allocations()1216 bool track_allocations() const { return params_->track_allocations; }
1217
1218 // Records temp memory allocation. Tensor object is recorded to identify the
1219 // case where temp memory is used as output memory.
1220 void record_temp_memory_allocation(int64 size, const Tensor& t)
1221 TF_LOCKS_EXCLUDED(tracking_state_->stats_mu);
1222
1223 // Returns recorded size of temporary memory;
1224 int64 temp_memory_allocated() const
1225 TF_LOCKS_EXCLUDED(tracking_state_->stats_mu);
1226
1227 // Records persistent memory allocation, size can be negative indicating
1228 // deallocation.
1229 void record_persistent_memory_allocation(int64 size, int64 alloc_id = -1)
1230 TF_LOCKS_EXCLUDED(tracking_state_->stats_mu);
1231
1232 // Returns recorded size and ids of persistent memory.
1233 int64 persistent_memory_allocated() const
1234 TF_LOCKS_EXCLUDED(tracking_state_->stats_mu);
1235
1236 std::vector<int64> persistent_alloc_ids() const
1237 TF_LOCKS_EXCLUDED(tracking_state_->stats_mu);
1238
1239 // Resets counters for temp and persistent memory and recorded ids.
1240 void clear_recorded_memory() TF_LOCKS_EXCLUDED(tracking_state_->stats_mu);
1241
1242 bool input_is_ref(int index) const;
1243
1244 void set_record_memory_consumption(bool v);
1245
1246 // Used by OpKernel implementations to track actively running deferred ops.
1247 //
1248 // A deferred op is one whose Compute method returns (or whose ComputeAsync
1249 // method invokes the callback) when work is scheduled onto a device. At that
1250 // point, we don't know when the work will actually complete (or if it has
1251 // already completed) on the device. These functions allow the executor to
1252 // track the status of deferred ops and act accordingly.
1253 //
1254 // Deferred OpKernel implementations must use these methods to get two
1255 // functions. It then must call these two functions in pairs, before and after
1256 // device execution, respectively.
inc_num_deferred_ops_function()1257 TF_MUST_USE_RESULT std::function<void()> inc_num_deferred_ops_function() {
1258 DCHECK(params_->op_kernel->is_deferred());
1259 return params_->inc_num_deferred_ops_function
1260 ? params_->inc_num_deferred_ops_function
1261 : []() {};
1262 }
dec_num_deferred_ops_function()1263 TF_MUST_USE_RESULT std::function<void()> dec_num_deferred_ops_function() {
1264 DCHECK(params_->op_kernel->is_deferred());
1265 return params_->dec_num_deferred_ops_function
1266 ? params_->dec_num_deferred_ops_function
1267 : []() {};
1268 }
1269
1270 Allocator* get_allocator(AllocatorAttributes attr);
1271
1272 private:
1273 bool record_memory_consumption_ = false;
1274
1275 // Internal common method used when allocating tensor memory
allocate_tensor(DataType type,const TensorShape & shape,Tensor * out_tensor,AllocatorAttributes allocator_attr)1276 Status allocate_tensor(DataType type, const TensorShape& shape,
1277 Tensor* out_tensor,
1278 AllocatorAttributes allocator_attr) {
1279 return allocate_tensor(type, shape, out_tensor, allocator_attr,
1280 AllocationAttributes());
1281 }
1282
1283 Status allocate_tensor(DataType type, const TensorShape& shape,
1284 Tensor* out_tensor, AllocatorAttributes allocator_attr,
1285 const AllocationAttributes& allocation_attr);
1286
1287 // Helpers for `set_output()`.
1288
1289 // Returns `true` if the tensor was copied into an allocated output.
1290 bool maybe_set_output_by_allocate_and_copy(int index, const Tensor& tensor);
1291
1292 void maybe_track_allocations_for_set_output(const Tensor& tensor);
1293
1294 Status get_input_index(StringPiece name, int* out_index) const;
1295 Status get_output_index(StringPiece name, int* out_index) const;
1296
1297 // Initialize the allocated_scope_ids_ set the first time this method is
1298 // called.
1299 void maybe_initialize_scope_id_set();
1300
1301 Status status_;
1302 friend class CollectiveExecutor; // for access to params_
1303 Params* params_; // not owned
1304 gtl::InlinedVector<TensorValue, 4> outputs_;
1305
1306 // Keep track of calls to ScopedAllocator.
1307 // TODO(ayushd): change to absl::flat_hash_set.
1308 std::unique_ptr<std::unordered_set<int32>> allocated_scope_ids_;
1309
1310 // The following data members are only used when allocation tracking is
1311 // enabled, memory consumption is being recorded, or tensor access is being
1312 // recorded.
1313 struct TrackingState {
1314 mutable mutex mu;
1315 gtl::InlinedVector<WrappedAllocator, 4> wrapped_allocators
1316 TF_GUARDED_BY(mu);
1317
1318 mutable mutex stats_mu;
1319 int64 temp_memory_allocated TF_GUARDED_BY(stats_mu) = 0;
1320
1321 int64 persistent_memory_allocated TF_GUARDED_BY(stats_mu) = 0;
1322 gtl::InlinedVector<std::pair<const void*, int64>, 2>
1323 temp_tensor_buffer_and_size TF_GUARDED_BY(stats_mu);
1324 gtl::InlinedVector<int64, 2> persistent_alloc_ids TF_GUARDED_BY(stats_mu);
1325 };
1326 std::unique_ptr<TrackingState> tracking_state_;
1327
1328 // For access to `params_->op_kernel`.
1329 friend void CheckNotInComputeAsync(OpKernelContext* ctx,
1330 const char* correct_macro_name);
1331
1332 TF_DISALLOW_COPY_AND_ASSIGN(OpKernelContext);
1333 };
1334
1335 template <>
1336 const Eigen::ThreadPoolDevice& OpKernelContext::eigen_device() const;
1337
1338 template <>
1339 const Eigen::GpuDevice& OpKernelContext::eigen_device() const;
1340
1341
1342 // Register your OpKernel by specifying the Op's name, the device the
1343 // kernel runs on, any type attr constraints for this kernel, any
1344 // host-memory args, and the class to instantiate. Examples:
1345 //
1346 // // A kernel that supports all types.
1347 // REGISTER_KERNEL_BUILDER(Name("Save").Device(DEVICE_CPU), SaveOp);
1348 //
1349 // // The following are equivalent ways of specifying that the kernel only
1350 // // works if the "T" type attr is set to DT_FLOAT.
1351 // REGISTER_KERNEL_BUILDER(
1352 // Name("Sub").Device(DEVICE_CPU).TypeConstraint<float>("T"),
1353 // SubOp<float>);
1354 // // (You would then repeat this for every type supported by "Sub".)
1355 //
1356 // // This form allows you to specify a list of types as the constraint.
1357 // REGISTER_KERNEL_BUILDER(Name("Sub")
1358 // .Device(DEVICE_CPU)
1359 // .TypeConstraint("T", {DT_FLOAT}),
1360 // SubOp<float>);
1361 //
1362 // // A kernel that expects one of the input tensors in host memory.
1363 // REGISTER_KERNEL_BUILDER(
1364 // Name("Reshape").Device(DEVICE_GPU).HostMemory("shape"), ReshapeOp);
1365 //
1366 // See kernel_def_builder for details.
1367
1368 // Instantiate an OpKernel that has been registered. Returns nullptr
1369 // if no operation for that type of device / input signature combination
1370 // (and a NOT_FOUND *status), or there is an error in construction (and
1371 // an INVALID_ARGUMENT *status). Otherwise, the caller takes ownership
1372 // of the returned pointer.
1373 // EXPECTED USAGE: unique_ptr<OpKernel> op = CreateOpKernel(...);
1374 // REQUIRES: def has all attrs specified (e.g. using AddDefaultsToNodeDef()).
1375 std::unique_ptr<OpKernel> CreateOpKernel(DeviceType device_type,
1376 DeviceBase* device,
1377 Allocator* allocator,
1378 const NodeDef& node_def,
1379 int graph_def_version, Status* status);
1380
1381 std::unique_ptr<OpKernel> CreateOpKernel(
1382 DeviceType device_type, DeviceBase* device, Allocator* allocator,
1383 const std::shared_ptr<const NodeProperties>& props, int graph_def_version,
1384 Status* status);
1385
1386 Status CreateOpKernel(DeviceType device_type, DeviceBase* device,
1387 Allocator* allocator, FunctionLibraryRuntime* flib,
1388 const std::shared_ptr<const NodeProperties>& props,
1389 int graph_def_version, OpKernel** kernel);
1390
1391 Status CreateOpKernel(DeviceType device_type, DeviceBase* device,
1392 Allocator* allocator, FunctionLibraryRuntime* flib,
1393 ResourceMgr* resource_mgr,
1394 const std::shared_ptr<const NodeProperties>& props,
1395 int graph_def_version, OpKernel** kernel);
1396
1397 // Returns into 'device_types' the subset of prioritized_types that this
1398 // binary has registered for the given NodeDef.
1399 //
1400 // REQUIRES: * 'device_types' is not nullptr.
1401 // * def has all attrs specified (e.g. using AddDefaultsToNodeDef()).
1402 Status SupportedDeviceTypesForNode(
1403 const std::vector<DeviceType>& prioritized_types, const NodeDef& def,
1404 PrioritizedDeviceTypeVector* device_types,
1405 const DeviceNameUtils::ParsedName* local_address_spec = nullptr);
1406
1407 // Returns a message with a description of the kernels registered for op
1408 // `op_name`.
1409 std::string KernelsRegisteredForOp(StringPiece op_name);
1410
1411 // Call once after Op registration has completed.
1412 Status ValidateKernelRegistrations(const OpRegistryInterface& op_registry);
1413
1414 // -----------------------------------------------------------------------------
1415 // OpKernel registration implementation follows, please ignore.
1416
1417 // Allow the REGISTER_KERNEL_BUILDER(Name("op_name").Device(...)...) syntax.
1418 namespace register_kernel {
1419
1420 class Name : public KernelDefBuilder {
1421 public:
Name(const char * op)1422 explicit Name(const char* op) : KernelDefBuilder(op) {}
1423 };
1424
1425 } // namespace register_kernel
1426
1427 // Kernel registration appears as:
1428 // REGISTER_KERNEL_BUILDER(Name("OpName").Device(DEVICE_CPU)..., OpImpl)
1429 // We'd like to have "OpName" as a constant-expression, without requiring that
1430 // of the overall KernelDefBuilder expression (beginning with the
1431 // register_kernel::Name constructor above).
1432 //
1433 // So, we pull the "OpName" part to a separate macro-level argument. This
1434 // involves treating Name("OpName") as a macro call, via token-pasting (e.g.
1435 // M_## => M_Name("OpName")), and having it expand to '"OpName",
1436 // Name("OpName")' which is then usable as two arguments.
1437 #define TF_EXTRACT_KERNEL_NAME_Name(name_str) \
1438 name_str, ::tensorflow::register_kernel::Name(name_str)
1439 #define TF_EXTRACT_KERNEL_NAME_IMPL(m, ...) m(__VA_ARGS__)
1440 #define TF_EXTRACT_KERNEL_NAME(m, kernel_builder, ...) \
1441 TF_EXTRACT_KERNEL_NAME_IMPL(m, TF_EXTRACT_KERNEL_NAME_##kernel_builder, \
1442 __VA_ARGS__)
1443
1444 // REGISTER_KERNEL_BUILDER_IMPL_2, with a unique 'ctr' as the first argument.
1445 // TODO(dodgen): There are some uses of this macro inside functions, where
1446 // kernel_builder refers to (non-const) locals (they should be fixed). To
1447 // accommodate those, kernel_builder.Build() appears as an argument to an
1448 // immediately-called lambda (not in the lambda itself).
1449 #define REGISTER_KERNEL_BUILDER_IMPL_3(ctr, op_name, kernel_builder_expr, \
1450 is_system_kernel, ...) \
1451 static ::tensorflow::InitOnStartupMarker const register_kernel_##ctr \
1452 TF_ATTRIBUTE_UNUSED = \
1453 TF_INIT_ON_STARTUP_IF(is_system_kernel || \
1454 (SHOULD_REGISTER_OP_KERNEL(#__VA_ARGS__) && \
1455 SHOULD_REGISTER_OP(op_name))) \
1456 << ([](::tensorflow::KernelDef const* kernel_def) { \
1457 ::tensorflow::kernel_factory::OpKernelRegistrar registrar( \
1458 kernel_def, #__VA_ARGS__, \
1459 [](::tensorflow::OpKernelConstruction* context) \
1460 -> ::tensorflow::OpKernel* { \
1461 return new __VA_ARGS__(context); \
1462 }); \
1463 (void)registrar; \
1464 return ::tensorflow::InitOnStartupMarker{}; \
1465 })(kernel_builder_expr.Build());
1466
1467 // REGISTER_KERNEL_BUILDER_IMPL, but with kernel_builder split to op_name,
1468 // kernel_builder_expr.
1469 #define REGISTER_KERNEL_BUILDER_IMPL_2(op_name, kernel_builder_expr, \
1470 is_system_kernel, ...) \
1471 TF_NEW_ID_FOR_INIT(REGISTER_KERNEL_BUILDER_IMPL_3, op_name, \
1472 kernel_builder_expr, is_system_kernel, __VA_ARGS__)
1473
1474 // REGISTER_KERNEL_BUILDER, but with is_system_kernel bound.
1475 #define REGISTER_KERNEL_BUILDER_IMPL(kernel_builder, is_system_kernel, ...) \
1476 TF_EXTRACT_KERNEL_NAME(REGISTER_KERNEL_BUILDER_IMPL_2, kernel_builder, \
1477 is_system_kernel, __VA_ARGS__)
1478
1479 #define REGISTER_KERNEL_BUILDER(kernel_builder, ...) \
1480 TF_ATTRIBUTE_ANNOTATE("tf:kernel") \
1481 REGISTER_KERNEL_BUILDER_IMPL(kernel_builder, false, __VA_ARGS__)
1482
1483 // The `REGISTER_SYSTEM_KERNEL_BUILDER()` macro acts as
1484 // `REGISTER_KERNEL_BUILDER()` except that the kernel is registered
1485 // unconditionally even when selective registration is used.
1486 #define REGISTER_SYSTEM_KERNEL_BUILDER(kernel_builder, ...) \
1487 TF_ATTRIBUTE_ANNOTATE("tf:kernel") \
1488 TF_ATTRIBUTE_ANNOTATE("tf:kernel:system") \
1489 REGISTER_KERNEL_BUILDER_IMPL(kernel_builder, true, __VA_ARGS__)
1490
1491 // Checks whether a given kernel is registered on device_type.
1492 bool KernelDefAvailable(const DeviceType& device_type, const NodeDef& node_def);
1493
1494 // If node of node_name, experimental_debug_info, node_op, node_device and
1495 // node_attrs has a corresponding kernel registered on device_type, returns OK
1496 // and fill in the kernel def and kernel_class_name. <def> and
1497 // <kernel_class_name> may be null.
1498 Status FindKernelDef(
1499 const DeviceType& device_type, StringPiece node_name,
1500 bool has_experimental_debug_info,
1501 const NodeDef_ExperimentalDebugInfo& experimental_debug_info,
1502 StringPiece node_op, StringPiece node_device, AttrSlice node_attrs,
1503 const KernelDef** def, std::string* kernel_class_name);
1504
1505 // If node_def has a corresponding kernel registered on device_type,
1506 // returns OK and fill in the kernel def and kernel_class_name. <def> and
1507 // <kernel_class_name> may be null.
1508 Status FindKernelDef(const DeviceType& device_type, const NodeDef& node_def,
1509 const KernelDef** def, std::string* kernel_class_name);
1510
1511 // Writes a list of all registered kernels to LOG(INFO), to help users debug
1512 // missing kernel errors.
1513 void LogAllRegisteredKernels();
1514
1515 // Gets a list of all registered kernels.
1516 KernelList GetAllRegisteredKernels();
1517
1518 // Gets a list of all registered kernels for which predicate returns true
1519 KernelList GetFilteredRegisteredKernels(
1520 const std::function<bool(const KernelDef&)>& predicate);
1521
1522 // Gets a list of all registered kernels for a given op
1523 KernelList GetRegisteredKernelsForOp(StringPiece op_name);
1524
1525 namespace kernel_factory {
1526
1527 // OpKernelFactory is responsible for creating OpKernels when TensorFlow needs
1528 // them. You register factories with the TensorFlow core by constructing an
1529 // OpKernelRegistrar and passing the factory as a constructor parameter.
1530 class OpKernelFactory {
1531 public:
1532 virtual OpKernel* Create(OpKernelConstruction* context) = 0;
1533 virtual ~OpKernelFactory() = default;
1534 };
1535
1536 class OpKernelRegistrar {
1537 public:
1538 // Registers the given kernel factory with TensorFlow. TF will call the
1539 // factory Create() method when it determines that a kernel matching the given
1540 // KernelDef is required.
OpKernelRegistrar(const KernelDef * kernel_def,StringPiece kernel_class_name,std::unique_ptr<OpKernelFactory> factory)1541 OpKernelRegistrar(const KernelDef* kernel_def, StringPiece kernel_class_name,
1542 std::unique_ptr<OpKernelFactory> factory) {
1543 InitInternal(kernel_def, kernel_class_name, std::move(factory));
1544 }
1545
1546 // Registers the given factory function with TensorFlow. This is equivalent
1547 // to registering a factory whose Create function invokes `create_fn`.
OpKernelRegistrar(const KernelDef * kernel_def,StringPiece kernel_class_name,OpKernel * (* create_fn)(OpKernelConstruction *))1548 OpKernelRegistrar(const KernelDef* kernel_def, StringPiece kernel_class_name,
1549 OpKernel* (*create_fn)(OpKernelConstruction*)) {
1550 InitInternal(kernel_def, kernel_class_name,
1551 absl::make_unique<PtrOpKernelFactory>(create_fn));
1552 }
1553
1554 private:
1555 struct PtrOpKernelFactory : public OpKernelFactory {
PtrOpKernelFactoryPtrOpKernelFactory1556 explicit PtrOpKernelFactory(OpKernel* (*create_func)(OpKernelConstruction*))
1557 : create_func_(create_func) {}
1558
1559 OpKernel* Create(OpKernelConstruction* context) override;
1560
1561 OpKernel* (*create_func_)(OpKernelConstruction*);
1562 };
1563
1564 void InitInternal(const KernelDef* kernel_def, StringPiece kernel_class_name,
1565 std::unique_ptr<OpKernelFactory> factory);
1566 };
1567
1568 } // namespace kernel_factory
1569
1570 // -----------------------------------------------------------------------------
1571 // Template and inline method implementations, please ignore
1572
1573 template <class T>
GetAttr(StringPiece attr_name,T * value)1574 Status OpKernelConstruction::GetAttr(StringPiece attr_name, T* value) const {
1575 return GetNodeAttr(def(), attr_name, value);
1576 }
1577
input_dtype(int index)1578 inline DataType OpKernelContext::input_dtype(int index) const {
1579 DCHECK_GE(index, 0);
1580 DCHECK_LT(index, num_inputs());
1581 const TensorValue& value((*params_->inputs)[index]);
1582 return value.dtype();
1583 }
1584
input_memory_type(int index)1585 inline MemoryType OpKernelContext::input_memory_type(int index) const {
1586 DCHECK_GE(index, 0);
1587 DCHECK_LT(index, num_inputs());
1588 return op_kernel().input_memory_types()[index];
1589 }
1590
expected_output_dtype(int index)1591 inline DataType OpKernelContext::expected_output_dtype(int index) const {
1592 DCHECK_GE(index, 0);
1593 DCHECK_LT(index, num_outputs());
1594 return params_->op_kernel->output_type(index);
1595 }
1596
output_memory_type(int index)1597 inline MemoryType OpKernelContext::output_memory_type(int index) const {
1598 DCHECK_GE(index, 0);
1599 DCHECK_LT(index, num_outputs());
1600 return op_kernel().output_memory_types()[index];
1601 }
1602
input_is_ref(int index)1603 inline bool OpKernelContext::input_is_ref(int index) const {
1604 const TensorValue& value((*params_->inputs)[index]);
1605 return value.is_ref();
1606 }
1607
1608 // no input if tensor == nullptr.
has_input(int index)1609 inline bool OpKernelContext::has_input(int index) const {
1610 DCHECK_GE(index, 0);
1611 DCHECK_LT(index, num_inputs());
1612 return (*params_->inputs)[index].tensor != nullptr;
1613 }
1614
input_ref_mutex(int index)1615 inline mutex* OpKernelContext::input_ref_mutex(int index) {
1616 DCHECK_GE(index, 0);
1617 DCHECK_LT(index, num_inputs());
1618 DCHECK(input_is_ref(index));
1619 return (*params_->inputs)[index].mutex_if_ref;
1620 }
1621
mutable_output(int index)1622 inline Tensor* OpKernelContext::mutable_output(int index) {
1623 DCHECK_GE(index, 0);
1624 DCHECK_LT(index, num_outputs());
1625 return outputs_[index].tensor;
1626 }
1627
release_output(int index)1628 inline TensorValue OpKernelContext::release_output(int index) {
1629 DCHECK_GE(index, 0);
1630 DCHECK_LT(index, num_outputs());
1631 TensorValue value = outputs_[index];
1632 outputs_[index] = TensorValue();
1633 return value;
1634 }
1635
forward_input_or_allocate_output(gtl::ArraySlice<int> candidate_input_indices,int output_index,const TensorShape & output_shape,Tensor ** output,int * forwarded_input)1636 inline Status OpKernelContext::forward_input_or_allocate_output(
1637 gtl::ArraySlice<int> candidate_input_indices, int output_index,
1638 const TensorShape& output_shape, Tensor** output, int* forwarded_input) {
1639 for (int input_index : candidate_input_indices) {
1640 if (forward_input_to_output_with_shape(input_index, output_index,
1641 output_shape, output)) {
1642 if (forwarded_input != nullptr) {
1643 *forwarded_input = input_index;
1644 }
1645 return Status::OK();
1646 }
1647 }
1648 if (forwarded_input != nullptr) {
1649 *forwarded_input = -1;
1650 }
1651 return allocate_output(output_index, output_shape, output);
1652 }
1653
forward_input_or_allocate_output(gtl::ArraySlice<StringPiece> candidate_input_names,StringPiece output_name,const TensorShape & output_shape,Tensor ** output)1654 inline Status OpKernelContext::forward_input_or_allocate_output(
1655 gtl::ArraySlice<StringPiece> candidate_input_names, StringPiece output_name,
1656 const TensorShape& output_shape, Tensor** output) {
1657 for (const StringPiece& input_name : candidate_input_names) {
1658 if (forward_input_to_output_with_shape(input_name, output_name,
1659 output_shape, output)
1660 .ok()) {
1661 return Status::OK();
1662 }
1663 }
1664 return allocate_output(output_name, output_shape, output);
1665 }
1666
1667 template <typename T>
op_device_context()1668 T* OpKernelContext::op_device_context() {
1669 static_assert(std::is_base_of<DeviceContext, T>::value,
1670 "T is not a subclass of DeviceContext");
1671 return static_cast<T*>(op_device_context());
1672 }
1673
1674 inline const Tensor& OpInputList::operator[](int i) const {
1675 DCHECK_GE(i, 0);
1676 DCHECK_LT(i, stop_ - start_);
1677 return ctx_->input(start_ + i);
1678 }
1679
ref_mutex(int i)1680 inline mutex* OpMutableInputList::ref_mutex(int i) {
1681 DCHECK_GE(i, 0);
1682 DCHECK_LT(i, stop_ - start_);
1683 return ctx_->input_ref_mutex(start_ + i);
1684 }
1685
at(int i,bool lock_held)1686 inline Tensor OpMutableInputList::at(int i, bool lock_held) {
1687 DCHECK_GE(i, 0);
1688 DCHECK_LT(i, stop_ - start_);
1689 return ctx_->mutable_input(start_ + i, lock_held);
1690 }
1691
1692 inline Tensor* OpOutputList::operator[](int i) {
1693 DCHECK_GE(i, 0);
1694 DCHECK_LT(i, stop_ - start_);
1695 return ctx_->mutable_output(start_ + i);
1696 }
1697
required(int i)1698 inline bool OpOutputList::required(int i) const {
1699 DCHECK_GE(i, 0);
1700 DCHECK_LT(i, stop_ - start_);
1701 return ctx_->output_required(start_ + i);
1702 }
1703
expected_output_dtype(int i)1704 inline DataType OpOutputList::expected_output_dtype(int i) const {
1705 DCHECK_GE(i, 0);
1706 DCHECK_LT(i, stop_ - start_);
1707 return ctx_->expected_output_dtype(start_ + i);
1708 }
1709
allocate(int i,const TensorShape & shape,Tensor ** output)1710 inline Status OpOutputList::allocate(int i, const TensorShape& shape,
1711 Tensor** output) {
1712 DCHECK_GE(i, 0);
1713 DCHECK_LT(i, stop_ - start_);
1714 return ctx_->allocate_output(start_ + i, shape, output);
1715 }
1716
set(int i,const Tensor & tensor)1717 inline void OpOutputList::set(int i, const Tensor& tensor) {
1718 DCHECK_GE(i, 0);
1719 DCHECK_LT(i, stop_ - start_);
1720 ctx_->set_output(start_ + i, tensor);
1721 }
1722
set(int i,Tensor && tensor)1723 inline void OpOutputList::set(int i, Tensor&& tensor) {
1724 DCHECK_GE(i, 0);
1725 DCHECK_LT(i, stop_ - start_);
1726 ctx_->set_output(start_ + i, std::move(tensor));
1727 }
1728
set_ref(int i,mutex * mu,Tensor * tensor_for_ref)1729 inline void OpOutputList::set_ref(int i, mutex* mu, Tensor* tensor_for_ref) {
1730 DCHECK_GE(i, 0);
1731 DCHECK_LT(i, stop_ - start_);
1732 ctx_->set_output_ref(i, mu, tensor_for_ref);
1733 }
1734
1735 // Generate a fatal error if OP_REQUIRES or OP_REQUIRES_OK are used in
1736 // AsyncOpKernel implementations. If these macros are used and the condition
1737 // does not hold, the `done` callback will never be called and the system will
1738 // deadlock, so a crash failure is preferable. Since the OP_REQUIRES[_OK] macros
1739 // are legal to use in AsyncOpKernel constructors, we use overload resolution
1740 // to distinguish between OpKernelConstruction* and OpKernelContext* context
1741 // types.
1742 class XlaOpKernelContext;
CheckNotInComputeAsync(XlaOpKernelContext *,const char *)1743 inline void CheckNotInComputeAsync(XlaOpKernelContext*, const char*) {}
CheckNotInComputeAsync(OpKernelConstruction *,const char *)1744 inline void CheckNotInComputeAsync(OpKernelConstruction*, const char*) {}
1745 void CheckNotInComputeAsync(OpKernelContext* ctx,
1746 const char* correct_macro_name);
1747
1748 } // namespace tensorflow
1749
1750 #endif // TENSORFLOW_CORE_FRAMEWORK_OP_KERNEL_H_
1751