• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <queue>
16 
17 #include "tensorflow/core/framework/dataset.h"
18 #include "tensorflow/core/framework/op_kernel.h"
19 #include "tensorflow/core/framework/partial_tensor_shape.h"
20 #include "tensorflow/core/framework/tensor.h"
21 #include "tensorflow/core/framework/tensor_shape.h"
22 #include "tensorflow/core/lib/core/blocking_counter.h"
23 #include "tensorflow/core/lib/core/errors.h"
24 #include "tensorflow/core/lib/core/threadpool.h"
25 #include "tensorflow/core/lib/io/buffered_inputstream.h"
26 #include "tensorflow/core/lib/io/inputbuffer.h"
27 #include "tensorflow/core/lib/io/path.h"
28 #include "tensorflow/core/lib/io/random_inputstream.h"
29 #include "tensorflow/core/lib/io/record_reader.h"
30 #include "tensorflow/core/lib/io/zlib_compression_options.h"
31 #include "tensorflow/core/lib/io/zlib_inputstream.h"
32 #include "tensorflow/core/platform/env.h"
33 
34 namespace tensorflow {
35 namespace data {
36 namespace experimental {
37 namespace {
38 
39 class MatchingFilesDatasetOp : public DatasetOpKernel {
40  public:
41   using DatasetOpKernel::DatasetOpKernel;
42 
MakeDataset(OpKernelContext * ctx,DatasetBase ** output)43   void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
44     const Tensor* patterns_t;
45     OP_REQUIRES_OK(ctx, ctx->input("patterns", &patterns_t));
46     const auto patterns = patterns_t->flat<tstring>();
47     size_t num_patterns = static_cast<size_t>(patterns.size());
48     std::vector<tstring> pattern_strs;
49     pattern_strs.reserve(num_patterns);
50 
51     for (size_t i = 0; i < num_patterns; i++) {
52       pattern_strs.push_back(patterns(i));
53     }
54 
55     *output = new Dataset(ctx, std::move(pattern_strs));
56   }
57 
58  private:
59   class Dataset : public DatasetBase {
60    public:
Dataset(OpKernelContext * ctx,std::vector<tstring> patterns)61     Dataset(OpKernelContext* ctx, std::vector<tstring> patterns)
62         : DatasetBase(DatasetContext(ctx)), patterns_(std::move(patterns)) {}
63 
MakeIteratorInternal(const string & prefix) const64     std::unique_ptr<IteratorBase> MakeIteratorInternal(
65         const string& prefix) const override {
66       return absl::make_unique<Iterator>(
67           Iterator::Params{this, strings::StrCat(prefix, "::MatchingFiles")});
68     }
69 
output_dtypes() const70     const DataTypeVector& output_dtypes() const override {
71       static DataTypeVector* dtypes = new DataTypeVector({DT_STRING});
72       return *dtypes;
73     }
74 
output_shapes() const75     const std::vector<PartialTensorShape>& output_shapes() const override {
76       static std::vector<PartialTensorShape>* shapes =
77           new std::vector<PartialTensorShape>({{}});
78       return *shapes;
79     }
80 
DebugString() const81     string DebugString() const override {
82       return "MatchingFilesDatasetOp::Dataset";
83     }
84 
InputDatasets(std::vector<const DatasetBase * > * inputs) const85     Status InputDatasets(
86         std::vector<const DatasetBase*>* inputs) const override {
87       return Status::OK();
88     }
89 
CheckExternalState() const90     Status CheckExternalState() const override { return Status::OK(); }
91 
92    protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const93     Status AsGraphDefInternal(SerializationContext* ctx,
94                               DatasetGraphDefBuilder* b,
95                               Node** output) const override {
96       Node* patterns_node = nullptr;
97       TF_RETURN_IF_ERROR(b->AddVector(patterns_, &patterns_node));
98       TF_RETURN_IF_ERROR(b->AddDataset(this, {patterns_node}, output));
99       return Status::OK();
100     }
101 
102    private:
103     class Iterator : public DatasetIterator<Dataset> {
104      public:
Iterator(const Params & params)105       explicit Iterator(const Params& params)
106           : DatasetIterator<Dataset>(params) {}
107 
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)108       Status GetNextInternal(IteratorContext* ctx,
109                              std::vector<Tensor>* out_tensors,
110                              bool* end_of_sequence) override {
111         mutex_lock l(mu_);
112         FileSystem* fs;
113 
114         TF_RETURN_IF_ERROR(ctx->env()->GetFileSystemForFile(
115             dataset()->patterns_[(current_pattern_index_ > 0)
116                                      ? current_pattern_index_ - 1
117                                      : 0],
118             &fs));
119 
120         while (!filepath_queue_.empty() ||
121                current_pattern_index_ < dataset()->patterns_.size()) {
122           // All the elements in the heap will be the matched filenames or the
123           // potential directories.
124           if (!filepath_queue_.empty()) {
125             PathStatus current_path = filepath_queue_.top();
126             filepath_queue_.pop();
127 
128             if (!current_path.second) {
129               Tensor filepath_tensor(ctx->allocator({}), DT_STRING, {});
130 
131               // Replace the forward slash with the backslash for Windows path
132               if (isWindows_) {
133                 std::replace(current_path.first.begin(),
134                              current_path.first.end(), '/', '\\');
135               }
136 
137               filepath_tensor.scalar<tstring>()() =
138                   std::move(current_path.first);
139               out_tensors->emplace_back(std::move(filepath_tensor));
140               *end_of_sequence = false;
141               hasMatch_ = true;
142               return Status::OK();
143             }
144 
145             // In this case, current_path is a directory. Then continue the
146             // search.
147             TF_RETURN_IF_ERROR(
148                 UpdateIterator(ctx, fs, current_path.first, current_pattern_));
149           } else {
150             // search a new pattern
151             current_pattern_ = dataset()->patterns_[current_pattern_index_];
152             StringPiece current_pattern_view = StringPiece(current_pattern_);
153 
154             // Windows paths contain backslashes and Windows APIs accept forward
155             // and backslashes equivalently, so we convert the pattern to use
156             // forward slashes exclusively. The backslash is used as the
157             // indicator of Windows paths. Note that this is not ideal, since
158             // the API expects backslash as an escape character, but no code
159             // appears to rely on this behavior
160             if (current_pattern_view.find('\\') != std::string::npos) {
161               isWindows_ = true;
162               std::replace(&current_pattern_[0],
163                            &current_pattern_[0] + current_pattern_.size(), '\\',
164                            '/');
165             } else {
166               isWindows_ = false;
167             }
168 
169             StringPiece fixed_prefix = current_pattern_view.substr(
170                 0, current_pattern_view.find_first_of("*?[\\"));
171             string current_dir(io::Dirname(fixed_prefix));
172 
173             // If current_dir is empty then we need to fix up fixed_prefix and
174             // current_pattern_ to include . as the top level directory.
175             if (current_dir.empty()) {
176               current_dir = ".";
177               current_pattern_ = io::JoinPath(current_dir, current_pattern_);
178             }
179 
180             TF_RETURN_IF_ERROR(
181                 UpdateIterator(ctx, fs, current_dir, current_pattern_));
182             ++current_pattern_index_;
183           }
184         }
185 
186         *end_of_sequence = true;
187         if (hasMatch_) {
188           return Status::OK();
189         } else {
190           return errors::NotFound("Don't find any matched files");
191         }
192       }
193 
194      protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const195       std::shared_ptr<model::Node> CreateNode(
196           IteratorContext* ctx, model::Node::Args args) const override {
197         return model::MakeSourceNode(std::move(args));
198       }
199 
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)200       Status SaveInternal(SerializationContext* ctx,
201                           IteratorStateWriter* writer) override {
202         mutex_lock l(mu_);
203         TF_RETURN_IF_ERROR(writer->WriteScalar(
204             full_name("current_pattern_index"), current_pattern_index_));
205 
206         TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("current_pattern"),
207                                                current_pattern_));
208         TF_RETURN_IF_ERROR(
209             writer->WriteScalar(full_name("hasMatch"), hasMatch_));
210         TF_RETURN_IF_ERROR(
211             writer->WriteScalar(full_name("isWindows"), isWindows_));
212 
213         if (!filepath_queue_.empty()) {
214           TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("queue_size"),
215                                                  filepath_queue_.size()));
216           int i = 0;
217           while (!filepath_queue_.empty()) {
218             TF_RETURN_IF_ERROR(
219                 writer->WriteScalar(full_name(strings::StrCat("path_", i)),
220                                     filepath_queue_.top().first));
221             TF_RETURN_IF_ERROR(writer->WriteScalar(
222                 full_name(strings::StrCat("path_status_", i)),
223                 filepath_queue_.top().second));
224             filepath_queue_.pop();
225             i++;
226           }
227         }
228 
229         return Status::OK();
230       }
231 
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)232       Status RestoreInternal(IteratorContext* ctx,
233                              IteratorStateReader* reader) override {
234         mutex_lock l(mu_);
235         int64 current_pattern_index;
236         TF_RETURN_IF_ERROR(reader->ReadScalar(
237             full_name("current_pattern_index"), &current_pattern_index));
238         current_pattern_index_ = size_t(current_pattern_index);
239 
240         tstring current_pattern_tstr;
241         TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("current_pattern"),
242                                               &current_pattern_tstr));
243         current_pattern_ = current_pattern_tstr;
244 
245         int64 hasMatch;
246         TF_RETURN_IF_ERROR(
247             reader->ReadScalar(full_name("hasMatch"), &hasMatch));
248         hasMatch_ = static_cast<bool>(hasMatch);
249 
250         int64 isWindows;
251         TF_RETURN_IF_ERROR(
252             reader->ReadScalar(full_name("isWindows"), &isWindows));
253         isWindows_ = static_cast<bool>(isWindows);
254 
255         if (reader->Contains(full_name("queue_size"))) {
256           int64 queue_size;
257           TF_RETURN_IF_ERROR(
258               reader->ReadScalar(full_name("queue_size"), &queue_size));
259           for (int i = 0; i < queue_size; i++) {
260             tstring path;
261             int64 path_status;
262             TF_RETURN_IF_ERROR(reader->ReadScalar(
263                 full_name(strings::StrCat("path_", i)), &path));
264             TF_RETURN_IF_ERROR(reader->ReadScalar(
265                 full_name(strings::StrCat("path_status_", i)), &path_status));
266             filepath_queue_.push(
267                 PathStatus(path, static_cast<bool>(path_status)));
268           }
269         }
270 
271         return Status::OK();
272       }
273 
274      private:
UpdateIterator(IteratorContext * ctx,FileSystem * fs,const string & dir,const string & eval_pattern)275       Status UpdateIterator(IteratorContext* ctx, FileSystem* fs,
276                             const string& dir, const string& eval_pattern)
277           TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
278         StringPiece fixed_prefix =
279             StringPiece(eval_pattern)
280                 .substr(0, eval_pattern.find_first_of("*?[\\"));
281 
282         filepath_queue_.push(PathStatus(dir, true));
283         Status ret;  // Status to return
284 
285         // DFS to find the first element in the iterator.
286         while (!filepath_queue_.empty()) {
287           const PathStatus current_path = filepath_queue_.top();
288 
289           // All the files in the heap are matched with the pattern, so finish
290           // the search if current_path is a file.
291           if (!current_path.second) {
292             return Status::OK();
293           }
294 
295           filepath_queue_.pop();
296 
297           // If current_path is a directory, search its children.
298           const string& current_dir = current_path.first;
299           std::vector<string> children;
300           ret.Update(fs->GetChildren(current_dir, &children));
301 
302           // Handle the error cases: 1) continue the search if the status is
303           // NOT_FOUND; 2) return the non-ok status immediately if it is not
304           // NOT_FOUND.
305           if (ret.code() == error::NOT_FOUND) {
306             continue;
307           } else if (!ret.ok()) {
308             return ret;
309           }
310 
311           // children_dir_status holds is_dir status for children. It can have
312           // three possible values: OK for true; FAILED_PRECONDITION for false;
313           // CANCELLED if we don't calculate IsDirectory (we might do that
314           // because there isn't any point in exploring that child path).
315           std::vector<Status> children_dir_status;
316           children_dir_status.resize(children.size());
317 
318           // This IsDirectory call can be expensive for some FS. Parallelizing
319           // it.
320           auto is_directory_fn = [fs, current_dir, &children, &fixed_prefix,
321                                   &children_dir_status](int i) {
322             const string child_path = io::JoinPath(current_dir, children[i]);
323             // In case the child_path doesn't start with the fixed_prefix, then
324             // we don't need to explore this path.
325             if (!absl::StartsWith(child_path, fixed_prefix)) {
326               children_dir_status[i] =
327                   errors::Cancelled("Operation not needed");
328             } else {
329               children_dir_status[i] = fs->IsDirectory(child_path);
330             }
331           };
332 
333           BlockingCounter counter(children.size());
334           for (int i = 0; i < children.size(); i++) {
335             (*ctx->runner())([&is_directory_fn, &counter, i] {
336               is_directory_fn(i);
337               counter.DecrementCount();
338             });
339           }
340           counter.Wait();
341 
342           for (int i = 0; i < children.size(); i++) {
343             const string& child_dir_path =
344                 io::JoinPath(current_dir, children[i]);
345             const Status& child_dir_status = children_dir_status[i];
346 
347             // If the IsDirectory call was cancelled we bail.
348             if (child_dir_status.code() == tensorflow::error::CANCELLED) {
349               continue;
350             }
351 
352             if (child_dir_status.ok()) {
353               // push the child dir for next search
354               filepath_queue_.push(PathStatus(child_dir_path, true));
355             } else {
356               // This case will be a file: if the file matches the pattern, push
357               // it to the heap; otherwise, ignore it.
358               if (ctx->env()->MatchPath(child_dir_path, eval_pattern)) {
359                 filepath_queue_.push(PathStatus(child_dir_path, false));
360               }
361             }
362           }
363         }
364         return ret;
365       }
366 
367       mutex mu_;
368       // True means the path is a directory; False means the path is a filename.
369       typedef std::pair<string, bool> PathStatus;
370       std::priority_queue<PathStatus, std::vector<PathStatus>,
371                           std::greater<PathStatus>>
372           filepath_queue_ TF_GUARDED_BY(mu_);
373       size_t current_pattern_index_ TF_GUARDED_BY(mu_) = 0;
374       tstring current_pattern_ TF_GUARDED_BY(mu_);
375       bool hasMatch_ TF_GUARDED_BY(mu_) = false;
376       bool isWindows_ TF_GUARDED_BY(mu_) = false;
377     };
378 
379     const std::vector<tstring> patterns_;
380   };
381 };
382 
383 REGISTER_KERNEL_BUILDER(Name("MatchingFilesDataset").Device(DEVICE_CPU),
384                         MatchingFilesDatasetOp);
385 REGISTER_KERNEL_BUILDER(
386     Name("ExperimentalMatchingFilesDataset").Device(DEVICE_CPU),
387     MatchingFilesDatasetOp);
388 
389 }  // namespace
390 }  // namespace experimental
391 }  // namespace data
392 }  // namespace tensorflow
393