• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""APIs to deal with input datasets efficiently in DTensor.
16
17When using tf.data with DTensor, the `DTensorDataset` API can be used to
18efficiently handle loading the input data and correctly packing it to the
19corresponding devices. This API is intended to work with unbatched data and can
20be used for both data and model parallel setups.
21
22Example usage:
23
24>>> # 1-D mesh with 4 devices
25>>> mesh = dtensor.Mesh(dim_names=['batch'], ...)
26>>> layout = dtensor.Layout.batch_sharded(mesh, 'batch', rank=1)
27>>> dataset = tf.data.Dataset.range(256)
28>>> d_dataset = dtensor.DTensorDataset(
29...     dataset=dataset,
30...     global_batch_size=16,
31...     mesh=mesh,
32...     layouts=layout,
33...     batch_dim='batch')
34>>> d_iter = iter(d_dataset)
35>>> # Each batch is a length 16 tensor sharded across 4 devices
36>>> batch_0_dtensor = next(d_iter)
37>>> batch_0_dtensor
38<tf.Tensor: shape=(16,),
39            dtype=int64,
40            value={"CPU:0": [ 0  1  2  4],
41                   "CPU:1": [ 5  6  7  8],
42                   "CPU:2": [ 9 10 11 12],
43                   "CPU:3": [13 14 15 16]}>
44>>> batch_1_dtensor = next(d_iter)
45>>> batch_1_dtensor
46<tf.Tensor: shape=(16,),
47            dtype=int64,
48            value={"CPU:0": [17 18 19 20],
49                   "CPU:1": [21 22 23 24],
50                   "CPU:2": [25 26 27 28],
51                   "CPU:3": [29 30 31 32]}>
52
53For multi-client setups, `DTensorDataset` interacts with tf.data service to
54correctly distribute the dataset among the participating clients. DTensor works
55with tf.data service in co-located mode where each worker is running alongside
56the DTensor client (the Tensorflow Python process). The `TFDataServiceConfig`
57dataclass can be filled with information about the tf.data service cluster, and
58passed to `DTensorDataset` to enable distribution.
59"""
60
61import dataclasses
62
63from typing import Any, List, Optional, Sequence, Tuple
64
65from tensorflow.dtensor.python import api
66from tensorflow.dtensor.python import layout as layout_lib
67from tensorflow.python.data.experimental.ops import data_service_ops
68from tensorflow.python.data.ops import dataset_ops
69from tensorflow.python.data.ops import iterator_ops
70from tensorflow.python.framework import constant_op
71from tensorflow.python.framework import dtypes
72from tensorflow.python.framework import errors
73from tensorflow.python.framework import ops
74from tensorflow.python.framework import tensor_spec
75from tensorflow.python.ops import array_ops
76from tensorflow.python.ops import math_ops
77from tensorflow.python.util import nest
78from tensorflow.python.util.tf_export import tf_export
79
80
81@dataclasses.dataclass
82class TFDataServiceConfig:
83  """Specifies the tf.data service configuration to use.
84
85  Attributes:
86    dispatcher_address: a string specifying the address of the tf.data service
87      dispatcher server.
88    job_name: a non-empty string identifying the shared job that will be created
89      on tf.data service to process this dataset.
90  """
91  dispatcher_address: str
92  job_name: str
93
94
95# TODO(b/223275517): Add support for get_next_as_optional().
96class _DTensorIterator(iterator_ops.IteratorBase):
97  """An iterator for a tf.data.Dataset distributed using DTensor.
98
99  DTensorIterator encapsulates multiple underlying dataset iterators. It handles
100  retrieving the tensors to be placed on each underlying device and then uses
101  the 'pack' operation to create and return a DTensor. Thus users need only
102  interact with a single DTensorIterator to automatically distribute dataset
103  tensors onto devices.
104  """
105
106  def __init__(self, datasets: Sequence[Tuple[int, dataset_ops.DatasetV2]],
107               element_spec: tensor_spec.TensorSpec, layouts: Any,
108               num_local_devices_per_replica: int):
109    """Initializes a distributed iterator for DTensor datasets.
110
111    The DTensorIterator uses 'replica IDs' to identify shards of a dataset. Here
112    the term 'replica' is used in the data-parallel context where each replica
113    receives a partition of the global batch. Depending on the model parallelism
114    in the layouts supplied, each device within that replica may receive the
115    same partition of the global batch (no model parallelism), or specific
116    slices of that partition.
117
118    Args:
119      datasets: a dictionary mapping each unique local replica ID to the dataset
120        object whose elements will be placed on the devices corresponding to
121        that replica.
122      element_spec: the underlying dataset's element spec.
123      layouts: a structure of DTensor layouts to be applied to the dataset
124        values. This can be a single layout or (possibly nested) tuples or
125        dictionaries of layouts, and the structure must match the structure of
126        the dataset.
127      num_local_devices_per_replica: the number of local devices for each
128        replica.
129    """
130    self._iterators = [
131        (replica_id, iter(dataset)) for replica_id, dataset in datasets
132    ]
133    self._element_spec = element_spec
134    self._layouts = layouts
135    self._num_local_devices_per_replica = num_local_devices_per_replica
136    self._flattened_layouts = nest.flatten(self._layouts)
137
138  def __next__(self):
139    try:
140      return self.get_next()
141    except errors.OutOfRangeError as e:
142      raise StopIteration from e
143
144  def __iter__(self):
145    return self
146
147  @property
148  def element_spec(self):
149    """The type specification of an element of this iterator.
150
151    A possibly nested structure of `tf.TypeSpec` objects matching the structure
152    of an element of this iterator.
153    """
154    return self._element_spec
155
156  def get_next(self):
157    """Returns the next element.
158
159    Returns:
160      A possibly nested structure of values matching
161      `tf.data.Iterator.element_spec`.
162
163    Raises:
164      `tf.errors.OutOfRangeError`: if the end of the underlying iterators has
165        been reached.
166      RuntimeError: if any of the underlying iterators do not return the
167        expected number of items.
168    """
169    # Create the data structure to store the individual elements of the current
170    # batch. We store a list per element in the flattened dataset batch, and
171    # each list should contain as many tensors as there local devices.
172    curr_batch_elems = [[] for _ in range(len(self._flattened_layouts))]
173
174    for _, iterator in self._iterators:
175      for _ in range(self._num_local_devices_per_replica):
176        element = iterator.get_next()
177
178        # Separate the dataset elements based on the structure of the dataset.
179        flattened_element = nest.flatten(element)
180        for idx, batch in enumerate(flattened_element):
181          curr_batch_elems[idx].append(batch)
182
183    flattened_output = []
184    for batch_elems, layout in zip(curr_batch_elems, self._flattened_layouts):
185      expected_num_elems = layout.mesh.num_local_devices()
186      actual_num_elems = len(batch_elems)
187      if actual_num_elems != expected_num_elems:
188        raise RuntimeError('Expected to pack %d elements in batch but got %d' %
189                           (expected_num_elems, actual_num_elems))
190      flattened_output.append(api.pack(batch_elems, layout))
191    return nest.pack_sequence_as(self._layouts, flattened_output)
192
193  def get_next_as_optional(self):
194    """Returns the next element wrapped in `tf.experimental.Optional`.
195
196    If the iterator has reached the end of the sequence, the returned
197    `tf.experimental.Optional` will have no value.
198
199    Returns:
200      A `tf.experimental.Optional` object representing the next element.
201    """
202    raise NotImplementedError(
203        'get_next_as_optional not yet supported: b/223275517')
204
205  @property
206  def _type_spec(self):
207    return iterator_ops.IteratorSpec(self._element_spec)
208
209
210def _validate_input(flattened_layouts: Sequence[layout_lib.Layout],
211                    flattened_elem_spec: Sequence[tensor_spec.TensorSpec],
212                    dataset_already_batched: bool):
213  """Checks that the dataset's layouts and element specs are compatible.
214
215  Args:
216    flattened_layouts: the flattened list of layouts used to distribute the
217      dataset.
218    flattened_elem_spec: the flattened list of element specs used in the
219      dataset's components.
220    dataset_already_batched: whether the dataset to be validated is already
221      batched.
222
223  Raises:
224    ValueError: if the dataset's inputs are incompatible.
225  """
226  if not flattened_elem_spec:
227    raise ValueError(
228        'Expected input element spec of at least one element, was empty.')
229
230  first_elem_shape = flattened_elem_spec[0].shape
231
232  for layout, elem_spec in zip(flattened_layouts, flattened_elem_spec):
233    if elem_spec.shape.rank is None:
234      raise ValueError(
235          'Dataset element shape must have a valid rank, got spec %s.' %
236          elem_spec)
237
238    # Check that layout's rank matches the element's rank. If dataset is not yet
239    # batched, then the layout's rank must be one greater than the element's
240    # rank.
241    expected_rank = elem_spec.shape.rank
242    if not dataset_already_batched:
243      expected_rank += 1
244    if layout.rank != expected_rank:
245      raise ValueError(
246          ('Expected layout with rank %d for element spec %s, got layout %s. '
247           'Check that the dataset is not batched before passing to '
248           'DTensorDataset.') %
249          (expected_rank, elem_spec, layout.sharding_specs))
250
251    if dataset_already_batched:
252      # Check that the batch dimension size of all dataset elements match.
253      batch_dim_size = first_elem_shape.as_list()[0]
254      if batch_dim_size is None:
255        raise ValueError(
256            ('Size of batch dimension of element spec %s is None. Ensure '
257             'drop_remainder=True when batching the dataset.') % elem_spec)
258
259      if elem_spec.shape.as_list()[0] != batch_dim_size:
260        raise ValueError(
261            ('Size of batch dimension of element spec %s does not match '
262             'expected size %d.') % (elem_spec, batch_dim_size))
263
264
265def _shard_counts(layout: layout_lib.Layout,
266                  batch_dim: Optional[str] = None) -> List[int]:
267  """Computes a list of the number of shards in each dimension of the layout.
268
269  The shard counts are used to slice each dataset element. The batch dimension's
270  count is overridden to 1 since we only consider how many shards to make
271  locally (within each local replica). Sharding across clients is handled by
272  either tf.data.Dataset's shard transformation (in the single-client case) or
273  tf.data service's distribute function (in the multi-client case).
274
275  Args:
276    layout: the layout to compute the shard counts for.
277    batch_dim: the name of the batch dimension of the layout, if present.
278
279  Returns:
280    A list of shard counts, one element per dimension of the layout.
281  """
282  shard_counts = []
283  for spec in layout.sharding_specs:
284    if spec in (batch_dim, layout_lib.UNSHARDED):
285      shard_counts.append(1)
286    else:
287      shard_counts.append(layout.mesh.dim_size(spec))
288  return shard_counts
289
290
291def _index_matrix(layout: layout_lib.Layout,
292                  elem_spec: tensor_spec.TensorSpec) -> ops.Tensor:
293  """Computes a utility matrix to derive device-based slice offsets.
294
295  This function builds a matrix of shape `[mesh.rank, layout.rank]` for each
296  dataset element. This matrix can be used to slice the DTensor components
297  returned by the iterator according to the local device that component is to be
298  placed on. This can be done by multiplying the device offsets of shape
299  `[1, mesh.rank]` with this index matrix to get a `[1, layout.rank]` shape
300  tensor containing the slice offsets.
301
302  Note: the index on the batch dim is always 0 since sharding on the batch
303  dimension is handled by either tf.data.Dataset's shard transformation (in the
304  single-client case) or tf.data service's distribute function (in the
305  multi-client case). If there is no sharding on the batch dimension (or any
306  other dimension), the slice index remains 0.
307
308  Args:
309    layout: the layout of the dataset element.
310    elem_spec: the spec of the dataset element.
311
312  Returns:
313    The index matrix as a tensor.
314  """
315  matrix = []
316  for dim in layout.mesh.dim_names:
317    row = [0]
318    for layout_idx, spec in enumerate(layout.sharding_specs[1:]):
319      if spec == layout_lib.UNSHARDED or spec != dim:
320        row.append(0)
321      else:
322        row.append(elem_spec.shape[layout_idx] // layout.mesh.dim_size(dim))
323    matrix.append(row)
324
325  return constant_op.constant(matrix, dtype=dtypes.int32)
326
327
328@tf_export('experimental.dtensor.DTensorDataset', v1=[])
329class DTensorDataset(dataset_ops.UnaryUnchangedStructureDataset):
330  """A dataset of DTensors.
331
332  DTensorDataset encapsulates a `tf.data.Dataset` whose elements are
333  automatically packed and returned as DTensors based on a given mesh and
334  layouts.
335  """
336
337  def __init__(self,
338               dataset: dataset_ops.DatasetV2,
339               *,
340               mesh: layout_lib.Mesh,
341               layouts: Any,
342               global_batch_size: int,
343               dataset_already_batched: bool = False,
344               batch_dim: Optional[str] = None,
345               prefetch: Optional[int] = None,
346               tf_data_service_config: Optional[TFDataServiceConfig] = None):
347    """Creates a DTensorDataset.
348
349    DTensorDataset automatically handles distribution of the dataset elements to
350    each client's devices. It can be used to create an iterator that returns
351    DTensors of the input data on each iteration.
352
353    DTensorDataset works best with unbatched datasets. It takes the mesh and the
354    provided layouts to automatically calculate how to batch the input locally
355    for each replica.
356
357    If the provided dataset is already batched according to the per-replica
358    batch size, then `dataset_already_batched` must be set and DTensorDataset
359    will check that the batch size is consistent with the intended
360    `global_batch_size` using the layout information. Each replica receives a
361    separate slice of the global batch, thus the per-replica batch size can be
362    computed as the global batch size divided by the number of model replicas.
363    For a DTensor mesh, the number of replicas is equal to the size of the
364    mesh's batch dimension.
365
366    TODO(b/223275517): add support for input datasets that are already batched
367    to the global batch size.
368
369    Args:
370      dataset: a `tf.data.Dataset` object.
371      mesh: the DTensor mesh to place the dataset batches on.
372      layouts: a structure of DTensor layouts to be applied to the input dataset
373        values. This can be a single layout or (possibly nested) tuples or
374        dictionaries of layouts, and the structure must match the structure of
375        the dataset. Either all or none of the layouts should be sharded on the
376        batch dimension; having only a subset of layouts batch sharded will not
377        work and raises a ValueError.
378      global_batch_size: the desired global batch size.
379      dataset_already_batched: must be set only if the dataset is already
380        batched to the per-replica batch size. The batched dataset must have
381        `drop_remainder=True` set since DTensor requires static shapes for
382        slicing the input tensors.
383      batch_dim: the mesh dimension on which the input's batch dimension is
384        sharded. Set to None if the input layouts do not shard on the batch
385        dimension.
386      prefetch: number of batches to prefetch using Dataset.prefetch.
387      tf_data_service_config: if operating in multi-client mode, this config
388        specifies the tf.data service configuration to use.
389
390    Raises:
391      ValueError: on any of the following situations,
392        1. if the structures and ranks of layouts and the dataset do not match.
393        2. if the shapes in the dataset's spec are not fully defined.
394        3. if batch_dim is specified and all layouts are not batch-sharded.
395        4. if per_replica_batch_size is specified for an already batched Dataset
396           but it does not match the expected per-replica size based on the
397           provided mesh.
398      TypeError: if type of structures of layouts and the dataset do not match.
399    """
400    super().__init__(dataset, dataset_ops.to_variant(dataset))
401
402    self._mesh = mesh
403    self._layouts = layouts
404    self._batch_dim = batch_dim
405    self._prefetch = prefetch
406    self._tf_data_service_config = tf_data_service_config
407
408    self._element_spec = dataset.element_spec
409
410    nest.assert_same_structure(self._element_spec, self._layouts)
411    flattened_layouts = nest.flatten(self._layouts)
412    flattened_elem_spec = nest.flatten(self._element_spec)
413
414    if batch_dim:
415      num_global_replicas = mesh.dim_size(batch_dim)
416      self._local_replica_ids = list(
417          dict.fromkeys(
418              [loc[batch_dim] for loc in mesh.local_device_locations()]))
419
420      for layout in flattened_layouts:
421        if batch_dim != layout.sharding_specs[0]:
422          raise ValueError(
423              ('batch_dim %s was specified but at least one layout did not '
424               'contain it: %s') % (batch_dim, layout))
425    else:
426      # Only one replica since there is no sharding on the batch dimension.
427      num_global_replicas = 1
428      self._local_replica_ids = [0]
429
430    # Validate layout and element spec compatibility, and raise ValueError if
431    # invalid.
432    _validate_input(
433        flattened_layouts,
434        flattened_elem_spec,
435        dataset_already_batched=dataset_already_batched)
436
437    expected_batch_size = global_batch_size // num_global_replicas
438    if not dataset_already_batched:
439      self._batched_dataset = dataset.batch(
440          expected_batch_size, drop_remainder=True)
441    else:
442      per_replica_batch_size = flattened_elem_spec[0].shape.as_list()[0]
443      if per_replica_batch_size != expected_batch_size:
444        raise ValueError(
445            ('per_replica_batch_size does not matched expected size based on '
446             'the mesh, got %d but expected %d.') %
447            (per_replica_batch_size, expected_batch_size))
448      self._batched_dataset = dataset
449
450    num_global_devices_per_replica = api.num_global_devices(
451        mesh.device_type()) // num_global_replicas
452    self._num_local_replicas = len(self._local_replica_ids)
453    self._num_local_devices_per_replica = mesh.num_local_devices(
454    ) // self._num_local_replicas
455    # The number of clients each replica is split over.
456    self._num_clients_per_replica = (
457        num_global_devices_per_replica //
458        self._num_local_devices_per_replica)
459    # In the case where a replica is split across multiple clients, an offset
460    # needs to be added to the index used by the partitioning logic such that
461    # the local devices on that client can be correctly matched to slices of the
462    # input tensor(s). If replicas are wholly contained within a client, then
463    # this offset is always 0.
464    self._partition_offset = (
465        api.client_id() %
466        self._num_clients_per_replica) * self._num_local_devices_per_replica
467
468    # Helper data structures used in partitioning the dataset tensors.
469    self._all_shard_counts = [
470        _shard_counts(layout, batch_dim) for layout in flattened_layouts
471    ]
472    self._index_matrices = [
473        _index_matrix(layout, elem_spec) for layout, elem_spec in zip(
474            flattened_layouts, flattened_elem_spec)
475    ]
476
477  def __iter__(self):
478    datasets: List[Tuple[int, dataset_ops.DatasetV2]] = []
479
480    # Start with the batched the dataset.
481    local_dataset = self._batched_dataset
482
483    # If a replica is split over multiple clients then each batch needs to be
484    # repeated before distribution as many times as there are clients
485    # corresponding to that replica.
486    if self._batch_dim is not None:
487      local_dataset = self._repeat_batch(local_dataset,
488                                         self._num_clients_per_replica)
489
490    # Apply distribution here (if specified) so all remaining transformations
491    # are executed locally.
492    if self._tf_data_service_config is not None:
493      if self._batch_dim is None:
494        sharding_policy = data_service_ops.ShardingPolicy.OFF
495      else:
496        sharding_policy = data_service_ops.ShardingPolicy.FILE_OR_DATA
497
498      local_dataset = local_dataset.apply(
499          data_service_ops.distribute(
500              processing_mode=sharding_policy,
501              service=self._tf_data_service_config.dispatcher_address,
502              job_name=f'{self._tf_data_service_config.job_name}_{api.client_id()}',
503              target_workers='LOCAL'))
504
505    for local_replica_idx, replica_id in enumerate(self._local_replica_ids):
506      # Select the shard for the corresponding replica.
507      dataset = local_dataset.shard(self._num_local_replicas, local_replica_idx)
508
509      # Repeat each batch for each local device in the replica.
510      dataset = self._repeat_batch(dataset, self._num_local_devices_per_replica)
511
512      # Slice each shard further for all non-batch dim shards. If there is no
513      # non-batch dim sharding, this slice is essentially a no-op.
514      dataset = self._partition(dataset)
515
516      # Apply prefetch as the last step. Since each batch is repeated, the
517      # number of elements to prefetch has to be scaled by the same size.
518      if self._prefetch is not None:
519        dataset = dataset.prefetch(
520            self._prefetch * self._num_local_devices_per_replica)
521
522      datasets.append((replica_id, dataset))
523
524    return _DTensorIterator(datasets, self._element_spec, self._layouts,
525                            self._num_local_devices_per_replica)
526
527  def _repeat_batch(self, dataset, repeats):
528    def repeat(*x):
529      return dataset_ops.DatasetV2.from_tensors(x).repeat(repeats)
530
531    return dataset.flat_map(repeat)
532
533  def _partition(self, dataset):
534    """Slices each dataset element on any sharded non-batch dimension."""
535
536    # TODO(b/223275517): decouple from self and make testable.
537    def slice_batch(index, batch):
538      flattened_batch = nest.flatten(batch)
539      flattened_output = []
540
541      norm_index = math_ops.cast(
542          index % self._num_local_devices_per_replica, dtype=dtypes.int32)
543      norm_index += self._partition_offset
544      coords = self._mesh.coords(norm_index)
545      coords = array_ops.reshape(coords, (1, -1))
546
547      for element, shard_counts, idx_matrix in zip(flattened_batch,
548                                                   self._all_shard_counts,
549                                                   self._index_matrices):
550        indexes = math_ops.matmul(coords, idx_matrix)
551        start = array_ops.reshape(indexes, (-1,))
552        size = array_ops.shape_v2(
553            element, out_type=dtypes.int32) // shard_counts
554        flattened_output.append(
555            array_ops.slice(element, begin=start, size=size))
556
557      return nest.pack_sequence_as(batch, flattened_output)
558
559    enumerated_dataset = dataset.enumerate()
560    partitioned_dataset = enumerated_dataset.map(slice_batch)
561    return partitioned_dataset
562