• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/data/serialization_utils.h"
17 
18 #include <string>
19 #include <utility>
20 
21 #include "tensorflow/core/common_runtime/graph_constructor.h"
22 #include "tensorflow/core/common_runtime/graph_runner.h"
23 #include "tensorflow/core/data/dataset_utils.h"
24 #include "tensorflow/core/framework/dataset.h"
25 #include "tensorflow/core/framework/function.h"
26 #include "tensorflow/core/graph/graph_def_builder.h"
27 
28 namespace tensorflow {
29 namespace data {
30 namespace {
31 
32 constexpr char kDelimiter[] = "@@";
33 constexpr char kComponent[] = "component";
34 constexpr char kNumComponents[] = "num_components";
35 constexpr char kNumElements[] = "num_elements";
36 constexpr char kIsDataset[] = ".is_dataset";
37 constexpr char kOutputNode[] = ".output_node";
38 
39 // We assume that all keys are of the form <iterator_prefix>:<name>. We extract
40 // the iterator name by getting rid of everything post the final colon.
GetIteratorName(StringPiece key,string * name)41 Status GetIteratorName(StringPiece key, string* name) {
42   if (!str_util::StartsWith(key, data::kFullNameRandomHex)) {
43     return errors::InvalidArgument("Save key: ", key,
44                                    " not generated using full_name.");
45   }
46   std::vector<string> split_keys = str_util::Split(key, data::kPipe);
47   if (split_keys.size() != 2) {
48     return errors::InvalidArgument("Save key: ", key,
49                                    " not generated using full_name.");
50   }
51   string real_key = split_keys[1];
52   const int pos = real_key.rfind(kColon);
53   *name = real_key.substr(0, pos);
54   return Status::OK();
55 }
56 
FromGraphDef(FunctionLibraryRuntime * flr,const GraphDef & graph_def,const std::vector<std::pair<string,Tensor>> & input_list,const string & output_node,Tensor * result)57 Status FromGraphDef(FunctionLibraryRuntime* flr, const GraphDef& graph_def,
58                     const std::vector<std::pair<string, Tensor>>& input_list,
59                     const string& output_node, Tensor* result) {
60   FunctionLibraryRuntime* cloned_flr = nullptr;
61   std::unique_ptr<ProcessFunctionLibraryRuntime> pflr = nullptr;
62   std::unique_ptr<FunctionLibraryDefinition> lib_def = nullptr;
63   TF_RETURN_IF_ERROR(flr->Clone(&lib_def, &pflr, &cloned_flr, true));
64   TF_RETURN_IF_ERROR(AddToFunctionLibrary(lib_def.get(), graph_def.library()));
65   Graph graph(OpRegistry::Global());
66   TF_RETURN_IF_ERROR(ImportGraphDef({}, graph_def, &graph, nullptr));
67   std::vector<Tensor> outputs;
68   GraphRunner graph_runner(cloned_flr->device());
69   TF_RETURN_IF_ERROR(graph_runner.Run(&graph, cloned_flr, input_list,
70                                       {output_node}, &outputs));
71   *result = outputs[0];
72   return Status::OK();
73 }
74 
75 // FindStatefulOps searches `graph_def` for all of its stateful ops storing
76 // their names in `stateful_op_names`.
FindStatefulOps(const GraphDef & graph_def,std::vector<string> * stateful_op_names)77 Status FindStatefulOps(const GraphDef& graph_def,
78                        std::vector<string>* stateful_op_names) {
79   FunctionLibraryDefinition lib_def(OpRegistry::Global(), graph_def.library());
80 
81   // Iterate over all nodes in the graph.
82   for (const auto& node : graph_def.node()) {
83     // Each Dataset graph has a _Retval op in the end which is marked stateful
84     if (node.op() == FunctionLibraryDefinition::kRetOp) continue;
85     if (!IsNodeStateful(lib_def, node).ok()) {
86       stateful_op_names->push_back(node.op());
87     }
88   }
89 
90   // Iterate over all functions.
91   for (const auto& fdef : graph_def.library().function()) {
92     if (!fdef.signature().is_stateful()) continue;
93     for (const auto& node : fdef.node_def()) {
94       if (!IsNodeStateful(lib_def, node).ok()) {
95         stateful_op_names->push_back(
96             absl::StrCat(node.op(), " in function: ", fdef.signature().name()));
97       }
98     }
99   }
100   return Status::OK();
101 }
102 
103 }  // namespace
104 
ReadElementsFromCheckpoint(IteratorContext * ctx,IteratorStateReader * reader,StringPiece key_prefix,std::vector<std::vector<Tensor>> * elements)105 Status ReadElementsFromCheckpoint(IteratorContext* ctx,
106                                   IteratorStateReader* reader,
107                                   StringPiece key_prefix,
108                                   std::vector<std::vector<Tensor>>* elements) {
109   int64_t num_elements;
110   TF_RETURN_IF_ERROR(
111       reader->ReadScalar(key_prefix, kNumElements, &num_elements));
112   DCHECK(elements->empty());
113   elements->reserve(num_elements);
114   for (int i = 0; i < num_elements; ++i) {
115     std::string element_prefix = absl::StrCat(key_prefix, "::", i);
116     int64_t num_components;
117     TF_RETURN_IF_ERROR(
118         reader->ReadScalar(element_prefix, kNumComponents, &num_components));
119     elements->emplace_back();
120     std::vector<Tensor>& element = elements->at(i);
121     element.reserve(num_components);
122     for (int j = 0; j < num_components; ++j) {
123       element.emplace_back();
124       TF_RETURN_IF_ERROR(reader->ReadTensor(
125           ctx->flr(), element_prefix, absl::StrCat(kComponent, "[", j, "]"),
126           &element.back()));
127     }
128   }
129   return Status::OK();
130 }
131 
WriteElementsToCheckpoint(IteratorStateWriter * writer,StringPiece key_prefix,const std::vector<std::vector<Tensor>> & elements)132 Status WriteElementsToCheckpoint(
133     IteratorStateWriter* writer, StringPiece key_prefix,
134     const std::vector<std::vector<Tensor>>& elements) {
135   TF_RETURN_IF_ERROR(
136       writer->WriteScalar(key_prefix, kNumElements, elements.size()));
137   for (int i = 0; i < elements.size(); ++i) {
138     const std::vector<Tensor>& element = elements[i];
139     std::string element_prefix = absl::StrCat(key_prefix, "::", i);
140     TF_RETURN_IF_ERROR(
141         writer->WriteScalar(element_prefix, kNumComponents, element.size()));
142     for (int j = 0; j < elements[i].size(); ++j) {
143       TF_RETURN_IF_ERROR(writer->WriteTensor(
144           element_prefix, absl::StrCat(kComponent, "[", j, "]"), element[j]));
145     }
146   }
147   return Status::OK();
148 }
149 
VariantTensorDataReader(const std::vector<const tensorflow::VariantTensorData * > & data)150 VariantTensorDataReader::VariantTensorDataReader(
151     const std::vector<const tensorflow::VariantTensorData*>& data) {
152   for (const auto& d : data) {
153     string metadata;
154     d->get_metadata(&metadata);
155     auto keys = str_util::Split(metadata, kDelimiter, str_util::SkipEmpty());
156     const string name = keys[0];
157     data_[name] = d;
158     map_[name] = std::map<string, size_t>();
159     for (size_t i = 1; i < keys.size(); ++i) {
160       map_[name][keys[i]] = i - 1;
161     }
162   }
163 }
164 
ReadScalar(StringPiece key,int64 * val) const165 Status VariantTensorDataReader::ReadScalar(StringPiece key, int64* val) const {
166   string name;
167   TF_RETURN_IF_ERROR(GetIteratorName(key, &name));
168   return ReadScalar(name, key, val);
169 }
170 
ReadScalar(StringPiece name,StringPiece key,int64 * val) const171 Status VariantTensorDataReader::ReadScalar(StringPiece name, StringPiece key,
172                                            int64* val) const {
173   return ReadScalarInternal(name, key, val);
174 }
175 
ReadScalar(StringPiece key,tstring * val) const176 Status VariantTensorDataReader::ReadScalar(StringPiece key,
177                                            tstring* val) const {
178   string name;
179   TF_RETURN_IF_ERROR(GetIteratorName(key, &name));
180   return ReadScalar(name, key, val);
181 }
182 
ReadScalar(StringPiece name,StringPiece key,tstring * val) const183 Status VariantTensorDataReader::ReadScalar(StringPiece name, StringPiece key,
184                                            tstring* val) const {
185   return ReadScalarInternal(name, key, val);
186 }
187 
ReadTensor(StringPiece key,Tensor * val) const188 Status VariantTensorDataReader::ReadTensor(StringPiece key, Tensor* val) const {
189   string name;
190   TF_RETURN_IF_ERROR(GetIteratorName(key, &name));
191   return ReadTensor(name, key, val);
192 }
193 
ReadTensor(FunctionLibraryRuntime * flr,StringPiece key,Tensor * val) const194 Status VariantTensorDataReader::ReadTensor(FunctionLibraryRuntime* flr,
195                                            StringPiece key, Tensor* val) const {
196   string name;
197   TF_RETURN_IF_ERROR(GetIteratorName(key, &name));
198   return ReadTensorInternal(flr, name, key, val);
199 }
200 
ReadTensor(StringPiece name,StringPiece key,Tensor * val) const201 Status VariantTensorDataReader::ReadTensor(StringPiece name, StringPiece key,
202                                            Tensor* val) const {
203   return ReadTensor(/*flr=*/nullptr, name, key, val);
204 }
205 
ReadTensor(FunctionLibraryRuntime * flr,StringPiece name,StringPiece key,Tensor * val) const206 Status VariantTensorDataReader::ReadTensor(FunctionLibraryRuntime* flr,
207                                            StringPiece name, StringPiece key,
208                                            Tensor* val) const {
209   return ReadTensorInternal(flr, name, key, val);
210 }
211 
Contains(StringPiece key) const212 bool VariantTensorDataReader::Contains(StringPiece key) const {
213   string name;
214   if (!GetIteratorName(key, &name).ok()) {
215     return false;
216   }
217   return Contains(name, key);
218 }
219 
Contains(StringPiece n,StringPiece key) const220 bool VariantTensorDataReader::Contains(StringPiece n, StringPiece key) const {
221   string name(n);
222   auto it = map_.find(name);
223   if (it == map_.end()) {
224     return false;
225   }
226   const auto& bucket = it->second;
227   return bucket.find(string(key)) != bucket.end();
228 }
229 
230 template <typename T>
ReadScalarInternal(StringPiece n,StringPiece key,T * val) const231 Status VariantTensorDataReader::ReadScalarInternal(StringPiece n,
232                                                    StringPiece key,
233                                                    T* val) const {
234   string name(n);
235   auto it = map_.find(name);
236   if (it == map_.end()) {
237     return errors::NotFound(name);
238   }
239   const auto& bucket = it->second;
240   auto key_it = bucket.find(string(key));
241   if (key_it == bucket.end()) {
242     return errors::NotFound(key);
243   }
244   *val = data_.at(name)->tensors(key_it->second).scalar<T>()();
245   return Status::OK();
246 }
247 
ReadTensorInternal(FunctionLibraryRuntime * flr,StringPiece n,StringPiece key,Tensor * val) const248 Status VariantTensorDataReader::ReadTensorInternal(FunctionLibraryRuntime* flr,
249                                                    StringPiece n,
250                                                    StringPiece key,
251                                                    Tensor* val) const {
252   if (Contains(n, strings::StrCat(key, kIsDataset))) {
253     return ReadDatasetInternal(flr, n, key, val);
254   }
255   string name(n);
256   auto it = map_.find(name);
257   if (it == map_.end()) {
258     return errors::NotFound(name);
259   }
260   const auto& bucket = it->second;
261   auto key_it = bucket.find(string(key));
262   if (key_it == bucket.end()) {
263     return errors::NotFound(key);
264   }
265   *val = data_.at(name)->tensors(key_it->second);
266   return Status::OK();
267 }
268 
ReadDatasetInternal(FunctionLibraryRuntime * flr,StringPiece n,StringPiece key,Tensor * val) const269 Status VariantTensorDataReader::ReadDatasetInternal(FunctionLibraryRuntime* flr,
270                                                     StringPiece n,
271                                                     StringPiece key,
272                                                     Tensor* val) const {
273   if (flr == nullptr) {
274     return errors::Internal(
275         "Function library runtime is needed to restore a dataset.");
276   }
277   tstring output_node, serialized_graph_def;
278   TF_RETURN_IF_ERROR(
279       ReadScalar(n, strings::StrCat(key, kOutputNode), &output_node));
280   TF_RETURN_IF_ERROR(
281       ReadScalar(n, strings::StrCat(key), &serialized_graph_def));
282   GraphDef graph_def;
283   graph_def.ParseFromString(serialized_graph_def);
284   TF_RETURN_IF_ERROR(FromGraphDef(flr, graph_def, {}, output_node, val));
285   return Status::OK();
286 }
287 
WriteScalar(StringPiece key,const int64_t val)288 Status VariantTensorDataWriter::WriteScalar(StringPiece key,
289                                             const int64_t val) {
290   string name;
291   TF_RETURN_IF_ERROR(GetIteratorName(key, &name));
292   return WriteScalar(name, key, val);
293 }
294 
WriteScalar(StringPiece name,StringPiece key,const int64_t val)295 Status VariantTensorDataWriter::WriteScalar(StringPiece name, StringPiece key,
296                                             const int64_t val) {
297   return WriteScalarInternal(name, key, val);
298 }
299 
WriteScalar(StringPiece key,const tstring & val)300 Status VariantTensorDataWriter::WriteScalar(StringPiece key,
301                                             const tstring& val) {
302   string name;
303   TF_RETURN_IF_ERROR(GetIteratorName(key, &name));
304   return WriteScalar(name, key, val);
305 }
306 
WriteScalar(StringPiece name,StringPiece key,const tstring & val)307 Status VariantTensorDataWriter::WriteScalar(StringPiece name, StringPiece key,
308                                             const tstring& val) {
309   return WriteScalarInternal(name, key, val);
310 }
311 
WriteTensor(StringPiece key,const Tensor & val)312 Status VariantTensorDataWriter::WriteTensor(StringPiece key,
313                                             const Tensor& val) {
314   string name;
315   TF_RETURN_IF_ERROR(GetIteratorName(key, &name));
316   return WriteTensor(name, key, val);
317 }
318 
WriteTensor(StringPiece name,StringPiece key,const Tensor & val)319 Status VariantTensorDataWriter::WriteTensor(StringPiece name, StringPiece key,
320                                             const Tensor& val) {
321   return WriteTensorInternal(name, key, val);
322 }
323 
MaybeFlush()324 void VariantTensorDataWriter::MaybeFlush() {
325   if (is_flushed_) return;
326   for (auto& keys : keys_) {
327     const string name = keys.first;
328     string metadata = name;
329     for (size_t i = 0; i < keys_[name].size(); ++i) {
330       strings::StrAppend(&metadata, kDelimiter, keys_[name][i]);
331     }
332     data_[name]->set_metadata(metadata);
333   }
334   is_flushed_ = true;
335 }
336 
Reset()337 void VariantTensorDataWriter::Reset() {
338   is_flushed_ = false;
339   data_.clear();
340   keys_.clear();
341 }
342 
ReleaseData(std::vector<std::unique_ptr<VariantTensorData>> * variants)343 void VariantTensorDataWriter::ReleaseData(
344     std::vector<std::unique_ptr<VariantTensorData>>* variants) {
345   MaybeFlush();
346   for (auto& it : data_) {
347     variants->push_back(std::move(it.second));
348   }
349   Reset();
350 }
351 
GetData(std::vector<const VariantTensorData * > * variants)352 void VariantTensorDataWriter::GetData(
353     std::vector<const VariantTensorData*>* variants) {
354   MaybeFlush();
355   for (auto& it : data_) {
356     variants->push_back(it.second.get());
357   }
358 }
359 
360 template <typename T>
WriteScalarInternal(StringPiece name,StringPiece key,const T & val)361 Status VariantTensorDataWriter::WriteScalarInternal(StringPiece name,
362                                                     StringPiece key,
363                                                     const T& val) {
364   if (is_flushed_) {
365     return errors::FailedPrecondition(
366         "Cannot call WriteScalar after GetData or ReleaseData is called");
367   }
368   Tensor val_t = Tensor(DataTypeToEnum<T>::v(), TensorShape({}));
369   val_t.scalar<T>()() = val;
370   return WriteTensorInternal(name, key, val_t);
371 }
372 
WriteTensorInternal(StringPiece n,StringPiece key,const Tensor & val)373 Status VariantTensorDataWriter::WriteTensorInternal(StringPiece n,
374                                                     StringPiece key,
375                                                     const Tensor& val) {
376   DatasetBase* dataset;
377   if (GetDatasetFromVariantTensor(val, &dataset).ok()) {
378     return WriteDatasetInternal(n, key, dataset);
379   }
380   if (is_flushed_) {
381     return errors::FailedPrecondition(
382         "Cannot call WriteTensor after GetData or ReleaseData is called");
383   }
384   DCHECK_EQ(key.find(kDelimiter), string::npos);
385   string name(n);
386   if (keys_.count(name) == 0) {
387     keys_[name] = std::vector<string>();
388   }
389   keys_[name].push_back(string(key));
390   if (data_.count(name) == 0) {
391     data_[name] = absl::make_unique<VariantTensorData>();
392     data_[name]->set_type_name("tensorflow::Iterator");
393   }
394   *(data_[name]->add_tensors()) = val;
395   return Status::OK();
396 }
397 
WriteDatasetInternal(StringPiece n,StringPiece key,const DatasetBase * dataset)398 Status VariantTensorDataWriter::WriteDatasetInternal(
399     StringPiece n, StringPiece key, const DatasetBase* dataset) {
400   GraphDef graph_def;
401   SerializationContext ctx((SerializationContext::Params()));
402   TF_RETURN_IF_ERROR(AsGraphDef(nullptr, dataset, std::move(ctx), &graph_def));
403   string output_node;
404   for (const auto& node : graph_def.node()) {
405     if (node.op() == "_Retval") {
406       output_node = node.input(0);
407       break;
408     }
409   }
410   string result;
411   graph_def.SerializeToString(&result);
412   TF_RETURN_IF_ERROR(WriteScalar(n, strings::StrCat(key, kIsDataset), ""));
413   TF_RETURN_IF_ERROR(
414       WriteScalar(n, strings::StrCat(key, kOutputNode), output_node));
415   TF_RETURN_IF_ERROR(WriteScalar(n, key, result));
416   return Status::OK();
417 }
418 
AsGraphDefMinimal(OpKernelContext * ctx,const DatasetBase * input,std::vector<std::pair<string,Tensor>> * input_list,GraphDef * result,string * dataset_node)419 Status AsGraphDefMinimal(OpKernelContext* ctx, const DatasetBase* input,
420                          std::vector<std::pair<string, Tensor>>* input_list,
421                          GraphDef* result, string* dataset_node) {
422   SerializationContext::Params params(ctx);
423   params.input_list = input_list;
424   params.external_state_policy =
425       SerializationContext::ExternalStatePolicy::kIgnore;
426   params.fail_if_unimplemented = false;
427   params.serialize_data_tensors = false;
428   params.preserve_random_seeds = false;
429   SerializationContext serialization_ctx(params);
430   TF_RETURN_IF_ERROR(
431       AsGraphDef(ctx, input, std::move(serialization_ctx), result));
432 
433   // Symbolic `_Retval` node indicates which node corresponds to the dataset.
434   for (const auto& node : result->node()) {
435     if (node.op() == "_Retval") {
436       *dataset_node = node.input(0);
437     }
438   }
439   return Status::OK();
440 }
441 
AsGraphDef(OpKernelContext * ctx,const DatasetBase * dataset,SerializationContext && serialization_ctx,GraphDef * graph_def)442 Status AsGraphDef(OpKernelContext* ctx, const DatasetBase* dataset,
443                   SerializationContext&& serialization_ctx,
444                   GraphDef* graph_def) {
445   if (serialization_ctx.external_state_policy() ==
446       SerializationContext::ExternalStatePolicy::kFail) {
447     TF_RETURN_IF_ERROR(dataset->CheckExternalState());
448   }
449   if (serialization_ctx.external_state_policy() ==
450       SerializationContext::ExternalStatePolicy::kWarn) {
451     std::vector<string> stateful_op_names;
452     TF_RETURN_IF_ERROR(FindStatefulOps(*graph_def, &stateful_op_names));
453     if (!stateful_op_names.empty()) {
454       LOG(WARNING) << "We found the following stateful ops in the dataset "
455                       "construction graph whose state would not be "
456                       "serialized and might "
457                       "cause subtle bugs: "
458                    << absl::StrJoin(stateful_op_names, ", ");
459     }
460   }
461   GraphDefBuilder b;
462   DatasetBase::DatasetGraphDefBuilder db(&b);
463   Node* output_node = nullptr;
464   TF_RETURN_IF_ERROR(
465       db.AddInputDataset(&serialization_ctx, dataset, &output_node));
466   // Insert a purely symbolic _Retval node to indicate to consumers which node
467   // represents `dataset`.
468   ops::UnaryOp("_Retval", output_node,
469                b.opts()
470                    .WithName("dataset")
471                    .WithAttr("T", DT_VARIANT)
472                    .WithAttr("index", 0));
473   TF_RETURN_IF_ERROR(b.ToGraphDef(graph_def));
474   return Status::OK();
475 }
476 
477 }  // namespace data
478 }  // namespace tensorflow
479