README.md
1# Registrations
2
3To configure SaveModel or checkpointing beyond the basic saving and loading
4steps [documentation TBD], registration is required.
5
6Currently, only TensorFlow-internal
7registrations are allowed, and must be added to the allowlist.
8
9* `tensorflow.python.saved_model.registration.register_tf_serializable`
10 * Allowlist: tf_serializable_allowlist.txt
11* `tensorflow.python.saved_model.registration.register_tf_checkpoint_saver`
12 * Allowlist: tf_checkpoint_saver_allowlist.txt
13
14[TOC]
15
16## SavedModel serializable registration
17
18Custom objects must be registered in order to get the correct deserialization
19method when loading. The registered name of the class is saved to the proto.
20
21Keras already has a similar mechanism for registering serializables:
22[`tf.keras.utils.register_keras_serializable(package, name)`](https://www.tensorflow.org/api_docs/python/tf/keras/utils/register_keras_serializable).
23This has been imported to core TensorFlow:
24
25```python
26registration.register_serializable(package, name)
27registration.register_tf_serializable(name) # If TensorFlow-internal.
28```
29
30* package: The package that this class belongs to.
31* name: The name of this class. The registered name that is saved in the proto
32 is "{package}.{name}" (for TensorFlow internal registration, the package
33 name is `tf`)
34
35## Checkpoint saver registration
36
37If `Trackables` share state or require complicated coordination between multiple
38`Trackables` (e.g. `DTensor`), then users may register a save and restore
39functions for these objects.
40
41```
42tf.saved_model.register_checkpoint_saver(
43 predicate, save_fn=None, restore_fn=None):
44```
45
46* `predicate`: A function that returns `True` if a `Trackable` object should
47 be saved using the registered `save_fn` or `restore_fn`.
48* `save_fn`: A python function or `tf.function` or `None`. If `None`, run the
49 default saving process which calls `Trackable._serialize_to_tensors`.
50* `restore_fn`: A `tf.function` or `None`. If `None`, run the default
51 restoring process which calls `Trackable._restore_from_tensors`.
52
53**`save_fn` details**
54
55```
56@tf.function # optional decorator
57def save_fn(trackables, file_prefix): -> List[shard filenames]
58```
59
60* `trackables`: A dictionary of `{object_prefix: Trackable}`. The
61 object_prefix can be used as the object names, and uniquely identify each
62 `Trackable`. `trackables` is the filtered set of trackables that pass the
63 predicate.
64* `file_prefix`: A string or string tensor of the checkpoint prefix.
65* `shard filenames`: A list of filenames written using `io_ops.save_v2`, which
66 will be merged into the checkpoint data files. These should be prefixed by
67 `file_prefix`.
68
69This function can be a python function, in which case shard filenames can be an
70empty list (if the values are written without the `SaveV2` op).
71
72If this function is a `tf.function`, then the shards must be written using the
73SaveV2 op. This guarantees the checkpoint format is compatible with existing
74checkpoint readers and managers.
75
76**`restore_fn` details**
77
78```
79@tf.function # required decorator
80def restore_fn(trackables, file_prefix): -> None
81```
82
83A `tf.function` with the spec:
84
85* `trackables`: A dictionary of `{object_prefix: Trackable}`. The
86 `object_prefix` can be used as the object name, and uniquely identifies each
87 Trackable. The Trackable objects are the filtered results of the registered
88 predicate.
89* `file_prefix`: A string or string tensor of the checkpoint prefix.
90
91**Why are restore functions required to be a `tf.function`?** The short answer
92is, the SavedModel format must maintain the invariant that SavedModel packages
93can be used for inference on any platform and language. SavedModel inference
94needs to be able to restore checkpointed values, so the restore function must be
95directly encoded into the SavedModel in the Graph. We also have security
96measures over FunctionDef and GraphDef, so users can check that the SavedModel
97will not run arbitrary code (a feature of `saved_model_cli`).
98
99## Example
100
101Below shows a `Stack` module that contains multiple `Parts` (a subclass of
102`tf.Variable`). When a `Stack` is saved to a checkpoint, the `Parts` are stacked
103together and a single entry in the checkpoint is created. The checkpoint value
104is restored to all of the `Parts` in the `Stack`.
105
106```
107@registration.register_serializable()
108class Part(resource_variable_ops.ResourceVariable):
109
110 def __init__(self, value):
111 self._init_from_args(value)
112
113 @classmethod
114 def _deserialize_from_proto(cls, **kwargs):
115 return cls([0, 0])
116
117
118@registration.register_serializable()
119class Stack(tracking.AutoTrackable):
120
121 def __init__(self, parts=None):
122 self.parts = parts
123
124 @def_function.function(input_signature=[])
125 def value(self):
126 return array_ops.stack(self.parts)
127
128
129def get_tensor_slices(trackables):
130 tensor_names = []
131 shapes_and_slices = []
132 tensors = []
133 restored_trackables = []
134 for obj_prefix, obj in trackables.items():
135 if isinstance(obj, Part):
136 continue # only save stacks
137 tensor_names.append(obj_prefix + "/value")
138 shapes_and_slices.append("")
139 x = obj.value()
140 with ops.device("/device:CPU:0"):
141 tensors.append(array_ops.identity(x))
142 restored_trackables.append(obj)
143
144 return tensor_names, shapes_and_slices, tensors, restored_trackables
145
146
147def save_stacks_and_parts(trackables, file_prefix):
148 """Save stack and part objects to a checkpoint shard."""
149 tensor_names, shapes_and_slices, tensors, _ = get_tensor_slices(trackables)
150 io_ops.save_v2(file_prefix, tensor_names, shapes_and_slices, tensors)
151 return file_prefix
152
153
154def restore_stacks_and_parts(trackables, merged_prefix):
155 tensor_names, shapes_and_slices, tensors, restored_trackables = (
156 get_tensor_slices(trackables))
157 dtypes = [t.dtype for t in tensors]
158 restored_tensors = io_ops.restore_v2(merged_prefix, tensor_names,
159 shapes_and_slices, dtypes)
160 for trackable, restored_tensor in zip(restored_trackables, restored_tensors):
161 expected_shape = trackable.value().get_shape()
162 restored_tensor = array_ops.reshape(restored_tensor, expected_shape)
163 parts = array_ops.unstack(restored_tensor)
164 for part, restored_part in zip(trackable.parts, parts):
165 part.assign(restored_part)
166
167
168registration.register_checkpoint_saver(
169 name="stacks",
170 predicate=lambda x: isinstance(x, (Stack, Part)),
171 save_fn=save_stacks_and_parts,
172 restore_fn=restore_stacks_and_parts)
173```
174