1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15import io 16import os 17 18import numpy 19from tensorflow.python.checkpoint import checkpoint as util 20from tensorflow.python.client import session 21from tensorflow.python.framework import ops 22from tensorflow.python.framework import test_util 23from tensorflow.python.module import module 24from tensorflow.python.platform import test 25from tensorflow.python.trackable import python_state 26 27 28class _NumpyState(module.Module): 29 """A checkpointable object whose NumPy array attributes are saved/restored. 30 31 Example usage: 32 33 ```python 34 arrays = _NumpyState() 35 checkpoint = tf.train.Checkpoint(numpy_arrays=arrays) 36 arrays.x = numpy.zeros([3, 4]) 37 save_path = checkpoint.save("/tmp/ckpt") 38 arrays.x[1, 1] = 4. 39 checkpoint.restore(save_path) 40 assert (arrays.x == numpy.zeros([3, 4])).all() 41 42 second_checkpoint = tf.train.Checkpoint( 43 numpy_arrays=_NumpyState()) 44 # Attributes of NumpyState objects are created automatically by restore() 45 second_checkpoint.restore(save_path) 46 assert (second_checkpoint.numpy_arrays.x == numpy.zeros([3, 4])).all() 47 ``` 48 49 Note that `NumpyState` objects re-create the attributes of the previously 50 saved object on `restore()`. This is in contrast to TensorFlow variables, for 51 which a `Variable` object must be created and assigned to an attribute. 52 53 This snippet works both when graph building and when executing eagerly. On 54 save, the NumPy array(s) are fed as strings to be saved in the checkpoint (via 55 a placeholder when graph building, or as a string constant when executing 56 eagerly). When restoring they skip the TensorFlow graph entirely, and so no 57 restore ops need be run. This means that restoration always happens eagerly, 58 rather than waiting for `checkpoint.restore(...).run_restore_ops()` like 59 TensorFlow variables when graph building. 60 """ 61 62 def __init__(self): 63 super(_NumpyState, self).__setattr__("_arrays", module.Module()) 64 65 def __getattribute__(self, name): 66 """Un-wrap `_NumpyWrapper` objects when accessing attributes.""" 67 try: 68 arrays = super(_NumpyState, self).__getattribute__("_arrays") 69 except AttributeError: 70 # _arrays hasn't been assigned yet 71 return super(_NumpyState, self).__getattribute__(name) 72 try: 73 value = getattr(arrays, name) 74 except AttributeError: 75 dummy_array = numpy.array([]) 76 setattr(arrays, name, _NumpyWrapper(dummy_array)) 77 value = getattr(arrays, name) 78 if value.array is dummy_array: 79 # No set or restored attribute with this name 80 delattr(arrays, name) 81 return super(_NumpyState, self).__getattribute__(name) 82 83 if isinstance(value, _NumpyWrapper): 84 return value.array 85 return super(_NumpyState, self).__getattribute__(name) 86 87 def __setattr__(self, name, value): 88 """Automatically wrap NumPy arrays assigned to attributes.""" 89 if isinstance(value, (numpy.ndarray, numpy.generic)): 90 try: 91 existing = getattr(self._arrays, name) 92 existing.array = value 93 return 94 except AttributeError: 95 value = _NumpyWrapper(value) 96 setattr(self._arrays, name, value) 97 return 98 super(_NumpyState, self).__setattr__(name, value) 99 100 101class _NumpyWrapper(python_state.PythonState): 102 """Wraps a NumPy array for storage in an object-based checkpoint.""" 103 104 def __init__(self, array): 105 """Specify a NumPy array to wrap. 106 107 Args: 108 array: The NumPy array to save and restore (may be overwritten). 109 """ 110 self.array = array 111 112 def serialize(self): 113 """Callback to serialize the array.""" 114 string_file = io.BytesIO() 115 try: 116 numpy.save(string_file, self.array, allow_pickle=False) 117 serialized = string_file.getvalue() 118 finally: 119 string_file.close() 120 return serialized 121 122 def deserialize(self, string_value): 123 """Callback to deserialize the array.""" 124 string_file = io.BytesIO(string_value) 125 try: 126 self.array = numpy.load(string_file, allow_pickle=False) # pylint: disable=unexpected-keyword-arg 127 finally: 128 string_file.close() 129 130 131class NumpyStateTests(test.TestCase): 132 133 def testWrapper(self): 134 directory = self.get_temp_dir() 135 prefix = os.path.join(directory, "ckpt") 136 root = util.Checkpoint(numpy=_NumpyWrapper(numpy.array([1.]))) 137 save_path = root.save(prefix) 138 root.numpy.array *= 2. 139 self.assertEqual([2.], root.numpy.array) 140 root.restore(save_path) 141 self.assertEqual([1.], root.numpy.array) 142 143 @test_util.run_in_graph_and_eager_modes 144 def testSaveRestoreNumpyState(self): 145 directory = self.get_temp_dir() 146 prefix = os.path.join(directory, "ckpt") 147 save_state = _NumpyState() 148 saver = util.Checkpoint(numpy=save_state) 149 save_state.a = numpy.ones([2, 2]) 150 save_state.b = numpy.ones([2, 2]) 151 save_state.b = numpy.zeros([2, 2]) 152 save_state.c = numpy.int64(3) 153 self.assertAllEqual(numpy.ones([2, 2]), save_state.a) 154 self.assertAllEqual(numpy.zeros([2, 2]), save_state.b) 155 self.assertEqual(3, save_state.c) 156 first_save_path = saver.save(prefix) 157 save_state.a[1, 1] = 2. 158 save_state.c = numpy.int64(4) 159 second_save_path = saver.save(prefix) 160 161 load_state = _NumpyState() 162 loader = util.Checkpoint(numpy=load_state) 163 loader.restore(first_save_path).initialize_or_restore() 164 self.assertAllEqual(numpy.ones([2, 2]), load_state.a) 165 self.assertAllEqual(numpy.zeros([2, 2]), load_state.b) 166 self.assertEqual(3, load_state.c) 167 load_state.a[0, 0] = 42. 168 self.assertAllEqual([[42., 1.], [1., 1.]], load_state.a) 169 loader.restore(first_save_path).run_restore_ops() 170 self.assertAllEqual(numpy.ones([2, 2]), load_state.a) 171 loader.restore(second_save_path).run_restore_ops() 172 self.assertAllEqual([[1., 1.], [1., 2.]], load_state.a) 173 self.assertAllEqual(numpy.zeros([2, 2]), load_state.b) 174 self.assertEqual(4, load_state.c) 175 176 def testNoGraphPollution(self): 177 graph = ops.Graph() 178 with graph.as_default(), session.Session(): 179 directory = self.get_temp_dir() 180 prefix = os.path.join(directory, "ckpt") 181 save_state = _NumpyState() 182 saver = util.Checkpoint(numpy=save_state) 183 save_state.a = numpy.ones([2, 2]) 184 save_path = saver.save(prefix) 185 saver.restore(save_path) 186 graph.finalize() 187 saver.save(prefix) 188 save_state.a = numpy.zeros([2, 2]) 189 saver.save(prefix) 190 saver.restore(save_path) 191 192 @test_util.run_in_graph_and_eager_modes 193 def testDocstringExample(self): 194 arrays = _NumpyState() 195 checkpoint = util.Checkpoint(numpy_arrays=arrays) 196 arrays.x = numpy.zeros([3, 4]) 197 save_path = checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt")) 198 arrays.x[1, 1] = 4. 199 checkpoint.restore(save_path) 200 self.assertAllEqual(numpy.zeros([3, 4]), arrays.x) 201 202 second_checkpoint = util.Checkpoint(numpy_arrays=_NumpyState()) 203 second_checkpoint.restore(save_path) 204 self.assertAllEqual(numpy.zeros([3, 4]), second_checkpoint.numpy_arrays.x) 205 206 207if __name__ == "__main__": 208 ops.enable_eager_execution() 209 test.main() 210