1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_KERNELS_DATA_ITERATOR_OPS_H_ 17 #define TENSORFLOW_CORE_KERNELS_DATA_ITERATOR_OPS_H_ 18 19 #include "tensorflow/core/common_runtime/function.h" 20 #include "tensorflow/core/framework/dataset.h" 21 #include "tensorflow/core/framework/op_kernel.h" 22 #include "tensorflow/core/framework/tensor_shape.h" 23 #include "tensorflow/core/framework/types.h" 24 #include "tensorflow/core/kernels/ops_util.h" 25 26 namespace tensorflow { 27 namespace data { 28 29 class IteratorResource; 30 31 class IteratorHandleOp : public OpKernel { 32 public: 33 explicit IteratorHandleOp(OpKernelConstruction* ctx); 34 35 // The resource is deleted from the resource manager only when it is private 36 // to kernel. Ideally the resource should be deleted when it is no longer held 37 // by anyone, but it would break backward compatibility. 38 ~IteratorHandleOp() override; 39 40 void Compute(OpKernelContext* context) override LOCKS_EXCLUDED(mu_); 41 42 private: 43 // During the first Compute(), resource is either created or looked up using 44 // shared_name. In the latter case, the resource found should be verified if 45 // it is compatible with this op's configuration. The verification may fail in 46 // cases such as two graphs asking queues of the same shared name to have 47 // inconsistent capacities. 48 Status VerifyResource(IteratorResource* resource); 49 50 template <typename To, typename From> // use like this: down_cast<T*>(foo); down_cast(From * f)51 static inline To down_cast(From* f) { // so we only accept pointers 52 static_assert( 53 (std::is_base_of<From, typename std::remove_pointer<To>::type>::value), 54 "target type not derived from source type"); 55 56 // We skip the assert and hence the dynamic_cast if RTTI is disabled. 57 #if !defined(__GNUC__) || defined(__GXX_RTTI) 58 // Uses RTTI in dbg and fastbuild. asserts are disabled in opt builds. 59 assert(f == nullptr || dynamic_cast<To>(f) != nullptr); 60 #endif // !defined(__GNUC__) || defined(__GXX_RTTI) 61 return static_cast<To>(f); 62 } 63 64 FunctionLibraryRuntime* CreatePrivateFLR( 65 OpKernelContext* ctx, std::unique_ptr<DeviceMgr>* device_mgr, 66 std::unique_ptr<FunctionLibraryDefinition>* flib_def, 67 std::unique_ptr<ProcessFunctionLibraryRuntime>* pflr); 68 69 mutex mu_; 70 ContainerInfo cinfo_; // Written once under mu_ then constant afterwards. 71 IteratorResource* resource_ GUARDED_BY(mu_) = nullptr; 72 DataTypeVector output_dtypes_; 73 std::vector<PartialTensorShape> output_shapes_; 74 const int graph_def_version_; 75 string name_; 76 }; 77 78 // Like IteratorHandleOp, but creates handles which are never shared, and does 79 // not hold a reference to these handles. The latter is important for eager 80 // execution, since OpKernel instances generally live as long as the program 81 // running them. 82 class AnonymousIteratorHandleOp : public OpKernel { 83 public: 84 explicit AnonymousIteratorHandleOp(OpKernelConstruction* context); 85 86 void Compute(OpKernelContext* context) override; 87 88 private: 89 // Coordinates Iterator unique name creation across AnonymousIteratorHandleOp 90 // instances. 91 static mutex static_resource_lookup_mutex_; 92 // current_id_ is just a hint for creating unique names. If it turns out 93 // there's a collision (e.g. because another AnonymousIteratorHandleOp 94 // instance is generating handles) we'll just skip that id. 95 static int64 current_id_ GUARDED_BY(static_resource_lookup_mutex_); 96 DataTypeVector output_dtypes_; 97 std::vector<PartialTensorShape> output_shapes_; 98 const int graph_def_version_; 99 }; 100 101 class MakeIteratorOp : public OpKernel { 102 public: MakeIteratorOp(OpKernelConstruction * ctx)103 explicit MakeIteratorOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} 104 105 void Compute(OpKernelContext* ctx) override; 106 }; 107 108 class IteratorGetNextOp : public AsyncOpKernel { 109 public: IteratorGetNextOp(OpKernelConstruction * ctx)110 explicit IteratorGetNextOp(OpKernelConstruction* ctx) 111 : AsyncOpKernel(ctx), 112 background_worker_(ctx->env(), "tf_data_iterator_get_next") {} 113 114 void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override; 115 116 private: 117 BackgroundWorker background_worker_; 118 }; 119 120 class IteratorGetNextAsOptionalOp : public AsyncOpKernel { 121 public: IteratorGetNextAsOptionalOp(OpKernelConstruction * ctx)122 explicit IteratorGetNextAsOptionalOp(OpKernelConstruction* ctx) 123 : AsyncOpKernel(ctx), 124 background_worker_(ctx->env(), 125 "tf_data_iterator_get_next_as_optional") { 126 OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); 127 OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); 128 } 129 130 void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override; 131 132 private: 133 BackgroundWorker background_worker_; 134 DataTypeVector output_types_; 135 std::vector<PartialTensorShape> output_shapes_; 136 }; 137 138 class IteratorGetNextSyncOp : public OpKernel { 139 public: IteratorGetNextSyncOp(OpKernelConstruction * ctx)140 explicit IteratorGetNextSyncOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} 141 142 void Compute(OpKernelContext* ctx) override; 143 }; 144 145 class IteratorToStringHandleOp : public OpKernel { 146 public: IteratorToStringHandleOp(OpKernelConstruction * ctx)147 explicit IteratorToStringHandleOp(OpKernelConstruction* ctx) 148 : OpKernel(ctx) {} 149 150 void Compute(OpKernelContext* ctx) override; 151 }; 152 153 class IteratorFromStringHandleOp : public OpKernel { 154 public: 155 explicit IteratorFromStringHandleOp(OpKernelConstruction* ctx); 156 157 void Compute(OpKernelContext* ctx) override; 158 159 private: 160 DataTypeVector output_dtypes_; 161 std::vector<PartialTensorShape> output_shapes_; 162 }; 163 164 } // namespace data 165 } // namespace tensorflow 166 167 #endif // TENSORFLOW_CORE_KERNELS_DATA_ITERATOR_OPS_H_ 168