1"""Utility for creating multiple dependencies with synchronized save/restore.""" 2# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# ============================================================================== 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import functools 21 22from tensorflow.python.ops import control_flow_ops 23from tensorflow.python.training import saver as saver_lib 24from tensorflow.python.training.tracking import base as trackable 25 26 27class _CallbackSaveable(saver_lib.BaseSaverBuilder.SaveableObject): 28 """Wraps save and restore callbacks as a `SaveableObject`.""" 29 30 def __init__(self, name, dtype, save_callback, restore_callback): 31 self._restore_callback = restore_callback 32 spec = saver_lib.BaseSaverBuilder.SaveSpec( 33 tensor=save_callback, 34 slice_spec="", 35 name=name, 36 dtype=dtype) 37 super(_CallbackSaveable, self).__init__( 38 save_callback, [spec], name) 39 40 def restore(self, restored_tensors, restored_shapes): 41 """Restore the same value into both variables.""" 42 tensor, = restored_tensors 43 return self._restore_callback(tensor) 44 45 46class _SplitDependency(trackable.Trackable): 47 """Looks like a regular variable while synchronizing save/restores.""" 48 49 def __init__(self, save_buffer, restore_buffer, name, dtype, num_components, 50 fill_save_buffer_fn, consume_restore_buffer_fn): 51 self._save_buffer = save_buffer 52 self._restore_buffer = restore_buffer 53 self._name = name 54 self._dtype = dtype 55 self._num_components = num_components 56 self._fill_save_buffer_fn = fill_save_buffer_fn 57 self._consume_restore_buffer_fn = consume_restore_buffer_fn 58 59 def _save(self): 60 """Pull from the shared buffer, populating it if necessary.""" 61 if self._name not in self._save_buffer: 62 if self._save_buffer: 63 raise AssertionError( 64 ("Split dependency %s (%s) unsynchronized. Split dependencies must " 65 "be saved together.") % (self._name, self)) 66 self._fill_save_buffer_fn(self._save_buffer) 67 return self._save_buffer.pop(self._name) 68 69 def _restore(self, tensor): 70 """Push into the shared buffer, flushing it if necessary.""" 71 if self._name in self._restore_buffer: 72 raise AssertionError( 73 ("Split dependency %s (%s) unsynchronized. Split dependencies must " 74 "be restored together.") % (self._name, self)) 75 self._restore_buffer[self._name] = tensor 76 if len(self._restore_buffer) == self._num_components: 77 op = self._consume_restore_buffer_fn(self._restore_buffer) 78 self._restore_buffer.clear() 79 return op 80 else: 81 return control_flow_ops.no_op() 82 83 def _gather_saveables_for_checkpoint(self): 84 """Looks to Trackable like a regular variable.""" 85 return { 86 trackable.VARIABLE_VALUE_KEY: 87 functools.partial(_CallbackSaveable, 88 dtype=self._dtype, 89 save_callback=self._save, 90 restore_callback=self._restore) 91 } 92 93 94def split_dependency(component_names, component_dtypes, 95 fill_save_buffer_fn, consume_restore_buffer_fn): 96 """Creates multiple dependencies with a synchronized save/restore. 97 98 Useful when a single op produces `Tensor`s which should each be saved under 99 different objects, or when `Tensor`s saved with many different objects need to 100 be restored together as inputs to a single op (i.e. an object which uses a 101 single fused op may be swapped out for a subgraph of objects, and these two 102 programs are checkpoint compatible). 103 104 Args: 105 component_names: A sequence of names for the split 106 dependencies. `fill_save_buffer_fn` must add these keys to the dictionary 107 it is passed, and `consume_restore_buffer_fn` will receive a dictionary 108 with these keys. 109 component_dtypes: Data types for the `Tensor`s being saved and restored, a 110 sequence corresponding to `component_names`. 111 fill_save_buffer_fn: A function which takes an empty dictionary as an 112 argument and adds `Tensor`s with `component_names` as keys. These 113 `Tensor`s will be saved as if they were individual variables. 114 consume_restore_buffer_fn: A function which takes a dictionary with 115 `component_names` as keys mapping to restored individual `Tensor`s and 116 returns a restore op (or if executing eagerly, runs the restoration and 117 may return `None`). 118 119 Returns: 120 A dictionary mapping from names to Trackable objects. If one is 121 reachable from an object as a dependency, the others should be too; adding 122 dependencies on some but not all of the objects will result in errors. 123 """ 124 save_buffer = {} 125 restore_buffer = {} 126 split_dependencies = {} 127 for name, dtype in zip(component_names, component_dtypes): 128 split_dependencies[name] = _SplitDependency( 129 save_buffer=save_buffer, 130 restore_buffer=restore_buffer, 131 name=name, 132 dtype=dtype, 133 num_components=len(component_names), 134 fill_save_buffer_fn=fill_save_buffer_fn, 135 consume_restore_buffer_fn=consume_restore_buffer_fn) 136 return split_dependencies 137