• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Contains functionaility for Checkpoint/SavedModel in DTensor."""
16
17import collections
18from typing import Dict, List, Union
19
20from tensorflow.dtensor.python import api
21from tensorflow.dtensor.python import d_variable
22from tensorflow.dtensor.python import gen_dtensor_ops
23from tensorflow.dtensor.python import layout as layout_lib
24from tensorflow.dtensor.python import mesh_util
25from tensorflow.python.eager import context
26from tensorflow.python.framework import errors_impl
27from tensorflow.python.framework import ops
28from tensorflow.python.ops import array_ops
29from tensorflow.python.ops import io_ops
30from tensorflow.python.ops import variables as tf_variables
31from tensorflow.python.util.tf_export import tf_export
32
33
34def sharded_prefix(
35    mesh: layout_lib.Mesh,
36    prefix: List[str],
37    tensor_names: List[str],
38    shape_and_slices: List[str],
39    tensors: List[ops.Tensor],
40):
41  """Generates all sharded prefix in distributed Save.
42
43  DTensor SaveV2 SPMD would generate multiple SaveV2 ops on saving devices,
44  and it is desired to not save with same shard_prefix so that content will
45  not be overwritten.
46
47  (prefix, tensor_names, tensors(with layouts)) and saving mesh collectively
48  defines a unique set of shard prefix that is generated for all the Save ops.
49  Usually, (prefix, tensor_names, shape_and_slices, tensors) should match what
50  is used in save.
51
52  Args:
53    mesh: The mesh that is used in save op. Usually a CPU mesh, and matches what
54      is used in Save op.
55    prefix: The prefix of saving files.
56    tensor_names: a list of tensor names used in save op.
57    shape_and_slices: a list of shape and slice specification used in save op.
58      The only supported value is "" as we don't support distributed saving with
59      slices yet.
60    tensors: a list of tensors used in save op. The order should match
61      tensor_names.
62
63  Returns:
64    A one d string tensor that represents all shard_prefix generated.
65  """
66  layout_str = array_ops.stack(
67      [api.fetch_layout(tensor).to_string() for tensor in tensors], axis=0)
68  layouts = api.pack([layout_str] * mesh.num_local_devices(),
69                     layout_lib.Layout.replicated(mesh, rank=1))
70
71  mesh_str_tensor = api.pack([mesh.to_string()] * mesh.num_local_devices(),
72                             layout_lib.Layout.replicated(mesh, rank=0))
73  return gen_dtensor_ops.d_tensor_sharded_prefix(
74      prefix,
75      tensor_names,
76      shape_and_slices,
77      mesh_str_tensor,
78      layouts=layouts,
79      tensors=tensors)
80
81
82@tf_export('experimental.dtensor.sharded_save', v1=[])
83def sharded_save(
84    mesh: layout_lib.Mesh,
85    file_prefix: Union[str, ops.Tensor],
86    tensor_names: Union[List[str], ops.Tensor],
87    shape_and_slices: Union[List[str], ops.Tensor],
88    tensors: List[Union[ops.Tensor, tf_variables.Variable]],
89):
90  """Saves given named tensor slices in a sharded, multi-client safe fashion.
91
92  The method makes sure the checkpoint directory state is correct in a sharded
93  mutli-client saving. Namely, we place a barrier after SaveV2 to make sure
94  every client has done writing the files. And another one after
95  MergeV2Checkpoints to make sure all Metadata is properly merged.
96
97  Upon existing, the checkpoint is completed and the all directory operations
98  are done.
99
100  Args:
101    mesh: The Mesh that contains the Tensors to save.
102    file_prefix: The prefix of checkpoint.
103    tensor_names: a list of tensor names used in save op.
104    shape_and_slices: a list of shape and slice specification used in save op.
105      The only supported value is "" as we don't support distributed saving with
106      slices yet.
107    tensors: a list of tensors used in save op. The order should match
108      tensor_names.
109
110  Returns:
111    A MergeV2Checkpoints op that merged all Metadata.
112  """
113  with ops.device(api.device_name()):
114    io_ops.save_v2(file_prefix, tensor_names, shape_and_slices, tensors)
115
116  # Query generated shards and generate MergeV2.
117  generated_shards = sharded_prefix(mesh.host_mesh(), [file_prefix],
118                                    tensor_names, shape_and_slices, tensors)
119
120  # Make sure all clients have written the files
121  mesh_util.barrier(mesh.host_mesh(), 'SaveV2')  # pylint: disable=protected-access
122
123  with ops.device(api.device_name()):
124    merge_op = io_ops.MergeV2Checkpoints(
125        checkpoint_prefixes=generated_shards,
126        destination_prefix=file_prefix,
127        delete_old_dirs=True)
128
129  # Make sure first device in first host has finished merge.
130  mesh_util.barrier(mesh.host_mesh(), 'MergeV2Checkpoints')
131
132  return merge_op
133
134
135@tf_export('experimental.dtensor.enable_save_as_bf16', v1=[])
136def enable_save_as_bf16(variables: List[tf_variables.Variable]):
137  """Allows float32 DVariables to be checkpointed and restored as bfloat16.
138
139  The method only affects the DVariable part inside the model and leaves
140  non-DTensor Variables/Tensors untouched.
141
142  Args:
143    variables: A list of tf.Variable to be enabled with bfloat16 save/restore.
144      Only has effect on DTensor Variables as they go through d_variables with
145      DTensor Specific logis.
146  """
147  for v in variables:
148    if isinstance(v, d_variable.DVariable):
149      v.save_as_bf16 = True
150
151
152@tf_export('experimental.dtensor.name_based_restore', v1=[])
153def name_based_restore(
154    mesh: layout_lib.Mesh,
155    checkpoint_prefix: str,
156    name_tensor_dict: Dict[str, Union[ops.Tensor, tf_variables.Variable]],
157):
158  """Restores from checkpoint_prefix to name based DTensors.
159
160  It is required to have already-initialized DTensor variables that have same
161  shape/dtype for the tensors being restored.
162
163  Also, we currently only support a named based restore on a single mesh.
164
165  Args:
166    mesh: The single mesh that all Tensors would be restored to.
167    checkpoint_prefix : The prefix of checkpoint to be restored.
168    name_tensor_dict: A ordered dictionary of tensor_names to a DTensor. The
169      DTensor shape/dtype must match the tensors being saved/restored for now.
170
171  Returns:
172    A dictionary of name to its restored DTensor value.
173  """
174  if not context.executing_eagerly():
175    raise ValueError('name based restore must run eagerly.')
176
177  ordered_name_tensor_dict = name_tensor_dict
178  if not isinstance(name_tensor_dict, collections.OrderedDict):
179    ordered_name_tensor_dict = collections.OrderedDict(name_tensor_dict)
180
181  # Make sure that all tensors are on CPU mesh for now.
182  # This might not be a hard limitation in the future.
183  for name, tensor in ordered_name_tensor_dict.items():
184    try:
185      if api.fetch_layout(tensor).mesh.device_type().upper() != 'CPU':
186        raise ValueError(
187            'Restoring a non CPU Tensor is not supported currently. Offending '
188            'tensor name : {tensor_name}'.format(tensor_name=name))
189    except errors_impl.OpError as op_error:
190      raise ValueError(
191          'Saving/Restoring tensor must be a DTensor') from op_error
192
193  # Now that we have all tensors on CPU mesh, do a DTensorRestoreV2.
194  checkpoint_prefix = api.pack(
195      [checkpoint_prefix] * mesh.num_local_devices(),
196      layout_lib.Layout.replicated(mesh.host_mesh(), rank=0))
197  # Explicitly pack to mesh to avoid implicit small constant extraction, which
198  # does not work larger restores that has lots of names.
199  tensor_names = api.pack(
200      [list(ordered_name_tensor_dict.keys())] * mesh.num_local_devices(),
201      layout_lib.Layout.replicated(mesh.host_mesh(), rank=1))
202  shape_and_slices = api.pack(
203      [[''] * len(ordered_name_tensor_dict)] * mesh.num_local_devices(),
204      layout_lib.Layout.replicated(mesh.host_mesh(), rank=1))
205  # A list of TensorShape representing all shapes for the input tensors.
206  input_shapes = [tensor.shape for tensor in ordered_name_tensor_dict.values()]
207  input_layouts = [
208      api.fetch_layout(tensor).to_string()
209      for tensor in ordered_name_tensor_dict.values()
210  ]
211
212  with ops.device(api.device_name()):
213    restored_cpu_tensors = gen_dtensor_ops.d_tensor_restore_v2(
214        prefix=checkpoint_prefix,
215        tensor_names=tensor_names,
216        shape_and_slices=shape_and_slices,
217        input_shapes=input_shapes,
218        input_layouts=input_layouts,
219        dtypes=[tensor.dtype for tensor in ordered_name_tensor_dict.values()])
220
221  return collections.OrderedDict(
222      zip(ordered_name_tensor_dict.keys(), restored_cpu_tensors))
223
224
225@tf_export('experimental.dtensor.name_based_save', v1=[])
226def name_based_save(mesh: layout_lib.Mesh, checkpoint_prefix: Union[str,
227                                                                    ops.Tensor],
228                    name_tensor_dict: Dict[str, Union[ops.Tensor,
229                                                      tf_variables.Variable]]):
230  """Saves name based Tensor into a Checkpoint.
231
232  The function prepares the input dictionary to the format of a `sharded_save`,
233  so that it can take advantage of DTensor SPMD based distributed save.
234
235  Same as restore, the function only supports saving on the single mesh.
236
237  Args:
238    mesh: The single mesh that all Tensors would be restored to.
239    checkpoint_prefix : The prefix of checkpoint to be restored.
240    name_tensor_dict: A ordered dictionary of tensor_names to a DTensor. The
241      DTensor shape/dtype must match the tensors being saved/restored for now.
242  """
243  if not context.executing_eagerly():
244    raise ValueError('name based save must run eagerly.')
245
246  ordered_name_tensor_dict = name_tensor_dict
247  if not isinstance(name_tensor_dict, collections.OrderedDict):
248    ordered_name_tensor_dict = collections.OrderedDict(name_tensor_dict)
249
250  # Current _dtensor_device() in api.py is the correct way of specifying
251  # DTensor device singletons. The API itself will be eventually be moved to
252  # a public API and provides global singleton in DTensor context.
253  # For now, we just use the current `internal` API and aim at migrating in
254  # one shot later.
255  # TODO(hthu): Provide _dtensor_device() singleton as a public API.
256  # pylint: disable=protected-access
257  checkpoint_prefix = api.pack([checkpoint_prefix] * mesh.num_local_devices(),
258                               layout_lib.Layout.replicated(
259                                   mesh.host_mesh(), rank=0))
260  tensor_names = api.pack(
261      [list(ordered_name_tensor_dict.keys())] * mesh.num_local_devices(),
262      layout_lib.Layout.replicated(mesh.host_mesh(), rank=1))
263
264  sharded_save(
265      mesh,
266      file_prefix=checkpoint_prefix,
267      tensor_names=tensor_names,
268      shape_and_slices=[''] * len(ordered_name_tensor_dict),
269      tensors=list(ordered_name_tensor_dict.values()))
270