1# Copyright 2022 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""APIs to deal with input datasets efficiently in DTensor. 16 17When using tf.data with DTensor, the `DTensorDataset` API can be used to 18efficiently handle loading the input data and correctly packing it to the 19corresponding devices. This API is intended to work with unbatched data and can 20be used for both data and model parallel setups. 21 22Example usage: 23 24>>> # 1-D mesh with 4 devices 25>>> mesh = dtensor.Mesh(dim_names=['batch'], ...) 26>>> layout = dtensor.Layout.batch_sharded(mesh, 'batch', rank=1) 27>>> dataset = tf.data.Dataset.range(256) 28>>> d_dataset = dtensor.DTensorDataset( 29... dataset=dataset, 30... global_batch_size=16, 31... mesh=mesh, 32... layouts=layout, 33... batch_dim='batch') 34>>> d_iter = iter(d_dataset) 35>>> # Each batch is a length 16 tensor sharded across 4 devices 36>>> batch_0_dtensor = next(d_iter) 37>>> batch_0_dtensor 38<tf.Tensor: shape=(16,), 39 dtype=int64, 40 value={"CPU:0": [ 0 1 2 4], 41 "CPU:1": [ 5 6 7 8], 42 "CPU:2": [ 9 10 11 12], 43 "CPU:3": [13 14 15 16]}> 44>>> batch_1_dtensor = next(d_iter) 45>>> batch_1_dtensor 46<tf.Tensor: shape=(16,), 47 dtype=int64, 48 value={"CPU:0": [17 18 19 20], 49 "CPU:1": [21 22 23 24], 50 "CPU:2": [25 26 27 28], 51 "CPU:3": [29 30 31 32]}> 52 53For multi-client setups, `DTensorDataset` interacts with tf.data service to 54correctly distribute the dataset among the participating clients. DTensor works 55with tf.data service in co-located mode where each worker is running alongside 56the DTensor client (the Tensorflow Python process). The `TFDataServiceConfig` 57dataclass can be filled with information about the tf.data service cluster, and 58passed to `DTensorDataset` to enable distribution. 59""" 60 61import dataclasses 62 63from typing import Any, List, Optional, Sequence, Tuple 64 65from tensorflow.dtensor.python import api 66from tensorflow.dtensor.python import layout as layout_lib 67from tensorflow.python.data.experimental.ops import data_service_ops 68from tensorflow.python.data.ops import dataset_ops 69from tensorflow.python.data.ops import iterator_ops 70from tensorflow.python.framework import constant_op 71from tensorflow.python.framework import dtypes 72from tensorflow.python.framework import errors 73from tensorflow.python.framework import ops 74from tensorflow.python.framework import tensor_spec 75from tensorflow.python.ops import array_ops 76from tensorflow.python.ops import math_ops 77from tensorflow.python.util import nest 78from tensorflow.python.util.tf_export import tf_export 79 80 81@dataclasses.dataclass 82class TFDataServiceConfig: 83 """Specifies the tf.data service configuration to use. 84 85 Attributes: 86 dispatcher_address: a string specifying the address of the tf.data service 87 dispatcher server. 88 job_name: a non-empty string identifying the shared job that will be created 89 on tf.data service to process this dataset. 90 """ 91 dispatcher_address: str 92 job_name: str 93 94 95# TODO(b/223275517): Add support for get_next_as_optional(). 96class _DTensorIterator(iterator_ops.IteratorBase): 97 """An iterator for a tf.data.Dataset distributed using DTensor. 98 99 DTensorIterator encapsulates multiple underlying dataset iterators. It handles 100 retrieving the tensors to be placed on each underlying device and then uses 101 the 'pack' operation to create and return a DTensor. Thus users need only 102 interact with a single DTensorIterator to automatically distribute dataset 103 tensors onto devices. 104 """ 105 106 def __init__(self, datasets: Sequence[Tuple[int, dataset_ops.DatasetV2]], 107 element_spec: tensor_spec.TensorSpec, layouts: Any, 108 num_local_devices_per_replica: int): 109 """Initializes a distributed iterator for DTensor datasets. 110 111 The DTensorIterator uses 'replica IDs' to identify shards of a dataset. Here 112 the term 'replica' is used in the data-parallel context where each replica 113 receives a partition of the global batch. Depending on the model parallelism 114 in the layouts supplied, each device within that replica may receive the 115 same partition of the global batch (no model parallelism), or specific 116 slices of that partition. 117 118 Args: 119 datasets: a dictionary mapping each unique local replica ID to the dataset 120 object whose elements will be placed on the devices corresponding to 121 that replica. 122 element_spec: the underlying dataset's element spec. 123 layouts: a structure of DTensor layouts to be applied to the dataset 124 values. This can be a single layout or (possibly nested) tuples or 125 dictionaries of layouts, and the structure must match the structure of 126 the dataset. 127 num_local_devices_per_replica: the number of local devices for each 128 replica. 129 """ 130 self._iterators = [ 131 (replica_id, iter(dataset)) for replica_id, dataset in datasets 132 ] 133 self._element_spec = element_spec 134 self._layouts = layouts 135 self._num_local_devices_per_replica = num_local_devices_per_replica 136 self._flattened_layouts = nest.flatten(self._layouts) 137 138 def __next__(self): 139 try: 140 return self.get_next() 141 except errors.OutOfRangeError as e: 142 raise StopIteration from e 143 144 def __iter__(self): 145 return self 146 147 @property 148 def element_spec(self): 149 """The type specification of an element of this iterator. 150 151 A possibly nested structure of `tf.TypeSpec` objects matching the structure 152 of an element of this iterator. 153 """ 154 return self._element_spec 155 156 def get_next(self): 157 """Returns the next element. 158 159 Returns: 160 A possibly nested structure of values matching 161 `tf.data.Iterator.element_spec`. 162 163 Raises: 164 `tf.errors.OutOfRangeError`: if the end of the underlying iterators has 165 been reached. 166 RuntimeError: if any of the underlying iterators do not return the 167 expected number of items. 168 """ 169 # Create the data structure to store the individual elements of the current 170 # batch. We store a list per element in the flattened dataset batch, and 171 # each list should contain as many tensors as there local devices. 172 curr_batch_elems = [[] for _ in range(len(self._flattened_layouts))] 173 174 for _, iterator in self._iterators: 175 for _ in range(self._num_local_devices_per_replica): 176 element = iterator.get_next() 177 178 # Separate the dataset elements based on the structure of the dataset. 179 flattened_element = nest.flatten(element) 180 for idx, batch in enumerate(flattened_element): 181 curr_batch_elems[idx].append(batch) 182 183 flattened_output = [] 184 for batch_elems, layout in zip(curr_batch_elems, self._flattened_layouts): 185 expected_num_elems = layout.mesh.num_local_devices() 186 actual_num_elems = len(batch_elems) 187 if actual_num_elems != expected_num_elems: 188 raise RuntimeError('Expected to pack %d elements in batch but got %d' % 189 (expected_num_elems, actual_num_elems)) 190 flattened_output.append(api.pack(batch_elems, layout)) 191 return nest.pack_sequence_as(self._layouts, flattened_output) 192 193 def get_next_as_optional(self): 194 """Returns the next element wrapped in `tf.experimental.Optional`. 195 196 If the iterator has reached the end of the sequence, the returned 197 `tf.experimental.Optional` will have no value. 198 199 Returns: 200 A `tf.experimental.Optional` object representing the next element. 201 """ 202 raise NotImplementedError( 203 'get_next_as_optional not yet supported: b/223275517') 204 205 @property 206 def _type_spec(self): 207 return iterator_ops.IteratorSpec(self._element_spec) 208 209 210def _validate_input(flattened_layouts: Sequence[layout_lib.Layout], 211 flattened_elem_spec: Sequence[tensor_spec.TensorSpec], 212 dataset_already_batched: bool): 213 """Checks that the dataset's layouts and element specs are compatible. 214 215 Args: 216 flattened_layouts: the flattened list of layouts used to distribute the 217 dataset. 218 flattened_elem_spec: the flattened list of element specs used in the 219 dataset's components. 220 dataset_already_batched: whether the dataset to be validated is already 221 batched. 222 223 Raises: 224 ValueError: if the dataset's inputs are incompatible. 225 """ 226 if not flattened_elem_spec: 227 raise ValueError( 228 'Expected input element spec of at least one element, was empty.') 229 230 first_elem_shape = flattened_elem_spec[0].shape 231 232 for layout, elem_spec in zip(flattened_layouts, flattened_elem_spec): 233 if elem_spec.shape.rank is None: 234 raise ValueError( 235 'Dataset element shape must have a valid rank, got spec %s.' % 236 elem_spec) 237 238 # Check that layout's rank matches the element's rank. If dataset is not yet 239 # batched, then the layout's rank must be one greater than the element's 240 # rank. 241 expected_rank = elem_spec.shape.rank 242 if not dataset_already_batched: 243 expected_rank += 1 244 if layout.rank != expected_rank: 245 raise ValueError( 246 ('Expected layout with rank %d for element spec %s, got layout %s. ' 247 'Check that the dataset is not batched before passing to ' 248 'DTensorDataset.') % 249 (expected_rank, elem_spec, layout.sharding_specs)) 250 251 if dataset_already_batched: 252 # Check that the batch dimension size of all dataset elements match. 253 batch_dim_size = first_elem_shape.as_list()[0] 254 if batch_dim_size is None: 255 raise ValueError( 256 ('Size of batch dimension of element spec %s is None. Ensure ' 257 'drop_remainder=True when batching the dataset.') % elem_spec) 258 259 if elem_spec.shape.as_list()[0] != batch_dim_size: 260 raise ValueError( 261 ('Size of batch dimension of element spec %s does not match ' 262 'expected size %d.') % (elem_spec, batch_dim_size)) 263 264 265def _shard_counts(layout: layout_lib.Layout, 266 batch_dim: Optional[str] = None) -> List[int]: 267 """Computes a list of the number of shards in each dimension of the layout. 268 269 The shard counts are used to slice each dataset element. The batch dimension's 270 count is overridden to 1 since we only consider how many shards to make 271 locally (within each local replica). Sharding across clients is handled by 272 either tf.data.Dataset's shard transformation (in the single-client case) or 273 tf.data service's distribute function (in the multi-client case). 274 275 Args: 276 layout: the layout to compute the shard counts for. 277 batch_dim: the name of the batch dimension of the layout, if present. 278 279 Returns: 280 A list of shard counts, one element per dimension of the layout. 281 """ 282 shard_counts = [] 283 for spec in layout.sharding_specs: 284 if spec in (batch_dim, layout_lib.UNSHARDED): 285 shard_counts.append(1) 286 else: 287 shard_counts.append(layout.mesh.dim_size(spec)) 288 return shard_counts 289 290 291def _index_matrix(layout: layout_lib.Layout, 292 elem_spec: tensor_spec.TensorSpec) -> ops.Tensor: 293 """Computes a utility matrix to derive device-based slice offsets. 294 295 This function builds a matrix of shape `[mesh.rank, layout.rank]` for each 296 dataset element. This matrix can be used to slice the DTensor components 297 returned by the iterator according to the local device that component is to be 298 placed on. This can be done by multiplying the device offsets of shape 299 `[1, mesh.rank]` with this index matrix to get a `[1, layout.rank]` shape 300 tensor containing the slice offsets. 301 302 Note: the index on the batch dim is always 0 since sharding on the batch 303 dimension is handled by either tf.data.Dataset's shard transformation (in the 304 single-client case) or tf.data service's distribute function (in the 305 multi-client case). If there is no sharding on the batch dimension (or any 306 other dimension), the slice index remains 0. 307 308 Args: 309 layout: the layout of the dataset element. 310 elem_spec: the spec of the dataset element. 311 312 Returns: 313 The index matrix as a tensor. 314 """ 315 matrix = [] 316 for dim in layout.mesh.dim_names: 317 row = [0] 318 for layout_idx, spec in enumerate(layout.sharding_specs[1:]): 319 if spec == layout_lib.UNSHARDED or spec != dim: 320 row.append(0) 321 else: 322 row.append(elem_spec.shape[layout_idx] // layout.mesh.dim_size(dim)) 323 matrix.append(row) 324 325 return constant_op.constant(matrix, dtype=dtypes.int32) 326 327 328@tf_export('experimental.dtensor.DTensorDataset', v1=[]) 329class DTensorDataset(dataset_ops.UnaryUnchangedStructureDataset): 330 """A dataset of DTensors. 331 332 DTensorDataset encapsulates a `tf.data.Dataset` whose elements are 333 automatically packed and returned as DTensors based on a given mesh and 334 layouts. 335 """ 336 337 def __init__(self, 338 dataset: dataset_ops.DatasetV2, 339 *, 340 mesh: layout_lib.Mesh, 341 layouts: Any, 342 global_batch_size: int, 343 dataset_already_batched: bool = False, 344 batch_dim: Optional[str] = None, 345 prefetch: Optional[int] = None, 346 tf_data_service_config: Optional[TFDataServiceConfig] = None): 347 """Creates a DTensorDataset. 348 349 DTensorDataset automatically handles distribution of the dataset elements to 350 each client's devices. It can be used to create an iterator that returns 351 DTensors of the input data on each iteration. 352 353 DTensorDataset works best with unbatched datasets. It takes the mesh and the 354 provided layouts to automatically calculate how to batch the input locally 355 for each replica. 356 357 If the provided dataset is already batched according to the per-replica 358 batch size, then `dataset_already_batched` must be set and DTensorDataset 359 will check that the batch size is consistent with the intended 360 `global_batch_size` using the layout information. Each replica receives a 361 separate slice of the global batch, thus the per-replica batch size can be 362 computed as the global batch size divided by the number of model replicas. 363 For a DTensor mesh, the number of replicas is equal to the size of the 364 mesh's batch dimension. 365 366 TODO(b/223275517): add support for input datasets that are already batched 367 to the global batch size. 368 369 Args: 370 dataset: a `tf.data.Dataset` object. 371 mesh: the DTensor mesh to place the dataset batches on. 372 layouts: a structure of DTensor layouts to be applied to the input dataset 373 values. This can be a single layout or (possibly nested) tuples or 374 dictionaries of layouts, and the structure must match the structure of 375 the dataset. Either all or none of the layouts should be sharded on the 376 batch dimension; having only a subset of layouts batch sharded will not 377 work and raises a ValueError. 378 global_batch_size: the desired global batch size. 379 dataset_already_batched: must be set only if the dataset is already 380 batched to the per-replica batch size. The batched dataset must have 381 `drop_remainder=True` set since DTensor requires static shapes for 382 slicing the input tensors. 383 batch_dim: the mesh dimension on which the input's batch dimension is 384 sharded. Set to None if the input layouts do not shard on the batch 385 dimension. 386 prefetch: number of batches to prefetch using Dataset.prefetch. 387 tf_data_service_config: if operating in multi-client mode, this config 388 specifies the tf.data service configuration to use. 389 390 Raises: 391 ValueError: on any of the following situations, 392 1. if the structures and ranks of layouts and the dataset do not match. 393 2. if the shapes in the dataset's spec are not fully defined. 394 3. if batch_dim is specified and all layouts are not batch-sharded. 395 4. if per_replica_batch_size is specified for an already batched Dataset 396 but it does not match the expected per-replica size based on the 397 provided mesh. 398 TypeError: if type of structures of layouts and the dataset do not match. 399 """ 400 super().__init__(dataset, dataset_ops.to_variant(dataset)) 401 402 self._mesh = mesh 403 self._layouts = layouts 404 self._batch_dim = batch_dim 405 self._prefetch = prefetch 406 self._tf_data_service_config = tf_data_service_config 407 408 self._element_spec = dataset.element_spec 409 410 nest.assert_same_structure(self._element_spec, self._layouts) 411 flattened_layouts = nest.flatten(self._layouts) 412 flattened_elem_spec = nest.flatten(self._element_spec) 413 414 if batch_dim: 415 num_global_replicas = mesh.dim_size(batch_dim) 416 self._local_replica_ids = list( 417 dict.fromkeys( 418 [loc[batch_dim] for loc in mesh.local_device_locations()])) 419 420 for layout in flattened_layouts: 421 if batch_dim != layout.sharding_specs[0]: 422 raise ValueError( 423 ('batch_dim %s was specified but at least one layout did not ' 424 'contain it: %s') % (batch_dim, layout)) 425 else: 426 # Only one replica since there is no sharding on the batch dimension. 427 num_global_replicas = 1 428 self._local_replica_ids = [0] 429 430 # Validate layout and element spec compatibility, and raise ValueError if 431 # invalid. 432 _validate_input( 433 flattened_layouts, 434 flattened_elem_spec, 435 dataset_already_batched=dataset_already_batched) 436 437 expected_batch_size = global_batch_size // num_global_replicas 438 if not dataset_already_batched: 439 self._batched_dataset = dataset.batch( 440 expected_batch_size, drop_remainder=True) 441 else: 442 per_replica_batch_size = flattened_elem_spec[0].shape.as_list()[0] 443 if per_replica_batch_size != expected_batch_size: 444 raise ValueError( 445 ('per_replica_batch_size does not matched expected size based on ' 446 'the mesh, got %d but expected %d.') % 447 (per_replica_batch_size, expected_batch_size)) 448 self._batched_dataset = dataset 449 450 num_global_devices_per_replica = api.num_global_devices( 451 mesh.device_type()) // num_global_replicas 452 self._num_local_replicas = len(self._local_replica_ids) 453 self._num_local_devices_per_replica = mesh.num_local_devices( 454 ) // self._num_local_replicas 455 # The number of clients each replica is split over. 456 self._num_clients_per_replica = ( 457 num_global_devices_per_replica // 458 self._num_local_devices_per_replica) 459 # In the case where a replica is split across multiple clients, an offset 460 # needs to be added to the index used by the partitioning logic such that 461 # the local devices on that client can be correctly matched to slices of the 462 # input tensor(s). If replicas are wholly contained within a client, then 463 # this offset is always 0. 464 self._partition_offset = ( 465 api.client_id() % 466 self._num_clients_per_replica) * self._num_local_devices_per_replica 467 468 # Helper data structures used in partitioning the dataset tensors. 469 self._all_shard_counts = [ 470 _shard_counts(layout, batch_dim) for layout in flattened_layouts 471 ] 472 self._index_matrices = [ 473 _index_matrix(layout, elem_spec) for layout, elem_spec in zip( 474 flattened_layouts, flattened_elem_spec) 475 ] 476 477 def __iter__(self): 478 datasets: List[Tuple[int, dataset_ops.DatasetV2]] = [] 479 480 # Start with the batched the dataset. 481 local_dataset = self._batched_dataset 482 483 # If a replica is split over multiple clients then each batch needs to be 484 # repeated before distribution as many times as there are clients 485 # corresponding to that replica. 486 if self._batch_dim is not None: 487 local_dataset = self._repeat_batch(local_dataset, 488 self._num_clients_per_replica) 489 490 # Apply distribution here (if specified) so all remaining transformations 491 # are executed locally. 492 if self._tf_data_service_config is not None: 493 if self._batch_dim is None: 494 sharding_policy = data_service_ops.ShardingPolicy.OFF 495 else: 496 sharding_policy = data_service_ops.ShardingPolicy.FILE_OR_DATA 497 498 local_dataset = local_dataset.apply( 499 data_service_ops.distribute( 500 processing_mode=sharding_policy, 501 service=self._tf_data_service_config.dispatcher_address, 502 job_name=f'{self._tf_data_service_config.job_name}_{api.client_id()}', 503 target_workers='LOCAL')) 504 505 for local_replica_idx, replica_id in enumerate(self._local_replica_ids): 506 # Select the shard for the corresponding replica. 507 dataset = local_dataset.shard(self._num_local_replicas, local_replica_idx) 508 509 # Repeat each batch for each local device in the replica. 510 dataset = self._repeat_batch(dataset, self._num_local_devices_per_replica) 511 512 # Slice each shard further for all non-batch dim shards. If there is no 513 # non-batch dim sharding, this slice is essentially a no-op. 514 dataset = self._partition(dataset) 515 516 # Apply prefetch as the last step. Since each batch is repeated, the 517 # number of elements to prefetch has to be scaled by the same size. 518 if self._prefetch is not None: 519 dataset = dataset.prefetch( 520 self._prefetch * self._num_local_devices_per_replica) 521 522 datasets.append((replica_id, dataset)) 523 524 return _DTensorIterator(datasets, self._element_spec, self._layouts, 525 self._num_local_devices_per_replica) 526 527 def _repeat_batch(self, dataset, repeats): 528 def repeat(*x): 529 return dataset_ops.DatasetV2.from_tensors(x).repeat(repeats) 530 531 return dataset.flat_map(repeat) 532 533 def _partition(self, dataset): 534 """Slices each dataset element on any sharded non-batch dimension.""" 535 536 # TODO(b/223275517): decouple from self and make testable. 537 def slice_batch(index, batch): 538 flattened_batch = nest.flatten(batch) 539 flattened_output = [] 540 541 norm_index = math_ops.cast( 542 index % self._num_local_devices_per_replica, dtype=dtypes.int32) 543 norm_index += self._partition_offset 544 coords = self._mesh.coords(norm_index) 545 coords = array_ops.reshape(coords, (1, -1)) 546 547 for element, shard_counts, idx_matrix in zip(flattened_batch, 548 self._all_shard_counts, 549 self._index_matrices): 550 indexes = math_ops.matmul(coords, idx_matrix) 551 start = array_ops.reshape(indexes, (-1,)) 552 size = array_ops.shape_v2( 553 element, out_type=dtypes.int32) // shard_counts 554 flattened_output.append( 555 array_ops.slice(element, begin=start, size=size)) 556 557 return nest.pack_sequence_as(batch, flattened_output) 558 559 enumerated_dataset = dataset.enumerate() 560 partitioned_dataset = enumerated_dataset.map(slice_batch) 561 return partitioned_dataset 562