1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/util/tensor_slice_reader.h"
17
18 #include <utility>
19 #include <vector>
20
21 #include "tensorflow/core/framework/types.pb.h"
22 #include "tensorflow/core/framework/versions.h"
23 #include "tensorflow/core/lib/core/errors.h"
24 #include "tensorflow/core/lib/io/iterator.h"
25 #include "tensorflow/core/lib/io/table.h"
26 #include "tensorflow/core/lib/io/table_options.h"
27 #include "tensorflow/core/platform/env.h"
28 #include "tensorflow/core/platform/logging.h"
29 #include "tensorflow/core/platform/protobuf.h"
30 #include "tensorflow/core/public/version.h"
31 #include "tensorflow/core/util/saved_tensor_slice_util.h"
32 #include "tensorflow/core/util/tensor_slice_util.h"
33
34 namespace tensorflow {
35
36 namespace checkpoint {
37
~Table()38 TensorSliceReader::Table::~Table() {}
39
40 namespace {
41 class TensorSliceReaderTable : public TensorSliceReader::Table {
42 public:
43 // Takes ownership of 'f'.
TensorSliceReaderTable(RandomAccessFile * f,table::Table * t)44 explicit TensorSliceReaderTable(RandomAccessFile* f, table::Table* t)
45 : file_(f), table_(t) {}
46
~TensorSliceReaderTable()47 ~TensorSliceReaderTable() override {
48 delete table_;
49 delete file_;
50 }
51
Get(const string & key,string * value)52 bool Get(const string& key, string* value) override {
53 std::unique_ptr<table::Iterator> iter(table_->NewIterator());
54 iter->Seek(key);
55 if (iter->Valid() && iter->key() == key) {
56 StringPiece v = iter->value();
57 value->assign(v.data(), v.size());
58 return true;
59 } else {
60 return false;
61 }
62 }
63
64 private:
65 RandomAccessFile* file_; // Owns.
66 table::Table* table_;
67 };
68 } // namespace
69
OpenTableTensorSliceReader(const string & fname,TensorSliceReader::Table ** result)70 Status OpenTableTensorSliceReader(const string& fname,
71 TensorSliceReader::Table** result) {
72 *result = nullptr;
73 Env* env = Env::Default();
74 std::unique_ptr<RandomAccessFile> f;
75 Status s = env->NewRandomAccessFile(fname, &f);
76 if (s.ok()) {
77 uint64 file_size;
78 s = env->GetFileSize(fname, &file_size);
79 if (s.ok()) {
80 table::Options options;
81 table::Table* table;
82 s = table::Table::Open(options, f.get(), file_size, &table);
83 if (s.ok()) {
84 *result = new TensorSliceReaderTable(f.release(), table);
85 return Status::OK();
86 } else {
87 s = Status(s.code(),
88 strings::StrCat(s.error_message(),
89 ": perhaps your file is in a different "
90 "file format and you need to use a "
91 "different restore operator?"));
92 }
93 }
94 }
95 LOG(WARNING) << "Could not open " << fname << ": " << s;
96 return s;
97 }
98
TensorSliceReader(const string & filepattern)99 TensorSliceReader::TensorSliceReader(const string& filepattern)
100 : TensorSliceReader(filepattern, OpenTableTensorSliceReader,
101 kLoadAllShards) {}
102
TensorSliceReader(const string & filepattern,OpenTableFunction open_function)103 TensorSliceReader::TensorSliceReader(const string& filepattern,
104 OpenTableFunction open_function)
105 : TensorSliceReader(filepattern, std::move(open_function), kLoadAllShards) {
106 }
107
TensorSliceReader(const string & filepattern,OpenTableFunction open_function,int preferred_shard)108 TensorSliceReader::TensorSliceReader(const string& filepattern,
109 OpenTableFunction open_function,
110 int preferred_shard)
111 : filepattern_(filepattern), open_function_(std::move(open_function)) {
112 VLOG(1) << "TensorSliceReader for " << filepattern;
113 Status s = Env::Default()->GetMatchingPaths(filepattern, &fnames_);
114 if (!s.ok()) {
115 status_ = errors::InvalidArgument(
116 "Unsuccessful TensorSliceReader constructor: "
117 "Failed to get matching files on ",
118 filepattern, ": ", s.ToString());
119 return;
120 }
121 if (fnames_.empty()) {
122 status_ = errors::NotFound(
123 "Unsuccessful TensorSliceReader constructor: "
124 "Failed to find any matching files for ",
125 filepattern);
126 return;
127 }
128 sss_.resize(fnames_.size());
129 for (size_t shard = 0; shard < fnames_.size(); ++shard) {
130 fname_to_index_.insert(std::make_pair(fnames_[shard], shard));
131 }
132 if (preferred_shard == kLoadAllShards || fnames_.size() == 1 ||
133 static_cast<size_t>(preferred_shard) >= fnames_.size()) {
134 LoadAllShards();
135 } else {
136 VLOG(1) << "Loading shard " << preferred_shard << " for " << filepattern_;
137 LoadShard(preferred_shard);
138 }
139 }
140
LoadShard(int shard) const141 void TensorSliceReader::LoadShard(int shard) const {
142 CHECK_LT(shard, sss_.size());
143 if (sss_[shard] || !status_.ok()) {
144 return; // Already loaded, or invalid.
145 }
146 string value;
147 SavedTensorSlices sts;
148 const string fname = fnames_[shard];
149 VLOG(1) << "Reading meta data from file " << fname << "...";
150 Table* table;
151 Status s = open_function_(fname, &table);
152 if (!s.ok()) {
153 status_ = errors::DataLoss("Unable to open table file ", fname, ": ",
154 s.ToString());
155 return;
156 }
157 sss_[shard].reset(table);
158 if (!(table->Get(kSavedTensorSlicesKey, &value) &&
159 ParseProtoUnlimited(&sts, value))) {
160 status_ = errors::Internal(
161 "Failed to find the saved tensor slices at the beginning of the "
162 "checkpoint file: ",
163 fname);
164 return;
165 }
166 status_ = CheckVersions(sts.meta().versions(), TF_CHECKPOINT_VERSION,
167 TF_CHECKPOINT_VERSION_MIN_PRODUCER, "Checkpoint",
168 "checkpoint");
169 if (!status_.ok()) return;
170 for (const SavedSliceMeta& ssm : sts.meta().tensor()) {
171 TensorShape ssm_shape(ssm.shape());
172 for (const TensorSliceProto& tsp : ssm.slice()) {
173 TensorSlice ss_slice(tsp);
174 status_ = RegisterTensorSlice(ssm.name(), ssm_shape, ssm.type(), fname,
175 ss_slice, &tensors_);
176 if (!status_.ok()) return;
177 }
178 }
179 }
180
LoadAllShards() const181 void TensorSliceReader::LoadAllShards() const {
182 VLOG(1) << "Loading all shards for " << filepattern_;
183 for (size_t i = 0; i < fnames_.size() && status_.ok(); ++i) {
184 LoadShard(i);
185 }
186 all_shards_loaded_ = true;
187 }
188
FindTensorSlice(const string & name,const TensorSlice & slice,std::vector<std::pair<TensorSlice,string>> * details) const189 const TensorSliceSet* TensorSliceReader::FindTensorSlice(
190 const string& name, const TensorSlice& slice,
191 std::vector<std::pair<TensorSlice, string>>* details) const {
192 const TensorSliceSet* tss = gtl::FindPtrOrNull(tensors_, name);
193 if (tss && !tss->QueryMeta(slice, details)) {
194 return nullptr;
195 }
196 return tss;
197 }
198
~TensorSliceReader()199 TensorSliceReader::~TensorSliceReader() {
200 for (auto& temp : tensors_) {
201 delete temp.second;
202 }
203 tensors_.clear();
204 }
205
HasTensor(const string & name,TensorShape * shape,DataType * type) const206 bool TensorSliceReader::HasTensor(const string& name, TensorShape* shape,
207 DataType* type) const {
208 mutex_lock l(mu_);
209 const TensorSliceSet* tss = gtl::FindPtrOrNull(tensors_, name);
210 if (!tss && !all_shards_loaded_) {
211 VLOG(1) << "Did not find tensor in preferred shard, loading all shards: "
212 << name;
213 LoadAllShards();
214 tss = gtl::FindPtrOrNull(tensors_, name);
215 }
216 if (tss) {
217 if (shape) {
218 *shape = tss->shape();
219 }
220 if (type) {
221 *type = tss->type();
222 }
223 return true;
224 } else {
225 return false;
226 }
227 }
228
GetTensor(const string & name,std::unique_ptr<tensorflow::Tensor> * out_tensor) const229 Status TensorSliceReader::GetTensor(
230 const string& name, std::unique_ptr<tensorflow::Tensor>* out_tensor) const {
231 DataType type;
232 TensorShape shape;
233 TensorSlice slice;
234 {
235 mutex_lock l(mu_);
236 const TensorSliceSet* tss = gtl::FindPtrOrNull(tensors_, name);
237 if (tss == nullptr) {
238 return errors::NotFound(name, " not found in checkpoint file");
239 }
240
241 if (tss->Slices().size() > 1) {
242 // TODO(sherrym): Support multi-slice checkpoints.
243 return errors::Unimplemented("Sliced checkpoints are not supported");
244 }
245
246 type = tss->type();
247 shape = tss->shape();
248 slice = tss->Slices().begin()->second.slice;
249 }
250
251 std::unique_ptr<tensorflow::Tensor> t(new tensorflow::Tensor(type, shape));
252 bool success = false;
253
254 #define READER_COPY(dt) \
255 case dt: \
256 success = CopySliceData(name, slice, \
257 t->flat<EnumToDataType<dt>::Type>().data()); \
258 break;
259
260 switch (type) {
261 READER_COPY(DT_FLOAT);
262 READER_COPY(DT_DOUBLE);
263 READER_COPY(DT_INT32);
264 READER_COPY(DT_UINT8);
265 READER_COPY(DT_INT16);
266 READER_COPY(DT_INT8);
267 READER_COPY(DT_INT64);
268 READER_COPY(DT_STRING);
269 default:
270 return errors::Unimplemented("Data type not supported");
271 }
272 #undef READER_COPY
273
274 if (!success) {
275 return errors::NotFound(name, " not found in checkpoint file");
276 }
277 std::swap(*out_tensor, t);
278
279 return Status::OK();
280 }
281
GetVariableToShapeMap() const282 TensorSliceReader::VarToShapeMap TensorSliceReader::GetVariableToShapeMap()
283 const {
284 VarToShapeMap name_to_shape;
285 if (status().ok()) {
286 for (auto& e : Tensors()) {
287 name_to_shape[e.first] = e.second->shape();
288 }
289 }
290 return name_to_shape;
291 }
292
293 TensorSliceReader::VarToDataTypeMap
GetVariableToDataTypeMap() const294 TensorSliceReader::GetVariableToDataTypeMap() const {
295 VarToDataTypeMap name_to_dtype;
296 if (status().ok()) {
297 for (auto& e : Tensors()) {
298 name_to_dtype[e.first] = e.second->type();
299 }
300 }
301 return name_to_dtype;
302 }
303
DebugString() const304 const string TensorSliceReader::DebugString() const {
305 string shape_str;
306 if (status().ok()) {
307 for (const auto& e : Tensors()) {
308 strings::StrAppend(&shape_str, e.first, " (",
309 DataType_Name(e.second->type()), ") ",
310 e.second->shape().DebugString());
311 // Indicates if a tensor has more than 1 slice (i.e., it's partitioned).
312 const int num_slices = e.second->Slices().size();
313 if (num_slices > 1) {
314 strings::StrAppend(&shape_str, ", ", num_slices, " slices");
315 }
316 strings::StrAppend(&shape_str, "\n");
317 }
318 }
319 return shape_str;
320 }
321
322 } // namespace checkpoint
323
324 } // namespace tensorflow
325