• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/util/tensor_slice_reader.h"
17 
18 #include <utility>
19 #include <vector>
20 
21 #include "tensorflow/core/framework/types.pb.h"
22 #include "tensorflow/core/framework/versions.h"
23 #include "tensorflow/core/lib/core/errors.h"
24 #include "tensorflow/core/lib/io/iterator.h"
25 #include "tensorflow/core/lib/io/table.h"
26 #include "tensorflow/core/lib/io/table_options.h"
27 #include "tensorflow/core/platform/env.h"
28 #include "tensorflow/core/platform/logging.h"
29 #include "tensorflow/core/platform/protobuf.h"
30 #include "tensorflow/core/public/version.h"
31 #include "tensorflow/core/util/saved_tensor_slice_util.h"
32 #include "tensorflow/core/util/tensor_slice_util.h"
33 
34 namespace tensorflow {
35 
36 namespace checkpoint {
37 
~Table()38 TensorSliceReader::Table::~Table() {}
39 
40 namespace {
41 class TensorSliceReaderTable : public TensorSliceReader::Table {
42  public:
43   // Takes ownership of 'f'.
TensorSliceReaderTable(RandomAccessFile * f,table::Table * t)44   explicit TensorSliceReaderTable(RandomAccessFile* f, table::Table* t)
45       : file_(f), table_(t) {}
46 
~TensorSliceReaderTable()47   ~TensorSliceReaderTable() override {
48     delete table_;
49     delete file_;
50   }
51 
Get(const string & key,string * value)52   bool Get(const string& key, string* value) override {
53     std::unique_ptr<table::Iterator> iter(table_->NewIterator());
54     iter->Seek(key);
55     if (iter->Valid() && iter->key() == key) {
56       StringPiece v = iter->value();
57       value->assign(v.data(), v.size());
58       return true;
59     } else {
60       return false;
61     }
62   }
63 
64  private:
65   RandomAccessFile* file_;  // Owns.
66   table::Table* table_;
67 };
68 }  // namespace
69 
OpenTableTensorSliceReader(const string & fname,TensorSliceReader::Table ** result)70 Status OpenTableTensorSliceReader(const string& fname,
71                                   TensorSliceReader::Table** result) {
72   *result = nullptr;
73   Env* env = Env::Default();
74   std::unique_ptr<RandomAccessFile> f;
75   Status s = env->NewRandomAccessFile(fname, &f);
76   if (s.ok()) {
77     uint64 file_size;
78     s = env->GetFileSize(fname, &file_size);
79     if (s.ok()) {
80       table::Options options;
81       table::Table* table;
82       s = table::Table::Open(options, f.get(), file_size, &table);
83       if (s.ok()) {
84         *result = new TensorSliceReaderTable(f.release(), table);
85         return Status::OK();
86       } else {
87         s = Status(s.code(),
88                    strings::StrCat(s.error_message(),
89                                    ": perhaps your file is in a different "
90                                    "file format and you need to use a "
91                                    "different restore operator?"));
92       }
93     }
94   }
95   LOG(WARNING) << "Could not open " << fname << ": " << s;
96   return s;
97 }
98 
TensorSliceReader(const string & filepattern)99 TensorSliceReader::TensorSliceReader(const string& filepattern)
100     : TensorSliceReader(filepattern, OpenTableTensorSliceReader,
101                         kLoadAllShards) {}
102 
TensorSliceReader(const string & filepattern,OpenTableFunction open_function)103 TensorSliceReader::TensorSliceReader(const string& filepattern,
104                                      OpenTableFunction open_function)
105     : TensorSliceReader(filepattern, std::move(open_function), kLoadAllShards) {
106 }
107 
TensorSliceReader(const string & filepattern,OpenTableFunction open_function,int preferred_shard)108 TensorSliceReader::TensorSliceReader(const string& filepattern,
109                                      OpenTableFunction open_function,
110                                      int preferred_shard)
111     : filepattern_(filepattern), open_function_(std::move(open_function)) {
112   VLOG(1) << "TensorSliceReader for " << filepattern;
113   Status s = Env::Default()->GetMatchingPaths(filepattern, &fnames_);
114   if (!s.ok()) {
115     status_ = errors::InvalidArgument(
116         "Unsuccessful TensorSliceReader constructor: "
117         "Failed to get matching files on ",
118         filepattern, ": ", s.ToString());
119     return;
120   }
121   if (fnames_.empty()) {
122     status_ = errors::NotFound(
123         "Unsuccessful TensorSliceReader constructor: "
124         "Failed to find any matching files for ",
125         filepattern);
126     return;
127   }
128   sss_.resize(fnames_.size());
129   for (size_t shard = 0; shard < fnames_.size(); ++shard) {
130     fname_to_index_.insert(std::make_pair(fnames_[shard], shard));
131   }
132   if (preferred_shard == kLoadAllShards || fnames_.size() == 1 ||
133       static_cast<size_t>(preferred_shard) >= fnames_.size()) {
134     LoadAllShards();
135   } else {
136     VLOG(1) << "Loading shard " << preferred_shard << " for " << filepattern_;
137     LoadShard(preferred_shard);
138   }
139 }
140 
LoadShard(int shard) const141 void TensorSliceReader::LoadShard(int shard) const {
142   CHECK_LT(shard, sss_.size());
143   if (sss_[shard] || !status_.ok()) {
144     return;  // Already loaded, or invalid.
145   }
146   string value;
147   SavedTensorSlices sts;
148   const string fname = fnames_[shard];
149   VLOG(1) << "Reading meta data from file " << fname << "...";
150   Table* table;
151   Status s = open_function_(fname, &table);
152   if (!s.ok()) {
153     status_ = errors::DataLoss("Unable to open table file ", fname, ": ",
154                                s.ToString());
155     return;
156   }
157   sss_[shard].reset(table);
158   if (!(table->Get(kSavedTensorSlicesKey, &value) &&
159         ParseProtoUnlimited(&sts, value))) {
160     status_ = errors::Internal(
161         "Failed to find the saved tensor slices at the beginning of the "
162         "checkpoint file: ",
163         fname);
164     return;
165   }
166   status_ = CheckVersions(sts.meta().versions(), TF_CHECKPOINT_VERSION,
167                           TF_CHECKPOINT_VERSION_MIN_PRODUCER, "Checkpoint",
168                           "checkpoint");
169   if (!status_.ok()) return;
170   for (const SavedSliceMeta& ssm : sts.meta().tensor()) {
171     TensorShape ssm_shape(ssm.shape());
172     for (const TensorSliceProto& tsp : ssm.slice()) {
173       TensorSlice ss_slice(tsp);
174       status_ = RegisterTensorSlice(ssm.name(), ssm_shape, ssm.type(), fname,
175                                     ss_slice, &tensors_);
176       if (!status_.ok()) return;
177     }
178   }
179 }
180 
LoadAllShards() const181 void TensorSliceReader::LoadAllShards() const {
182   VLOG(1) << "Loading all shards for " << filepattern_;
183   for (size_t i = 0; i < fnames_.size() && status_.ok(); ++i) {
184     LoadShard(i);
185   }
186   all_shards_loaded_ = true;
187 }
188 
FindTensorSlice(const string & name,const TensorSlice & slice,std::vector<std::pair<TensorSlice,string>> * details) const189 const TensorSliceSet* TensorSliceReader::FindTensorSlice(
190     const string& name, const TensorSlice& slice,
191     std::vector<std::pair<TensorSlice, string>>* details) const {
192   const TensorSliceSet* tss = gtl::FindPtrOrNull(tensors_, name);
193   if (tss && !tss->QueryMeta(slice, details)) {
194     return nullptr;
195   }
196   return tss;
197 }
198 
~TensorSliceReader()199 TensorSliceReader::~TensorSliceReader() {
200   for (auto& temp : tensors_) {
201     delete temp.second;
202   }
203   tensors_.clear();
204 }
205 
HasTensor(const string & name,TensorShape * shape,DataType * type) const206 bool TensorSliceReader::HasTensor(const string& name, TensorShape* shape,
207                                   DataType* type) const {
208   mutex_lock l(mu_);
209   const TensorSliceSet* tss = gtl::FindPtrOrNull(tensors_, name);
210   if (!tss && !all_shards_loaded_) {
211     VLOG(1) << "Did not find tensor in preferred shard, loading all shards: "
212             << name;
213     LoadAllShards();
214     tss = gtl::FindPtrOrNull(tensors_, name);
215   }
216   if (tss) {
217     if (shape) {
218       *shape = tss->shape();
219     }
220     if (type) {
221       *type = tss->type();
222     }
223     return true;
224   } else {
225     return false;
226   }
227 }
228 
GetTensor(const string & name,std::unique_ptr<tensorflow::Tensor> * out_tensor) const229 Status TensorSliceReader::GetTensor(
230     const string& name, std::unique_ptr<tensorflow::Tensor>* out_tensor) const {
231   DataType type;
232   TensorShape shape;
233   TensorSlice slice;
234   {
235     mutex_lock l(mu_);
236     const TensorSliceSet* tss = gtl::FindPtrOrNull(tensors_, name);
237     if (tss == nullptr) {
238       return errors::NotFound(name, " not found in checkpoint file");
239     }
240 
241     if (tss->Slices().size() > 1) {
242       // TODO(sherrym): Support multi-slice checkpoints.
243       return errors::Unimplemented("Sliced checkpoints are not supported");
244     }
245 
246     type = tss->type();
247     shape = tss->shape();
248     slice = tss->Slices().begin()->second.slice;
249   }
250 
251   std::unique_ptr<tensorflow::Tensor> t(new tensorflow::Tensor(type, shape));
252   bool success = false;
253 
254 #define READER_COPY(dt)                                                  \
255   case dt:                                                               \
256     success = CopySliceData(name, slice,                                 \
257                             t->flat<EnumToDataType<dt>::Type>().data()); \
258     break;
259 
260   switch (type) {
261     READER_COPY(DT_FLOAT);
262     READER_COPY(DT_DOUBLE);
263     READER_COPY(DT_INT32);
264     READER_COPY(DT_UINT8);
265     READER_COPY(DT_INT16);
266     READER_COPY(DT_INT8);
267     READER_COPY(DT_INT64);
268     READER_COPY(DT_STRING);
269     default:
270       return errors::Unimplemented("Data type not supported");
271   }
272 #undef READER_COPY
273 
274   if (!success) {
275     return errors::NotFound(name, " not found in checkpoint file");
276   }
277   std::swap(*out_tensor, t);
278 
279   return Status::OK();
280 }
281 
GetVariableToShapeMap() const282 TensorSliceReader::VarToShapeMap TensorSliceReader::GetVariableToShapeMap()
283     const {
284   VarToShapeMap name_to_shape;
285   if (status().ok()) {
286     for (auto& e : Tensors()) {
287       name_to_shape[e.first] = e.second->shape();
288     }
289   }
290   return name_to_shape;
291 }
292 
293 TensorSliceReader::VarToDataTypeMap
GetVariableToDataTypeMap() const294 TensorSliceReader::GetVariableToDataTypeMap() const {
295   VarToDataTypeMap name_to_dtype;
296   if (status().ok()) {
297     for (auto& e : Tensors()) {
298       name_to_dtype[e.first] = e.second->type();
299     }
300   }
301   return name_to_dtype;
302 }
303 
DebugString() const304 const string TensorSliceReader::DebugString() const {
305   string shape_str;
306   if (status().ok()) {
307     for (const auto& e : Tensors()) {
308       strings::StrAppend(&shape_str, e.first, " (",
309                          DataType_Name(e.second->type()), ") ",
310                          e.second->shape().DebugString());
311       // Indicates if a tensor has more than 1 slice (i.e., it's partitioned).
312       const int num_slices = e.second->Slices().size();
313       if (num_slices > 1) {
314         strings::StrAppend(&shape_str, ", ", num_slices, " slices");
315       }
316       strings::StrAppend(&shape_str, "\n");
317     }
318   }
319   return shape_str;
320 }
321 
322 }  // namespace checkpoint
323 
324 }  // namespace tensorflow
325