• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_KERNELS_DATA_DATASET_UTILS_H_
16 #define TENSORFLOW_CORE_KERNELS_DATA_DATASET_UTILS_H_
17 
18 #include "tensorflow/core/common_runtime/function.h"
19 #include "tensorflow/core/framework/dataset.h"
20 #include "tensorflow/core/framework/function.h"
21 #include "tensorflow/core/framework/resource_mgr.h"
22 #include "tensorflow/core/framework/tensor.h"
23 
24 namespace tensorflow {
25 namespace data {
26 
27 // Creates a resource handle with a unique name for the given resource.
28 template <typename T>
CreateHandle(OpKernelContext * ctx,T * resource,const string & container_name,ResourceHandle * handle)29 Status CreateHandle(OpKernelContext* ctx, T* resource,
30                     const string& container_name, ResourceHandle* handle) {
31   static std::atomic<int64> resource_id_counter(0);
32   string unique_name =
33       strings::StrCat(container_name, resource_id_counter.fetch_add(1));
34   ResourceMgr* mgr = ctx->resource_manager();
35   TF_RETURN_IF_ERROR(mgr->Create<T>(container_name, unique_name, resource));
36 
37   *handle = MakeResourceHandle(container_name, unique_name, *ctx->device(),
38                                TypeIndex::Make<T>());
39   return Status::OK();
40 }
41 
42 template <typename T>
43 class AnonymousResourceOp : public OpKernel {
44  public:
AnonymousResourceOp(OpKernelConstruction * context)45   explicit AnonymousResourceOp(OpKernelConstruction* context)
46       : OpKernel(context) {}
47 
Compute(OpKernelContext * ctx)48   void Compute(OpKernelContext* ctx) override {
49     FunctionLibraryRuntime* lib;
50     std::unique_ptr<FunctionLibraryDefinition> flib_def(nullptr);
51     std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(nullptr);
52     OP_REQUIRES_OK(
53         ctx, ctx->function_library()->Clone(&flib_def, &pflr, &lib, true));
54     T* resource;
55     OP_REQUIRES_OK(ctx, CreateResource(ctx, std::move(flib_def),
56                                        std::move(pflr), lib, &resource));
57 
58     ResourceHandle handle;
59     OP_REQUIRES_OK(ctx, CreateHandle(ctx, resource, name(), &handle));
60     Tensor* handle_t;
61     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &handle_t));
62     handle_t->scalar<ResourceHandle>()() = handle;
63 
64     if (create_deleter_) {
65       Tensor* deleter_t;
66       AllocatorAttributes attr;
67       attr.set_on_host(true);
68       OP_REQUIRES_OK(
69           ctx, ctx->allocate_output(1, TensorShape({}), &deleter_t, attr));
70       deleter_t->scalar<Variant>()() =
71           ResourceDeleter(handle, ctx->resource_manager());
72     }
73   }
74 
75  protected:
76   virtual string name() = 0;
77 
78   virtual Status CreateResource(
79       OpKernelContext* ctx, std::unique_ptr<FunctionLibraryDefinition> flib_def,
80       std::unique_ptr<ProcessFunctionLibraryRuntime> pflr,
81       FunctionLibraryRuntime* lib, T** resource) = 0;
82 
83   bool create_deleter_ = true;
84 };
85 
86 // Returns Status::OK() if `expected` and `received` types match,
87 // errors::InvalidArgument otherwise.
88 Status VerifyTypesMatch(const DataTypeVector& expected,
89                         const DataTypeVector& received);
90 
91 Status VerifyTypesMatch(const DataTypeVector& expected,
92                         const std::vector<Tensor>& received);
93 
94 // Returns Status::OK() if `expected` and `received` shapes are compatible,
95 // errors::InvalidArgument otherwise.
96 Status VerifyShapesCompatible(const std::vector<PartialTensorShape>& expected,
97                               const std::vector<PartialTensorShape>& received);
98 
99 Status VerifyShapesCompatible(const std::vector<PartialTensorShape>& expected,
100                               const std::vector<Tensor>& received);
101 
102 // Writes dataset elements to the checkpoint writer using the given key prefix.
103 // The elements can be read back by passing the same key prefix to
104 // ReadElementsFromCheckpoint. Only one list of elements can be written under
105 // the same key_prefix.
106 Status WriteElementsToCheckpoint(
107     IteratorStateWriter* writer, StringPiece key_prefix,
108     const std::vector<std::vector<Tensor>>& elements);
109 
110 // Reads dataset elements from the checkpoint reader using the given key prefix.
111 Status ReadElementsFromCheckpoint(IteratorStateReader* reader,
112                                   StringPiece key_prefix,
113                                   std::vector<std::vector<Tensor>>* elements);
114 
115 // Dataset op level determinism policy.
116 class DeterminismPolicy {
117  public:
118   enum class Type : int {
119     // The op must produce elements deterministically.
120     kDeterministic,
121     // The op may relax determinism to improve performance.
122     kNondeterministic,
123     // The determinism policy is not specified at the op level. In this case we
124     // use the experimental_deterministic dataset option to determine the
125     // determinism policy.
126     kDefault,
127   };
128   static constexpr const char* const kDeterministic = "true";
129   static constexpr const char* const kNondeterministic = "false";
130   static constexpr const char* const kDefault = "default";
131 
DeterminismPolicy()132   DeterminismPolicy() : determinism_(Type::kDefault) {}
DeterminismPolicy(Type determinism)133   explicit DeterminismPolicy(Type determinism) : determinism_(determinism) {}
134   // Creates a DeterminismPolicy with Type kDeterministic or
135   // kNondeterministic, depending on the values of `is_deterministic`.
136   explicit DeterminismPolicy(bool is_deterministic);
137 
138   static Status FromString(const std::string& s, DeterminismPolicy* out);
139 
140   // Returns the string representing the determinism policy. This will be one of
141   // the string constants defined above.
142   std::string String() const;
143 
144   /// Convenience methods for checking the DeterminismPolicy::Type.
IsDeterministic()145   bool IsDeterministic() const { return determinism_ == Type::kDeterministic; }
IsNondeterministic()146   bool IsNondeterministic() const {
147     return determinism_ == Type::kNondeterministic;
148   }
IsDefault()149   bool IsDefault() const { return determinism_ == Type::kDefault; }
150 
151  private:
152   Type determinism_;
153 };
154 
155 // Resolves non-deterministic seeds if necessary, returning either the original
156 // seeds or the resolved seeds.
157 //
158 // By TensorFlow convention, if both seeds are 0, they should be replaced with
159 // non-deterministically chosen seeds.
160 std::pair<int64, int64> MaybeOverrideSeeds(std::pair<int64, int64> seeds);
161 
162 // Helper class for reading data from a vector of VariantTensorData objects.
163 class VariantTensorDataReader : public IteratorStateReader {
164  public:
165   explicit VariantTensorDataReader(
166       const std::vector<const VariantTensorData*>& data);
167 
168   Status ReadScalar(StringPiece key, int64* val) const override;
169   Status ReadScalar(StringPiece key, tstring* val) const override;
170   Status ReadTensor(StringPiece key, Tensor* val) const override;
171   bool Contains(StringPiece key) const override;
172 
173   Status ReadScalar(StringPiece name, StringPiece key,
174                     int64* val) const override;
175   Status ReadScalar(StringPiece name, StringPiece key,
176                     tstring* val) const override;
177   Status ReadTensor(StringPiece name, StringPiece key,
178                     Tensor* val) const override;
179   bool Contains(StringPiece name, StringPiece key) const override;
180 
181  private:
182   template <typename T>
183   Status ReadScalarInternal(StringPiece key, T* val) const;
184   Status ReadTensorInternal(StringPiece key, Tensor* val) const;
185 
186   template <typename T>
187   Status ReadScalarInternal(StringPiece name, StringPiece key, T* val) const;
188   Status ReadTensorInternal(StringPiece name, StringPiece key,
189                             Tensor* val) const;
190 
191   std::map<string, std::map<string, size_t>> map_;
192   std::map<string, const VariantTensorData*> data_;  // Not owned.
193 };
194 
195 // Helper class used to build a list of VariantTensorData objects, one for each
196 // iterator which is determined from the key supplied from the Write* calls.
197 // Sample usage:
198 // VariantTensorDataWriter writer;
199 // writer.WriteScalar(full_name("buffer_size"), buffer_.size());
200 // writer.WriteScalar(full_name("num_threads"), threadpool_.size());
201 // ....
202 // std::vector<std::unique_ptr<VariantTensorData>> variants;
203 // writer.ReleaseData(&variants);
204 // Now the VariantTensorData objects can be used to serialize.
205 class VariantTensorDataWriter : public IteratorStateWriter {
206  public:
207   Status WriteScalar(StringPiece key, const int64 val) override;
208   Status WriteScalar(StringPiece key, const tstring& val) override;
209   Status WriteTensor(StringPiece key, const Tensor& val) override;
210 
211   Status WriteScalar(StringPiece name, StringPiece key,
212                      const int64 val) override;
213   Status WriteScalar(StringPiece name, StringPiece key,
214                      const tstring& val) override;
215   Status WriteTensor(StringPiece name, StringPiece key,
216                      const Tensor& val) override;
217 
218   // Releases the built VariantTensorData's to `variants`. Clears out all
219   // class state.
220   void ReleaseData(std::vector<std::unique_ptr<VariantTensorData>>* variants);
221 
222   // Obtains a read-only version of the VariantTensorData's built.
223   void GetData(std::vector<const VariantTensorData*>* variants);
224 
225  private:
226   void MaybeFlush();
227   void Reset();
228 
229   template <typename T>
230   Status WriteScalarInternal(StringPiece key, const T& val);
231   Status WriteTensorInternal(StringPiece key, const Tensor& val);
232 
233   template <typename T>
234   Status WriteScalarInternal(StringPiece name, StringPiece key, const T& val);
235   Status WriteTensorInternal(StringPiece name, StringPiece key,
236                              const Tensor& val);
237 
238   bool is_flushed_ = false;
239   std::map<string, std::unique_ptr<VariantTensorData>> data_;
240   std::map<string, std::vector<string>> keys_;
241 };
242 
243 // Adds the functions in `to_add` to `base`. If a function with a matching
244 // signature already exists in `base`, replaces it with the function from
245 // `to_add`.
246 Status AddToFunctionLibrary(FunctionLibraryDefinition* base,
247                             const FunctionLibraryDefinition& to_add);
248 Status AddToFunctionLibrary(FunctionLibraryDefinition* base,
249                             const FunctionDefLibrary& to_add);
250 
251 // Creates a runner that runs functions with limited parallelism.
252 std::function<void(std::function<void()>)> RunnerWithMaxParallelism(
253     std::function<void(std::function<void()>)> runner, int max_parallelism);
254 
255 // Op for creating a typed dummy resource.
256 //
257 // This op is used to provide a resource "placeholder" for ops such as
258 // `CacheDatasetV2` or `ShuffleDatasetV2` that expects a resource input.
259 // Originally, the lifetime of the resources passed into these ops was managed
260 // externally. After the implementation changed to manage the lifetime of the
261 // resources (including creation) by the ops themselves, the resource input is
262 // only needed to pass a resource handle through graph rewrites. When they are
263 // invoked from user code, the implementation passes in a dummy resource.
264 template <typename ResourceType>
265 class DummyResourceOp : public OpKernel {
266  public:
DummyResourceOp(OpKernelConstruction * ctx)267   explicit DummyResourceOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
268 
Compute(OpKernelContext * ctx)269   void Compute(OpKernelContext* ctx) override {
270     Tensor* tensor;
271     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &tensor));
272     tensor->scalar<ResourceHandle>()() = MakeResourceHandle<ResourceType>(
273         ctx, /*container=*/"", /*name=*/"dummy_resource");
274   }
275 };
276 
277 // Given an op prefix and an op to match, returns whether the op to match
278 // is a match for any version of the op prefix. For example,
279 // MatchesAnyVersion("BatchDataset", "BatchDataset") == true
280 // MatchesAnyVersion("BatchDataset", "BatchDatasetV2") == true
281 // MatchesAnyVersion("BatchDataset", "BatchDatasetV3") == true
282 // MatchesAnyVersion("PaddedBatchDataset", "BatchDataset") == false
283 bool MatchesAnyVersion(StringPiece op_prefix, StringPiece op_to_match);
284 
285 // Based on `job_name`, `optimizations_enabled`, `optimizations_disabled` and
286 // `optimizations_default`, returns the list of optimizations that will be
287 // applied.
288 std::vector<tstring> SelectOptimizations(
289     const string& job_name,
290     const absl::flat_hash_map<string, uint64>& live_experiments,
291     const std::vector<tstring>& optimizations_enabled,
292     const std::vector<tstring>& optimizations_disabled,
293     const std::vector<tstring>& optimizations_default,
294     std::function<uint64(const string&)> hash_func);
295 
296 // Removes device placements from the ops of all functions in `library`.
297 void StripDevicePlacement(FunctionDefLibrary* library);
298 
299 // Copies partial of the batch output.
300 Status CopyPartialBatch(int64 num_elements, const Tensor& value,
301                         Tensor* output);
302 
303 // Reads a batch when restoring the iterator.
304 Status ReadBatch(int64 batch_size, const string& iterator_prefix,
305                  const string& batch_prefix, IteratorContext* ctx,
306                  IteratorStateReader* reader, std::vector<Tensor>* batch);
307 
308 // Writes a batch when saving the iterator.
309 Status WriteBatch(int64 batch_size, int64 num_elements,
310                   const string& iterator_prefix, const string& batch_prefix,
311                   IteratorStateWriter* writer, std::vector<Tensor>* batch);
312 
313 // Reads a status when restoring the iterator.
314 Status ReadStatus(const string& iterator_prefix, const string& prefix,
315                   IteratorStateReader* reader, Status* status);
316 
317 // Writes a status when saving the iterator.
318 Status WriteStatus(const string& iterator_prefix, const string& prefix,
319                    const Status& status, IteratorStateWriter* writer);
320 
321 // Processes a batch to output. In the case a partial batch is encountered, copy
322 // only partial of the batch.
323 Status ProcessBatch(int64 batch_size, int64 num_elements, bool drop_remainder,
324                     const Status& status, IteratorContext* ctx,
325                     std::vector<Tensor>* output, bool* end_of_sequence,
326                     std::vector<Tensor>* batch);
327 
328 // Copies the input elements to a batch.
329 Status CopyBatch(bool parallel_copy, IteratorContext* ctx,
330                  std::vector<Tensor>* out_tensors,
331                  std::vector<std::vector<Tensor>>* batch_elements);
332 
333 }  // namespace data
334 }  // namespace tensorflow
335 
336 #endif  // TENSORFLOW_CORE_KERNELS_DATA_DATASET_UTILS_H_
337