• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_DATA_DATASET_UTILS_H_
16 #define TENSORFLOW_CORE_DATA_DATASET_UTILS_H_
17 
18 #include <functional>
19 #include <string>
20 
21 #include "absl/container/flat_hash_set.h"
22 #include "tensorflow/core/common_runtime/function.h"
23 #include "tensorflow/core/framework/dataset.h"
24 #include "tensorflow/core/framework/function.h"
25 #include "tensorflow/core/framework/resource_mgr.h"
26 #include "tensorflow/core/framework/tensor.h"
27 
28 namespace tensorflow {
29 namespace data {
30 
31 // Constant used for indicating that the argument of tf.data.Dataset.shard
32 // should be supplied by the auto-sharding rewrite.
33 constexpr int kShardHint = -1;
34 
35 // Creates a resource handle with a unique name for the given resource.
36 template <typename T>
CreateHandle(OpKernelContext * ctx,T * resource,const string & container_name,ResourceHandle * handle)37 Status CreateHandle(OpKernelContext* ctx, T* resource,
38                     const string& container_name, ResourceHandle* handle) {
39   static std::atomic<int64> resource_id_counter(0);
40   string unique_name =
41       strings::StrCat(container_name, resource_id_counter.fetch_add(1));
42   ResourceMgr* mgr = ctx->resource_manager();
43   TF_RETURN_IF_ERROR(mgr->Create<T>(container_name, unique_name, resource));
44 
45   *handle = MakeResourceHandle(container_name, unique_name, *ctx->device(),
46                                TypeIndex::Make<T>());
47   return Status::OK();
48 }
49 
50 template <typename T>
51 class AnonymousResourceOp : public OpKernel {
52  public:
AnonymousResourceOp(OpKernelConstruction * context)53   explicit AnonymousResourceOp(OpKernelConstruction* context)
54       : OpKernel(context) {}
55 
Compute(OpKernelContext * ctx)56   void Compute(OpKernelContext* ctx) override {
57     FunctionLibraryRuntime* lib;
58     std::unique_ptr<FunctionLibraryDefinition> flib_def(nullptr);
59     std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(nullptr);
60     OP_REQUIRES_OK(
61         ctx, ctx->function_library()->Clone(&flib_def, &pflr, &lib, true));
62     T* resource;
63     OP_REQUIRES_OK(ctx, CreateResource(ctx, std::move(flib_def),
64                                        std::move(pflr), lib, &resource));
65 
66     ResourceHandle handle;
67     OP_REQUIRES_OK(ctx, CreateHandle(ctx, resource, name(), &handle));
68     Tensor* handle_t;
69     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &handle_t));
70     handle_t->scalar<ResourceHandle>()() = handle;
71 
72     if (create_deleter_) {
73       Tensor* deleter_t;
74       AllocatorAttributes attr;
75       attr.set_on_host(true);
76       OP_REQUIRES_OK(
77           ctx, ctx->allocate_output(1, TensorShape({}), &deleter_t, attr));
78       deleter_t->scalar<Variant>()() =
79           ResourceDeleter(handle, ctx->resource_manager());
80     }
81   }
82 
83  protected:
84   virtual string name() = 0;
85 
86   virtual Status CreateResource(
87       OpKernelContext* ctx, std::unique_ptr<FunctionLibraryDefinition> flib_def,
88       std::unique_ptr<ProcessFunctionLibraryRuntime> pflr,
89       FunctionLibraryRuntime* lib, T** resource) = 0;
90 
91   bool create_deleter_ = true;
92 };
93 
94 // Returns Status::OK() if `expected` and `received` types match,
95 // errors::InvalidArgument otherwise.
96 Status VerifyTypesMatch(const DataTypeVector& expected,
97                         const DataTypeVector& received);
98 
99 Status VerifyTypesMatch(const DataTypeVector& expected,
100                         const std::vector<Tensor>& received);
101 
102 // Returns Status::OK() if `expected` and `received` shapes are compatible,
103 // errors::InvalidArgument otherwise.
104 Status VerifyShapesCompatible(const std::vector<PartialTensorShape>& expected,
105                               const std::vector<PartialTensorShape>& received);
106 
107 Status VerifyShapesCompatible(const std::vector<PartialTensorShape>& expected,
108                               const std::vector<Tensor>& received);
109 
110 // Dataset op level determinism policy.
111 class DeterminismPolicy {
112  public:
113   enum class Type : int {
114     // The op must produce elements deterministically.
115     kDeterministic,
116     // The op may relax determinism to improve performance.
117     kNondeterministic,
118     // The determinism policy is not specified at the op level. In this case we
119     // use the experimental_deterministic dataset option to determine the
120     // determinism policy.
121     kDefault,
122   };
123   static constexpr const char* const kDeterministic = "true";
124   static constexpr const char* const kNondeterministic = "false";
125   static constexpr const char* const kDefault = "default";
126 
DeterminismPolicy()127   DeterminismPolicy() : determinism_(Type::kDefault) {}
DeterminismPolicy(Type determinism)128   explicit DeterminismPolicy(Type determinism) : determinism_(determinism) {}
129   // Creates a DeterminismPolicy with Type kDeterministic or
130   // kNondeterministic, depending on the values of `is_deterministic`.
131   explicit DeterminismPolicy(bool is_deterministic);
132 
133   static Status FromString(const std::string& s, DeterminismPolicy* out);
134 
135   // Returns the string representing the determinism policy. This will be one of
136   // the string constants defined above.
137   std::string String() const;
138 
139   /// Convenience methods for checking the DeterminismPolicy::Type.
IsDeterministic()140   bool IsDeterministic() const { return determinism_ == Type::kDeterministic; }
IsNondeterministic()141   bool IsNondeterministic() const {
142     return determinism_ == Type::kNondeterministic;
143   }
IsDefault()144   bool IsDefault() const { return determinism_ == Type::kDefault; }
145 
146  private:
147   Type determinism_;
148 };
149 
150 // Resolves non-deterministic seeds if necessary, returning either the original
151 // seeds or the resolved seeds.
152 //
153 // By TensorFlow convention, if both seeds are 0, they should be replaced with
154 // non-deterministically chosen seeds.
155 std::pair<int64, int64> MaybeOverrideSeeds(std::pair<int64, int64> seeds);
156 
157 // Adds the functions in `to_add` to `base`. If a function with a matching
158 // signature already exists in `base`, replaces it with the function from
159 // `to_add`.
160 Status AddToFunctionLibrary(FunctionLibraryDefinition* base,
161                             const FunctionLibraryDefinition& to_add);
162 Status AddToFunctionLibrary(FunctionLibraryDefinition* base,
163                             const FunctionDefLibrary& to_add);
164 
165 // Determines whether the given function is stateful.
166 Status IsFunctionStateful(const FunctionLibraryDefinition& library,
167                           const FunctionDef& function_def);
168 
169 // Determines whether the given node is stateful.
170 Status IsNodeStateful(const FunctionLibraryDefinition& library,
171                       const NodeDef& node);
172 
173 // Creates a runner that runs functions with limited parallelism.
174 std::function<void(std::function<void()>)> RunnerWithMaxParallelism(
175     std::function<void(std::function<void()>)> runner, int max_parallelism);
176 
177 // Op for creating a typed dummy resource.
178 //
179 // This op is used to provide a resource "placeholder" for ops such as
180 // `CacheDatasetV2` or `ShuffleDatasetV2` that expects a resource input.
181 // Originally, the lifetime of the resources passed into these ops was managed
182 // externally. After the implementation changed to manage the lifetime of the
183 // resources (including creation) by the ops themselves, the resource input is
184 // only needed to pass a resource handle through graph rewrites. When they are
185 // invoked from user code, the implementation passes in a dummy resource.
186 template <typename ResourceType>
187 class DummyResourceOp : public OpKernel {
188  public:
DummyResourceOp(OpKernelConstruction * ctx)189   explicit DummyResourceOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
190 
Compute(OpKernelContext * ctx)191   void Compute(OpKernelContext* ctx) override {
192     Tensor* tensor;
193     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &tensor));
194     tensor->scalar<ResourceHandle>()() = MakeResourceHandle<ResourceType>(
195         ctx, /*container=*/"", /*name=*/"dummy_resource");
196   }
197 };
198 
199 // Given an op prefix and an op to match, returns whether the op to match
200 // is a match for any version of the op prefix. For example,
201 // MatchesAnyVersion("BatchDataset", "BatchDataset") == true
202 // MatchesAnyVersion("BatchDataset", "BatchDatasetV2") == true
203 // MatchesAnyVersion("BatchDataset", "BatchDatasetV3") == true
204 // MatchesAnyVersion("PaddedBatchDataset", "BatchDataset") == false
205 bool MatchesAnyVersion(StringPiece op_prefix, StringPiece op_to_match);
206 
207 // Removes device placements from the ops of all functions in `library`.
208 void StripDevicePlacement(FunctionDefLibrary* library);
209 
210 // Copies partial of the batch output.
211 Status CopyPartialBatch(int64_t num_elements, const Tensor& value,
212                         Tensor* output);
213 
214 // Reads a batch when restoring the iterator.
215 Status ReadBatch(IteratorContext* ctx, IteratorStateReader* reader,
216                  int64_t batch_size, const string& iterator_prefix,
217                  const string& batch_prefix, std::vector<Tensor>* batch);
218 
219 // Writes a batch when saving the iterator.
220 Status WriteBatch(int64_t batch_size, int64_t num_elements,
221                   const string& iterator_prefix, const string& batch_prefix,
222                   IteratorStateWriter* writer, std::vector<Tensor>* batch);
223 
224 // Reads a status when restoring the iterator.
225 Status ReadStatus(const string& iterator_prefix, const string& prefix,
226                   IteratorStateReader* reader, Status* status);
227 
228 // Writes a status when saving the iterator.
229 Status WriteStatus(const string& iterator_prefix, const string& prefix,
230                    const Status& status, IteratorStateWriter* writer);
231 
232 // Processes a batch to output. In the case a partial batch is encountered, copy
233 // only partial of the batch.
234 Status ProcessBatch(int64_t batch_size, int64_t num_elements,
235                     bool drop_remainder, const Status& status,
236                     IteratorContext* ctx, std::vector<Tensor>* output,
237                     bool* end_of_sequence, std::vector<Tensor>* batch);
238 
239 // Copies the input elements to a batch.
240 //
241 // The `batch_elements` argument contains the individual elements to copy into a
242 // batch. The `parallel_copy` argument indicates whether to parallelize the
243 // copy. The `allocation_callback` argument can be used to pass a callback to
244 // invoke upon successful allocation of the memory for the batch. The
245 // `out_tensors` argument will be used to store the resulting batch (one for
246 // each component of the input).
247 Status CopyBatch(IteratorContext* ctx,
248                  const std::vector<std::vector<Tensor>>& batch_elements,
249                  bool parallel_copy,
250                  std::function<Status()> allocation_callback,
251                  std::vector<Tensor>* out_tensors);
252 
253 // Computes the set of experiments to apply based on the job name, rollout
254 // percentage of registered experiments, and the TF_DATA_EXPERIMENT_OPT_IN and
255 // TF_DATA_EXPERIMENT_OPT_OUT environment variables.
256 absl::flat_hash_set<string> GetExperiments();
257 absl::flat_hash_set<string> GetExperiments(
258     const string& job_name, std::function<uint64(const string&)> hash_func);
259 
260 // Logs and records the experiments that will be applied.
261 void LogAndRecordExperiments(const absl::flat_hash_set<string>& experiments);
262 
263 // Computes the set of enabled, disabled, and default optimizations based on the
264 // given options.
265 void GetOptimizations(const Options& options,
266                       absl::flat_hash_set<tstring>* optimizations_enabled,
267                       absl::flat_hash_set<tstring>* optimizations_disabled,
268                       absl::flat_hash_set<tstring>* optimizations_default);
269 
270 // Determines which optimizations should be applied.
271 absl::flat_hash_set<tstring> SelectOptimizations(
272     const absl::flat_hash_set<string>& experiments,
273     const absl::flat_hash_set<tstring>& optimizations_enabled,
274     const absl::flat_hash_set<tstring>& optimizations_disabled,
275     const absl::flat_hash_set<tstring>& optimizations_default);
276 
277 // Creates graph rewrite configs based on the given options.
278 absl::flat_hash_set<tstring> CreateGraphRewriteConfigs(const Options& options);
279 
280 // Determines whether max intra-op parallelism should be configured.
281 bool ShouldConfigureMaxIntraOpParallelism(const Options& options);
282 
283 // Determines whether private threadpool should be used.
284 bool ShouldUsePrivateThreadPool(const Options& options);
285 
286 // Determines whether autotuning should be used.
287 bool ShouldUseAutotuning(const Options& options);
288 
289 // Determines whether optimizations should be applied.
290 bool ShouldApplyOptimizations(
291     const Options& options,
292     const absl::flat_hash_set<tstring>& optimizations_enabled,
293     const absl::flat_hash_set<tstring>& optimizations_default);
294 
295 // Registry of tf.data experiments.
296 class DatasetExperimentRegistry {
297  public:
298   // Registers the experiment.
299   static void Register(const string& experiment, int64_t rollout_pct);
300 
301   // Returns all registered experiments.
302   static absl::flat_hash_map<string, int64> Experiments();
303 };
304 
305 // Helper class to register a dataset experiment.
306 class DatasetExperimentRegistrar {
307  public:
DatasetExperimentRegistrar(const string & experiment,int64_t rollout_pct)308   explicit DatasetExperimentRegistrar(const string& experiment,
309                                       int64_t rollout_pct) {
310     DatasetExperimentRegistry::Register(experiment, rollout_pct);
311   }
312 };
313 
314 // Macro that can be used to register a dataset experiment.
315 #define REGISTER_DATASET_EXPERIMENT(experiment, rollout_pct) \
316   REGISTER_DATASET_OP_NAME_UNIQ_HELPER(__COUNTER__, experiment, rollout_pct)
317 
318 #define REGISTER_DATASET_OP_NAME_UNIQ_HELPER(ctr, experiment, rollout_pct) \
319   REGISTER_DATASET_OP_NAME_UNIQ(ctr, experiment, rollout_pct)
320 
321 #define REGISTER_DATASET_OP_NAME_UNIQ(ctr, experiment, rollout_pct) \
322   static ::tensorflow::data::DatasetExperimentRegistrar             \
323       registrar__body__##ctr##__object(experiment, rollout_pct)
324 
325 }  // namespace data
326 }  // namespace tensorflow
327 
328 #endif  // TENSORFLOW_CORE_DATA_DATASET_UTILS_H_
329