• Home
Name Date Size #Lines LOC

..--

BUILDD04-Jul-20252 KiB8979

README.mdD04-Jul-20256.3 KiB174132

__init__.pyD04-Jul-20252.4 KiB5027

registration.pyD04-Jul-202515 KiB383287

registration_saving_test.pyD04-Jul-202514.4 KiB419310

registration_test.pyD04-Jul-20257.5 KiB221162

test_util.pyD04-Jul-20251 KiB266

tf_checkpoint_saver_allowlist.txtD04-Jul-2025111 53

tf_registration_test.pyD04-Jul-20253.9 KiB10670

tf_serializable_allowlist.txtD04-Jul-2025172 64

README.md

1# Registrations
2
3To configure SaveModel or checkpointing beyond the basic saving and loading
4steps [documentation TBD], registration is required.
5
6Currently, only TensorFlow-internal
7registrations are allowed, and must be added to the allowlist.
8
9* `tensorflow.python.saved_model.registration.register_tf_serializable`
10  * Allowlist: tf_serializable_allowlist.txt
11*  `tensorflow.python.saved_model.registration.register_tf_checkpoint_saver`
12  * Allowlist: tf_checkpoint_saver_allowlist.txt
13
14[TOC]
15
16## SavedModel serializable registration
17
18Custom objects must be registered in order to get the correct deserialization
19method when loading. The registered name of the class is saved to the proto.
20
21Keras already has a similar mechanism for registering serializables:
22[`tf.keras.utils.register_keras_serializable(package, name)`](https://www.tensorflow.org/api_docs/python/tf/keras/utils/register_keras_serializable).
23This has been imported to core TensorFlow:
24
25```python
26registration.register_serializable(package, name)
27registration.register_tf_serializable(name)  # If TensorFlow-internal.
28```
29
30*   package: The package that this class belongs to.
31*   name: The name of this class. The registered name that is saved in the proto
32    is "{package}.{name}" (for TensorFlow internal registration, the package
33    name is `tf`)
34
35## Checkpoint saver registration
36
37If `Trackables` share state or require complicated coordination between multiple
38`Trackables` (e.g. `DTensor`), then users may register a save and restore
39functions for these objects.
40
41```
42tf.saved_model.register_checkpoint_saver(
43    predicate, save_fn=None, restore_fn=None):
44```
45
46*   `predicate`: A function that returns `True` if a `Trackable` object should
47    be saved using the registered `save_fn` or `restore_fn`.
48*   `save_fn`: A python function or `tf.function` or `None`. If `None`, run the
49    default saving process which calls `Trackable._serialize_to_tensors`.
50*   `restore_fn`: A `tf.function` or `None`. If `None`, run the default
51    restoring process which calls `Trackable._restore_from_tensors`.
52
53**`save_fn` details**
54
55```
56@tf.function  # optional decorator
57def save_fn(trackables, file_prefix): -> List[shard filenames]
58```
59
60*   `trackables`: A dictionary of `{object_prefix: Trackable}`. The
61    object_prefix can be used as the object names, and uniquely identify each
62    `Trackable`. `trackables` is the filtered set of trackables that pass the
63    predicate.
64*   `file_prefix`: A string or string tensor of the checkpoint prefix.
65*   `shard filenames`: A list of filenames written using `io_ops.save_v2`, which
66    will be merged into the checkpoint data files. These should be prefixed by
67    `file_prefix`.
68
69This function can be a python function, in which case shard filenames can be an
70empty list (if the values are written without the `SaveV2` op).
71
72If this function is a `tf.function`, then the shards must be written using the
73SaveV2 op. This guarantees the checkpoint format is compatible with existing
74checkpoint readers and managers.
75
76**`restore_fn` details**
77
78```
79@tf.function  # required decorator
80def restore_fn(trackables, file_prefix): -> None
81```
82
83A `tf.function` with the spec:
84
85*   `trackables`: A dictionary of `{object_prefix: Trackable}`. The
86    `object_prefix` can be used as the object name, and uniquely identifies each
87    Trackable. The Trackable objects are the filtered results of the registered
88    predicate.
89*   `file_prefix`: A string or string tensor of the checkpoint prefix.
90
91**Why are restore functions required to be a `tf.function`?** The short answer
92is, the SavedModel format must maintain the invariant that SavedModel packages
93can be used for inference on any platform and language. SavedModel inference
94needs to be able to restore checkpointed values, so the restore function must be
95directly encoded into the SavedModel in the Graph. We also have security
96measures over FunctionDef and GraphDef, so users can check that the SavedModel
97will not run arbitrary code (a feature of `saved_model_cli`).
98
99## Example
100
101Below shows a `Stack` module that contains multiple `Parts` (a subclass of
102`tf.Variable`). When a `Stack` is saved to a checkpoint, the `Parts` are stacked
103together and a single entry in the checkpoint is created. The checkpoint value
104is restored to all of the `Parts` in the `Stack`.
105
106```
107@registration.register_serializable()
108class Part(resource_variable_ops.ResourceVariable):
109
110  def __init__(self, value):
111    self._init_from_args(value)
112
113  @classmethod
114  def _deserialize_from_proto(cls, **kwargs):
115    return cls([0, 0])
116
117
118@registration.register_serializable()
119class Stack(tracking.AutoTrackable):
120
121  def __init__(self, parts=None):
122    self.parts = parts
123
124  @def_function.function(input_signature=[])
125  def value(self):
126    return array_ops.stack(self.parts)
127
128
129def get_tensor_slices(trackables):
130  tensor_names = []
131  shapes_and_slices = []
132  tensors = []
133  restored_trackables = []
134  for obj_prefix, obj in trackables.items():
135    if isinstance(obj, Part):
136      continue  # only save stacks
137    tensor_names.append(obj_prefix + "/value")
138    shapes_and_slices.append("")
139    x = obj.value()
140    with ops.device("/device:CPU:0"):
141      tensors.append(array_ops.identity(x))
142    restored_trackables.append(obj)
143
144  return tensor_names, shapes_and_slices, tensors, restored_trackables
145
146
147def save_stacks_and_parts(trackables, file_prefix):
148  """Save stack and part objects to a checkpoint shard."""
149  tensor_names, shapes_and_slices, tensors, _ = get_tensor_slices(trackables)
150  io_ops.save_v2(file_prefix, tensor_names, shapes_and_slices, tensors)
151  return file_prefix
152
153
154def restore_stacks_and_parts(trackables, merged_prefix):
155  tensor_names, shapes_and_slices, tensors, restored_trackables = (
156      get_tensor_slices(trackables))
157  dtypes = [t.dtype for t in tensors]
158  restored_tensors = io_ops.restore_v2(merged_prefix, tensor_names,
159                                       shapes_and_slices, dtypes)
160  for trackable, restored_tensor in zip(restored_trackables, restored_tensors):
161    expected_shape = trackable.value().get_shape()
162    restored_tensor = array_ops.reshape(restored_tensor, expected_shape)
163    parts = array_ops.unstack(restored_tensor)
164    for part, restored_part in zip(trackable.parts, parts):
165      part.assign(restored_part)
166
167
168registration.register_checkpoint_saver(
169    name="stacks",
170    predicate=lambda x: isinstance(x, (Stack, Part)),
171    save_fn=save_stacks_and_parts,
172    restore_fn=restore_stacks_and_parts)
173```
174