• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/util/tensor_slice_writer.h"
17 
18 #include <utility>
19 
20 #include "tensorflow/core/framework/versions.pb.h"
21 #include "tensorflow/core/lib/core/errors.h"
22 #include "tensorflow/core/lib/io/table_builder.h"
23 #include "tensorflow/core/lib/random/random.h"
24 #include "tensorflow/core/lib/strings/strcat.h"
25 #include "tensorflow/core/platform/env.h"
26 #include "tensorflow/core/platform/logging.h"
27 #include "tensorflow/core/public/version.h"
28 #include "tensorflow/core/util/saved_tensor_slice_util.h"
29 
30 namespace tensorflow {
31 
32 namespace checkpoint {
33 
34 namespace {
35 
36 class TableBuilder : public TensorSliceWriter::Builder {
37  public:
TableBuilder(const string & name,WritableFile * f)38   TableBuilder(const string& name, WritableFile* f) : name_(name), file_(f) {
39     table::Options option;
40     option.compression = table::kNoCompression;
41     builder_.reset(new table::TableBuilder(option, f));
42   }
Add(StringPiece key,StringPiece val)43   void Add(StringPiece key, StringPiece val) override {
44     builder_->Add(key, val);
45   }
Finish(int64 * file_size)46   Status Finish(int64* file_size) override {
47     *file_size = -1;
48     Status s = builder_->Finish();
49     if (s.ok()) {
50       s = file_->Close();
51       if (s.ok()) {
52         *file_size = builder_->FileSize();
53       }
54     }
55     if (!s.ok()) {
56       s = errors::Internal("Error writing (tmp) checkpoint file: ", name_, ": ",
57                            s.ToString());
58     }
59     builder_.reset();
60     file_.reset();
61     return s;
62   }
63 
64  private:
65   string name_;
66   std::unique_ptr<WritableFile> file_;
67   std::unique_ptr<table::TableBuilder> builder_;
68 };
69 }  // anonymous namespace
70 
CreateTableTensorSliceBuilder(const string & name,TensorSliceWriter::Builder ** builder)71 Status CreateTableTensorSliceBuilder(const string& name,
72                                      TensorSliceWriter::Builder** builder) {
73   *builder = nullptr;
74   std::unique_ptr<WritableFile> f;
75   Status s = Env::Default()->NewWritableFile(name, &f);
76   if (s.ok()) {
77     *builder = new TableBuilder(name, f.release());
78     return Status::OK();
79   } else {
80     return s;
81   }
82 }
83 
TensorSliceWriter(const string & filename,CreateBuilderFunction create_builder)84 TensorSliceWriter::TensorSliceWriter(const string& filename,
85                                      CreateBuilderFunction create_builder)
86     : filename_(filename),
87       create_builder_(std::move(create_builder)),
88       tmpname_(strings::StrCat(filename, ".tempstate", random::New64())),
89       slices_(0) {
90   VersionDef* versions = sts_.mutable_meta()->mutable_versions();
91   versions->set_producer(TF_CHECKPOINT_VERSION);
92   versions->set_min_consumer(TF_CHECKPOINT_VERSION_MIN_CONSUMER);
93 }
94 
Finish()95 Status TensorSliceWriter::Finish() {
96   Builder* b;
97   Status s = create_builder_(tmpname_, &b);
98   if (!s.ok()) {
99     delete b;
100     return s;
101   }
102   std::unique_ptr<Builder> builder(b);
103 
104   // We save the saved tensor slice metadata as the first element.
105   string meta;
106   sts_.AppendToString(&meta);
107   builder->Add(kSavedTensorSlicesKey, meta);
108 
109   // Go through all the data and add them
110   for (const auto& x : data_) {
111     builder->Add(x.first, x.second);
112   }
113 
114   int64 file_size;
115   s = builder->Finish(&file_size);
116   // We need to rename the file to the proper name
117   if (s.ok()) {
118     s = Env::Default()->RenameFile(tmpname_, filename_);
119     if (s.ok()) {
120       VLOG(1) << "Written " << slices_ << " slices for "
121               << sts_.meta().tensor_size() << " tensors (" << file_size
122               << " bytes) to " << filename_;
123     } else {
124       LOG(ERROR) << "Failed to rename file " << tmpname_ << " to " << filename_;
125     }
126   } else {
127     Env::Default()->DeleteFile(tmpname_).IgnoreError();
128   }
129   return s;
130 }
131 
132 /* static */
MaxBytesPerElement(DataType dt)133 size_t TensorSliceWriter::MaxBytesPerElement(DataType dt) {
134   switch (dt) {
135     case DT_FLOAT:
136       return 4;
137     case DT_DOUBLE:
138       return 8;
139     case DT_INT32:
140       return 10;
141     case DT_UINT8:
142       return 2;
143     case DT_INT16:
144       return 10;
145     case DT_INT8:
146       return 10;
147     case DT_COMPLEX64:
148       return 8;
149     case DT_INT64:
150       return 10;
151     case DT_BOOL:
152       return 1;
153     case DT_QINT8:
154       return 10;
155     case DT_QUINT8:
156       return 2;
157     case DT_QINT32:
158       return 10;
159     case DT_QINT16:
160       return 10;
161     case DT_QUINT16:
162       return 3;
163     case DT_UINT16:
164       return 3;
165     case DT_COMPLEX128:
166       return 16;
167     case DT_HALF:
168       return 3;
169     case DT_INVALID:
170     case DT_STRING:
171     case DT_BFLOAT16:
172     default:
173       LOG(FATAL) << "MaxBytesPerElement not implemented for dtype: " << dt;
174   }
175   return 0;
176 }
177 
178 template <>
SaveData(const tstring * data,int64 num_elements,SavedSlice * ss)179 Status TensorSliceWriter::SaveData(const tstring* data, int64 num_elements,
180                                    SavedSlice* ss) {
181   size_t size_bound = ss->ByteSize() + kTensorProtoHeaderBytes +
182                       (num_elements * MaxBytesPerElement(DT_INT32));
183   for (int64 i = 0; i < num_elements; ++i) {
184     size_bound += data[i].size();
185   }
186   if (size_bound > kMaxMessageBytes) {
187     return errors::InvalidArgument(
188         "Tensor slice is too large to serialize (conservative estimate: ",
189         size_bound, " bytes)");
190   }
191   Fill(data, num_elements, ss->mutable_data());
192   DCHECK_GE(ss->ByteSize(), 0);
193   DCHECK_LE(ss->ByteSize(), size_bound);
194   return Status::OK();
195 }
196 
197 }  // namespace checkpoint
198 
199 }  // namespace tensorflow
200