• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1"""Utilities for including Python state in TensorFlow checkpoints."""
2# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ==============================================================================
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import abc
21import functools
22
23import six
24
25from tensorflow.python.training.tracking import base
26from tensorflow.python.util.tf_export import tf_export
27
28
29@tf_export("train.experimental.PythonState")
30@six.add_metaclass(abc.ABCMeta)
31class PythonState(base.Trackable):
32  """A mixin for putting Python state in an object-based checkpoint.
33
34  This is an abstract class which allows extensions to TensorFlow's object-based
35  checkpointing (see `tf.train.Checkpoint`). For example a wrapper for NumPy
36  arrays:
37
38  ```python
39  import io
40  import numpy
41
42  class NumpyWrapper(tf.train.experimental.PythonState):
43
44    def __init__(self, array):
45      self.array = array
46
47    def serialize(self):
48      string_file = io.BytesIO()
49      try:
50        numpy.save(string_file, self.array, allow_pickle=False)
51        serialized = string_file.getvalue()
52      finally:
53        string_file.close()
54      return serialized
55
56    def deserialize(self, string_value):
57      string_file = io.BytesIO(string_value)
58      try:
59        self.array = numpy.load(string_file, allow_pickle=False)
60      finally:
61        string_file.close()
62  ```
63
64  Instances of `NumpyWrapper` are checkpointable objects, and will be saved and
65  restored from checkpoints along with TensorFlow state like variables.
66
67  ```python
68  root = tf.train.Checkpoint(numpy=NumpyWrapper(numpy.array([1.])))
69  save_path = root.save(prefix)
70  root.numpy.array *= 2.
71  assert [2.] == root.numpy.array
72  root.restore(save_path)
73  assert [1.] == root.numpy.array
74  ```
75  """
76
77  @abc.abstractmethod
78  def serialize(self):
79    """Callback to serialize the object. Returns a string."""
80
81  @abc.abstractmethod
82  def deserialize(self, string_value):
83    """Callback to deserialize the object."""
84
85  def _gather_saveables_for_checkpoint(self):
86    """Specify callbacks for saving and restoring `array`."""
87    return {
88        "py_state": functools.partial(
89            base.PythonStringStateSaveable,
90            state_callback=self.serialize,
91            restore_callback=self.deserialize)
92        }
93