• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // The utility to read checkpoints for google brain tensor ops and v3
17 // checkpoints for dist_belief.
18 
19 #ifndef TENSORFLOW_CORE_UTIL_TENSOR_SLICE_READER_H_
20 #define TENSORFLOW_CORE_UTIL_TENSOR_SLICE_READER_H_
21 
22 #include <unordered_map>
23 
24 #include <vector>
25 #include "tensorflow/core/framework/tensor.h"
26 #include "tensorflow/core/framework/tensor_shape.h"
27 #include "tensorflow/core/framework/tensor_slice.h"
28 #include "tensorflow/core/framework/types.pb.h"
29 #include "tensorflow/core/lib/core/status.h"
30 #include "tensorflow/core/lib/core/stringpiece.h"
31 #include "tensorflow/core/lib/gtl/map_util.h"
32 #include "tensorflow/core/platform/logging.h"
33 #include "tensorflow/core/platform/macros.h"
34 #include "tensorflow/core/platform/mutex.h"
35 #include "tensorflow/core/platform/protobuf.h"
36 #include "tensorflow/core/platform/types.h"
37 #include "tensorflow/core/util/saved_tensor_slice.pb.h"
38 #include "tensorflow/core/util/saved_tensor_slice_util.h"
39 #include "tensorflow/core/util/tensor_slice_set.h"
40 #include "tensorflow/core/util/tensor_slice_util.h"
41 
42 namespace tensorflow {
43 
44 namespace checkpoint {
45 
46 // The reader reads in all the meta data about all the tensor slices. Then it
47 // will try to read the relevant data on-demand to produce the data for the
48 // slices needed.
49 // NOTE(yangke): another way to do this is to first load a list of the tensor
50 // slices needed and then just selectively read some of the meta data. That
51 // might optimize the loading but makes the logic a bit more complicated. We
52 // might want to revisit that.
53 // TODO(yangke): consider moving to TensorProto.
54 class TensorSliceReader {
55  public:
56   // Abstract interface for reading data out of a tensor slice checkpoint file
57   class Table {
58    public:
59     virtual ~Table();
60     virtual bool Get(const string& key, string* value) = 0;
61   };
62   typedef std::function<Status(const string&, Table**)> OpenTableFunction;
63 
64   static constexpr int kLoadAllShards = -1;
65   TensorSliceReader(const string& filepattern);
66   TensorSliceReader(const string& filepattern, OpenTableFunction open_function);
67   TensorSliceReader(const string& filepattern, OpenTableFunction open_function,
68                     int preferred_shard);
69   virtual ~TensorSliceReader();
70 
71   // Get the filename this reader is attached to.
filepattern()72   const string& filepattern() const { return filepattern_; }
73 
74   // Get the number of files matched.
num_files()75   int num_files() const { return sss_.size(); }
76 
77   // Get the status of the reader.
status()78   const Status status() const { return status_; }
79 
80   // Checks if the reader contains any slice of a tensor. In case the reader
81   // does contain the tensor, if "shape" is not nullptr, fill "shape" with the
82   // shape of the tensor; if "type" is not nullptr, fill "type" with the type
83   // of the tensor.
84   bool HasTensor(const string& name, TensorShape* shape, DataType* type) const;
85 
86   // Checks if the reader contains all the data about a tensor slice, and if
87   // yes, copies the data of the slice to "data". The caller needs to make sure
88   // that "data" points to a buffer that holds enough data.
89   // This is a slow function since it needs to read sstables.
90   template <typename T>
91   bool CopySliceData(const string& name, const TensorSlice& slice,
92                      T* data) const;
93 
94   // Get the tensors.
Tensors()95   const std::unordered_map<string, TensorSliceSet*>& Tensors() const {
96     return tensors_;
97   }
98 
99   // Returns value for one tensor. Only single slice checkpoints are supported
100   // at the moment.
101   Status GetTensor(const string& name,
102                    std::unique_ptr<tensorflow::Tensor>* out_tensor) const;
103 
104   typedef std::unordered_map<string, TensorShape> VarToShapeMap;
105   typedef std::unordered_map<string, DataType> VarToDataTypeMap;
106 
107   // Returns a map from tensor name to shape.
108   VarToShapeMap GetVariableToShapeMap() const;
109 
110   // Returns a map from tensor name to data type.
111   VarToDataTypeMap GetVariableToDataTypeMap() const;
112 
113   // Returns a string containing names and shapes of all the tensors.
114   const string DebugString() const;
115 
116  private:
117   friend class TensorSliceWriteTestHelper;
118 
119   void LoadShard(int shard) const;
120   void LoadAllShards() const;
121 
122   const TensorSliceSet* FindTensorSlice(
123       const string& name, const TensorSlice& slice,
124       std::vector<std::pair<TensorSlice, string>>* details) const;
125 
126   const string filepattern_;
127   const OpenTableFunction open_function_;
128   std::vector<string> fnames_;
129   std::unordered_map<string, int> fname_to_index_;
130 
131   // Guards the attributes below.
132   mutable mutex mu_;
133   mutable bool all_shards_loaded_ = false;
134   mutable std::vector<std::unique_ptr<Table>> sss_;
135   mutable std::unordered_map<string, TensorSliceSet*> tensors_;
136   mutable Status status_;
137 
138   TF_DISALLOW_COPY_AND_ASSIGN(TensorSliceReader);
139 };
140 
141 Status OpenTableTensorSliceReader(const string& fname,
142                                   TensorSliceReader::Table** table);
143 
144 template <typename T>
CopySliceData(const string & name,const TensorSlice & slice,T * data)145 bool TensorSliceReader::CopySliceData(const string& name,
146                                       const TensorSlice& slice, T* data) const {
147   std::vector<std::pair<TensorSlice, string>> details;
148   const TensorSliceSet* tss;
149   {
150     mutex_lock l(mu_);
151     tss = FindTensorSlice(name, slice, &details);
152     if (!tss && !all_shards_loaded_) {
153       VLOG(1) << "Did not find slice in preferred shard, loading all shards."
154               << name << ": " << slice.DebugString();
155       LoadAllShards();
156       tss = FindTensorSlice(name, slice, &details);
157     }
158     if (!tss) {
159       // No such tensor
160       return false;
161     }
162   }
163   // We have the data -- copy it over.
164   string value;
165   for (const auto& x : details) {
166     const TensorSlice& slice_s = x.first;
167     const string& fname = x.second;
168     int idx = gtl::FindWithDefault(fname_to_index_, fname, -1);
169     CHECK_GE(idx, 0) << "Failed to find the index for filename " << fname;
170     // We read a record in the corresponding sstable
171     const string key = EncodeTensorNameSlice(name, slice_s);
172     if (!sss_[idx]->Get(key, &value)) {
173       VLOG(1) << "Failed to seek to the record for tensor " << name
174               << ", slice " << slice_s.DebugString()
175               << ": computed key = " << key;
176       return false;
177     }
178     SavedTensorSlices sts;
179     if (!ParseProtoUnlimited(&sts, value)) {
180       VLOG(1) << "Failed to parse the record for tensor " << name << ", slice "
181               << slice_s.DebugString() << ": computed key = " << key;
182       return false;
183     }
184     CopyDataFromTensorSliceToTensorSlice(
185         tss->shape(), slice_s, slice,
186         checkpoint::TensorProtoData<T>(sts.data().data()), data);
187   }
188   return true;
189 }
190 
191 }  // namespace checkpoint
192 
193 }  // namespace tensorflow
194 
195 #endif  // TENSORFLOW_CORE_UTIL_TENSOR_SLICE_READER_H_
196