• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/data/split_utils.h"
16 
17 #include <functional>
18 #include <string>
19 #include <utility>
20 
21 #include "tensorflow/core/platform/errors.h"
22 
23 namespace tensorflow {
24 namespace data {
25 namespace {
26 constexpr char kNumToSkip[] = "num_to_skip";
27 constexpr char kSplitProvider[] = "split_provider";
28 constexpr char kSlash[] = "/";
29 constexpr char kIndex[] = "index";
30 }  // namespace
31 
IndexSplitProvider(int64_t n)32 IndexSplitProvider::IndexSplitProvider(int64_t n) : i_(0), n_(n) {}
33 
GetNext(Tensor * split,bool * end_of_splits)34 Status IndexSplitProvider::GetNext(Tensor* split, bool* end_of_splits) {
35   mutex_lock l(mu_);
36   if (i_ >= n_) {
37     *end_of_splits = true;
38     return OkStatus();
39   }
40   *end_of_splits = false;
41   *split = Tensor(DT_INT64, TensorShape{});
42   split->scalar<int64_t>()() = i_++;
43   return OkStatus();
44 }
45 
Reset()46 Status IndexSplitProvider::Reset() {
47   mutex_lock l(mu_);
48   i_ = 0;
49   return OkStatus();
50 }
51 
Save(std::function<std::string (std::string)> full_name,IteratorStateWriter * writer)52 Status IndexSplitProvider::Save(
53     std::function<std::string(std::string)> full_name,
54     IteratorStateWriter* writer) {
55   mutex_lock l(mu_);
56   return writer->WriteScalar(full_name(kIndex), i_);
57 }
58 
Restore(std::function<std::string (std::string)> full_name,IteratorStateReader * reader)59 Status IndexSplitProvider::Restore(
60     std::function<std::string(std::string)> full_name,
61     IteratorStateReader* reader) {
62   mutex_lock l(mu_);
63   return reader->ReadScalar(full_name(kIndex), &i_);
64 }
65 
ShardingSplitProvider(int64_t num_shards,int64_t shard_index,std::shared_ptr<SplitProvider> split_provider)66 ShardingSplitProvider::ShardingSplitProvider(
67     int64_t num_shards, int64_t shard_index,
68     std::shared_ptr<SplitProvider> split_provider)
69     : num_shards_(num_shards),
70       shard_index_(shard_index),
71       split_provider_(split_provider),
72       num_to_skip_(shard_index_) {}
73 
GetNext(Tensor * split,bool * end_of_splits)74 Status ShardingSplitProvider::GetNext(Tensor* split, bool* end_of_splits) {
75   mutex_lock l(mu_);
76   while (num_to_skip_ > 0) {
77     TF_RETURN_IF_ERROR(split_provider_->GetNext(split, end_of_splits));
78     if (*end_of_splits) {
79       return OkStatus();
80     }
81     num_to_skip_--;
82   }
83   num_to_skip_ = num_shards_ - 1;
84   TF_RETURN_IF_ERROR(split_provider_->GetNext(split, end_of_splits));
85   return OkStatus();
86 }
87 
Reset()88 Status ShardingSplitProvider::Reset() {
89   mutex_lock l(mu_);
90   TF_RETURN_IF_ERROR(split_provider_->Reset());
91   num_to_skip_ = shard_index_;
92   return OkStatus();
93 }
94 
Save(std::function<std::string (std::string)> full_name,IteratorStateWriter * writer)95 Status ShardingSplitProvider::Save(
96     std::function<std::string(std::string)> full_name,
97     IteratorStateWriter* writer) {
98   mutex_lock l(mu_);
99   TF_RETURN_IF_ERROR(split_provider_->Save(
100       [&](const std::string& key) {
101         return full_name(absl::StrCat(kSplitProvider, kSlash, key));
102       },
103       writer));
104   return writer->WriteScalar(full_name(kNumToSkip), num_to_skip_);
105 }
106 
Restore(std::function<std::string (std::string)> full_name,IteratorStateReader * reader)107 Status ShardingSplitProvider::Restore(
108     std::function<std::string(std::string)> full_name,
109     IteratorStateReader* reader) {
110   mutex_lock l(mu_);
111   TF_RETURN_IF_ERROR(split_provider_->Restore(
112       [&](const std::string& key) {
113         return full_name(absl::StrCat(kSplitProvider, kSlash, key));
114       },
115       reader));
116   TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kNumToSkip), &num_to_skip_));
117   return OkStatus();
118 }
119 
GetSingleSplitProvider(IteratorContext * ctx,const DatasetBase * dataset)120 StatusOr<std::shared_ptr<SplitProvider>> GetSingleSplitProvider(
121     IteratorContext* ctx, const DatasetBase* dataset) {
122   if (ctx->split_providers().size() != 1) {
123     return errors::FailedPrecondition(
124         "Failed to get single split provider for dataset ",
125         dataset->DebugString(), ". Found ", ctx->split_providers().size(),
126         " split providers");
127   }
128   return ctx->split_providers()[0];
129 }
130 
GetSplitProviders(const DatasetBase * dataset)131 StatusOr<std::vector<std::unique_ptr<SplitProvider>>> GetSplitProviders(
132     const DatasetBase* dataset) {
133   std::vector<std::unique_ptr<SplitProvider>> result;
134   std::vector<const DatasetBase*> inputs;
135   TF_RETURN_IF_ERROR(dataset->InputDatasets(&inputs));
136   for (const auto& input : inputs) {
137     std::vector<std::unique_ptr<SplitProvider>> providers;
138     TF_RETURN_IF_ERROR(input->MakeSplitProviders(&providers));
139     for (auto& provider : providers) {
140       result.push_back(std::move(provider));
141     }
142   }
143   return result;
144 }
145 
CreateInputIteratorContexts(IteratorContext * ctx,const DatasetBase * dataset)146 StatusOr<std::vector<IteratorContext>> CreateInputIteratorContexts(
147     IteratorContext* ctx, const DatasetBase* dataset) {
148   std::vector<const DatasetBase*> inputs;
149   TF_RETURN_IF_ERROR(dataset->InputDatasets(&inputs));
150   std::vector<IteratorContext> result;
151   if (ctx->split_providers().empty()) {
152     for (int i = 0; i < inputs.size(); ++i) {
153       result.emplace_back(ctx);
154     }
155     return result;
156   }
157   int64_t num_sources = 0;
158   for (size_t i = 0; i < inputs.size(); ++i) {
159     if (inputs[i]->num_sources() < 0) {
160       return errors::FailedPrecondition(
161           "Failed to determine the number of sources for dataset of type ",
162           inputs[i]->type_string());
163     }
164     num_sources += inputs[i]->num_sources();
165   }
166   if (num_sources != ctx->split_providers().size()) {
167     return errors::FailedPrecondition(
168         "Attempted to feed ", ctx->split_providers().size(),
169         " split providers into a dataset with ", num_sources, " sources");
170   }
171   int64_t split_provider_index = 0;
172   for (size_t i = 0; i < inputs.size(); ++i) {
173     IteratorContext::Params params(ctx);
174     params.split_providers.clear();
175     for (int j = 0; j < inputs[i]->num_sources(); ++j) {
176       params.split_providers.push_back(
177           ctx->split_providers()[split_provider_index + j]);
178     }
179     split_provider_index += inputs[i]->num_sources();
180     result.emplace_back(std::move(params));
181   }
182   return result;
183 }
184 
GetSingleSplitProvider(IteratorContext * ctx,const DatasetBase * dataset)185 StatusOr<std::shared_ptr<SplitProvider>> GetSingleSplitProvider(
186     IteratorContext* ctx, const DatasetBase* dataset) {
187   if (ctx->split_providers().size() != 1) {
188     return errors::FailedPrecondition(
189         "Failed to get single split provider for dataset ",
190         dataset->DebugString(), ". Found ", ctx->split_providers().size(),
191         " split providers");
192   }
193   return ctx->split_providers()[0];
194 }
195 
GetSplitProviders(const DatasetBase * dataset)196 StatusOr<std::vector<std::unique_ptr<SplitProvider>>> GetSplitProviders(
197     const DatasetBase* dataset) {
198   std::vector<std::unique_ptr<SplitProvider>> result;
199   std::vector<const DatasetBase*> inputs;
200   TF_RETURN_IF_ERROR(dataset->InputDatasets(&inputs));
201   for (const auto& input : inputs) {
202     std::vector<std::unique_ptr<SplitProvider>> providers;
203     TF_RETURN_IF_ERROR(input->MakeSplitProviders(&providers));
204     for (auto& provider : providers) {
205       result.push_back(std::move(provider));
206     }
207   }
208   return result;
209 }
210 
CreateInputIteratorContexts(IteratorContext * ctx,const DatasetBase * dataset)211 StatusOr<std::vector<IteratorContext>> CreateInputIteratorContexts(
212     IteratorContext* ctx, const DatasetBase* dataset) {
213   std::vector<const DatasetBase*> inputs;
214   TF_RETURN_IF_ERROR(dataset->InputDatasets(&inputs));
215   std::vector<IteratorContext> result;
216   if (ctx->split_providers().empty()) {
217     for (int i = 0; i < inputs.size(); ++i) {
218       result.emplace_back(ctx);
219     }
220     return result;
221   }
222   int64_t split_provider_index = 0;
223   for (size_t i = 0; i < inputs.size(); ++i) {
224     IteratorContext::Params params(ctx);
225     if (inputs[i]->num_sources() < 0) {
226       return errors::FailedPrecondition(
227           "Failed to determine the number of sources for dataset of type ",
228           inputs[i]->type_string());
229     }
230     params.split_providers.clear();
231     for (int j = 0; j < inputs[i]->num_sources(); ++j) {
232       params.split_providers.push_back(
233           ctx->split_providers()[split_provider_index + j]);
234     }
235     split_provider_index += inputs[i]->num_sources();
236     result.emplace_back(std::move(params));
237   }
238   if (split_provider_index != ctx->split_providers().size()) {
239     return errors::FailedPrecondition("Attempted to feed ",
240                                       ctx->split_providers().size(),
241                                       " split providers into a dataset with ",
242                                       split_provider_index, " sources");
243   }
244   return result;
245 }
246 
247 }  // namespace data
248 }  // namespace tensorflow
249