1# Copyright 2022 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Contains functionaility for Checkpoint/SavedModel in DTensor.""" 16 17import collections 18from typing import Dict, List, Union 19 20from tensorflow.dtensor.python import api 21from tensorflow.dtensor.python import d_variable 22from tensorflow.dtensor.python import gen_dtensor_ops 23from tensorflow.dtensor.python import layout as layout_lib 24from tensorflow.dtensor.python import mesh_util 25from tensorflow.python.eager import context 26from tensorflow.python.framework import errors_impl 27from tensorflow.python.framework import ops 28from tensorflow.python.ops import array_ops 29from tensorflow.python.ops import io_ops 30from tensorflow.python.ops import variables as tf_variables 31from tensorflow.python.util.tf_export import tf_export 32 33 34def sharded_prefix( 35 mesh: layout_lib.Mesh, 36 prefix: List[str], 37 tensor_names: List[str], 38 shape_and_slices: List[str], 39 tensors: List[ops.Tensor], 40): 41 """Generates all sharded prefix in distributed Save. 42 43 DTensor SaveV2 SPMD would generate multiple SaveV2 ops on saving devices, 44 and it is desired to not save with same shard_prefix so that content will 45 not be overwritten. 46 47 (prefix, tensor_names, tensors(with layouts)) and saving mesh collectively 48 defines a unique set of shard prefix that is generated for all the Save ops. 49 Usually, (prefix, tensor_names, shape_and_slices, tensors) should match what 50 is used in save. 51 52 Args: 53 mesh: The mesh that is used in save op. Usually a CPU mesh, and matches what 54 is used in Save op. 55 prefix: The prefix of saving files. 56 tensor_names: a list of tensor names used in save op. 57 shape_and_slices: a list of shape and slice specification used in save op. 58 The only supported value is "" as we don't support distributed saving with 59 slices yet. 60 tensors: a list of tensors used in save op. The order should match 61 tensor_names. 62 63 Returns: 64 A one d string tensor that represents all shard_prefix generated. 65 """ 66 layout_str = array_ops.stack( 67 [api.fetch_layout(tensor).to_string() for tensor in tensors], axis=0) 68 layouts = api.pack([layout_str] * mesh.num_local_devices(), 69 layout_lib.Layout.replicated(mesh, rank=1)) 70 71 mesh_str_tensor = api.pack([mesh.to_string()] * mesh.num_local_devices(), 72 layout_lib.Layout.replicated(mesh, rank=0)) 73 return gen_dtensor_ops.d_tensor_sharded_prefix( 74 prefix, 75 tensor_names, 76 shape_and_slices, 77 mesh_str_tensor, 78 layouts=layouts, 79 tensors=tensors) 80 81 82@tf_export('experimental.dtensor.sharded_save', v1=[]) 83def sharded_save( 84 mesh: layout_lib.Mesh, 85 file_prefix: Union[str, ops.Tensor], 86 tensor_names: Union[List[str], ops.Tensor], 87 shape_and_slices: Union[List[str], ops.Tensor], 88 tensors: List[Union[ops.Tensor, tf_variables.Variable]], 89): 90 """Saves given named tensor slices in a sharded, multi-client safe fashion. 91 92 The method makes sure the checkpoint directory state is correct in a sharded 93 mutli-client saving. Namely, we place a barrier after SaveV2 to make sure 94 every client has done writing the files. And another one after 95 MergeV2Checkpoints to make sure all Metadata is properly merged. 96 97 Upon existing, the checkpoint is completed and the all directory operations 98 are done. 99 100 Args: 101 mesh: The Mesh that contains the Tensors to save. 102 file_prefix: The prefix of checkpoint. 103 tensor_names: a list of tensor names used in save op. 104 shape_and_slices: a list of shape and slice specification used in save op. 105 The only supported value is "" as we don't support distributed saving with 106 slices yet. 107 tensors: a list of tensors used in save op. The order should match 108 tensor_names. 109 110 Returns: 111 A MergeV2Checkpoints op that merged all Metadata. 112 """ 113 with ops.device(api.device_name()): 114 io_ops.save_v2(file_prefix, tensor_names, shape_and_slices, tensors) 115 116 # Query generated shards and generate MergeV2. 117 generated_shards = sharded_prefix(mesh.host_mesh(), [file_prefix], 118 tensor_names, shape_and_slices, tensors) 119 120 # Make sure all clients have written the files 121 mesh_util.barrier(mesh.host_mesh(), 'SaveV2') # pylint: disable=protected-access 122 123 with ops.device(api.device_name()): 124 merge_op = io_ops.MergeV2Checkpoints( 125 checkpoint_prefixes=generated_shards, 126 destination_prefix=file_prefix, 127 delete_old_dirs=True) 128 129 # Make sure first device in first host has finished merge. 130 mesh_util.barrier(mesh.host_mesh(), 'MergeV2Checkpoints') 131 132 return merge_op 133 134 135@tf_export('experimental.dtensor.enable_save_as_bf16', v1=[]) 136def enable_save_as_bf16(variables: List[tf_variables.Variable]): 137 """Allows float32 DVariables to be checkpointed and restored as bfloat16. 138 139 The method only affects the DVariable part inside the model and leaves 140 non-DTensor Variables/Tensors untouched. 141 142 Args: 143 variables: A list of tf.Variable to be enabled with bfloat16 save/restore. 144 Only has effect on DTensor Variables as they go through d_variables with 145 DTensor Specific logis. 146 """ 147 for v in variables: 148 if isinstance(v, d_variable.DVariable): 149 v.save_as_bf16 = True 150 151 152@tf_export('experimental.dtensor.name_based_restore', v1=[]) 153def name_based_restore( 154 mesh: layout_lib.Mesh, 155 checkpoint_prefix: str, 156 name_tensor_dict: Dict[str, Union[ops.Tensor, tf_variables.Variable]], 157): 158 """Restores from checkpoint_prefix to name based DTensors. 159 160 It is required to have already-initialized DTensor variables that have same 161 shape/dtype for the tensors being restored. 162 163 Also, we currently only support a named based restore on a single mesh. 164 165 Args: 166 mesh: The single mesh that all Tensors would be restored to. 167 checkpoint_prefix : The prefix of checkpoint to be restored. 168 name_tensor_dict: A ordered dictionary of tensor_names to a DTensor. The 169 DTensor shape/dtype must match the tensors being saved/restored for now. 170 171 Returns: 172 A dictionary of name to its restored DTensor value. 173 """ 174 if not context.executing_eagerly(): 175 raise ValueError('name based restore must run eagerly.') 176 177 ordered_name_tensor_dict = name_tensor_dict 178 if not isinstance(name_tensor_dict, collections.OrderedDict): 179 ordered_name_tensor_dict = collections.OrderedDict(name_tensor_dict) 180 181 # Make sure that all tensors are on CPU mesh for now. 182 # This might not be a hard limitation in the future. 183 for name, tensor in ordered_name_tensor_dict.items(): 184 try: 185 if api.fetch_layout(tensor).mesh.device_type().upper() != 'CPU': 186 raise ValueError( 187 'Restoring a non CPU Tensor is not supported currently. Offending ' 188 'tensor name : {tensor_name}'.format(tensor_name=name)) 189 except errors_impl.OpError as op_error: 190 raise ValueError( 191 'Saving/Restoring tensor must be a DTensor') from op_error 192 193 # Now that we have all tensors on CPU mesh, do a DTensorRestoreV2. 194 checkpoint_prefix = api.pack( 195 [checkpoint_prefix] * mesh.num_local_devices(), 196 layout_lib.Layout.replicated(mesh.host_mesh(), rank=0)) 197 # Explicitly pack to mesh to avoid implicit small constant extraction, which 198 # does not work larger restores that has lots of names. 199 tensor_names = api.pack( 200 [list(ordered_name_tensor_dict.keys())] * mesh.num_local_devices(), 201 layout_lib.Layout.replicated(mesh.host_mesh(), rank=1)) 202 shape_and_slices = api.pack( 203 [[''] * len(ordered_name_tensor_dict)] * mesh.num_local_devices(), 204 layout_lib.Layout.replicated(mesh.host_mesh(), rank=1)) 205 # A list of TensorShape representing all shapes for the input tensors. 206 input_shapes = [tensor.shape for tensor in ordered_name_tensor_dict.values()] 207 input_layouts = [ 208 api.fetch_layout(tensor).to_string() 209 for tensor in ordered_name_tensor_dict.values() 210 ] 211 212 with ops.device(api.device_name()): 213 restored_cpu_tensors = gen_dtensor_ops.d_tensor_restore_v2( 214 prefix=checkpoint_prefix, 215 tensor_names=tensor_names, 216 shape_and_slices=shape_and_slices, 217 input_shapes=input_shapes, 218 input_layouts=input_layouts, 219 dtypes=[tensor.dtype for tensor in ordered_name_tensor_dict.values()]) 220 221 return collections.OrderedDict( 222 zip(ordered_name_tensor_dict.keys(), restored_cpu_tensors)) 223 224 225@tf_export('experimental.dtensor.name_based_save', v1=[]) 226def name_based_save(mesh: layout_lib.Mesh, checkpoint_prefix: Union[str, 227 ops.Tensor], 228 name_tensor_dict: Dict[str, Union[ops.Tensor, 229 tf_variables.Variable]]): 230 """Saves name based Tensor into a Checkpoint. 231 232 The function prepares the input dictionary to the format of a `sharded_save`, 233 so that it can take advantage of DTensor SPMD based distributed save. 234 235 Same as restore, the function only supports saving on the single mesh. 236 237 Args: 238 mesh: The single mesh that all Tensors would be restored to. 239 checkpoint_prefix : The prefix of checkpoint to be restored. 240 name_tensor_dict: A ordered dictionary of tensor_names to a DTensor. The 241 DTensor shape/dtype must match the tensors being saved/restored for now. 242 """ 243 if not context.executing_eagerly(): 244 raise ValueError('name based save must run eagerly.') 245 246 ordered_name_tensor_dict = name_tensor_dict 247 if not isinstance(name_tensor_dict, collections.OrderedDict): 248 ordered_name_tensor_dict = collections.OrderedDict(name_tensor_dict) 249 250 # Current _dtensor_device() in api.py is the correct way of specifying 251 # DTensor device singletons. The API itself will be eventually be moved to 252 # a public API and provides global singleton in DTensor context. 253 # For now, we just use the current `internal` API and aim at migrating in 254 # one shot later. 255 # TODO(hthu): Provide _dtensor_device() singleton as a public API. 256 # pylint: disable=protected-access 257 checkpoint_prefix = api.pack([checkpoint_prefix] * mesh.num_local_devices(), 258 layout_lib.Layout.replicated( 259 mesh.host_mesh(), rank=0)) 260 tensor_names = api.pack( 261 [list(ordered_name_tensor_dict.keys())] * mesh.num_local_devices(), 262 layout_lib.Layout.replicated(mesh.host_mesh(), rank=1)) 263 264 sharded_save( 265 mesh, 266 file_prefix=checkpoint_prefix, 267 tensor_names=tensor_names, 268 shape_and_slices=[''] * len(ordered_name_tensor_dict), 269 tensors=list(ordered_name_tensor_dict.values())) 270