1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/util/tensor_slice_reader.h"
17
18 #include <utility>
19 #include <vector>
20 #include "tensorflow/core/framework/types.pb_text.h"
21 #include "tensorflow/core/framework/versions.h"
22 #include "tensorflow/core/lib/core/errors.h"
23 #include "tensorflow/core/lib/gtl/stl_util.h"
24 #include "tensorflow/core/lib/io/iterator.h"
25 #include "tensorflow/core/lib/io/table.h"
26 #include "tensorflow/core/lib/io/table_options.h"
27 #include "tensorflow/core/platform/env.h"
28 #include "tensorflow/core/platform/logging.h"
29 #include "tensorflow/core/platform/protobuf.h"
30 #include "tensorflow/core/public/version.h"
31 #include "tensorflow/core/util/saved_tensor_slice_util.h"
32 #include "tensorflow/core/util/tensor_slice_util.h"
33
34 namespace tensorflow {
35
36 namespace checkpoint {
37
~Table()38 TensorSliceReader::Table::~Table() {}
39
40 namespace {
41 class TensorSliceReaderTable : public TensorSliceReader::Table {
42 public:
43 // Takes ownership of 'f'.
TensorSliceReaderTable(RandomAccessFile * f,table::Table * t)44 explicit TensorSliceReaderTable(RandomAccessFile* f, table::Table* t)
45 : file_(f), table_(t) {}
46
~TensorSliceReaderTable()47 ~TensorSliceReaderTable() override {
48 delete table_;
49 delete file_;
50 }
51
Get(const string & key,string * value)52 bool Get(const string& key, string* value) override {
53 std::unique_ptr<table::Iterator> iter(table_->NewIterator());
54 iter->Seek(key);
55 if (iter->Valid() && iter->key() == key) {
56 StringPiece v = iter->value();
57 value->assign(v.data(), v.size());
58 return true;
59 } else {
60 return false;
61 }
62 }
63
64 private:
65 RandomAccessFile* file_; // Owns.
66 table::Table* table_;
67 };
68 } // namespace
69
OpenTableTensorSliceReader(const string & fname,TensorSliceReader::Table ** result)70 Status OpenTableTensorSliceReader(const string& fname,
71 TensorSliceReader::Table** result) {
72 *result = nullptr;
73 Env* env = Env::Default();
74 std::unique_ptr<RandomAccessFile> f;
75 Status s = env->NewRandomAccessFile(fname, &f);
76 if (s.ok()) {
77 uint64 file_size;
78 s = env->GetFileSize(fname, &file_size);
79 if (s.ok()) {
80 table::Options options;
81 table::Table* table;
82 s = table::Table::Open(options, f.get(), file_size, &table);
83 if (s.ok()) {
84 *result = new TensorSliceReaderTable(f.release(), table);
85 return Status::OK();
86 } else {
87 s = Status(s.code(),
88 strings::StrCat(s.error_message(),
89 ": perhaps your file is in a different "
90 "file format and you need to use a "
91 "different restore operator?"));
92 }
93 }
94 }
95 LOG(WARNING) << "Could not open " << fname << ": " << s;
96 return s;
97 }
98
TensorSliceReader(const string & filepattern)99 TensorSliceReader::TensorSliceReader(const string& filepattern)
100 : TensorSliceReader(filepattern, OpenTableTensorSliceReader,
101 kLoadAllShards) {}
102
TensorSliceReader(const string & filepattern,OpenTableFunction open_function)103 TensorSliceReader::TensorSliceReader(const string& filepattern,
104 OpenTableFunction open_function)
105 : TensorSliceReader(filepattern, std::move(open_function), kLoadAllShards) {
106 }
107
TensorSliceReader(const string & filepattern,OpenTableFunction open_function,int preferred_shard)108 TensorSliceReader::TensorSliceReader(const string& filepattern,
109 OpenTableFunction open_function,
110 int preferred_shard)
111 : filepattern_(filepattern), open_function_(std::move(open_function)) {
112 VLOG(1) << "TensorSliceReader for " << filepattern;
113 Status s = Env::Default()->GetMatchingPaths(filepattern, &fnames_);
114 if (!s.ok()) {
115 status_ = errors::InvalidArgument(
116 "Unsuccessful TensorSliceReader constructor: "
117 "Failed to get matching files on ",
118 filepattern, ": ", s.ToString());
119 return;
120 }
121 if (fnames_.empty()) {
122 status_ = errors::NotFound(
123 "Unsuccessful TensorSliceReader constructor: "
124 "Failed to find any matching files for ",
125 filepattern);
126 return;
127 }
128 sss_.resize(fnames_.size());
129 for (size_t shard = 0; shard < fnames_.size(); ++shard) {
130 fname_to_index_.insert(std::make_pair(fnames_[shard], shard));
131 }
132 if (preferred_shard == kLoadAllShards || fnames_.size() == 1 ||
133 static_cast<size_t>(preferred_shard) >= fnames_.size()) {
134 LoadAllShards();
135 } else {
136 VLOG(1) << "Loading shard " << preferred_shard << " for " << filepattern_;
137 LoadShard(preferred_shard);
138 }
139 }
140
LoadShard(int shard) const141 void TensorSliceReader::LoadShard(int shard) const {
142 CHECK_LT(shard, sss_.size());
143 if (sss_[shard] || !status_.ok()) {
144 return; // Already loaded, or invalid.
145 }
146 string value;
147 SavedTensorSlices sts;
148 const string fname = fnames_[shard];
149 VLOG(1) << "Reading meta data from file " << fname << "...";
150 Table* table;
151 Status s = open_function_(fname, &table);
152 if (!s.ok()) {
153 status_ = errors::DataLoss("Unable to open table file ", fname, ": ",
154 s.ToString());
155 return;
156 }
157 sss_[shard].reset(table);
158 if (!(table->Get(kSavedTensorSlicesKey, &value) &&
159 ParseProtoUnlimited(&sts, value))) {
160 status_ = errors::Internal(
161 "Failed to find the saved tensor slices at the beginning of the "
162 "checkpoint file: ",
163 fname);
164 return;
165 }
166 status_ = CheckVersions(sts.meta().versions(), TF_CHECKPOINT_VERSION,
167 TF_CHECKPOINT_VERSION_MIN_PRODUCER, "Checkpoint",
168 "checkpoint");
169 if (!status_.ok()) return;
170 for (const SavedSliceMeta& ssm : sts.meta().tensor()) {
171 TensorShape ssm_shape(ssm.shape());
172 for (const TensorSliceProto& tsp : ssm.slice()) {
173 TensorSlice ss_slice(tsp);
174 status_ = RegisterTensorSlice(ssm.name(), ssm_shape, ssm.type(), fname,
175 ss_slice, &tensors_);
176 if (!status_.ok()) return;
177 }
178 }
179 }
180
LoadAllShards() const181 void TensorSliceReader::LoadAllShards() const {
182 VLOG(1) << "Loading all shards for " << filepattern_;
183 for (size_t i = 0; i < fnames_.size() && status_.ok(); ++i) {
184 LoadShard(i);
185 }
186 all_shards_loaded_ = true;
187 }
188
FindTensorSlice(const string & name,const TensorSlice & slice,std::vector<std::pair<TensorSlice,string>> * details) const189 const TensorSliceSet* TensorSliceReader::FindTensorSlice(
190 const string& name, const TensorSlice& slice,
191 std::vector<std::pair<TensorSlice, string>>* details) const {
192 const TensorSliceSet* tss = gtl::FindPtrOrNull(tensors_, name);
193 if (tss && !tss->QueryMeta(slice, details)) {
194 return nullptr;
195 }
196 return tss;
197 }
198
~TensorSliceReader()199 TensorSliceReader::~TensorSliceReader() { gtl::STLDeleteValues(&tensors_); }
200
HasTensor(const string & name,TensorShape * shape,DataType * type) const201 bool TensorSliceReader::HasTensor(const string& name, TensorShape* shape,
202 DataType* type) const {
203 mutex_lock l(mu_);
204 const TensorSliceSet* tss = gtl::FindPtrOrNull(tensors_, name);
205 if (!tss && !all_shards_loaded_) {
206 VLOG(1) << "Did not find tensor in preferred shard, loading all shards: "
207 << name;
208 LoadAllShards();
209 tss = gtl::FindPtrOrNull(tensors_, name);
210 }
211 if (tss) {
212 if (shape) {
213 *shape = tss->shape();
214 }
215 if (type) {
216 *type = tss->type();
217 }
218 return true;
219 } else {
220 return false;
221 }
222 }
223
GetTensor(const string & name,std::unique_ptr<tensorflow::Tensor> * out_tensor) const224 Status TensorSliceReader::GetTensor(
225 const string& name, std::unique_ptr<tensorflow::Tensor>* out_tensor) const {
226 DataType type;
227 TensorShape shape;
228 TensorSlice slice;
229 {
230 mutex_lock l(mu_);
231 const TensorSliceSet* tss = gtl::FindPtrOrNull(tensors_, name);
232 if (tss == nullptr) {
233 return errors::NotFound(name, " not found in checkpoint file");
234 }
235
236 if (tss->Slices().size() > 1) {
237 // TODO(sherrym): Support multi-slice checkpoints.
238 return errors::Unimplemented("Sliced checkpoints are not supported");
239 }
240
241 type = tss->type();
242 shape = tss->shape();
243 slice = tss->Slices().begin()->second.slice;
244 }
245
246 std::unique_ptr<tensorflow::Tensor> t(new tensorflow::Tensor(type, shape));
247 bool success = false;
248
249 #define READER_COPY(dt) \
250 case dt: \
251 success = CopySliceData(name, slice, \
252 t->flat<EnumToDataType<dt>::Type>().data()); \
253 break;
254
255 switch (type) {
256 READER_COPY(DT_FLOAT);
257 READER_COPY(DT_DOUBLE);
258 READER_COPY(DT_INT32);
259 READER_COPY(DT_UINT8);
260 READER_COPY(DT_INT16);
261 READER_COPY(DT_INT8);
262 READER_COPY(DT_INT64);
263 READER_COPY(DT_STRING);
264 default:
265 return errors::Unimplemented("Data type not supported");
266 }
267 #undef READER_COPY
268
269 if (!success) {
270 return errors::NotFound(name, " not found in checkpoint file");
271 }
272 std::swap(*out_tensor, t);
273
274 return Status::OK();
275 }
276
GetVariableToShapeMap() const277 TensorSliceReader::VarToShapeMap TensorSliceReader::GetVariableToShapeMap()
278 const {
279 VarToShapeMap name_to_shape;
280 if (status().ok()) {
281 for (auto& e : Tensors()) {
282 name_to_shape[e.first] = e.second->shape();
283 }
284 }
285 return name_to_shape;
286 }
287
288 TensorSliceReader::VarToDataTypeMap
GetVariableToDataTypeMap() const289 TensorSliceReader::GetVariableToDataTypeMap() const {
290 VarToDataTypeMap name_to_dtype;
291 if (status().ok()) {
292 for (auto& e : Tensors()) {
293 name_to_dtype[e.first] = e.second->type();
294 }
295 }
296 return name_to_dtype;
297 }
298
DebugString() const299 const string TensorSliceReader::DebugString() const {
300 string shape_str;
301 if (status().ok()) {
302 for (auto e : Tensors()) {
303 strings::StrAppend(&shape_str, e.first, " (",
304 EnumName_DataType(e.second->type()), ") ",
305 e.second->shape().DebugString());
306 // Indicates if a tensor has more than 1 slice (i.e., it's partitioned).
307 const int num_slices = e.second->Slices().size();
308 if (num_slices > 1) {
309 strings::StrAppend(&shape_str, ", ", num_slices, " slices");
310 }
311 strings::StrAppend(&shape_str, "\n");
312 }
313 }
314 return shape_str;
315 }
316
317 } // namespace checkpoint
318
319 } // namespace tensorflow
320