• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/util/tensor_slice_reader.h"
17 
18 #include <utility>
19 #include <vector>
20 #include "tensorflow/core/framework/types.pb_text.h"
21 #include "tensorflow/core/framework/versions.h"
22 #include "tensorflow/core/lib/core/errors.h"
23 #include "tensorflow/core/lib/gtl/stl_util.h"
24 #include "tensorflow/core/lib/io/iterator.h"
25 #include "tensorflow/core/lib/io/table.h"
26 #include "tensorflow/core/lib/io/table_options.h"
27 #include "tensorflow/core/platform/env.h"
28 #include "tensorflow/core/platform/logging.h"
29 #include "tensorflow/core/platform/protobuf.h"
30 #include "tensorflow/core/public/version.h"
31 #include "tensorflow/core/util/saved_tensor_slice_util.h"
32 #include "tensorflow/core/util/tensor_slice_util.h"
33 
34 namespace tensorflow {
35 
36 namespace checkpoint {
37 
~Table()38 TensorSliceReader::Table::~Table() {}
39 
40 namespace {
41 class TensorSliceReaderTable : public TensorSliceReader::Table {
42  public:
43   // Takes ownership of 'f'.
TensorSliceReaderTable(RandomAccessFile * f,table::Table * t)44   explicit TensorSliceReaderTable(RandomAccessFile* f, table::Table* t)
45       : file_(f), table_(t) {}
46 
~TensorSliceReaderTable()47   ~TensorSliceReaderTable() override {
48     delete table_;
49     delete file_;
50   }
51 
Get(const string & key,string * value)52   bool Get(const string& key, string* value) override {
53     std::unique_ptr<table::Iterator> iter(table_->NewIterator());
54     iter->Seek(key);
55     if (iter->Valid() && iter->key() == key) {
56       StringPiece v = iter->value();
57       value->assign(v.data(), v.size());
58       return true;
59     } else {
60       return false;
61     }
62   }
63 
64  private:
65   RandomAccessFile* file_;  // Owns.
66   table::Table* table_;
67 };
68 }  // namespace
69 
OpenTableTensorSliceReader(const string & fname,TensorSliceReader::Table ** result)70 Status OpenTableTensorSliceReader(const string& fname,
71                                   TensorSliceReader::Table** result) {
72   *result = nullptr;
73   Env* env = Env::Default();
74   std::unique_ptr<RandomAccessFile> f;
75   Status s = env->NewRandomAccessFile(fname, &f);
76   if (s.ok()) {
77     uint64 file_size;
78     s = env->GetFileSize(fname, &file_size);
79     if (s.ok()) {
80       table::Options options;
81       table::Table* table;
82       s = table::Table::Open(options, f.get(), file_size, &table);
83       if (s.ok()) {
84         *result = new TensorSliceReaderTable(f.release(), table);
85         return Status::OK();
86       } else {
87         s = Status(s.code(),
88                    strings::StrCat(s.error_message(),
89                                    ": perhaps your file is in a different "
90                                    "file format and you need to use a "
91                                    "different restore operator?"));
92       }
93     }
94   }
95   LOG(WARNING) << "Could not open " << fname << ": " << s;
96   return s;
97 }
98 
TensorSliceReader(const string & filepattern)99 TensorSliceReader::TensorSliceReader(const string& filepattern)
100     : TensorSliceReader(filepattern, OpenTableTensorSliceReader,
101                         kLoadAllShards) {}
102 
TensorSliceReader(const string & filepattern,OpenTableFunction open_function)103 TensorSliceReader::TensorSliceReader(const string& filepattern,
104                                      OpenTableFunction open_function)
105     : TensorSliceReader(filepattern, std::move(open_function), kLoadAllShards) {
106 }
107 
TensorSliceReader(const string & filepattern,OpenTableFunction open_function,int preferred_shard)108 TensorSliceReader::TensorSliceReader(const string& filepattern,
109                                      OpenTableFunction open_function,
110                                      int preferred_shard)
111     : filepattern_(filepattern), open_function_(std::move(open_function)) {
112   VLOG(1) << "TensorSliceReader for " << filepattern;
113   Status s = Env::Default()->GetMatchingPaths(filepattern, &fnames_);
114   if (!s.ok()) {
115     status_ = errors::InvalidArgument(
116         "Unsuccessful TensorSliceReader constructor: "
117         "Failed to get matching files on ",
118         filepattern, ": ", s.ToString());
119     return;
120   }
121   if (fnames_.empty()) {
122     status_ = errors::NotFound(
123         "Unsuccessful TensorSliceReader constructor: "
124         "Failed to find any matching files for ",
125         filepattern);
126     return;
127   }
128   sss_.resize(fnames_.size());
129   for (size_t shard = 0; shard < fnames_.size(); ++shard) {
130     fname_to_index_.insert(std::make_pair(fnames_[shard], shard));
131   }
132   if (preferred_shard == kLoadAllShards || fnames_.size() == 1 ||
133       static_cast<size_t>(preferred_shard) >= fnames_.size()) {
134     LoadAllShards();
135   } else {
136     VLOG(1) << "Loading shard " << preferred_shard << " for " << filepattern_;
137     LoadShard(preferred_shard);
138   }
139 }
140 
LoadShard(int shard) const141 void TensorSliceReader::LoadShard(int shard) const {
142   CHECK_LT(shard, sss_.size());
143   if (sss_[shard] || !status_.ok()) {
144     return;  // Already loaded, or invalid.
145   }
146   string value;
147   SavedTensorSlices sts;
148   const string fname = fnames_[shard];
149   VLOG(1) << "Reading meta data from file " << fname << "...";
150   Table* table;
151   Status s = open_function_(fname, &table);
152   if (!s.ok()) {
153     status_ = errors::DataLoss("Unable to open table file ", fname, ": ",
154                                s.ToString());
155     return;
156   }
157   sss_[shard].reset(table);
158   if (!(table->Get(kSavedTensorSlicesKey, &value) &&
159         ParseProtoUnlimited(&sts, value))) {
160     status_ = errors::Internal(
161         "Failed to find the saved tensor slices at the beginning of the "
162         "checkpoint file: ",
163         fname);
164     return;
165   }
166   status_ = CheckVersions(sts.meta().versions(), TF_CHECKPOINT_VERSION,
167                           TF_CHECKPOINT_VERSION_MIN_PRODUCER, "Checkpoint",
168                           "checkpoint");
169   if (!status_.ok()) return;
170   for (const SavedSliceMeta& ssm : sts.meta().tensor()) {
171     TensorShape ssm_shape(ssm.shape());
172     for (const TensorSliceProto& tsp : ssm.slice()) {
173       TensorSlice ss_slice(tsp);
174       status_ = RegisterTensorSlice(ssm.name(), ssm_shape, ssm.type(), fname,
175                                     ss_slice, &tensors_);
176       if (!status_.ok()) return;
177     }
178   }
179 }
180 
LoadAllShards() const181 void TensorSliceReader::LoadAllShards() const {
182   VLOG(1) << "Loading all shards for " << filepattern_;
183   for (size_t i = 0; i < fnames_.size() && status_.ok(); ++i) {
184     LoadShard(i);
185   }
186   all_shards_loaded_ = true;
187 }
188 
FindTensorSlice(const string & name,const TensorSlice & slice,std::vector<std::pair<TensorSlice,string>> * details) const189 const TensorSliceSet* TensorSliceReader::FindTensorSlice(
190     const string& name, const TensorSlice& slice,
191     std::vector<std::pair<TensorSlice, string>>* details) const {
192   const TensorSliceSet* tss = gtl::FindPtrOrNull(tensors_, name);
193   if (tss && !tss->QueryMeta(slice, details)) {
194     return nullptr;
195   }
196   return tss;
197 }
198 
~TensorSliceReader()199 TensorSliceReader::~TensorSliceReader() { gtl::STLDeleteValues(&tensors_); }
200 
HasTensor(const string & name,TensorShape * shape,DataType * type) const201 bool TensorSliceReader::HasTensor(const string& name, TensorShape* shape,
202                                   DataType* type) const {
203   mutex_lock l(mu_);
204   const TensorSliceSet* tss = gtl::FindPtrOrNull(tensors_, name);
205   if (!tss && !all_shards_loaded_) {
206     VLOG(1) << "Did not find tensor in preferred shard, loading all shards: "
207             << name;
208     LoadAllShards();
209     tss = gtl::FindPtrOrNull(tensors_, name);
210   }
211   if (tss) {
212     if (shape) {
213       *shape = tss->shape();
214     }
215     if (type) {
216       *type = tss->type();
217     }
218     return true;
219   } else {
220     return false;
221   }
222 }
223 
GetTensor(const string & name,std::unique_ptr<tensorflow::Tensor> * out_tensor) const224 Status TensorSliceReader::GetTensor(
225     const string& name, std::unique_ptr<tensorflow::Tensor>* out_tensor) const {
226   DataType type;
227   TensorShape shape;
228   TensorSlice slice;
229   {
230     mutex_lock l(mu_);
231     const TensorSliceSet* tss = gtl::FindPtrOrNull(tensors_, name);
232     if (tss == nullptr) {
233       return errors::NotFound(name, " not found in checkpoint file");
234     }
235 
236     if (tss->Slices().size() > 1) {
237       // TODO(sherrym): Support multi-slice checkpoints.
238       return errors::Unimplemented("Sliced checkpoints are not supported");
239     }
240 
241     type = tss->type();
242     shape = tss->shape();
243     slice = tss->Slices().begin()->second.slice;
244   }
245 
246   std::unique_ptr<tensorflow::Tensor> t(new tensorflow::Tensor(type, shape));
247   bool success = false;
248 
249 #define READER_COPY(dt)                                                  \
250   case dt:                                                               \
251     success = CopySliceData(name, slice,                                 \
252                             t->flat<EnumToDataType<dt>::Type>().data()); \
253     break;
254 
255   switch (type) {
256     READER_COPY(DT_FLOAT);
257     READER_COPY(DT_DOUBLE);
258     READER_COPY(DT_INT32);
259     READER_COPY(DT_UINT8);
260     READER_COPY(DT_INT16);
261     READER_COPY(DT_INT8);
262     READER_COPY(DT_INT64);
263     READER_COPY(DT_STRING);
264     default:
265       return errors::Unimplemented("Data type not supported");
266   }
267 #undef READER_COPY
268 
269   if (!success) {
270     return errors::NotFound(name, " not found in checkpoint file");
271   }
272   std::swap(*out_tensor, t);
273 
274   return Status::OK();
275 }
276 
GetVariableToShapeMap() const277 TensorSliceReader::VarToShapeMap TensorSliceReader::GetVariableToShapeMap()
278     const {
279   VarToShapeMap name_to_shape;
280   if (status().ok()) {
281     for (auto& e : Tensors()) {
282       name_to_shape[e.first] = e.second->shape();
283     }
284   }
285   return name_to_shape;
286 }
287 
288 TensorSliceReader::VarToDataTypeMap
GetVariableToDataTypeMap() const289 TensorSliceReader::GetVariableToDataTypeMap() const {
290   VarToDataTypeMap name_to_dtype;
291   if (status().ok()) {
292     for (auto& e : Tensors()) {
293       name_to_dtype[e.first] = e.second->type();
294     }
295   }
296   return name_to_dtype;
297 }
298 
DebugString() const299 const string TensorSliceReader::DebugString() const {
300   string shape_str;
301   if (status().ok()) {
302     for (auto e : Tensors()) {
303       strings::StrAppend(&shape_str, e.first, " (",
304                          EnumName_DataType(e.second->type()), ") ",
305                          e.second->shape().DebugString());
306       // Indicates if a tensor has more than 1 slice (i.e., it's partitioned).
307       const int num_slices = e.second->Slices().size();
308       if (num_slices > 1) {
309         strings::StrAppend(&shape_str, ", ", num_slices, " slices");
310       }
311       strings::StrAppend(&shape_str, "\n");
312     }
313   }
314   return shape_str;
315 }
316 
317 }  // namespace checkpoint
318 
319 }  // namespace tensorflow
320