1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_KERNELS_DATA_DATASET_UTILS_H_
16 #define TENSORFLOW_CORE_KERNELS_DATA_DATASET_UTILS_H_
17
18 #include "tensorflow/core/common_runtime/function.h"
19 #include "tensorflow/core/framework/dataset.h"
20 #include "tensorflow/core/framework/function.h"
21 #include "tensorflow/core/framework/resource_mgr.h"
22 #include "tensorflow/core/framework/tensor.h"
23
24 namespace tensorflow {
25 namespace data {
26
27 // Creates a resource handle with a unique name for the given resource.
28 template <typename T>
CreateHandle(OpKernelContext * ctx,T * resource,const string & container_name,ResourceHandle * handle)29 Status CreateHandle(OpKernelContext* ctx, T* resource,
30 const string& container_name, ResourceHandle* handle) {
31 static std::atomic<int64> resource_id_counter(0);
32 string unique_name =
33 strings::StrCat(container_name, resource_id_counter.fetch_add(1));
34 ResourceMgr* mgr = ctx->resource_manager();
35 TF_RETURN_IF_ERROR(mgr->Create<T>(container_name, unique_name, resource));
36
37 *handle = MakeResourceHandle(container_name, unique_name, *ctx->device(),
38 TypeIndex::Make<T>());
39 return Status::OK();
40 }
41
42 template <typename T>
43 class AnonymousResourceOp : public OpKernel {
44 public:
AnonymousResourceOp(OpKernelConstruction * context)45 explicit AnonymousResourceOp(OpKernelConstruction* context)
46 : OpKernel(context) {}
47
Compute(OpKernelContext * ctx)48 void Compute(OpKernelContext* ctx) override {
49 FunctionLibraryRuntime* lib;
50 std::unique_ptr<FunctionLibraryDefinition> flib_def(nullptr);
51 std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(nullptr);
52 OP_REQUIRES_OK(
53 ctx, ctx->function_library()->Clone(&flib_def, &pflr, &lib, true));
54 T* resource;
55 OP_REQUIRES_OK(ctx, CreateResource(ctx, std::move(flib_def),
56 std::move(pflr), lib, &resource));
57
58 ResourceHandle handle;
59 OP_REQUIRES_OK(ctx, CreateHandle(ctx, resource, name(), &handle));
60 Tensor* handle_t;
61 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &handle_t));
62 handle_t->scalar<ResourceHandle>()() = handle;
63
64 if (create_deleter_) {
65 Tensor* deleter_t;
66 AllocatorAttributes attr;
67 attr.set_on_host(true);
68 OP_REQUIRES_OK(
69 ctx, ctx->allocate_output(1, TensorShape({}), &deleter_t, attr));
70 deleter_t->scalar<Variant>()() =
71 ResourceDeleter(handle, ctx->resource_manager());
72 }
73 }
74
75 protected:
76 virtual string name() = 0;
77
78 virtual Status CreateResource(
79 OpKernelContext* ctx, std::unique_ptr<FunctionLibraryDefinition> flib_def,
80 std::unique_ptr<ProcessFunctionLibraryRuntime> pflr,
81 FunctionLibraryRuntime* lib, T** resource) = 0;
82
83 bool create_deleter_ = true;
84 };
85
86 // Returns Status::OK() if `expected` and `received` types match,
87 // errors::InvalidArgument otherwise.
88 Status VerifyTypesMatch(const DataTypeVector& expected,
89 const DataTypeVector& received);
90
91 Status VerifyTypesMatch(const DataTypeVector& expected,
92 const std::vector<Tensor>& received);
93
94 // Returns Status::OK() if `expected` and `received` shapes are compatible,
95 // errors::InvalidArgument otherwise.
96 Status VerifyShapesCompatible(const std::vector<PartialTensorShape>& expected,
97 const std::vector<PartialTensorShape>& received);
98
99 Status VerifyShapesCompatible(const std::vector<PartialTensorShape>& expected,
100 const std::vector<Tensor>& received);
101
102 // Writes dataset elements to the checkpoint writer using the given key prefix.
103 // The elements can be read back by passing the same key prefix to
104 // ReadElementsFromCheckpoint. Only one list of elements can be written under
105 // the same key_prefix.
106 Status WriteElementsToCheckpoint(
107 IteratorStateWriter* writer, StringPiece key_prefix,
108 const std::vector<std::vector<Tensor>>& elements);
109
110 // Reads dataset elements from the checkpoint reader using the given key prefix.
111 Status ReadElementsFromCheckpoint(IteratorStateReader* reader,
112 StringPiece key_prefix,
113 std::vector<std::vector<Tensor>>* elements);
114
115 // Dataset op level determinism policy.
116 class DeterminismPolicy {
117 public:
118 enum class Type : int {
119 // The op must produce elements deterministically.
120 kDeterministic,
121 // The op may relax determinism to improve performance.
122 kNondeterministic,
123 // The determinism policy is not specified at the op level. In this case we
124 // use the experimental_deterministic dataset option to determine the
125 // determinism policy.
126 kDefault,
127 };
128 static constexpr const char* const kDeterministic = "true";
129 static constexpr const char* const kNondeterministic = "false";
130 static constexpr const char* const kDefault = "default";
131
DeterminismPolicy()132 DeterminismPolicy() : determinism_(Type::kDefault) {}
DeterminismPolicy(Type determinism)133 explicit DeterminismPolicy(Type determinism) : determinism_(determinism) {}
134 // Creates a DeterminismPolicy with Type kDeterministic or
135 // kNondeterministic, depending on the values of `is_deterministic`.
136 explicit DeterminismPolicy(bool is_deterministic);
137
138 static Status FromString(const std::string& s, DeterminismPolicy* out);
139
140 // Returns the string representing the determinism policy. This will be one of
141 // the string constants defined above.
142 std::string String() const;
143
144 /// Convenience methods for checking the DeterminismPolicy::Type.
IsDeterministic()145 bool IsDeterministic() const { return determinism_ == Type::kDeterministic; }
IsNondeterministic()146 bool IsNondeterministic() const {
147 return determinism_ == Type::kNondeterministic;
148 }
IsDefault()149 bool IsDefault() const { return determinism_ == Type::kDefault; }
150
151 private:
152 Type determinism_;
153 };
154
155 // Resolves non-deterministic seeds if necessary, returning either the original
156 // seeds or the resolved seeds.
157 //
158 // By TensorFlow convention, if both seeds are 0, they should be replaced with
159 // non-deterministically chosen seeds.
160 std::pair<int64, int64> MaybeOverrideSeeds(std::pair<int64, int64> seeds);
161
162 // Helper class for reading data from a vector of VariantTensorData objects.
163 class VariantTensorDataReader : public IteratorStateReader {
164 public:
165 explicit VariantTensorDataReader(
166 const std::vector<const VariantTensorData*>& data);
167
168 Status ReadScalar(StringPiece key, int64* val) const override;
169 Status ReadScalar(StringPiece key, tstring* val) const override;
170 Status ReadTensor(StringPiece key, Tensor* val) const override;
171 bool Contains(StringPiece key) const override;
172
173 Status ReadScalar(StringPiece name, StringPiece key,
174 int64* val) const override;
175 Status ReadScalar(StringPiece name, StringPiece key,
176 tstring* val) const override;
177 Status ReadTensor(StringPiece name, StringPiece key,
178 Tensor* val) const override;
179 bool Contains(StringPiece name, StringPiece key) const override;
180
181 private:
182 template <typename T>
183 Status ReadScalarInternal(StringPiece key, T* val) const;
184 Status ReadTensorInternal(StringPiece key, Tensor* val) const;
185
186 template <typename T>
187 Status ReadScalarInternal(StringPiece name, StringPiece key, T* val) const;
188 Status ReadTensorInternal(StringPiece name, StringPiece key,
189 Tensor* val) const;
190
191 std::map<string, std::map<string, size_t>> map_;
192 std::map<string, const VariantTensorData*> data_; // Not owned.
193 };
194
195 // Helper class used to build a list of VariantTensorData objects, one for each
196 // iterator which is determined from the key supplied from the Write* calls.
197 // Sample usage:
198 // VariantTensorDataWriter writer;
199 // writer.WriteScalar(full_name("buffer_size"), buffer_.size());
200 // writer.WriteScalar(full_name("num_threads"), threadpool_.size());
201 // ....
202 // std::vector<std::unique_ptr<VariantTensorData>> variants;
203 // writer.ReleaseData(&variants);
204 // Now the VariantTensorData objects can be used to serialize.
205 class VariantTensorDataWriter : public IteratorStateWriter {
206 public:
207 Status WriteScalar(StringPiece key, const int64 val) override;
208 Status WriteScalar(StringPiece key, const tstring& val) override;
209 Status WriteTensor(StringPiece key, const Tensor& val) override;
210
211 Status WriteScalar(StringPiece name, StringPiece key,
212 const int64 val) override;
213 Status WriteScalar(StringPiece name, StringPiece key,
214 const tstring& val) override;
215 Status WriteTensor(StringPiece name, StringPiece key,
216 const Tensor& val) override;
217
218 // Releases the built VariantTensorData's to `variants`. Clears out all
219 // class state.
220 void ReleaseData(std::vector<std::unique_ptr<VariantTensorData>>* variants);
221
222 // Obtains a read-only version of the VariantTensorData's built.
223 void GetData(std::vector<const VariantTensorData*>* variants);
224
225 private:
226 void MaybeFlush();
227 void Reset();
228
229 template <typename T>
230 Status WriteScalarInternal(StringPiece key, const T& val);
231 Status WriteTensorInternal(StringPiece key, const Tensor& val);
232
233 template <typename T>
234 Status WriteScalarInternal(StringPiece name, StringPiece key, const T& val);
235 Status WriteTensorInternal(StringPiece name, StringPiece key,
236 const Tensor& val);
237
238 bool is_flushed_ = false;
239 std::map<string, std::unique_ptr<VariantTensorData>> data_;
240 std::map<string, std::vector<string>> keys_;
241 };
242
243 // Adds the functions in `to_add` to `base`. If a function with a matching
244 // signature already exists in `base`, replaces it with the function from
245 // `to_add`.
246 Status AddToFunctionLibrary(FunctionLibraryDefinition* base,
247 const FunctionLibraryDefinition& to_add);
248 Status AddToFunctionLibrary(FunctionLibraryDefinition* base,
249 const FunctionDefLibrary& to_add);
250
251 // Creates a runner that runs functions with limited parallelism.
252 std::function<void(std::function<void()>)> RunnerWithMaxParallelism(
253 std::function<void(std::function<void()>)> runner, int max_parallelism);
254
255 // Op for creating a typed dummy resource.
256 //
257 // This op is used to provide a resource "placeholder" for ops such as
258 // `CacheDatasetV2` or `ShuffleDatasetV2` that expects a resource input.
259 // Originally, the lifetime of the resources passed into these ops was managed
260 // externally. After the implementation changed to manage the lifetime of the
261 // resources (including creation) by the ops themselves, the resource input is
262 // only needed to pass a resource handle through graph rewrites. When they are
263 // invoked from user code, the implementation passes in a dummy resource.
264 template <typename ResourceType>
265 class DummyResourceOp : public OpKernel {
266 public:
DummyResourceOp(OpKernelConstruction * ctx)267 explicit DummyResourceOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
268
Compute(OpKernelContext * ctx)269 void Compute(OpKernelContext* ctx) override {
270 Tensor* tensor;
271 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &tensor));
272 tensor->scalar<ResourceHandle>()() = MakeResourceHandle<ResourceType>(
273 ctx, /*container=*/"", /*name=*/"dummy_resource");
274 }
275 };
276
277 // Given an op prefix and an op to match, returns whether the op to match
278 // is a match for any version of the op prefix. For example,
279 // MatchesAnyVersion("BatchDataset", "BatchDataset") == true
280 // MatchesAnyVersion("BatchDataset", "BatchDatasetV2") == true
281 // MatchesAnyVersion("BatchDataset", "BatchDatasetV3") == true
282 // MatchesAnyVersion("PaddedBatchDataset", "BatchDataset") == false
283 bool MatchesAnyVersion(StringPiece op_prefix, StringPiece op_to_match);
284
285 // Based on `job_name`, `optimizations_enabled`, `optimizations_disabled` and
286 // `optimizations_default`, returns the list of optimizations that will be
287 // applied.
288 std::vector<tstring> SelectOptimizations(
289 const string& job_name,
290 const absl::flat_hash_map<string, uint64>& live_experiments,
291 const std::vector<tstring>& optimizations_enabled,
292 const std::vector<tstring>& optimizations_disabled,
293 const std::vector<tstring>& optimizations_default,
294 std::function<uint64(const string&)> hash_func);
295
296 // Removes device placements from the ops of all functions in `library`.
297 void StripDevicePlacement(FunctionDefLibrary* library);
298
299 // Copies partial of the batch output.
300 Status CopyPartialBatch(int64 num_elements, const Tensor& value,
301 Tensor* output);
302
303 // Reads a batch when restoring the iterator.
304 Status ReadBatch(int64 batch_size, const string& iterator_prefix,
305 const string& batch_prefix, IteratorContext* ctx,
306 IteratorStateReader* reader, std::vector<Tensor>* batch);
307
308 // Writes a batch when saving the iterator.
309 Status WriteBatch(int64 batch_size, int64 num_elements,
310 const string& iterator_prefix, const string& batch_prefix,
311 IteratorStateWriter* writer, std::vector<Tensor>* batch);
312
313 // Reads a status when restoring the iterator.
314 Status ReadStatus(const string& iterator_prefix, const string& prefix,
315 IteratorStateReader* reader, Status* status);
316
317 // Writes a status when saving the iterator.
318 Status WriteStatus(const string& iterator_prefix, const string& prefix,
319 const Status& status, IteratorStateWriter* writer);
320
321 // Processes a batch to output. In the case a partial batch is encountered, copy
322 // only partial of the batch.
323 Status ProcessBatch(int64 batch_size, int64 num_elements, bool drop_remainder,
324 const Status& status, IteratorContext* ctx,
325 std::vector<Tensor>* output, bool* end_of_sequence,
326 std::vector<Tensor>* batch);
327
328 // Copies the input elements to a batch.
329 Status CopyBatch(bool parallel_copy, IteratorContext* ctx,
330 std::vector<Tensor>* out_tensors,
331 std::vector<std::vector<Tensor>>* batch_elements);
332
333 } // namespace data
334 } // namespace tensorflow
335
336 #endif // TENSORFLOW_CORE_KERNELS_DATA_DATASET_UTILS_H_
337