• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""DTensor variable and saveable."""
16
17import contextlib
18import functools
19
20from tensorflow.dtensor.python import api
21from tensorflow.dtensor.python import layout as layout_lib
22from tensorflow.python.eager import context
23from tensorflow.python.eager import def_function
24from tensorflow.python.framework import dtypes
25from tensorflow.python.framework import errors
26from tensorflow.python.framework import ops
27from tensorflow.python.ops import math_ops
28from tensorflow.python.ops import resource_variable_ops
29from tensorflow.python.trackable import base as trackable
30from tensorflow.python.training.saving import saveable_object
31from tensorflow.python.util.tf_export import tf_export
32
33
34class DSaveSpec(saveable_object.SaveSpec):
35  """DTensor SaveSpec that additionaly captures global_shape and layout."""
36
37  def __init__(self,
38               tensor,
39               slice_spec,
40               name,
41               global_shape,
42               layout,
43               dtype=None,
44               device=None):
45    super().__init__(
46        tensor=tensor,
47        slice_spec=slice_spec,
48        name=name,
49        dtype=dtype,
50        device=device)
51    self.global_shape = global_shape
52    self.layout = layout
53
54
55class _DVariableSaveable(saveable_object.SaveableObject):
56  """Class for defining how to save/restore DTensor variable."""
57
58  def __init__(self, dvariable, name):
59    with ops.device(dvariable.device):
60      original_layout = api.fetch_layout(dvariable)
61    # Record original layout to allow restore.
62    self._original_layout = original_layout
63    self._dvariable = dvariable
64
65    def pack(tensors, layout):
66      with ops.device(dvariable.device):
67        return api.pack(tensors, layout)
68
69    host_layout = layout_lib.Layout(original_layout.sharding_specs,
70                                    original_layout.mesh.host_mesh())
71
72    def get_host_dvariable():
73      # Copy to host mesh if needed.
74      if original_layout.mesh.device_type().upper() != 'CPU':
75        with ops.device(dvariable.device):
76          host_dvariable = DVariable(
77              api.pack(api.unpack(dvariable.read_value()), host_layout))
78      else:
79        host_dvariable = dvariable
80      return (math_ops.cast(host_dvariable, dtypes.bfloat16)
81              if self.should_cast(host_dvariable) else host_dvariable)
82
83    num_local_devices = original_layout.mesh.num_local_devices()
84    super(_DVariableSaveable, self).__init__(
85        None,
86        [
87            DSaveSpec(
88                tensor=get_host_dvariable,
89                slice_spec=pack([''] * num_local_devices,
90                                layout_lib.Layout.replicated(
91                                    original_layout.mesh.host_mesh(), rank=0)),
92                name=pack([name] * num_local_devices,
93                          layout_lib.Layout.replicated(
94                              original_layout.mesh.host_mesh(), rank=0)),
95                global_shape=dvariable.shape,
96                # Layout is attached as attribute, no need to put it as a
97                # Tensor on DTensorDevice.
98                layout=host_layout.to_string(),
99                dtype=dtypes.bfloat16
100                if self.should_cast(dvariable) else dvariable.dtype,
101                device=dvariable.device)
102        ],
103        name)
104
105  def should_cast(self, v):
106    """Returns True if v has float32 dtype and is intructed to save as bf16.
107
108    Args:
109      v : The variable that determines whether to cast.
110
111    Returns:
112      True if current savable DVariable is instructed to save as bfloat16 and
113        the variable has dtype float32.
114    """
115    return self._dvariable.save_as_bf16 and v.dtype == dtypes.float32
116
117  def restore(self, restored_tensors, restored_shapes):
118    """Restores the same value into all variables."""
119    tensor, = restored_tensors
120
121    @def_function.function
122    def _restore(t):
123      with ops.device(self._dvariable.device):
124        return api.copy_to_mesh(t, self._original_layout)
125
126    # This assign establishes connections from restored tensor and tensors
127    # being restored to -- so that restore in SPMD can backtrack the DVariable
128    # and its layout, given that we're using tf.function style restore.
129    # Note that the restored dvaraible is on CPU no matter what as the restoreV2
130    # op must run on CPU.
131    # TODO(b/159035705): Allow restore for Tensor objects as well?
132    # Restore the dvariable back to original layout.
133    if self._original_layout.mesh.device_type().upper() != 'CPU':
134      tensor = _restore(tensor)
135    return self._dvariable.assign(
136        math_ops.cast(tensor, dtype=self._dvariable.dtype) if self._dvariable
137        .save_as_bf16 else tensor)
138
139
140@tf_export('experimental.dtensor.DVariable', v1=[])
141class DVariable(resource_variable_ops.ResourceVariable):
142  """A replacement for tf.Variable which follows initial value placement.
143
144    The class also handles restore/save operations in DTensor. Note that,
145    DVariable may fall back to normal tf.Variable at this moment if
146    `initial_value` is not a DTensor.
147  """
148
149  def __init__(self, initial_value, *args, dtype=None, **kwargs):
150    """Overrides tf.Variable to fix VarHandleOp placements."""
151    # Variables by default use the current device scope for placement. This
152    # wrapper has them follow the initial value's placement instead (which will
153    # be the DTensor device if the initial value has a layout).
154
155    # Pop layout from kwargs since keras make_variable may pass a 'layout'
156    # keyword argument. We need to pop it because we are passing kwargs to
157    # super class constructor.
158    layout = kwargs.pop('layout', None)
159    shape = kwargs.get('shape', None)
160
161    if callable(initial_value):
162      unwrapped = initial_value
163      if issubclass(type(initial_value), functools.partial):
164        unwrapped = initial_value.func
165
166      # If wrapped is a CheckpointInitialValueCallable, this means that
167      # we are creating a Variable during a checkpoint restore.
168      # Thus the restore will happen now through this callable
169      # and we will create the DVariable with the restored dtensor.
170      if issubclass(type(unwrapped), trackable.CheckpointInitialValueCallable):
171        if not shape or not layout:
172          raise ValueError('Expected shape and layout to be not None.')
173
174        # CheckpointInitialValueCallable will call an eager tf.RestoreV2,
175        # which does not have any shape information or layout information
176        # attached. Thus we will do two things to have them correctly specified:
177        #
178        # The default layout scope allows us to correctly specify the output
179        # layout of the tf.RestoreV2 that will be called
180        #
181        # Passing shard_info with the correct shape allows the tf.RestoreV2
182        # ShapeInference to extract the shape.
183        initial_value = api.call_with_layout(
184            initial_value,
185            layout,
186            shard_info=trackable.ShardInfo(
187                shape=shape, offset=[0] * len(shape)))
188      else:
189        initial_value = initial_value()
190
191    # When the initial value came from a Checkpoint restoration, fetch tensor.
192    if isinstance(initial_value, trackable.CheckpointInitialValue):
193      initial_value = initial_value.wrapped_value
194
195    initial_value = ops.convert_to_tensor(initial_value, dtype=dtype)
196    variable_device = initial_value.device
197    self._save_as_bf16 = False
198    # TODO(b/159035705): The following code enables variable creation inside
199    # a tf.function. However, it requires a global dtensor device.
200    # if not variable_device and not tf.executing_eagerly():
201    #   try:
202    #     initial_value.op.get_attr("_layout")
203    #   except ValueError:
204    #     pass
205    #   else:
206    #     # The initial value is a DTensor, but because the DTensor device is
207    #     # only active during eager execution at the moment we need to
208    #     # translate that into a placement for the eager VarHandleOp.
209    #     variable_device = _dtensor_device().name
210    with ops.device(variable_device):
211      # If initial tensor assigned to DVariable is DTensor, record the layout of
212      # the resource so that this can be queried.
213      self.layout = None
214      if context.executing_eagerly():
215        try:
216          self.layout = api.fetch_layout(initial_value)
217        except (errors.InvalidArgumentError, errors.NotFoundError):
218          # For Non-DTensor tensors, fetch layout results in expected
219          # InvalidArgument or NotFoundError depending on whether the API
220          # is called within DTensor device scope or not.
221          self.layout = None
222          pass
223      mesh = self.layout.mesh if self.layout else None
224      with api.run_on(mesh) if mesh else contextlib.nullcontext():
225        super(DVariable, self).__init__(
226            initial_value, *args, dtype=dtype, **kwargs)
227
228  @property
229  def save_as_bf16(self):
230    return self._save_as_bf16
231
232  @save_as_bf16.setter
233  def save_as_bf16(self, save_as_bf16):
234    """Enables saving float32 as bfloat16."""
235    self._save_as_bf16 = save_as_bf16 and self.dtype == dtypes.float32
236
237  def _gather_saveables_for_checkpoint(self):
238    return {
239        trackable.VARIABLE_VALUE_KEY:
240            functools.partial(_DVariableSaveable, self)
241    }
242