• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/util/tensor_slice_writer.h"
17 
18 #include <utility>
19 
20 #include "tensorflow/core/framework/versions.pb.h"
21 #include "tensorflow/core/lib/core/errors.h"
22 #include "tensorflow/core/lib/io/table_builder.h"
23 #include "tensorflow/core/lib/random/random.h"
24 #include "tensorflow/core/lib/strings/strcat.h"
25 #include "tensorflow/core/platform/env.h"
26 #include "tensorflow/core/platform/logging.h"
27 #include "tensorflow/core/public/version.h"
28 #include "tensorflow/core/util/saved_tensor_slice_util.h"
29 
30 namespace tensorflow {
31 
32 namespace checkpoint {
33 
34 namespace {
35 
36 class TableBuilder : public TensorSliceWriter::Builder {
37  public:
TableBuilder(const string & name,WritableFile * f)38   TableBuilder(const string& name, WritableFile* f) : name_(name), file_(f) {
39     table::Options option;
40     option.compression = table::kNoCompression;
41     builder_.reset(new table::TableBuilder(option, f));
42   }
Add(StringPiece key,StringPiece val)43   void Add(StringPiece key, StringPiece val) override {
44     builder_->Add(key, val);
45   }
Finish(int64_t * file_size)46   Status Finish(int64_t* file_size) override {
47     *file_size = -1;
48     Status s = builder_->Finish();
49     if (s.ok()) {
50       s = file_->Close();
51       if (s.ok()) {
52         *file_size = builder_->FileSize();
53       }
54     }
55     if (!s.ok()) {
56       s = errors::Internal("Error writing (tmp) checkpoint file: ", name_, ": ",
57                            s.error_message());
58     }
59     builder_.reset();
60     file_.reset();
61     return s;
62   }
63 
64  private:
65   string name_;
66   std::unique_ptr<WritableFile> file_;
67   std::unique_ptr<table::TableBuilder> builder_;
68 };
69 }  // anonymous namespace
70 
CreateTableTensorSliceBuilder(const string & name,TensorSliceWriter::Builder ** builder)71 Status CreateTableTensorSliceBuilder(const string& name,
72                                      TensorSliceWriter::Builder** builder) {
73   *builder = nullptr;
74   std::unique_ptr<WritableFile> f;
75   Status s = Env::Default()->NewWritableFile(name, &f);
76   if (s.ok()) {
77     *builder = new TableBuilder(name, f.release());
78     return OkStatus();
79   } else {
80     return s;
81   }
82 }
83 
TensorSliceWriter(const string & filename,CreateBuilderFunction create_builder)84 TensorSliceWriter::TensorSliceWriter(const string& filename,
85                                      CreateBuilderFunction create_builder)
86     : filename_(filename),
87       create_builder_(std::move(create_builder)),
88       tmpname_(strings::StrCat(filename, ".tempstate", random::New64())),
89       slices_(0) {
90   VersionDef* versions = sts_.mutable_meta()->mutable_versions();
91   versions->set_producer(TF_CHECKPOINT_VERSION);
92   versions->set_min_consumer(TF_CHECKPOINT_VERSION_MIN_CONSUMER);
93 }
94 
Finish()95 Status TensorSliceWriter::Finish() {
96   Builder* b;
97   Status s = create_builder_(tmpname_, &b);
98   if (!s.ok()) {
99     delete b;
100     return s;
101   }
102   std::unique_ptr<Builder> builder(b);
103 
104   // We save the saved tensor slice metadata as the first element.
105   string meta;
106   sts_.AppendToString(&meta);
107   builder->Add(kSavedTensorSlicesKey, meta);
108 
109   // Go through all the data and add them
110   for (const auto& x : data_) {
111     builder->Add(x.first, x.second);
112   }
113 
114   int64_t file_size;
115   s = builder->Finish(&file_size);
116   // We need to rename the file to the proper name
117   if (s.ok()) {
118     s = Env::Default()->RenameFile(tmpname_, filename_);
119     if (s.ok()) {
120       VLOG(1) << "Written " << slices_ << " slices for "
121               << sts_.meta().tensor_size() << " tensors (" << file_size
122               << " bytes) to " << filename_;
123     } else {
124       LOG(ERROR) << "Failed to rename file " << tmpname_ << " to " << filename_;
125     }
126   } else {
127     Env::Default()->DeleteFile(tmpname_).IgnoreError();
128   }
129   return s;
130 }
131 
132 /* static */
MaxBytesPerElement(DataType dt)133 size_t TensorSliceWriter::MaxBytesPerElement(DataType dt) {
134   size_t max_bytes_per_element =
135       TensorSliceWriter::MaxBytesPerElementOrZero(dt);
136   if (max_bytes_per_element == 0) {
137     LOG(FATAL) << "MaxBytesPerElement not implemented for dtype: " << dt;
138   }
139   return max_bytes_per_element;
140 }
141 
142 /* static */
MaxBytesPerElementOrZero(DataType dt)143 size_t TensorSliceWriter::MaxBytesPerElementOrZero(DataType dt) {
144   switch (dt) {
145     case DT_FLOAT:
146       return 4;
147     case DT_DOUBLE:
148       return 8;
149     case DT_INT32:
150       return 10;
151     case DT_UINT8:
152       return 2;
153     case DT_INT16:
154       return 10;
155     case DT_INT8:
156       return 10;
157     case DT_COMPLEX64:
158       return 8;
159     case DT_INT64:
160       return 10;
161     case DT_BOOL:
162       return 1;
163     case DT_QINT8:
164       return 10;
165     case DT_QUINT8:
166       return 2;
167     case DT_QINT32:
168       return 10;
169     case DT_QINT16:
170       return 10;
171     case DT_QUINT16:
172       return 3;
173     case DT_UINT16:
174       return 3;
175     case DT_COMPLEX128:
176       return 16;
177     case DT_HALF:
178       return 3;
179     case DT_INVALID:
180     case DT_STRING:
181     case DT_BFLOAT16:
182     default:
183       return 0;
184   }
185 }
186 
187 template <>
SaveData(const tstring * data,int64_t num_elements,SavedSlice * ss)188 Status TensorSliceWriter::SaveData(const tstring* data, int64_t num_elements,
189                                    SavedSlice* ss) {
190   size_t size_bound = ss->ByteSize() + kTensorProtoHeaderBytes +
191                       (num_elements * MaxBytesPerElement(DT_INT32));
192   for (int64_t i = 0; i < num_elements; ++i) {
193     size_bound += data[i].size();
194   }
195   if (size_bound > kMaxMessageBytes) {
196     return errors::InvalidArgument(
197         "Tensor slice is too large to serialize (conservative estimate: ",
198         size_bound, " bytes)");
199   }
200   Fill(data, num_elements, ss->mutable_data());
201   DCHECK_GE(ss->ByteSize(), 0);
202   DCHECK_LE(ss->ByteSize(), size_bound);
203   return OkStatus();
204 }
205 
206 }  // namespace checkpoint
207 
208 }  // namespace tensorflow
209