• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15import io
16import os
17
18import numpy
19from tensorflow.python.checkpoint import checkpoint as util
20from tensorflow.python.client import session
21from tensorflow.python.framework import ops
22from tensorflow.python.framework import test_util
23from tensorflow.python.module import module
24from tensorflow.python.platform import test
25from tensorflow.python.trackable import python_state
26
27
28class _NumpyState(module.Module):
29  """A checkpointable object whose NumPy array attributes are saved/restored.
30
31  Example usage:
32
33  ```python
34  arrays = _NumpyState()
35  checkpoint = tf.train.Checkpoint(numpy_arrays=arrays)
36  arrays.x = numpy.zeros([3, 4])
37  save_path = checkpoint.save("/tmp/ckpt")
38  arrays.x[1, 1] = 4.
39  checkpoint.restore(save_path)
40  assert (arrays.x == numpy.zeros([3, 4])).all()
41
42  second_checkpoint = tf.train.Checkpoint(
43      numpy_arrays=_NumpyState())
44  # Attributes of NumpyState objects are created automatically by restore()
45  second_checkpoint.restore(save_path)
46  assert (second_checkpoint.numpy_arrays.x == numpy.zeros([3, 4])).all()
47  ```
48
49  Note that `NumpyState` objects re-create the attributes of the previously
50  saved object on `restore()`. This is in contrast to TensorFlow variables, for
51  which a `Variable` object must be created and assigned to an attribute.
52
53  This snippet works both when graph building and when executing eagerly. On
54  save, the NumPy array(s) are fed as strings to be saved in the checkpoint (via
55  a placeholder when graph building, or as a string constant when executing
56  eagerly). When restoring they skip the TensorFlow graph entirely, and so no
57  restore ops need be run. This means that restoration always happens eagerly,
58  rather than waiting for `checkpoint.restore(...).run_restore_ops()` like
59  TensorFlow variables when graph building.
60  """
61
62  def __init__(self):
63    super(_NumpyState, self).__setattr__("_arrays", module.Module())
64
65  def __getattribute__(self, name):
66    """Un-wrap `_NumpyWrapper` objects when accessing attributes."""
67    try:
68      arrays = super(_NumpyState, self).__getattribute__("_arrays")
69    except AttributeError:
70      # _arrays hasn't been assigned yet
71      return super(_NumpyState, self).__getattribute__(name)
72    try:
73      value = getattr(arrays, name)
74    except AttributeError:
75      dummy_array = numpy.array([])
76      setattr(arrays, name, _NumpyWrapper(dummy_array))
77      value = getattr(arrays, name)
78      if value.array is dummy_array:
79        # No set or restored attribute with this name
80        delattr(arrays, name)
81        return super(_NumpyState, self).__getattribute__(name)
82
83    if isinstance(value, _NumpyWrapper):
84      return value.array
85    return super(_NumpyState, self).__getattribute__(name)
86
87  def __setattr__(self, name, value):
88    """Automatically wrap NumPy arrays assigned to attributes."""
89    if isinstance(value, (numpy.ndarray, numpy.generic)):
90      try:
91        existing = getattr(self._arrays, name)
92        existing.array = value
93        return
94      except AttributeError:
95        value = _NumpyWrapper(value)
96      setattr(self._arrays, name, value)
97      return
98    super(_NumpyState, self).__setattr__(name, value)
99
100
101class _NumpyWrapper(python_state.PythonState):
102  """Wraps a NumPy array for storage in an object-based checkpoint."""
103
104  def __init__(self, array):
105    """Specify a NumPy array to wrap.
106
107    Args:
108      array: The NumPy array to save and restore (may be overwritten).
109    """
110    self.array = array
111
112  def serialize(self):
113    """Callback to serialize the array."""
114    string_file = io.BytesIO()
115    try:
116      numpy.save(string_file, self.array, allow_pickle=False)
117      serialized = string_file.getvalue()
118    finally:
119      string_file.close()
120    return serialized
121
122  def deserialize(self, string_value):
123    """Callback to deserialize the array."""
124    string_file = io.BytesIO(string_value)
125    try:
126      self.array = numpy.load(string_file, allow_pickle=False)  # pylint: disable=unexpected-keyword-arg
127    finally:
128      string_file.close()
129
130
131class NumpyStateTests(test.TestCase):
132
133  def testWrapper(self):
134    directory = self.get_temp_dir()
135    prefix = os.path.join(directory, "ckpt")
136    root = util.Checkpoint(numpy=_NumpyWrapper(numpy.array([1.])))
137    save_path = root.save(prefix)
138    root.numpy.array *= 2.
139    self.assertEqual([2.], root.numpy.array)
140    root.restore(save_path)
141    self.assertEqual([1.], root.numpy.array)
142
143  @test_util.run_in_graph_and_eager_modes
144  def testSaveRestoreNumpyState(self):
145    directory = self.get_temp_dir()
146    prefix = os.path.join(directory, "ckpt")
147    save_state = _NumpyState()
148    saver = util.Checkpoint(numpy=save_state)
149    save_state.a = numpy.ones([2, 2])
150    save_state.b = numpy.ones([2, 2])
151    save_state.b = numpy.zeros([2, 2])
152    save_state.c = numpy.int64(3)
153    self.assertAllEqual(numpy.ones([2, 2]), save_state.a)
154    self.assertAllEqual(numpy.zeros([2, 2]), save_state.b)
155    self.assertEqual(3, save_state.c)
156    first_save_path = saver.save(prefix)
157    save_state.a[1, 1] = 2.
158    save_state.c = numpy.int64(4)
159    second_save_path = saver.save(prefix)
160
161    load_state = _NumpyState()
162    loader = util.Checkpoint(numpy=load_state)
163    loader.restore(first_save_path).initialize_or_restore()
164    self.assertAllEqual(numpy.ones([2, 2]), load_state.a)
165    self.assertAllEqual(numpy.zeros([2, 2]), load_state.b)
166    self.assertEqual(3, load_state.c)
167    load_state.a[0, 0] = 42.
168    self.assertAllEqual([[42., 1.], [1., 1.]], load_state.a)
169    loader.restore(first_save_path).run_restore_ops()
170    self.assertAllEqual(numpy.ones([2, 2]), load_state.a)
171    loader.restore(second_save_path).run_restore_ops()
172    self.assertAllEqual([[1., 1.], [1., 2.]], load_state.a)
173    self.assertAllEqual(numpy.zeros([2, 2]), load_state.b)
174    self.assertEqual(4, load_state.c)
175
176  def testNoGraphPollution(self):
177    graph = ops.Graph()
178    with graph.as_default(), session.Session():
179      directory = self.get_temp_dir()
180      prefix = os.path.join(directory, "ckpt")
181      save_state = _NumpyState()
182      saver = util.Checkpoint(numpy=save_state)
183      save_state.a = numpy.ones([2, 2])
184      save_path = saver.save(prefix)
185      saver.restore(save_path)
186      graph.finalize()
187      saver.save(prefix)
188      save_state.a = numpy.zeros([2, 2])
189      saver.save(prefix)
190      saver.restore(save_path)
191
192  @test_util.run_in_graph_and_eager_modes
193  def testDocstringExample(self):
194    arrays = _NumpyState()
195    checkpoint = util.Checkpoint(numpy_arrays=arrays)
196    arrays.x = numpy.zeros([3, 4])
197    save_path = checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt"))
198    arrays.x[1, 1] = 4.
199    checkpoint.restore(save_path)
200    self.assertAllEqual(numpy.zeros([3, 4]), arrays.x)
201
202    second_checkpoint = util.Checkpoint(numpy_arrays=_NumpyState())
203    second_checkpoint.restore(save_path)
204    self.assertAllEqual(numpy.zeros([3, 4]), second_checkpoint.numpy_arrays.x)
205
206
207if __name__ == "__main__":
208  ops.enable_eager_execution()
209  test.main()
210