1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/data/serialization_utils.h"
17
18 #include <string>
19 #include <utility>
20
21 #include "tensorflow/core/common_runtime/graph_constructor.h"
22 #include "tensorflow/core/common_runtime/graph_runner.h"
23 #include "tensorflow/core/data/dataset_utils.h"
24 #include "tensorflow/core/framework/dataset.h"
25 #include "tensorflow/core/framework/function.h"
26 #include "tensorflow/core/graph/graph_def_builder.h"
27
28 namespace tensorflow {
29 namespace data {
30 namespace {
31
32 constexpr char kDelimiter[] = "@@";
33 constexpr char kComponent[] = "component";
34 constexpr char kNumComponents[] = "num_components";
35 constexpr char kNumElements[] = "num_elements";
36 constexpr char kIsDataset[] = ".is_dataset";
37 constexpr char kOutputNode[] = ".output_node";
38
39 // We assume that all keys are of the form <iterator_prefix>:<name>. We extract
40 // the iterator name by getting rid of everything post the final colon.
GetIteratorName(StringPiece key,string * name)41 Status GetIteratorName(StringPiece key, string* name) {
42 if (!str_util::StartsWith(key, data::kFullNameRandomHex)) {
43 return errors::InvalidArgument("Save key: ", key,
44 " not generated using full_name.");
45 }
46 std::vector<string> split_keys = str_util::Split(key, data::kPipe);
47 if (split_keys.size() != 2) {
48 return errors::InvalidArgument("Save key: ", key,
49 " not generated using full_name.");
50 }
51 string real_key = split_keys[1];
52 const int pos = real_key.rfind(kColon);
53 *name = real_key.substr(0, pos);
54 return Status::OK();
55 }
56
FromGraphDef(FunctionLibraryRuntime * flr,const GraphDef & graph_def,const std::vector<std::pair<string,Tensor>> & input_list,const string & output_node,Tensor * result)57 Status FromGraphDef(FunctionLibraryRuntime* flr, const GraphDef& graph_def,
58 const std::vector<std::pair<string, Tensor>>& input_list,
59 const string& output_node, Tensor* result) {
60 FunctionLibraryRuntime* cloned_flr = nullptr;
61 std::unique_ptr<ProcessFunctionLibraryRuntime> pflr = nullptr;
62 std::unique_ptr<FunctionLibraryDefinition> lib_def = nullptr;
63 TF_RETURN_IF_ERROR(flr->Clone(&lib_def, &pflr, &cloned_flr, true));
64 TF_RETURN_IF_ERROR(AddToFunctionLibrary(lib_def.get(), graph_def.library()));
65 Graph graph(OpRegistry::Global());
66 TF_RETURN_IF_ERROR(ImportGraphDef({}, graph_def, &graph, nullptr));
67 std::vector<Tensor> outputs;
68 GraphRunner graph_runner(cloned_flr->device());
69 TF_RETURN_IF_ERROR(graph_runner.Run(&graph, cloned_flr, input_list,
70 {output_node}, &outputs));
71 *result = outputs[0];
72 return Status::OK();
73 }
74
75 // FindStatefulOps searches `graph_def` for all of its stateful ops storing
76 // their names in `stateful_op_names`.
FindStatefulOps(const GraphDef & graph_def,std::vector<string> * stateful_op_names)77 Status FindStatefulOps(const GraphDef& graph_def,
78 std::vector<string>* stateful_op_names) {
79 FunctionLibraryDefinition lib_def(OpRegistry::Global(), graph_def.library());
80
81 // Iterate over all nodes in the graph.
82 for (const auto& node : graph_def.node()) {
83 // Each Dataset graph has a _Retval op in the end which is marked stateful
84 if (node.op() == FunctionLibraryDefinition::kRetOp) continue;
85 if (!IsNodeStateful(lib_def, node).ok()) {
86 stateful_op_names->push_back(node.op());
87 }
88 }
89
90 // Iterate over all functions.
91 for (const auto& fdef : graph_def.library().function()) {
92 if (!fdef.signature().is_stateful()) continue;
93 for (const auto& node : fdef.node_def()) {
94 if (!IsNodeStateful(lib_def, node).ok()) {
95 stateful_op_names->push_back(
96 absl::StrCat(node.op(), " in function: ", fdef.signature().name()));
97 }
98 }
99 }
100 return Status::OK();
101 }
102
103 } // namespace
104
ReadElementsFromCheckpoint(IteratorContext * ctx,IteratorStateReader * reader,StringPiece key_prefix,std::vector<std::vector<Tensor>> * elements)105 Status ReadElementsFromCheckpoint(IteratorContext* ctx,
106 IteratorStateReader* reader,
107 StringPiece key_prefix,
108 std::vector<std::vector<Tensor>>* elements) {
109 int64_t num_elements;
110 TF_RETURN_IF_ERROR(
111 reader->ReadScalar(key_prefix, kNumElements, &num_elements));
112 DCHECK(elements->empty());
113 elements->reserve(num_elements);
114 for (int i = 0; i < num_elements; ++i) {
115 std::string element_prefix = absl::StrCat(key_prefix, "::", i);
116 int64_t num_components;
117 TF_RETURN_IF_ERROR(
118 reader->ReadScalar(element_prefix, kNumComponents, &num_components));
119 elements->emplace_back();
120 std::vector<Tensor>& element = elements->at(i);
121 element.reserve(num_components);
122 for (int j = 0; j < num_components; ++j) {
123 element.emplace_back();
124 TF_RETURN_IF_ERROR(reader->ReadTensor(
125 ctx->flr(), element_prefix, absl::StrCat(kComponent, "[", j, "]"),
126 &element.back()));
127 }
128 }
129 return Status::OK();
130 }
131
WriteElementsToCheckpoint(IteratorStateWriter * writer,StringPiece key_prefix,const std::vector<std::vector<Tensor>> & elements)132 Status WriteElementsToCheckpoint(
133 IteratorStateWriter* writer, StringPiece key_prefix,
134 const std::vector<std::vector<Tensor>>& elements) {
135 TF_RETURN_IF_ERROR(
136 writer->WriteScalar(key_prefix, kNumElements, elements.size()));
137 for (int i = 0; i < elements.size(); ++i) {
138 const std::vector<Tensor>& element = elements[i];
139 std::string element_prefix = absl::StrCat(key_prefix, "::", i);
140 TF_RETURN_IF_ERROR(
141 writer->WriteScalar(element_prefix, kNumComponents, element.size()));
142 for (int j = 0; j < elements[i].size(); ++j) {
143 TF_RETURN_IF_ERROR(writer->WriteTensor(
144 element_prefix, absl::StrCat(kComponent, "[", j, "]"), element[j]));
145 }
146 }
147 return Status::OK();
148 }
149
VariantTensorDataReader(const std::vector<const tensorflow::VariantTensorData * > & data)150 VariantTensorDataReader::VariantTensorDataReader(
151 const std::vector<const tensorflow::VariantTensorData*>& data) {
152 for (const auto& d : data) {
153 string metadata;
154 d->get_metadata(&metadata);
155 auto keys = str_util::Split(metadata, kDelimiter, str_util::SkipEmpty());
156 const string name = keys[0];
157 data_[name] = d;
158 map_[name] = std::map<string, size_t>();
159 for (size_t i = 1; i < keys.size(); ++i) {
160 map_[name][keys[i]] = i - 1;
161 }
162 }
163 }
164
ReadScalar(StringPiece key,int64 * val) const165 Status VariantTensorDataReader::ReadScalar(StringPiece key, int64* val) const {
166 string name;
167 TF_RETURN_IF_ERROR(GetIteratorName(key, &name));
168 return ReadScalar(name, key, val);
169 }
170
ReadScalar(StringPiece name,StringPiece key,int64 * val) const171 Status VariantTensorDataReader::ReadScalar(StringPiece name, StringPiece key,
172 int64* val) const {
173 return ReadScalarInternal(name, key, val);
174 }
175
ReadScalar(StringPiece key,tstring * val) const176 Status VariantTensorDataReader::ReadScalar(StringPiece key,
177 tstring* val) const {
178 string name;
179 TF_RETURN_IF_ERROR(GetIteratorName(key, &name));
180 return ReadScalar(name, key, val);
181 }
182
ReadScalar(StringPiece name,StringPiece key,tstring * val) const183 Status VariantTensorDataReader::ReadScalar(StringPiece name, StringPiece key,
184 tstring* val) const {
185 return ReadScalarInternal(name, key, val);
186 }
187
ReadTensor(StringPiece key,Tensor * val) const188 Status VariantTensorDataReader::ReadTensor(StringPiece key, Tensor* val) const {
189 string name;
190 TF_RETURN_IF_ERROR(GetIteratorName(key, &name));
191 return ReadTensor(name, key, val);
192 }
193
ReadTensor(FunctionLibraryRuntime * flr,StringPiece key,Tensor * val) const194 Status VariantTensorDataReader::ReadTensor(FunctionLibraryRuntime* flr,
195 StringPiece key, Tensor* val) const {
196 string name;
197 TF_RETURN_IF_ERROR(GetIteratorName(key, &name));
198 return ReadTensorInternal(flr, name, key, val);
199 }
200
ReadTensor(StringPiece name,StringPiece key,Tensor * val) const201 Status VariantTensorDataReader::ReadTensor(StringPiece name, StringPiece key,
202 Tensor* val) const {
203 return ReadTensor(/*flr=*/nullptr, name, key, val);
204 }
205
ReadTensor(FunctionLibraryRuntime * flr,StringPiece name,StringPiece key,Tensor * val) const206 Status VariantTensorDataReader::ReadTensor(FunctionLibraryRuntime* flr,
207 StringPiece name, StringPiece key,
208 Tensor* val) const {
209 return ReadTensorInternal(flr, name, key, val);
210 }
211
Contains(StringPiece key) const212 bool VariantTensorDataReader::Contains(StringPiece key) const {
213 string name;
214 if (!GetIteratorName(key, &name).ok()) {
215 return false;
216 }
217 return Contains(name, key);
218 }
219
Contains(StringPiece n,StringPiece key) const220 bool VariantTensorDataReader::Contains(StringPiece n, StringPiece key) const {
221 string name(n);
222 auto it = map_.find(name);
223 if (it == map_.end()) {
224 return false;
225 }
226 const auto& bucket = it->second;
227 return bucket.find(string(key)) != bucket.end();
228 }
229
230 template <typename T>
ReadScalarInternal(StringPiece n,StringPiece key,T * val) const231 Status VariantTensorDataReader::ReadScalarInternal(StringPiece n,
232 StringPiece key,
233 T* val) const {
234 string name(n);
235 auto it = map_.find(name);
236 if (it == map_.end()) {
237 return errors::NotFound(name);
238 }
239 const auto& bucket = it->second;
240 auto key_it = bucket.find(string(key));
241 if (key_it == bucket.end()) {
242 return errors::NotFound(key);
243 }
244 *val = data_.at(name)->tensors(key_it->second).scalar<T>()();
245 return Status::OK();
246 }
247
ReadTensorInternal(FunctionLibraryRuntime * flr,StringPiece n,StringPiece key,Tensor * val) const248 Status VariantTensorDataReader::ReadTensorInternal(FunctionLibraryRuntime* flr,
249 StringPiece n,
250 StringPiece key,
251 Tensor* val) const {
252 if (Contains(n, strings::StrCat(key, kIsDataset))) {
253 return ReadDatasetInternal(flr, n, key, val);
254 }
255 string name(n);
256 auto it = map_.find(name);
257 if (it == map_.end()) {
258 return errors::NotFound(name);
259 }
260 const auto& bucket = it->second;
261 auto key_it = bucket.find(string(key));
262 if (key_it == bucket.end()) {
263 return errors::NotFound(key);
264 }
265 *val = data_.at(name)->tensors(key_it->second);
266 return Status::OK();
267 }
268
ReadDatasetInternal(FunctionLibraryRuntime * flr,StringPiece n,StringPiece key,Tensor * val) const269 Status VariantTensorDataReader::ReadDatasetInternal(FunctionLibraryRuntime* flr,
270 StringPiece n,
271 StringPiece key,
272 Tensor* val) const {
273 if (flr == nullptr) {
274 return errors::Internal(
275 "Function library runtime is needed to restore a dataset.");
276 }
277 tstring output_node, serialized_graph_def;
278 TF_RETURN_IF_ERROR(
279 ReadScalar(n, strings::StrCat(key, kOutputNode), &output_node));
280 TF_RETURN_IF_ERROR(
281 ReadScalar(n, strings::StrCat(key), &serialized_graph_def));
282 GraphDef graph_def;
283 graph_def.ParseFromString(serialized_graph_def);
284 TF_RETURN_IF_ERROR(FromGraphDef(flr, graph_def, {}, output_node, val));
285 return Status::OK();
286 }
287
WriteScalar(StringPiece key,const int64_t val)288 Status VariantTensorDataWriter::WriteScalar(StringPiece key,
289 const int64_t val) {
290 string name;
291 TF_RETURN_IF_ERROR(GetIteratorName(key, &name));
292 return WriteScalar(name, key, val);
293 }
294
WriteScalar(StringPiece name,StringPiece key,const int64_t val)295 Status VariantTensorDataWriter::WriteScalar(StringPiece name, StringPiece key,
296 const int64_t val) {
297 return WriteScalarInternal(name, key, val);
298 }
299
WriteScalar(StringPiece key,const tstring & val)300 Status VariantTensorDataWriter::WriteScalar(StringPiece key,
301 const tstring& val) {
302 string name;
303 TF_RETURN_IF_ERROR(GetIteratorName(key, &name));
304 return WriteScalar(name, key, val);
305 }
306
WriteScalar(StringPiece name,StringPiece key,const tstring & val)307 Status VariantTensorDataWriter::WriteScalar(StringPiece name, StringPiece key,
308 const tstring& val) {
309 return WriteScalarInternal(name, key, val);
310 }
311
WriteTensor(StringPiece key,const Tensor & val)312 Status VariantTensorDataWriter::WriteTensor(StringPiece key,
313 const Tensor& val) {
314 string name;
315 TF_RETURN_IF_ERROR(GetIteratorName(key, &name));
316 return WriteTensor(name, key, val);
317 }
318
WriteTensor(StringPiece name,StringPiece key,const Tensor & val)319 Status VariantTensorDataWriter::WriteTensor(StringPiece name, StringPiece key,
320 const Tensor& val) {
321 return WriteTensorInternal(name, key, val);
322 }
323
MaybeFlush()324 void VariantTensorDataWriter::MaybeFlush() {
325 if (is_flushed_) return;
326 for (auto& keys : keys_) {
327 const string name = keys.first;
328 string metadata = name;
329 for (size_t i = 0; i < keys_[name].size(); ++i) {
330 strings::StrAppend(&metadata, kDelimiter, keys_[name][i]);
331 }
332 data_[name]->set_metadata(metadata);
333 }
334 is_flushed_ = true;
335 }
336
Reset()337 void VariantTensorDataWriter::Reset() {
338 is_flushed_ = false;
339 data_.clear();
340 keys_.clear();
341 }
342
ReleaseData(std::vector<std::unique_ptr<VariantTensorData>> * variants)343 void VariantTensorDataWriter::ReleaseData(
344 std::vector<std::unique_ptr<VariantTensorData>>* variants) {
345 MaybeFlush();
346 for (auto& it : data_) {
347 variants->push_back(std::move(it.second));
348 }
349 Reset();
350 }
351
GetData(std::vector<const VariantTensorData * > * variants)352 void VariantTensorDataWriter::GetData(
353 std::vector<const VariantTensorData*>* variants) {
354 MaybeFlush();
355 for (auto& it : data_) {
356 variants->push_back(it.second.get());
357 }
358 }
359
360 template <typename T>
WriteScalarInternal(StringPiece name,StringPiece key,const T & val)361 Status VariantTensorDataWriter::WriteScalarInternal(StringPiece name,
362 StringPiece key,
363 const T& val) {
364 if (is_flushed_) {
365 return errors::FailedPrecondition(
366 "Cannot call WriteScalar after GetData or ReleaseData is called");
367 }
368 Tensor val_t = Tensor(DataTypeToEnum<T>::v(), TensorShape({}));
369 val_t.scalar<T>()() = val;
370 return WriteTensorInternal(name, key, val_t);
371 }
372
WriteTensorInternal(StringPiece n,StringPiece key,const Tensor & val)373 Status VariantTensorDataWriter::WriteTensorInternal(StringPiece n,
374 StringPiece key,
375 const Tensor& val) {
376 DatasetBase* dataset;
377 if (GetDatasetFromVariantTensor(val, &dataset).ok()) {
378 return WriteDatasetInternal(n, key, dataset);
379 }
380 if (is_flushed_) {
381 return errors::FailedPrecondition(
382 "Cannot call WriteTensor after GetData or ReleaseData is called");
383 }
384 DCHECK_EQ(key.find(kDelimiter), string::npos);
385 string name(n);
386 if (keys_.count(name) == 0) {
387 keys_[name] = std::vector<string>();
388 }
389 keys_[name].push_back(string(key));
390 if (data_.count(name) == 0) {
391 data_[name] = absl::make_unique<VariantTensorData>();
392 data_[name]->set_type_name("tensorflow::Iterator");
393 }
394 *(data_[name]->add_tensors()) = val;
395 return Status::OK();
396 }
397
WriteDatasetInternal(StringPiece n,StringPiece key,const DatasetBase * dataset)398 Status VariantTensorDataWriter::WriteDatasetInternal(
399 StringPiece n, StringPiece key, const DatasetBase* dataset) {
400 GraphDef graph_def;
401 SerializationContext ctx((SerializationContext::Params()));
402 TF_RETURN_IF_ERROR(AsGraphDef(nullptr, dataset, std::move(ctx), &graph_def));
403 string output_node;
404 for (const auto& node : graph_def.node()) {
405 if (node.op() == "_Retval") {
406 output_node = node.input(0);
407 break;
408 }
409 }
410 string result;
411 graph_def.SerializeToString(&result);
412 TF_RETURN_IF_ERROR(WriteScalar(n, strings::StrCat(key, kIsDataset), ""));
413 TF_RETURN_IF_ERROR(
414 WriteScalar(n, strings::StrCat(key, kOutputNode), output_node));
415 TF_RETURN_IF_ERROR(WriteScalar(n, key, result));
416 return Status::OK();
417 }
418
AsGraphDefMinimal(OpKernelContext * ctx,const DatasetBase * input,std::vector<std::pair<string,Tensor>> * input_list,GraphDef * result,string * dataset_node)419 Status AsGraphDefMinimal(OpKernelContext* ctx, const DatasetBase* input,
420 std::vector<std::pair<string, Tensor>>* input_list,
421 GraphDef* result, string* dataset_node) {
422 SerializationContext::Params params(ctx);
423 params.input_list = input_list;
424 params.external_state_policy =
425 SerializationContext::ExternalStatePolicy::kIgnore;
426 params.fail_if_unimplemented = false;
427 params.serialize_data_tensors = false;
428 params.preserve_random_seeds = false;
429 SerializationContext serialization_ctx(params);
430 TF_RETURN_IF_ERROR(
431 AsGraphDef(ctx, input, std::move(serialization_ctx), result));
432
433 // Symbolic `_Retval` node indicates which node corresponds to the dataset.
434 for (const auto& node : result->node()) {
435 if (node.op() == "_Retval") {
436 *dataset_node = node.input(0);
437 }
438 }
439 return Status::OK();
440 }
441
AsGraphDef(OpKernelContext * ctx,const DatasetBase * dataset,SerializationContext && serialization_ctx,GraphDef * graph_def)442 Status AsGraphDef(OpKernelContext* ctx, const DatasetBase* dataset,
443 SerializationContext&& serialization_ctx,
444 GraphDef* graph_def) {
445 if (serialization_ctx.external_state_policy() ==
446 SerializationContext::ExternalStatePolicy::kFail) {
447 TF_RETURN_IF_ERROR(dataset->CheckExternalState());
448 }
449 if (serialization_ctx.external_state_policy() ==
450 SerializationContext::ExternalStatePolicy::kWarn) {
451 std::vector<string> stateful_op_names;
452 TF_RETURN_IF_ERROR(FindStatefulOps(*graph_def, &stateful_op_names));
453 if (!stateful_op_names.empty()) {
454 LOG(WARNING) << "We found the following stateful ops in the dataset "
455 "construction graph whose state would not be "
456 "serialized and might "
457 "cause subtle bugs: "
458 << absl::StrJoin(stateful_op_names, ", ");
459 }
460 }
461 GraphDefBuilder b;
462 DatasetBase::DatasetGraphDefBuilder db(&b);
463 Node* output_node = nullptr;
464 TF_RETURN_IF_ERROR(
465 db.AddInputDataset(&serialization_ctx, dataset, &output_node));
466 // Insert a purely symbolic _Retval node to indicate to consumers which node
467 // represents `dataset`.
468 ops::UnaryOp("_Retval", output_node,
469 b.opts()
470 .WithName("dataset")
471 .WithAttr("T", DT_VARIANT)
472 .WithAttr("index", 0));
473 TF_RETURN_IF_ERROR(b.ToGraphDef(graph_def));
474 return Status::OK();
475 }
476
477 } // namespace data
478 } // namespace tensorflow
479