1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // The utility to read checkpoints for google brain tensor ops and v3 17 // checkpoints for dist_belief. 18 19 #ifndef TENSORFLOW_CORE_UTIL_TENSOR_SLICE_READER_CACHE_H_ 20 #define TENSORFLOW_CORE_UTIL_TENSOR_SLICE_READER_CACHE_H_ 21 22 #include <unordered_map> 23 24 #include "tensorflow/core/lib/core/status.h" 25 #include "tensorflow/core/platform/mutex.h" 26 #include "tensorflow/core/platform/types.h" 27 #include "tensorflow/core/util/tensor_slice_reader.h" 28 29 namespace tensorflow { 30 31 namespace checkpoint { 32 33 class TensorSliceReaderCache; 34 35 // Wrapper to a lazily allocated TensorSliceReaderCache. 36 class TensorSliceReaderCacheWrapper { 37 public: 38 TensorSliceReaderCacheWrapper(); 39 ~TensorSliceReaderCacheWrapper(); 40 41 // Same as TensorSliceReaderCache::GetReader(). 42 const TensorSliceReader* GetReader( 43 const string& filepattern, 44 TensorSliceReader::OpenTableFunction open_function, 45 int preferred_shard) const; 46 47 private: 48 mutable mutex mu_; 49 mutable TensorSliceReaderCache* cache_ = nullptr; 50 }; 51 52 // A cache of TensorSliceReaders. 53 class TensorSliceReaderCache { 54 public: 55 TensorSliceReaderCache(); 56 ~TensorSliceReaderCache(); 57 58 // Returns the TensorSliceReader corresponding to 'filepattern' and the 59 // open_function. May return nullptr if we can not create a new 60 // TensorSliceReader for the filepattern/open_function combination. 61 const TensorSliceReader* GetReader( 62 const string& filepattern, 63 TensorSliceReader::OpenTableFunction open_function, int preferred_shard); 64 65 private: 66 // Need to use a regular function type in the key map as std::function does 67 // not support ==. 68 typedef Status (*OpenFuncType)(const string&, TensorSliceReader::Table**); 69 70 // Protects attributes below. 71 mutex mu_; 72 73 // Maps of opened readers. 74 std::unordered_map<string, std::pair<OpenFuncType, TensorSliceReader*>> 75 readers_; 76 77 // Set of keys that a previous GetReader() call is still trying to populate. 78 std::set<string> still_opening_; 79 80 // Condition variable to notify when a reader has been created. 81 condition_variable cv_; 82 }; 83 84 } // namespace checkpoint 85 86 } // namespace tensorflow 87 88 #endif // TENSORFLOW_CORE_UTIL_TENSOR_SLICE_READER_CACHE_H_ 89