1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #include <queue> 16 17 #include "tensorflow/core/framework/dataset.h" 18 #include "tensorflow/core/framework/op_kernel.h" 19 #include "tensorflow/core/framework/partial_tensor_shape.h" 20 #include "tensorflow/core/framework/tensor.h" 21 #include "tensorflow/core/framework/tensor_shape.h" 22 #include "tensorflow/core/lib/core/blocking_counter.h" 23 #include "tensorflow/core/lib/core/errors.h" 24 #include "tensorflow/core/lib/core/threadpool.h" 25 #include "tensorflow/core/lib/io/buffered_inputstream.h" 26 #include "tensorflow/core/lib/io/inputbuffer.h" 27 #include "tensorflow/core/lib/io/path.h" 28 #include "tensorflow/core/lib/io/random_inputstream.h" 29 #include "tensorflow/core/lib/io/record_reader.h" 30 #include "tensorflow/core/lib/io/zlib_compression_options.h" 31 #include "tensorflow/core/lib/io/zlib_inputstream.h" 32 #include "tensorflow/core/platform/env.h" 33 34 namespace tensorflow { 35 namespace data { 36 namespace experimental { 37 namespace { 38 39 class MatchingFilesDatasetOp : public DatasetOpKernel { 40 public: 41 using DatasetOpKernel::DatasetOpKernel; 42 MakeDataset(OpKernelContext * ctx,DatasetBase ** output)43 void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override { 44 const Tensor* patterns_t; 45 OP_REQUIRES_OK(ctx, ctx->input("patterns", &patterns_t)); 46 const auto patterns = patterns_t->flat<tstring>(); 47 size_t num_patterns = static_cast<size_t>(patterns.size()); 48 std::vector<tstring> pattern_strs; 49 pattern_strs.reserve(num_patterns); 50 51 for (size_t i = 0; i < num_patterns; i++) { 52 pattern_strs.push_back(patterns(i)); 53 } 54 55 *output = new Dataset(ctx, std::move(pattern_strs)); 56 } 57 58 private: 59 class Dataset : public DatasetBase { 60 public: Dataset(OpKernelContext * ctx,std::vector<tstring> patterns)61 Dataset(OpKernelContext* ctx, std::vector<tstring> patterns) 62 : DatasetBase(DatasetContext(ctx)), patterns_(std::move(patterns)) {} 63 MakeIteratorInternal(const string & prefix) const64 std::unique_ptr<IteratorBase> MakeIteratorInternal( 65 const string& prefix) const override { 66 return absl::make_unique<Iterator>( 67 Iterator::Params{this, strings::StrCat(prefix, "::MatchingFiles")}); 68 } 69 output_dtypes() const70 const DataTypeVector& output_dtypes() const override { 71 static DataTypeVector* dtypes = new DataTypeVector({DT_STRING}); 72 return *dtypes; 73 } 74 output_shapes() const75 const std::vector<PartialTensorShape>& output_shapes() const override { 76 static std::vector<PartialTensorShape>* shapes = 77 new std::vector<PartialTensorShape>({{}}); 78 return *shapes; 79 } 80 DebugString() const81 string DebugString() const override { 82 return "MatchingFilesDatasetOp::Dataset"; 83 } 84 InputDatasets(std::vector<const DatasetBase * > * inputs) const85 Status InputDatasets( 86 std::vector<const DatasetBase*>* inputs) const override { 87 return Status::OK(); 88 } 89 CheckExternalState() const90 Status CheckExternalState() const override { return Status::OK(); } 91 92 protected: AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const93 Status AsGraphDefInternal(SerializationContext* ctx, 94 DatasetGraphDefBuilder* b, 95 Node** output) const override { 96 Node* patterns_node = nullptr; 97 TF_RETURN_IF_ERROR(b->AddVector(patterns_, &patterns_node)); 98 TF_RETURN_IF_ERROR(b->AddDataset(this, {patterns_node}, output)); 99 return Status::OK(); 100 } 101 102 private: 103 class Iterator : public DatasetIterator<Dataset> { 104 public: Iterator(const Params & params)105 explicit Iterator(const Params& params) 106 : DatasetIterator<Dataset>(params) {} 107 GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)108 Status GetNextInternal(IteratorContext* ctx, 109 std::vector<Tensor>* out_tensors, 110 bool* end_of_sequence) override { 111 mutex_lock l(mu_); 112 FileSystem* fs; 113 114 TF_RETURN_IF_ERROR(ctx->env()->GetFileSystemForFile( 115 dataset()->patterns_[(current_pattern_index_ > 0) 116 ? current_pattern_index_ - 1 117 : 0], 118 &fs)); 119 120 while (!filepath_queue_.empty() || 121 current_pattern_index_ < dataset()->patterns_.size()) { 122 // All the elements in the heap will be the matched filenames or the 123 // potential directories. 124 if (!filepath_queue_.empty()) { 125 PathStatus current_path = filepath_queue_.top(); 126 filepath_queue_.pop(); 127 128 if (!current_path.second) { 129 Tensor filepath_tensor(ctx->allocator({}), DT_STRING, {}); 130 131 // Replace the forward slash with the backslash for Windows path 132 if (isWindows_) { 133 std::replace(current_path.first.begin(), 134 current_path.first.end(), '/', '\\'); 135 } 136 137 filepath_tensor.scalar<tstring>()() = 138 std::move(current_path.first); 139 out_tensors->emplace_back(std::move(filepath_tensor)); 140 *end_of_sequence = false; 141 hasMatch_ = true; 142 return Status::OK(); 143 } 144 145 // In this case, current_path is a directory. Then continue the 146 // search. 147 TF_RETURN_IF_ERROR( 148 UpdateIterator(ctx, fs, current_path.first, current_pattern_)); 149 } else { 150 // search a new pattern 151 current_pattern_ = dataset()->patterns_[current_pattern_index_]; 152 StringPiece current_pattern_view = StringPiece(current_pattern_); 153 154 // Windows paths contain backslashes and Windows APIs accept forward 155 // and backslashes equivalently, so we convert the pattern to use 156 // forward slashes exclusively. The backslash is used as the 157 // indicator of Windows paths. Note that this is not ideal, since 158 // the API expects backslash as an escape character, but no code 159 // appears to rely on this behavior 160 if (current_pattern_view.find('\\') != std::string::npos) { 161 isWindows_ = true; 162 std::replace(¤t_pattern_[0], 163 ¤t_pattern_[0] + current_pattern_.size(), '\\', 164 '/'); 165 } else { 166 isWindows_ = false; 167 } 168 169 StringPiece fixed_prefix = current_pattern_view.substr( 170 0, current_pattern_view.find_first_of("*?[\\")); 171 string current_dir(io::Dirname(fixed_prefix)); 172 173 // If current_dir is empty then we need to fix up fixed_prefix and 174 // current_pattern_ to include . as the top level directory. 175 if (current_dir.empty()) { 176 current_dir = "."; 177 current_pattern_ = io::JoinPath(current_dir, current_pattern_); 178 } 179 180 TF_RETURN_IF_ERROR( 181 UpdateIterator(ctx, fs, current_dir, current_pattern_)); 182 ++current_pattern_index_; 183 } 184 } 185 186 *end_of_sequence = true; 187 if (hasMatch_) { 188 return Status::OK(); 189 } else { 190 return errors::NotFound("Don't find any matched files"); 191 } 192 } 193 194 protected: CreateNode(IteratorContext * ctx,model::Node::Args args) const195 std::shared_ptr<model::Node> CreateNode( 196 IteratorContext* ctx, model::Node::Args args) const override { 197 return model::MakeSourceNode(std::move(args)); 198 } 199 SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)200 Status SaveInternal(SerializationContext* ctx, 201 IteratorStateWriter* writer) override { 202 mutex_lock l(mu_); 203 TF_RETURN_IF_ERROR(writer->WriteScalar( 204 full_name("current_pattern_index"), current_pattern_index_)); 205 206 TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("current_pattern"), 207 current_pattern_)); 208 TF_RETURN_IF_ERROR( 209 writer->WriteScalar(full_name("hasMatch"), hasMatch_)); 210 TF_RETURN_IF_ERROR( 211 writer->WriteScalar(full_name("isWindows"), isWindows_)); 212 213 if (!filepath_queue_.empty()) { 214 TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("queue_size"), 215 filepath_queue_.size())); 216 int i = 0; 217 while (!filepath_queue_.empty()) { 218 TF_RETURN_IF_ERROR( 219 writer->WriteScalar(full_name(strings::StrCat("path_", i)), 220 filepath_queue_.top().first)); 221 TF_RETURN_IF_ERROR(writer->WriteScalar( 222 full_name(strings::StrCat("path_status_", i)), 223 filepath_queue_.top().second)); 224 filepath_queue_.pop(); 225 i++; 226 } 227 } 228 229 return Status::OK(); 230 } 231 RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)232 Status RestoreInternal(IteratorContext* ctx, 233 IteratorStateReader* reader) override { 234 mutex_lock l(mu_); 235 int64 current_pattern_index; 236 TF_RETURN_IF_ERROR(reader->ReadScalar( 237 full_name("current_pattern_index"), ¤t_pattern_index)); 238 current_pattern_index_ = size_t(current_pattern_index); 239 240 tstring current_pattern_tstr; 241 TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("current_pattern"), 242 ¤t_pattern_tstr)); 243 current_pattern_ = current_pattern_tstr; 244 245 int64 hasMatch; 246 TF_RETURN_IF_ERROR( 247 reader->ReadScalar(full_name("hasMatch"), &hasMatch)); 248 hasMatch_ = static_cast<bool>(hasMatch); 249 250 int64 isWindows; 251 TF_RETURN_IF_ERROR( 252 reader->ReadScalar(full_name("isWindows"), &isWindows)); 253 isWindows_ = static_cast<bool>(isWindows); 254 255 if (reader->Contains(full_name("queue_size"))) { 256 int64 queue_size; 257 TF_RETURN_IF_ERROR( 258 reader->ReadScalar(full_name("queue_size"), &queue_size)); 259 for (int i = 0; i < queue_size; i++) { 260 tstring path; 261 int64 path_status; 262 TF_RETURN_IF_ERROR(reader->ReadScalar( 263 full_name(strings::StrCat("path_", i)), &path)); 264 TF_RETURN_IF_ERROR(reader->ReadScalar( 265 full_name(strings::StrCat("path_status_", i)), &path_status)); 266 filepath_queue_.push( 267 PathStatus(path, static_cast<bool>(path_status))); 268 } 269 } 270 271 return Status::OK(); 272 } 273 274 private: UpdateIterator(IteratorContext * ctx,FileSystem * fs,const string & dir,const string & eval_pattern)275 Status UpdateIterator(IteratorContext* ctx, FileSystem* fs, 276 const string& dir, const string& eval_pattern) 277 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { 278 StringPiece fixed_prefix = 279 StringPiece(eval_pattern) 280 .substr(0, eval_pattern.find_first_of("*?[\\")); 281 282 filepath_queue_.push(PathStatus(dir, true)); 283 Status ret; // Status to return 284 285 // DFS to find the first element in the iterator. 286 while (!filepath_queue_.empty()) { 287 const PathStatus current_path = filepath_queue_.top(); 288 289 // All the files in the heap are matched with the pattern, so finish 290 // the search if current_path is a file. 291 if (!current_path.second) { 292 return Status::OK(); 293 } 294 295 filepath_queue_.pop(); 296 297 // If current_path is a directory, search its children. 298 const string& current_dir = current_path.first; 299 std::vector<string> children; 300 ret.Update(fs->GetChildren(current_dir, &children)); 301 302 // Handle the error cases: 1) continue the search if the status is 303 // NOT_FOUND; 2) return the non-ok status immediately if it is not 304 // NOT_FOUND. 305 if (ret.code() == error::NOT_FOUND) { 306 continue; 307 } else if (!ret.ok()) { 308 return ret; 309 } 310 311 // children_dir_status holds is_dir status for children. It can have 312 // three possible values: OK for true; FAILED_PRECONDITION for false; 313 // CANCELLED if we don't calculate IsDirectory (we might do that 314 // because there isn't any point in exploring that child path). 315 std::vector<Status> children_dir_status; 316 children_dir_status.resize(children.size()); 317 318 // This IsDirectory call can be expensive for some FS. Parallelizing 319 // it. 320 auto is_directory_fn = [fs, current_dir, &children, &fixed_prefix, 321 &children_dir_status](int i) { 322 const string child_path = io::JoinPath(current_dir, children[i]); 323 // In case the child_path doesn't start with the fixed_prefix, then 324 // we don't need to explore this path. 325 if (!absl::StartsWith(child_path, fixed_prefix)) { 326 children_dir_status[i] = 327 errors::Cancelled("Operation not needed"); 328 } else { 329 children_dir_status[i] = fs->IsDirectory(child_path); 330 } 331 }; 332 333 BlockingCounter counter(children.size()); 334 for (int i = 0; i < children.size(); i++) { 335 (*ctx->runner())([&is_directory_fn, &counter, i] { 336 is_directory_fn(i); 337 counter.DecrementCount(); 338 }); 339 } 340 counter.Wait(); 341 342 for (int i = 0; i < children.size(); i++) { 343 const string& child_dir_path = 344 io::JoinPath(current_dir, children[i]); 345 const Status& child_dir_status = children_dir_status[i]; 346 347 // If the IsDirectory call was cancelled we bail. 348 if (child_dir_status.code() == tensorflow::error::CANCELLED) { 349 continue; 350 } 351 352 if (child_dir_status.ok()) { 353 // push the child dir for next search 354 filepath_queue_.push(PathStatus(child_dir_path, true)); 355 } else { 356 // This case will be a file: if the file matches the pattern, push 357 // it to the heap; otherwise, ignore it. 358 if (ctx->env()->MatchPath(child_dir_path, eval_pattern)) { 359 filepath_queue_.push(PathStatus(child_dir_path, false)); 360 } 361 } 362 } 363 } 364 return ret; 365 } 366 367 mutex mu_; 368 // True means the path is a directory; False means the path is a filename. 369 typedef std::pair<string, bool> PathStatus; 370 std::priority_queue<PathStatus, std::vector<PathStatus>, 371 std::greater<PathStatus>> 372 filepath_queue_ TF_GUARDED_BY(mu_); 373 size_t current_pattern_index_ TF_GUARDED_BY(mu_) = 0; 374 tstring current_pattern_ TF_GUARDED_BY(mu_); 375 bool hasMatch_ TF_GUARDED_BY(mu_) = false; 376 bool isWindows_ TF_GUARDED_BY(mu_) = false; 377 }; 378 379 const std::vector<tstring> patterns_; 380 }; 381 }; 382 383 REGISTER_KERNEL_BUILDER(Name("MatchingFilesDataset").Device(DEVICE_CPU), 384 MatchingFilesDatasetOp); 385 REGISTER_KERNEL_BUILDER( 386 Name("ExperimentalMatchingFilesDataset").Device(DEVICE_CPU), 387 MatchingFilesDatasetOp); 388 389 } // namespace 390 } // namespace experimental 391 } // namespace data 392 } // namespace tensorflow 393