1# Copyright 2022 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for saveable_object_util.""" 16 17import os 18 19from tensorflow.python.checkpoint import checkpoint 20from tensorflow.python.checkpoint import saveable_compat 21from tensorflow.python.eager import context 22from tensorflow.python.eager import test 23from tensorflow.python.framework import dtypes 24from tensorflow.python.ops import gen_resource_variable_ops 25from tensorflow.python.ops import resource_variable_ops 26from tensorflow.python.ops import variables 27from tensorflow.python.trackable import base 28from tensorflow.python.trackable import resource 29from tensorflow.python.training.saving import saveable_object 30from tensorflow.python.training.saving import saveable_object_util 31 32 33class _VarSaveable(saveable_object.SaveableObject): 34 35 def __init__(self, var, slice_spec, name): 36 specs = [saveable_object.SaveSpec(var.read_value(), slice_spec, name)] 37 super().__init__(var, specs, name) 38 39 def restore(self, restored_tensors, restored_shapes): 40 return self.op.assign(restored_tensors[0]) 41 42 43class SaveableCompatibilityConverterTest(test.TestCase): 44 45 def test_convert_no_saveable(self): 46 t = base.Trackable() 47 converter = saveable_object_util.SaveableCompatibilityConverter(t) 48 self.assertEmpty(converter._serialize_to_tensors()) 49 converter._restore_from_tensors({}) 50 51 with self.assertRaisesRegex(ValueError, "Could not restore object"): 52 converter._restore_from_tensors({"": 0}) 53 54 def test_convert_single_saveable(self): 55 56 class MyTrackable(base.Trackable): 57 58 def __init__(self): 59 self.a = variables.Variable(5.0) 60 61 def _gather_saveables_for_checkpoint(self): 62 return {"a": lambda name: _VarSaveable(self.a, "", name)} 63 64 t = MyTrackable() 65 converter = saveable_object_util.SaveableCompatibilityConverter(t) 66 67 serialized_tensors = converter._serialize_to_tensors() 68 self.assertLen(serialized_tensors, 1) 69 self.assertIn("a", serialized_tensors) 70 self.assertEqual(5, self.evaluate(serialized_tensors["a"])) 71 72 with self.assertRaisesRegex(ValueError, "Could not restore object"): 73 converter._restore_from_tensors({}) 74 with self.assertRaisesRegex(ValueError, "Could not restore object"): 75 converter._restore_from_tensors({"not_a": 1.}) 76 77 self.assertEqual(5, self.evaluate(t.a)) 78 converter._restore_from_tensors({"a": 123.}) 79 self.assertEqual(123, self.evaluate(t.a)) 80 81 def test_convert_single_saveable_renamed(self): 82 83 class MyTrackable(base.Trackable): 84 85 def __init__(self): 86 self.a = variables.Variable(15.0) 87 88 def _gather_saveables_for_checkpoint(self): 89 return {"a": lambda name: _VarSaveable(self.a, "", name + "-value")} 90 91 t = MyTrackable() 92 converter = saveable_object_util.SaveableCompatibilityConverter(t) 93 94 serialized_tensors = converter._serialize_to_tensors() 95 96 self.assertLen(serialized_tensors, 1) 97 self.assertEqual(15, self.evaluate(serialized_tensors["a-value"])) 98 99 with self.assertRaisesRegex(ValueError, "Could not restore object"): 100 converter._restore_from_tensors({"a": 1.}) 101 102 self.assertEqual(15, self.evaluate(t.a)) 103 converter._restore_from_tensors({"a-value": 456.}) 104 self.assertEqual(456, self.evaluate(t.a)) 105 106 def test_convert_multiple_saveables(self): 107 108 class MyTrackable(base.Trackable): 109 110 def __init__(self): 111 self.a = variables.Variable(15.0) 112 self.b = variables.Variable(20.0) 113 114 def _gather_saveables_for_checkpoint(self): 115 return { 116 "a": lambda name: _VarSaveable(self.a, "", name + "-1"), 117 "b": lambda name: _VarSaveable(self.b, "", name + "-2")} 118 119 t = MyTrackable() 120 converter = saveable_object_util.SaveableCompatibilityConverter(t) 121 122 serialized_tensors = converter._serialize_to_tensors() 123 self.assertLen(serialized_tensors, 2) 124 self.assertEqual(15, self.evaluate(serialized_tensors["a-1"])) 125 self.assertEqual(20, self.evaluate(serialized_tensors["b-2"])) 126 127 with self.assertRaisesRegex(ValueError, "Could not restore object"): 128 converter._restore_from_tensors({"a": 1., "b": 2.}) 129 with self.assertRaisesRegex(ValueError, "Could not restore object"): 130 converter._restore_from_tensors({"b-2": 2.}) 131 132 converter._restore_from_tensors({"a-1": -123., "b-2": -456.}) 133 self.assertEqual(-123, self.evaluate(t.a)) 134 self.assertEqual(-456, self.evaluate(t.b)) 135 136 def test_convert_variables(self): 137 # The method `_gather_saveables_for_checkpoint` allowed the users to pass 138 # Variables instead of Saveables. 139 140 class MyTrackable(base.Trackable): 141 142 def __init__(self): 143 self.a = variables.Variable(25.) 144 self.b = resource_variable_ops.UninitializedVariable( 145 dtype=dtypes.float32) 146 147 def _gather_saveables_for_checkpoint(self): 148 return {"a": self.a, "b": self.b} 149 150 t = MyTrackable() 151 converter = saveable_object_util.SaveableCompatibilityConverter(t) 152 serialized_tensors = converter._serialize_to_tensors() 153 154 self.assertLen(serialized_tensors, 2) 155 self.assertEqual(25, self.evaluate(serialized_tensors["a"].tensor)) 156 self.assertIsNone(serialized_tensors["b"].tensor) 157 158 with self.assertRaisesRegex(ValueError, "Could not restore object"): 159 converter._restore_from_tensors({"a": 5.}) 160 161 converter._restore_from_tensors({"a": 5., "b": 6.}) 162 self.assertEqual(5, self.evaluate(t.a)) 163 self.assertEqual(6, self.evaluate(t.b)) 164 165 166class _MultiSpecSaveable(saveable_object.SaveableObject): 167 168 def __init__(self, obj, name): 169 self.obj = obj 170 specs = [ 171 saveable_object.SaveSpec(obj.a, "", name + "-a"), 172 saveable_object.SaveSpec(obj.b, "", name + "-b")] 173 super(_MultiSpecSaveable, self).__init__(None, specs, name) 174 175 def restore(self, restored_tensors, restored_shapes): 176 del restored_shapes # Unused. 177 self.obj.a.assign(restored_tensors[0]) 178 self.obj.b.assign(restored_tensors[1]) 179 180 181class MultipleSpecConverterTest(test.TestCase): 182 183 def test_multiple_specs_single_saveable(self): 184 185 class MyTrackable(base.Trackable): 186 187 def __init__(self): 188 self.a = variables.Variable(35.0) 189 self.b = variables.Variable(40.0) 190 191 def _gather_saveables_for_checkpoint(self): 192 return {"foo": lambda name: _MultiSpecSaveable(self, name)} 193 194 t = MyTrackable() 195 converter = saveable_object_util.SaveableCompatibilityConverter(t) 196 serialized_tensors = converter._serialize_to_tensors() 197 198 self.assertLen(serialized_tensors, 2) 199 self.assertEqual(35, self.evaluate(serialized_tensors["foo-a"])) 200 self.assertEqual(40, self.evaluate(serialized_tensors["foo-b"])) 201 converter._restore_from_tensors({"foo-a": 5., "foo-b": 6.}) 202 self.assertEqual(5, self.evaluate(t.a)) 203 self.assertEqual(6, self.evaluate(t.b)) 204 205 # Make sure that the legacy saveable name has been applied. 206 self.assertEqual("foo", saveable_compat.get_saveable_name(converter)) 207 208 def test_multiple_specs_multiple_saveables(self): 209 # This is an edge case not handled by the converter. Should raise an error. 210 211 class MyTrackable(base.Trackable): 212 213 def __init__(self): 214 self.a = variables.Variable(45.0) 215 self.b = variables.Variable(50.0) 216 217 def _gather_saveables_for_checkpoint(self): 218 return {"foo": lambda name: _MultiSpecSaveable(self, name), 219 "bar": lambda name: _MultiSpecSaveable(self, name)} 220 221 t = MyTrackable() 222 with self.assertRaises(saveable_compat.CheckpointConversionError): 223 saveable_object_util.SaveableCompatibilityConverter(t) 224 225 226class State(resource.TrackableResource): 227 228 def __init__(self, initial_value): 229 super().__init__() 230 self._initial_value = initial_value 231 self._initialize() 232 233 def _create_resource(self): 234 return gen_resource_variable_ops.var_handle_op( 235 shape=[], 236 dtype=dtypes.float32, 237 shared_name=context.anonymous_name(), 238 name="StateVar", 239 container="") 240 241 def _initialize(self): 242 gen_resource_variable_ops.assign_variable_op(self.resource_handle, 243 self._initial_value) 244 245 def _destroy_resource(self): 246 gen_resource_variable_ops.destroy_resource_op(self.resource_handle, 247 ignore_lookup_error=True) 248 249 def read(self): 250 return gen_resource_variable_ops.read_variable_op(self.resource_handle, 251 dtypes.float32) 252 253 def assign(self, value): 254 gen_resource_variable_ops.assign_variable_op(self.resource_handle, value) 255 256 257class _StateSaveable(saveable_object.SaveableObject): 258 259 def __init__(self, obj, name): 260 spec = saveable_object.SaveSpec(obj.read(), "", name) 261 self.obj = obj 262 super(_StateSaveable, self).__init__(obj, [spec], name) 263 264 def restore(self, restored_tensors, restored_shapes): 265 del restored_shapes # Unused. 266 self.obj.assign(restored_tensors[0]) 267 268 269class SaveableState(State): 270 271 def _gather_saveables_for_checkpoint(self): 272 return { 273 "value": lambda name: _StateSaveable(self, name) 274 } 275 276 277class TrackableState(State): 278 279 def _serialize_to_tensors(self): 280 return { 281 "value": self.read() 282 } 283 284 def _restore_from_tensors(self, restored_tensors): 285 self.assign(restored_tensors["value"]) 286 287 288class SaveableCompatibilityEndToEndTest(test.TestCase): 289 290 def test_checkpoint_comparison(self): 291 saveable_state = SaveableState(5.) 292 trackable_state = TrackableState(10.) 293 294 # First test that SaveableState and TrackableState are equivalent by 295 # saving a checkpoint with both objects and swapping values. 296 297 self.assertEqual(5, self.evaluate(saveable_state.read())) 298 self.assertEqual(10, self.evaluate(trackable_state.read())) 299 300 ckpt_path = os.path.join(self.get_temp_dir(), "ckpt") 301 checkpoint.Checkpoint(a=saveable_state, b=trackable_state).write(ckpt_path) 302 303 status = checkpoint.Checkpoint(b=saveable_state, 304 a=trackable_state).read(ckpt_path) 305 status.assert_consumed() 306 307 self.assertEqual(10, self.evaluate(saveable_state.read())) 308 self.assertEqual(5, self.evaluate(trackable_state.read())) 309 310 # Test that the converted SaveableState is compatible with the checkpoint 311 # saved above. 312 to_convert = SaveableState(0.0) 313 314 converted_saveable_state = ( 315 saveable_object_util.SaveableCompatibilityConverter(to_convert)) 316 317 checkpoint.Checkpoint(a=converted_saveable_state).read( 318 ckpt_path).assert_existing_objects_matched().expect_partial() 319 self.assertEqual(5, self.evaluate(to_convert.read())) 320 321 checkpoint.Checkpoint(b=converted_saveable_state).read( 322 ckpt_path).assert_existing_objects_matched().expect_partial() 323 self.assertEqual(10, self.evaluate(to_convert.read())) 324 325 326if __name__ == "__main__": 327 test.main() 328