• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // The utility to read checkpoints for google brain tensor ops and v3
17 // checkpoints for dist_belief.
18 
19 #ifndef TENSORFLOW_CORE_UTIL_TENSOR_SLICE_READER_CACHE_H_
20 #define TENSORFLOW_CORE_UTIL_TENSOR_SLICE_READER_CACHE_H_
21 
22 #include <unordered_map>
23 
24 #include "tensorflow/core/lib/core/status.h"
25 #include "tensorflow/core/platform/mutex.h"
26 #include "tensorflow/core/platform/types.h"
27 #include "tensorflow/core/util/tensor_slice_reader.h"
28 
29 namespace tensorflow {
30 
31 namespace checkpoint {
32 
33 class TensorSliceReaderCache;
34 
35 // Wrapper to a lazily allocated TensorSliceReaderCache.
36 class TensorSliceReaderCacheWrapper {
37  public:
38   TensorSliceReaderCacheWrapper();
39   ~TensorSliceReaderCacheWrapper();
40 
41   // Same as TensorSliceReaderCache::GetReader().
42   const TensorSliceReader* GetReader(
43       const string& filepattern,
44       TensorSliceReader::OpenTableFunction open_function,
45       int preferred_shard) const;
46 
47  private:
48   mutable mutex mu_;
49   mutable TensorSliceReaderCache* cache_ = nullptr;
50 };
51 
52 // A cache of TensorSliceReaders.
53 class TensorSliceReaderCache {
54  public:
55   TensorSliceReaderCache();
56   ~TensorSliceReaderCache();
57 
58   // Returns the TensorSliceReader corresponding to 'filepattern' and the
59   // open_function.  May return nullptr if we can not create a new
60   // TensorSliceReader for the filepattern/open_function combination.
61   const TensorSliceReader* GetReader(
62       const string& filepattern,
63       TensorSliceReader::OpenTableFunction open_function, int preferred_shard);
64 
65  private:
66   // Need to use a regular function type in the key map as std::function does
67   // not support ==.
68   typedef Status (*OpenFuncType)(const string&, TensorSliceReader::Table**);
69 
70   // Protects attributes below.
71   mutex mu_;
72 
73   // Maps of opened readers.
74   std::unordered_map<string, std::pair<OpenFuncType, TensorSliceReader*>>
75       readers_;
76 
77   // Set of keys that a previous GetReader() call is still trying to populate.
78   std::set<string> still_opening_;
79 
80   // Condition variable to notify when a reader has been created.
81   condition_variable cv_;
82 };
83 
84 }  // namespace checkpoint
85 
86 }  // namespace tensorflow
87 
88 #endif  // TENSORFLOW_CORE_UTIL_TENSOR_SLICE_READER_CACHE_H_
89