1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_DATA_DATASET_UTILS_H_
16 #define TENSORFLOW_CORE_DATA_DATASET_UTILS_H_
17
18 #include <functional>
19 #include <string>
20
21 #include "absl/container/flat_hash_set.h"
22 #include "tensorflow/core/common_runtime/function.h"
23 #include "tensorflow/core/framework/dataset.h"
24 #include "tensorflow/core/framework/function.h"
25 #include "tensorflow/core/framework/resource_mgr.h"
26 #include "tensorflow/core/framework/tensor.h"
27
28 namespace tensorflow {
29 namespace data {
30
31 // Constant used for indicating that the argument of tf.data.Dataset.shard
32 // should be supplied by the auto-sharding rewrite.
33 constexpr int kShardHint = -1;
34
35 // Creates a resource handle with a unique name for the given resource.
36 template <typename T>
CreateHandle(OpKernelContext * ctx,T * resource,const string & container_name,ResourceHandle * handle)37 Status CreateHandle(OpKernelContext* ctx, T* resource,
38 const string& container_name, ResourceHandle* handle) {
39 static std::atomic<int64> resource_id_counter(0);
40 string unique_name =
41 strings::StrCat(container_name, resource_id_counter.fetch_add(1));
42 ResourceMgr* mgr = ctx->resource_manager();
43 TF_RETURN_IF_ERROR(mgr->Create<T>(container_name, unique_name, resource));
44
45 *handle = MakeResourceHandle(container_name, unique_name, *ctx->device(),
46 TypeIndex::Make<T>());
47 return Status::OK();
48 }
49
50 template <typename T>
51 class AnonymousResourceOp : public OpKernel {
52 public:
AnonymousResourceOp(OpKernelConstruction * context)53 explicit AnonymousResourceOp(OpKernelConstruction* context)
54 : OpKernel(context) {}
55
Compute(OpKernelContext * ctx)56 void Compute(OpKernelContext* ctx) override {
57 FunctionLibraryRuntime* lib;
58 std::unique_ptr<FunctionLibraryDefinition> flib_def(nullptr);
59 std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(nullptr);
60 OP_REQUIRES_OK(
61 ctx, ctx->function_library()->Clone(&flib_def, &pflr, &lib, true));
62 T* resource;
63 OP_REQUIRES_OK(ctx, CreateResource(ctx, std::move(flib_def),
64 std::move(pflr), lib, &resource));
65
66 ResourceHandle handle;
67 OP_REQUIRES_OK(ctx, CreateHandle(ctx, resource, name(), &handle));
68 Tensor* handle_t;
69 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &handle_t));
70 handle_t->scalar<ResourceHandle>()() = handle;
71
72 if (create_deleter_) {
73 Tensor* deleter_t;
74 AllocatorAttributes attr;
75 attr.set_on_host(true);
76 OP_REQUIRES_OK(
77 ctx, ctx->allocate_output(1, TensorShape({}), &deleter_t, attr));
78 deleter_t->scalar<Variant>()() =
79 ResourceDeleter(handle, ctx->resource_manager());
80 }
81 }
82
83 protected:
84 virtual string name() = 0;
85
86 virtual Status CreateResource(
87 OpKernelContext* ctx, std::unique_ptr<FunctionLibraryDefinition> flib_def,
88 std::unique_ptr<ProcessFunctionLibraryRuntime> pflr,
89 FunctionLibraryRuntime* lib, T** resource) = 0;
90
91 bool create_deleter_ = true;
92 };
93
94 // Returns Status::OK() if `expected` and `received` types match,
95 // errors::InvalidArgument otherwise.
96 Status VerifyTypesMatch(const DataTypeVector& expected,
97 const DataTypeVector& received);
98
99 Status VerifyTypesMatch(const DataTypeVector& expected,
100 const std::vector<Tensor>& received);
101
102 // Returns Status::OK() if `expected` and `received` shapes are compatible,
103 // errors::InvalidArgument otherwise.
104 Status VerifyShapesCompatible(const std::vector<PartialTensorShape>& expected,
105 const std::vector<PartialTensorShape>& received);
106
107 Status VerifyShapesCompatible(const std::vector<PartialTensorShape>& expected,
108 const std::vector<Tensor>& received);
109
110 // Dataset op level determinism policy.
111 class DeterminismPolicy {
112 public:
113 enum class Type : int {
114 // The op must produce elements deterministically.
115 kDeterministic,
116 // The op may relax determinism to improve performance.
117 kNondeterministic,
118 // The determinism policy is not specified at the op level. In this case we
119 // use the experimental_deterministic dataset option to determine the
120 // determinism policy.
121 kDefault,
122 };
123 static constexpr const char* const kDeterministic = "true";
124 static constexpr const char* const kNondeterministic = "false";
125 static constexpr const char* const kDefault = "default";
126
DeterminismPolicy()127 DeterminismPolicy() : determinism_(Type::kDefault) {}
DeterminismPolicy(Type determinism)128 explicit DeterminismPolicy(Type determinism) : determinism_(determinism) {}
129 // Creates a DeterminismPolicy with Type kDeterministic or
130 // kNondeterministic, depending on the values of `is_deterministic`.
131 explicit DeterminismPolicy(bool is_deterministic);
132
133 static Status FromString(const std::string& s, DeterminismPolicy* out);
134
135 // Returns the string representing the determinism policy. This will be one of
136 // the string constants defined above.
137 std::string String() const;
138
139 /// Convenience methods for checking the DeterminismPolicy::Type.
IsDeterministic()140 bool IsDeterministic() const { return determinism_ == Type::kDeterministic; }
IsNondeterministic()141 bool IsNondeterministic() const {
142 return determinism_ == Type::kNondeterministic;
143 }
IsDefault()144 bool IsDefault() const { return determinism_ == Type::kDefault; }
145
146 private:
147 Type determinism_;
148 };
149
150 // Resolves non-deterministic seeds if necessary, returning either the original
151 // seeds or the resolved seeds.
152 //
153 // By TensorFlow convention, if both seeds are 0, they should be replaced with
154 // non-deterministically chosen seeds.
155 std::pair<int64, int64> MaybeOverrideSeeds(std::pair<int64, int64> seeds);
156
157 // Adds the functions in `to_add` to `base`. If a function with a matching
158 // signature already exists in `base`, replaces it with the function from
159 // `to_add`.
160 Status AddToFunctionLibrary(FunctionLibraryDefinition* base,
161 const FunctionLibraryDefinition& to_add);
162 Status AddToFunctionLibrary(FunctionLibraryDefinition* base,
163 const FunctionDefLibrary& to_add);
164
165 // Determines whether the given function is stateful.
166 Status IsFunctionStateful(const FunctionLibraryDefinition& library,
167 const FunctionDef& function_def);
168
169 // Determines whether the given node is stateful.
170 Status IsNodeStateful(const FunctionLibraryDefinition& library,
171 const NodeDef& node);
172
173 // Creates a runner that runs functions with limited parallelism.
174 std::function<void(std::function<void()>)> RunnerWithMaxParallelism(
175 std::function<void(std::function<void()>)> runner, int max_parallelism);
176
177 // Op for creating a typed dummy resource.
178 //
179 // This op is used to provide a resource "placeholder" for ops such as
180 // `CacheDatasetV2` or `ShuffleDatasetV2` that expects a resource input.
181 // Originally, the lifetime of the resources passed into these ops was managed
182 // externally. After the implementation changed to manage the lifetime of the
183 // resources (including creation) by the ops themselves, the resource input is
184 // only needed to pass a resource handle through graph rewrites. When they are
185 // invoked from user code, the implementation passes in a dummy resource.
186 template <typename ResourceType>
187 class DummyResourceOp : public OpKernel {
188 public:
DummyResourceOp(OpKernelConstruction * ctx)189 explicit DummyResourceOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
190
Compute(OpKernelContext * ctx)191 void Compute(OpKernelContext* ctx) override {
192 Tensor* tensor;
193 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &tensor));
194 tensor->scalar<ResourceHandle>()() = MakeResourceHandle<ResourceType>(
195 ctx, /*container=*/"", /*name=*/"dummy_resource");
196 }
197 };
198
199 // Given an op prefix and an op to match, returns whether the op to match
200 // is a match for any version of the op prefix. For example,
201 // MatchesAnyVersion("BatchDataset", "BatchDataset") == true
202 // MatchesAnyVersion("BatchDataset", "BatchDatasetV2") == true
203 // MatchesAnyVersion("BatchDataset", "BatchDatasetV3") == true
204 // MatchesAnyVersion("PaddedBatchDataset", "BatchDataset") == false
205 bool MatchesAnyVersion(StringPiece op_prefix, StringPiece op_to_match);
206
207 // Removes device placements from the ops of all functions in `library`.
208 void StripDevicePlacement(FunctionDefLibrary* library);
209
210 // Copies partial of the batch output.
211 Status CopyPartialBatch(int64_t num_elements, const Tensor& value,
212 Tensor* output);
213
214 // Reads a batch when restoring the iterator.
215 Status ReadBatch(IteratorContext* ctx, IteratorStateReader* reader,
216 int64_t batch_size, const string& iterator_prefix,
217 const string& batch_prefix, std::vector<Tensor>* batch);
218
219 // Writes a batch when saving the iterator.
220 Status WriteBatch(int64_t batch_size, int64_t num_elements,
221 const string& iterator_prefix, const string& batch_prefix,
222 IteratorStateWriter* writer, std::vector<Tensor>* batch);
223
224 // Reads a status when restoring the iterator.
225 Status ReadStatus(const string& iterator_prefix, const string& prefix,
226 IteratorStateReader* reader, Status* status);
227
228 // Writes a status when saving the iterator.
229 Status WriteStatus(const string& iterator_prefix, const string& prefix,
230 const Status& status, IteratorStateWriter* writer);
231
232 // Processes a batch to output. In the case a partial batch is encountered, copy
233 // only partial of the batch.
234 Status ProcessBatch(int64_t batch_size, int64_t num_elements,
235 bool drop_remainder, const Status& status,
236 IteratorContext* ctx, std::vector<Tensor>* output,
237 bool* end_of_sequence, std::vector<Tensor>* batch);
238
239 // Copies the input elements to a batch.
240 //
241 // The `batch_elements` argument contains the individual elements to copy into a
242 // batch. The `parallel_copy` argument indicates whether to parallelize the
243 // copy. The `allocation_callback` argument can be used to pass a callback to
244 // invoke upon successful allocation of the memory for the batch. The
245 // `out_tensors` argument will be used to store the resulting batch (one for
246 // each component of the input).
247 Status CopyBatch(IteratorContext* ctx,
248 const std::vector<std::vector<Tensor>>& batch_elements,
249 bool parallel_copy,
250 std::function<Status()> allocation_callback,
251 std::vector<Tensor>* out_tensors);
252
253 // Computes the set of experiments to apply based on the job name, rollout
254 // percentage of registered experiments, and the TF_DATA_EXPERIMENT_OPT_IN and
255 // TF_DATA_EXPERIMENT_OPT_OUT environment variables.
256 absl::flat_hash_set<string> GetExperiments();
257 absl::flat_hash_set<string> GetExperiments(
258 const string& job_name, std::function<uint64(const string&)> hash_func);
259
260 // Logs and records the experiments that will be applied.
261 void LogAndRecordExperiments(const absl::flat_hash_set<string>& experiments);
262
263 // Computes the set of enabled, disabled, and default optimizations based on the
264 // given options.
265 void GetOptimizations(const Options& options,
266 absl::flat_hash_set<tstring>* optimizations_enabled,
267 absl::flat_hash_set<tstring>* optimizations_disabled,
268 absl::flat_hash_set<tstring>* optimizations_default);
269
270 // Determines which optimizations should be applied.
271 absl::flat_hash_set<tstring> SelectOptimizations(
272 const absl::flat_hash_set<string>& experiments,
273 const absl::flat_hash_set<tstring>& optimizations_enabled,
274 const absl::flat_hash_set<tstring>& optimizations_disabled,
275 const absl::flat_hash_set<tstring>& optimizations_default);
276
277 // Creates graph rewrite configs based on the given options.
278 absl::flat_hash_set<tstring> CreateGraphRewriteConfigs(const Options& options);
279
280 // Determines whether max intra-op parallelism should be configured.
281 bool ShouldConfigureMaxIntraOpParallelism(const Options& options);
282
283 // Determines whether private threadpool should be used.
284 bool ShouldUsePrivateThreadPool(const Options& options);
285
286 // Determines whether autotuning should be used.
287 bool ShouldUseAutotuning(const Options& options);
288
289 // Determines whether optimizations should be applied.
290 bool ShouldApplyOptimizations(
291 const Options& options,
292 const absl::flat_hash_set<tstring>& optimizations_enabled,
293 const absl::flat_hash_set<tstring>& optimizations_default);
294
295 // Registry of tf.data experiments.
296 class DatasetExperimentRegistry {
297 public:
298 // Registers the experiment.
299 static void Register(const string& experiment, int64_t rollout_pct);
300
301 // Returns all registered experiments.
302 static absl::flat_hash_map<string, int64> Experiments();
303 };
304
305 // Helper class to register a dataset experiment.
306 class DatasetExperimentRegistrar {
307 public:
DatasetExperimentRegistrar(const string & experiment,int64_t rollout_pct)308 explicit DatasetExperimentRegistrar(const string& experiment,
309 int64_t rollout_pct) {
310 DatasetExperimentRegistry::Register(experiment, rollout_pct);
311 }
312 };
313
314 // Macro that can be used to register a dataset experiment.
315 #define REGISTER_DATASET_EXPERIMENT(experiment, rollout_pct) \
316 REGISTER_DATASET_OP_NAME_UNIQ_HELPER(__COUNTER__, experiment, rollout_pct)
317
318 #define REGISTER_DATASET_OP_NAME_UNIQ_HELPER(ctr, experiment, rollout_pct) \
319 REGISTER_DATASET_OP_NAME_UNIQ(ctr, experiment, rollout_pct)
320
321 #define REGISTER_DATASET_OP_NAME_UNIQ(ctr, experiment, rollout_pct) \
322 static ::tensorflow::data::DatasetExperimentRegistrar \
323 registrar__body__##ctr##__object(experiment, rollout_pct)
324
325 } // namespace data
326 } // namespace tensorflow
327
328 #endif // TENSORFLOW_CORE_DATA_DATASET_UTILS_H_
329