1# Copyright 2022 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""DTensor variable and saveable.""" 16 17import contextlib 18import functools 19 20from tensorflow.dtensor.python import api 21from tensorflow.dtensor.python import layout as layout_lib 22from tensorflow.python.eager import context 23from tensorflow.python.eager import def_function 24from tensorflow.python.framework import dtypes 25from tensorflow.python.framework import errors 26from tensorflow.python.framework import ops 27from tensorflow.python.ops import math_ops 28from tensorflow.python.ops import resource_variable_ops 29from tensorflow.python.trackable import base as trackable 30from tensorflow.python.training.saving import saveable_object 31from tensorflow.python.util.tf_export import tf_export 32 33 34class DSaveSpec(saveable_object.SaveSpec): 35 """DTensor SaveSpec that additionaly captures global_shape and layout.""" 36 37 def __init__(self, 38 tensor, 39 slice_spec, 40 name, 41 global_shape, 42 layout, 43 dtype=None, 44 device=None): 45 super().__init__( 46 tensor=tensor, 47 slice_spec=slice_spec, 48 name=name, 49 dtype=dtype, 50 device=device) 51 self.global_shape = global_shape 52 self.layout = layout 53 54 55class _DVariableSaveable(saveable_object.SaveableObject): 56 """Class for defining how to save/restore DTensor variable.""" 57 58 def __init__(self, dvariable, name): 59 with ops.device(dvariable.device): 60 original_layout = api.fetch_layout(dvariable) 61 # Record original layout to allow restore. 62 self._original_layout = original_layout 63 self._dvariable = dvariable 64 65 def pack(tensors, layout): 66 with ops.device(dvariable.device): 67 return api.pack(tensors, layout) 68 69 host_layout = layout_lib.Layout(original_layout.sharding_specs, 70 original_layout.mesh.host_mesh()) 71 72 def get_host_dvariable(): 73 # Copy to host mesh if needed. 74 if original_layout.mesh.device_type().upper() != 'CPU': 75 with ops.device(dvariable.device): 76 host_dvariable = DVariable( 77 api.pack(api.unpack(dvariable.read_value()), host_layout)) 78 else: 79 host_dvariable = dvariable 80 return (math_ops.cast(host_dvariable, dtypes.bfloat16) 81 if self.should_cast(host_dvariable) else host_dvariable) 82 83 num_local_devices = original_layout.mesh.num_local_devices() 84 super(_DVariableSaveable, self).__init__( 85 None, 86 [ 87 DSaveSpec( 88 tensor=get_host_dvariable, 89 slice_spec=pack([''] * num_local_devices, 90 layout_lib.Layout.replicated( 91 original_layout.mesh.host_mesh(), rank=0)), 92 name=pack([name] * num_local_devices, 93 layout_lib.Layout.replicated( 94 original_layout.mesh.host_mesh(), rank=0)), 95 global_shape=dvariable.shape, 96 # Layout is attached as attribute, no need to put it as a 97 # Tensor on DTensorDevice. 98 layout=host_layout.to_string(), 99 dtype=dtypes.bfloat16 100 if self.should_cast(dvariable) else dvariable.dtype, 101 device=dvariable.device) 102 ], 103 name) 104 105 def should_cast(self, v): 106 """Returns True if v has float32 dtype and is intructed to save as bf16. 107 108 Args: 109 v : The variable that determines whether to cast. 110 111 Returns: 112 True if current savable DVariable is instructed to save as bfloat16 and 113 the variable has dtype float32. 114 """ 115 return self._dvariable.save_as_bf16 and v.dtype == dtypes.float32 116 117 def restore(self, restored_tensors, restored_shapes): 118 """Restores the same value into all variables.""" 119 tensor, = restored_tensors 120 121 @def_function.function 122 def _restore(t): 123 with ops.device(self._dvariable.device): 124 return api.copy_to_mesh(t, self._original_layout) 125 126 # This assign establishes connections from restored tensor and tensors 127 # being restored to -- so that restore in SPMD can backtrack the DVariable 128 # and its layout, given that we're using tf.function style restore. 129 # Note that the restored dvaraible is on CPU no matter what as the restoreV2 130 # op must run on CPU. 131 # TODO(b/159035705): Allow restore for Tensor objects as well? 132 # Restore the dvariable back to original layout. 133 if self._original_layout.mesh.device_type().upper() != 'CPU': 134 tensor = _restore(tensor) 135 return self._dvariable.assign( 136 math_ops.cast(tensor, dtype=self._dvariable.dtype) if self._dvariable 137 .save_as_bf16 else tensor) 138 139 140@tf_export('experimental.dtensor.DVariable', v1=[]) 141class DVariable(resource_variable_ops.ResourceVariable): 142 """A replacement for tf.Variable which follows initial value placement. 143 144 The class also handles restore/save operations in DTensor. Note that, 145 DVariable may fall back to normal tf.Variable at this moment if 146 `initial_value` is not a DTensor. 147 """ 148 149 def __init__(self, initial_value, *args, dtype=None, **kwargs): 150 """Overrides tf.Variable to fix VarHandleOp placements.""" 151 # Variables by default use the current device scope for placement. This 152 # wrapper has them follow the initial value's placement instead (which will 153 # be the DTensor device if the initial value has a layout). 154 155 # Pop layout from kwargs since keras make_variable may pass a 'layout' 156 # keyword argument. We need to pop it because we are passing kwargs to 157 # super class constructor. 158 layout = kwargs.pop('layout', None) 159 shape = kwargs.get('shape', None) 160 161 if callable(initial_value): 162 unwrapped = initial_value 163 if issubclass(type(initial_value), functools.partial): 164 unwrapped = initial_value.func 165 166 # If wrapped is a CheckpointInitialValueCallable, this means that 167 # we are creating a Variable during a checkpoint restore. 168 # Thus the restore will happen now through this callable 169 # and we will create the DVariable with the restored dtensor. 170 if issubclass(type(unwrapped), trackable.CheckpointInitialValueCallable): 171 if not shape or not layout: 172 raise ValueError('Expected shape and layout to be not None.') 173 174 # CheckpointInitialValueCallable will call an eager tf.RestoreV2, 175 # which does not have any shape information or layout information 176 # attached. Thus we will do two things to have them correctly specified: 177 # 178 # The default layout scope allows us to correctly specify the output 179 # layout of the tf.RestoreV2 that will be called 180 # 181 # Passing shard_info with the correct shape allows the tf.RestoreV2 182 # ShapeInference to extract the shape. 183 initial_value = api.call_with_layout( 184 initial_value, 185 layout, 186 shard_info=trackable.ShardInfo( 187 shape=shape, offset=[0] * len(shape))) 188 else: 189 initial_value = initial_value() 190 191 # When the initial value came from a Checkpoint restoration, fetch tensor. 192 if isinstance(initial_value, trackable.CheckpointInitialValue): 193 initial_value = initial_value.wrapped_value 194 195 initial_value = ops.convert_to_tensor(initial_value, dtype=dtype) 196 variable_device = initial_value.device 197 self._save_as_bf16 = False 198 # TODO(b/159035705): The following code enables variable creation inside 199 # a tf.function. However, it requires a global dtensor device. 200 # if not variable_device and not tf.executing_eagerly(): 201 # try: 202 # initial_value.op.get_attr("_layout") 203 # except ValueError: 204 # pass 205 # else: 206 # # The initial value is a DTensor, but because the DTensor device is 207 # # only active during eager execution at the moment we need to 208 # # translate that into a placement for the eager VarHandleOp. 209 # variable_device = _dtensor_device().name 210 with ops.device(variable_device): 211 # If initial tensor assigned to DVariable is DTensor, record the layout of 212 # the resource so that this can be queried. 213 self.layout = None 214 if context.executing_eagerly(): 215 try: 216 self.layout = api.fetch_layout(initial_value) 217 except (errors.InvalidArgumentError, errors.NotFoundError): 218 # For Non-DTensor tensors, fetch layout results in expected 219 # InvalidArgument or NotFoundError depending on whether the API 220 # is called within DTensor device scope or not. 221 self.layout = None 222 pass 223 mesh = self.layout.mesh if self.layout else None 224 with api.run_on(mesh) if mesh else contextlib.nullcontext(): 225 super(DVariable, self).__init__( 226 initial_value, *args, dtype=dtype, **kwargs) 227 228 @property 229 def save_as_bf16(self): 230 return self._save_as_bf16 231 232 @save_as_bf16.setter 233 def save_as_bf16(self, save_as_bf16): 234 """Enables saving float32 as bfloat16.""" 235 self._save_as_bf16 = save_as_bf16 and self.dtype == dtypes.float32 236 237 def _gather_saveables_for_checkpoint(self): 238 return { 239 trackable.VARIABLE_VALUE_KEY: 240 functools.partial(_DVariableSaveable, self) 241 } 242