1# Copyright 2022 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""TPU-specific utilities for DTensor.""" 16 17import functools 18import time 19from typing import List, Optional, Dict 20 21from absl import flags 22import numpy as np 23 24from tensorflow.dtensor.python import api 25from tensorflow.dtensor.python import dtensor_device 26from tensorflow.dtensor.python import gen_dtensor_ops 27from tensorflow.dtensor.python import heartbeat 28from tensorflow.dtensor.python import layout as layout_lib 29from tensorflow.dtensor.python import multi_client_util 30from tensorflow.python.eager import context 31from tensorflow.python.eager import def_function 32from tensorflow.python.eager import function 33from tensorflow.python.framework import constant_op 34from tensorflow.python.framework import device as tf_device 35from tensorflow.python.framework import errors 36from tensorflow.python.framework import ops 37from tensorflow.python.framework import tfrt_utils 38from tensorflow.python.ops import array_ops 39from tensorflow.python.ops import math_ops 40from tensorflow.python.platform import tf_logging as logging 41from tensorflow.python.tpu import topology 42from tensorflow.python.util.tf_export import tf_export 43 44_INITIALIZED_TPU_SYSTEMS = {} 45_MESH_DIM_X = "x" 46_TPU_DEVICE_TYPE = "TPU" 47 48# A dedicated, hidden device used to make C++ API calls. 49_dtensor_device = None 50 51# `_topology._mesh_shape` contains the TPU hardware slice size. 52# `_topology.device_coordinates` maps TF task-device ordinals to TPU core IDs. 53_tpu_topology = None 54 55# Cache core ID <-> location mappings so we need not make repeated C++ calls. 56# Both are indexed by TF task-device ordinals. 57_all_core_ids = None 58_all_core_locations = None 59 60 61class _CoreLocation: 62 """Represents a TPU core's location in the mesh.""" 63 64 def __init__(self, x: int = 0, y: int = 0, z: int = 0, core: int = 0): 65 self.x = x 66 self.y = y 67 self.z = z 68 self.core = core 69 70 def __eq__(self, other): 71 if not isinstance(other, _CoreLocation): 72 return False 73 return self.x == other.x and self.y == other.y and self.z == other.z and self.core == other.core 74 75 def __ne__(self, other): 76 if not isinstance(other, _CoreLocation): 77 return True 78 return not self == other 79 80 def __hash__(self): 81 return hash((self.x, self.y, self.z, self.core)) 82 83 def __repr__(self): 84 return f"{type(self).__name__}(x={self.x}, y={self.y}, z={self.z}, core={self.core})" 85 86 def to_list(self): 87 return [self.x, self.y, self.z, self.core] 88 89 90def _create_device_array(shape, device_type, host_id, local_device_ids=None): 91 """Returns ID and device lists that can be used to create a mesh.""" 92 num_global_devices = api.num_global_devices(device_type) 93 global_device_ids = np.arange(num_global_devices).reshape(shape) 94 local_device_list = api.local_devices(device_type) 95 96 # User can specify local_device_ids or use default list for multi host. 97 num_local_devices = len(local_device_list) 98 local_device_ids = [ 99 x + host_id * num_local_devices for x in range(num_local_devices) 100 ] if not local_device_ids else local_device_ids 101 102 return global_device_ids, local_device_ids, local_device_list 103 104 105def _create_tpu_topology(core_locations: List[_CoreLocation], num_tasks: int, 106 num_devices_per_task: int) -> topology.Topology: 107 """Returns a Topology object build from a _CoreLocation list. 108 109 Args: 110 core_locations: A list of _CoreLocation objects sorted first by TF task ID 111 and then by per-task device ordinals. 112 num_tasks: The number of TF tasks in the cluster. 113 num_devices_per_task: The number of TPU devices local to each task. 114 """ 115 116 assert min([l.x for l in core_locations]) == 0 117 assert min([l.y for l in core_locations]) == 0 118 assert min([l.z for l in core_locations]) == 0 119 assert min([l.core for l in core_locations]) == 0 120 x_max = max([l.x for l in core_locations]) 121 y_max = max([l.y for l in core_locations]) 122 z_max = max([l.z for l in core_locations]) 123 core_max = max([l.core for l in core_locations]) 124 mesh_shape = [x_max + 1, y_max + 1, z_max + 1, core_max + 1] 125 126 device_coordinates = [[l.x, l.y, l.z, l.core] for l in core_locations] 127 device_coordinates = np.asarray(device_coordinates).reshape( 128 num_tasks, num_devices_per_task, 4) 129 130 return topology.Topology( 131 mesh_shape=mesh_shape, device_coordinates=device_coordinates) 132 133 134@tf_export("experimental.dtensor.shutdown_tpu_system", v1=[]) 135def dtensor_shutdown_tpu_system(): 136 """Shutdown TPU system.""" 137 138 @def_function.function 139 def _shutdown_tpu_system(): 140 return gen_dtensor_ops.shutdown_tpu_system() 141 142 success = _shutdown_tpu_system() if context.is_tfrt_enabled() else True 143 if success: 144 logging.info("TPU system shut down.") 145 else: 146 logging.warning("TPU system fails to shut down.") 147 148 149@tf_export("experimental.dtensor.initialize_tpu_system", v1=[]) 150def dtensor_initialize_tpu_system(enable_coordination_service=False): 151 """Initialize the TPU devices. 152 153 This functions performs additional TPU related initialization after 154 calling `dtensor.initialize_multi_client` to initialize multi-client DTensor. 155 Refer to `dtensor.initialize_multi_client` for relevant environment 156 variables that controls the initialization of multi-client DTensor. 157 158 Args: 159 enable_coordination_service: If true, enable distributed coordination 160 service to make sure that workers know the devices on each other, a 161 prerequisite for data transfer through cross-worker rendezvous. 162 163 Raises: 164 RuntimeError: If running inside a tf.function. 165 NotFoundError: If no TPU devices found in eager mode. 166 """ 167 168 assert context.executing_eagerly() 169 170 # Reconfigure TensorFlow to use TFRT TPU runtime if requested. 171 _configure_tpu_runtime() 172 173 # Collective GRPC servers are only necessary in mutli-client setup. 174 # Single clients can use local mode of collectives. 175 if api.num_clients() > 1 and not multi_client_util.is_initialized(): 176 multi_client_util.initialize_multi_client_cluster( 177 job_name=api.job_name(), 178 dtensor_jobs=api.jobs(), 179 client_id=api.client_id(), 180 collective_leader=api.full_job_name(task_id=0), 181 enable_coordination_service=enable_coordination_service) 182 183 # Make sure the server change is fully propagated before attempting to run 184 # the core ID merging logic below. 185 context.ensure_initialized() 186 context.async_wait() 187 context.context()._clear_caches() # pylint: disable=protected-access 188 189 @function.defun 190 def _tpu_init_fn(): 191 return gen_dtensor_ops.configure_and_initialize_global_tpu() 192 193 @def_function.function 194 def _set_global_tpu_array_fn(topology_proto): 195 gen_dtensor_ops.d_tensor_set_global_tpu_array(topology_proto) 196 197 try: 198 with ops.device("/job:" + api.full_job_name() + "/device:TPU_SYSTEM:0"): # pylint: disable=protected-access 199 my_core_ids = _tpu_init_fn() 200 logging.info("TPU core IDs: %s", my_core_ids) 201 context.initialize_logical_devices() 202 203 # Configure virtual CPUs that is 1:1 mapped to TPU cores. 204 context.context().set_logical_cpu_devices( 205 len(api.local_devices(_TPU_DEVICE_TYPE)), 206 tf_device.DeviceSpec( 207 job=api.job_name(), replica=0, task=api.client_id()).to_string()) 208 209 # `my_core_ids` contains the IDs of TPU cores attached to this host. 210 # 211 # To generate correct and efficient XLA AllReduce group assignment, we must 212 # merge these arrays from all hosts and broadcast the result back to all 213 # hosts, so all hosts can use these mappings in their MLIR passes. 214 # 215 # This is essentially doing what WaitForDistributedTpuOp and 216 # SetGlobalTPUArrayOp do, in our multi-client environment. 217 task_id = api.client_id() 218 num_tasks = api.num_clients() 219 num_devices = api.num_global_devices(_TPU_DEVICE_TYPE) 220 num_devices_per_task = int(num_devices / num_tasks) 221 222 # Create a one-time use mesh and layout just for merging core IDs. 223 mesh = layout_lib.Mesh([_MESH_DIM_X], 224 *_create_device_array((num_devices,), 225 _TPU_DEVICE_TYPE, 226 api.client_id())) 227 layout = layout_lib.Layout([_MESH_DIM_X, layout_lib.UNSHARDED], mesh) 228 device = dtensor_device.DTensorDevice(meshes=[mesh]) 229 logging.info("TPU core locations: %s", 230 device.tpu_core_ids_to_locations(my_core_ids)) 231 232 # At this point, we don't know which cores are attached to other hosts. 233 # The core ID mappings in the runtime haven't been set yet. 234 # 235 # The core ID merging AllReduce below is carefully written so it works 236 # without needing correct core mappings to be set in the runtime. We will 237 # use this AllReduce's result to set the core ID mappings, and all future 238 # user-initiated AllReduces will use the mappings. 239 # 240 # The runtime is hard-coded to ignore core ID mappings on this AllReduce. 241 all_core_ids = np.zeros([num_devices], dtype=np.int32) 242 for i in range(len(my_core_ids)): 243 all_core_ids[task_id * num_devices_per_task + i] = my_core_ids[i] 244 245 # Only one local device gets valid input: 8 local core IDs among 246 # (num_tasks - 1) * 8 zeros. The 8 core IDs are set using task ID as offset. 247 # The other 7 local devices get zero inputs. All devices on all host 248 # participate in one AllReduce, whose result will be core IDs arranged by 249 # task-device ordinals. 250 all_core_ids = constant_op.constant([all_core_ids]) 251 zeros = array_ops.zeros_like(all_core_ids) 252 all_core_ids = [all_core_ids] + [zeros] * (num_devices_per_task - 1) 253 254 with ops.device(device.name): 255 all_core_ids = device.pack(all_core_ids, layout) 256 all_core_ids = math_ops.reduce_sum(all_core_ids, axis=[0]) 257 unpacked_all_tpu_ids = device.unpack(all_core_ids) 258 259 all_core_ids = list(unpacked_all_tpu_ids[0].numpy()) 260 logging.info("All TPU core IDs: %s", all_core_ids) 261 262 # Set the default core ID mappings in the runtime for legacy code and tests. 263 # 264 # Legacy code and tests create TPU meshes directly without using the 265 # `create_tpu_mesh` function below. Those meshes have global device IDs 266 # equal to TF task-device ordinals. The `all_core_ids` array happens to 267 # arrange core IDs by TF task-device ordinals. Using this array on those 268 # meshes guarantee correct although inefficient results. 269 device.set_tpu_core_ids("", all_core_ids) 270 271 # Remember enough global, immutable information to be able to build any ring 272 # we want prescribed by `create_tpu_mesh` in the future. 273 global _all_core_ids 274 _all_core_ids = all_core_ids 275 276 all_core_locations = device.tpu_core_ids_to_locations(all_core_ids) 277 all_core_locations = [ 278 _CoreLocation(l[0], l[1], l[2], l[3]) for l in all_core_locations 279 ] 280 global _all_core_locations 281 _all_core_locations = all_core_locations 282 logging.info("All TPU core locations: %s", all_core_locations) 283 284 tpu_topology = _create_tpu_topology(all_core_locations, num_tasks, 285 num_devices_per_task) 286 287 _set_global_tpu_array_fn(tpu_topology.serialized()) 288 global _tpu_topology 289 _tpu_topology = tpu_topology 290 logging.vlog(1, "TPU Topology: %s, %s", tpu_topology.mesh_shape, 291 tpu_topology.device_coordinates) 292 293 global _dtensor_device 294 _dtensor_device = device 295 296 context.async_wait() 297 298 except errors.InvalidArgumentError as e: 299 raise errors.NotFoundError( 300 None, None, 301 "Initialization failed, no valid TPUs found. " + str(e)) from e 302 303 except errors.InternalError as e: 304 logging.error("Hit internal error during TPU system initialization. " 305 + "It is likely hareware failure. \nPlease check the error " 306 + "messages above to see whether that's the case. \nIf so, " 307 + "consider to restart the job or try another machine.") 308 raise e 309 310 # Optionally exchange heartbeats between workers every minute. 311 if api.num_clients() > 1 and api.heartbeat_enabled(): 312 logging.info( 313 "Starting DTensor heartbeat service exchanging signals every 10 minutes" 314 ) 315 heartbeat.start(period=180) 316 317 # Clear out the eager context caches since the memory is invalid now. 318 logging.info("Clearing out eager caches") 319 context.context()._clear_caches() # pylint: disable=protected-access 320 321 322def _enumerate_cores(bounds: List[int], ring_bounds: List[int], 323 ring_sizes: List[int], host_bounds: List[int], 324 host_sizes: List[int]) -> List[List[int]]: 325 """Enumerates cores within `bounds` from fatest to slowest varying axes. 326 327 Args: 328 bounds: Upper bounds of axes, from fastest to slowest varying. 329 ring_bounds: Upper bounds of ring size per axis in the same axis order. 330 ring_sizes: Number consecutive cores in the ring built so far, cumulatively. 331 host_bounds: Number of axis values per host in the same axis order. 332 host_sizes: Number consecutive cores on one host, cumulatively. 333 334 Returns: 335 Cores represented as a list of 4 integers in the same axis order. 336 """ 337 if not bounds: 338 return [[]] 339 340 # Recursively enumerate cores under all but the slowest varying axis. 341 partials = _enumerate_cores(bounds[:-1], ring_bounds[:-1], ring_sizes[:-1], 342 host_bounds[:-1], host_sizes[:-1]) 343 344 # Append the slowest varying axis to the end of all partial results. 345 # From ring_i|j to host_i|j to core_i|j, use progressively smaller or equal 346 # iteration groupings until every one of the bounds[-1] * len(partials) 347 # combinations is iterated on. 348 # Despite the six levels of nested loops below, the total time complexity for 349 # this invocation is O(N), where N is the number of cores in the topology. 350 results = [] 351 for ring_i in range(0, bounds[-1], ring_bounds[-1]): 352 for ring_j in range(0, len(partials), ring_sizes[-1]): 353 for host_i in range(ring_i, ring_i + ring_bounds[-1], host_bounds[-1]): 354 for host_j in range(ring_j, ring_j + ring_sizes[-1], host_sizes[-1]): 355 for i in range(host_i, host_i + host_bounds[-1]): 356 for j in range(host_j, host_j + host_sizes[-1]): 357 results.append(partials[j] + [i]) 358 return results 359 360 361def _enumerate_core_locations(bounds: List[int], ring_bounds: List[int], 362 axes: List[str], 363 can_split_host_across_rings: bool, 364 ring_size: int) -> List[_CoreLocation]: 365 """Enumerates all possible core locations under the axis iteration order. 366 367 Args: 368 bounds: A list of 4 positive integers, upper bound values for x, y, z, core. 369 ring_bounds: A list of 4 positive integers, upper bound values for ring size 370 in x, y, z, core axes. 371 axes: A permutation of ["x", "y", "z", "core"], the axis iteration order. 372 can_split_host_across_rings: If true, devices attached to the same host may 373 get assigned to different rings. 374 ring_size: Number of devices in a ring, only for argument validation. 375 376 Returns: 377 A list of all CoreLocation objects defined in a TPU slice of shape `bounds`, 378 sorted by axis iteration order specified by `axes`. 379 380 For example, given bounds=[2, 2, 1, 2] and axes=["core", "z", "y", "x"], 381 return 8 core locations expressed in (x, y, z, core) format but iterated in 382 core -> z -> y -> x order (fatest to slowest varying): 383 384 [_CoreLocation(0, 0, 0, 0), 385 _CoreLocation(0, 0, 0, 1), 386 _CoreLocation(0, 1, 0, 0), 387 _CoreLocation(0, 1, 0, 1), 388 _CoreLocation(1, 0, 0, 0), 389 _CoreLocation(1, 0, 0, 1), 390 _CoreLocation(1, 1, 0, 0), 391 _CoreLocation(1, 1, 0, 1)] 392 393 Raises: 394 ValueError: If ring_size cannot be fulfilled without splitting hosts. 395 """ 396 397 num_cores_per_chip = bounds[3] 398 if num_cores_per_chip != 1 and num_cores_per_chip != 2: 399 raise ValueError("Unsupported TPU slice size: %s" % bounds) 400 401 # Translate `axes` from string to integer format. 402 axes = [{"x": 0, "y": 1, "z": 2, "core": 3}[axis] for axis in axes] 403 # Reorder bounds from fastest to slowest varying axes. 404 bounds = [bounds[i] for i in axes] 405 406 # Set and validate host_bounds. 407 if can_split_host_across_rings: 408 # If we can split hosts, shrink every host to effectively contain 1 device. 409 host_bounds = [1, 1, 1, 1] 410 elif np.prod(bounds) <= 2: 411 # We must be running on 1x1 or 1x1x1 Forge. 412 host_bounds = [[1, 1, 1, num_cores_per_chip][i] for i in axes] 413 else: 414 # Other cases including 2x2 Forge and Borg must use a full donut. 415 host_bounds = [[2, 2, 1, num_cores_per_chip][i] for i in axes] 416 # host_sizes is the cumulative products of host_bounts. 417 host_sizes = [1] 418 for host_bound in host_bounds: 419 host_sizes.append(host_sizes[-1] * host_bound) 420 host_size = host_sizes.pop() 421 # When can_split_host_across_rings is false, a ring must contain at least as 422 # many devices as a host has. 423 if ring_size < host_size: 424 assert not can_split_host_across_rings 425 raise ValueError( 426 "Rings too small for can_split_host_across_rings = False: %d" % 427 ring_size) 428 429 # Reorder ring_bounds and validate it's element-wise >= host_bounds. 430 ring_bounds = [ring_bounds[i] for i in axes] 431 if ring_bounds < host_bounds: 432 raise ValueError("ring_bounds %s should be >= host_bounds %s" % 433 (ring_bounds, host_bounds)) 434 ring_sizes = [1] 435 # ring_sizes is the cumulative products of ring_bounds. 436 for ring_bound in ring_bounds: 437 ring_sizes.append(ring_sizes[-1] * ring_bound) 438 ring_sizes.pop() 439 440 # Enumerate cores in the given iteration order. Each core is represented as a 441 # list of int, which are offsets from fatest to slowest varying axes. 442 cores = _enumerate_cores(bounds, ring_bounds, ring_sizes, host_bounds, 443 host_sizes) 444 # Reorder offsets of each core back to the x, y, z, core order. 445 core_locations = [] 446 for core in cores: 447 core = [core[axes.index(i)] for i in range(4)] 448 core_locations.append(_CoreLocation(core[0], core[1], core[2], core[3])) 449 return core_locations 450 451 452def _build_all_reduce_ring(core_locations: List[_CoreLocation], 453 rotate: bool = False) -> List[int]: 454 """Reorders a list of TPU cores to optimize for AllReduce performance. 455 456 This is ported from the C++ tensorflow::BuildAllReduceRing function, 457 mixed with some logic from TF TPU's device_assignment._ring_3d. 458 459 Args: 460 core_locations: A list of core locations expressed as [x, y, z, core]. 461 rotate: If true, scan the cores in a column-major order. False by default. 462 463 Returns: 464 A permutation of the input list such that neighbors in the sequence are 465 nearby in the TPU topology. 466 """ 467 468 permutation = list(range(len(core_locations))) 469 if not permutation: 470 return permutation 471 logging.vlog(2, "Core locations in: %s", core_locations) 472 473 first_column = min([l.x for l in core_locations]) 474 first_row = min([l.y for l in core_locations]) 475 same_z = (len(set([l.z for l in core_locations])) == 1) 476 logging.vlog(2, "first_column: %d", first_column) 477 logging.vlog(2, "first_row: %d", first_row) 478 logging.vlog(2, "same_z: %s", same_z) 479 480 def _cmp_2d(ia: int, ib: int) -> int: 481 if not rotate: 482 a = core_locations[ia] 483 b = core_locations[ib] 484 485 # Order the first column last in the sequence, except for the first row. 486 a_first = (a.x == first_column and a.y != first_row) 487 b_first = (b.x == first_column and b.y != first_row) 488 if a_first != b_first: 489 return -1 if b_first else 1 490 491 # Order rows in increasing order, unless in the first column. 492 if a.y != b.y: 493 return b.y - a.y if a_first else a.y - b.y 494 495 # Order even rows left to right, odd rows right to left. 496 if a.x != b.x: 497 return a.x - b.x if a.y % 2 == 0 else b.x - a.x 498 499 # Order cores in increasing order. 500 return a.core - b.core 501 else: 502 a = core_locations[ia] 503 b = core_locations[ib] 504 505 # Order the first row last in the sequence, except for the first column. 506 a_first = (a.y == first_row and a.x != first_column) 507 b_first = (b.y == first_row and b.x != first_column) 508 if a_first != b_first: 509 return -1 if b_first else 1 510 511 # Order columns in increasing order, unless in the first row. 512 if a.x != b.x: 513 return b.x - a.x if a_first else a.x - b.x 514 515 # Order even columns top down, odd columns bottom up. 516 if a.y != b.y: 517 return a.y - b.y if a.x % 2 == 0 else b.y - a.y 518 519 # Order cores in increasing order. 520 return a.core - b.core 521 522 def _cmp_3d(ia: int, ib: int) -> int: 523 a = core_locations[ia] 524 b = core_locations[ib] 525 526 a_corner = (a.x == first_column and a.y == first_row) 527 b_corner = (b.x == first_column and b.y == first_row) 528 529 # If both are in the corner, order in reverse z then core order. 530 if a_corner and b_corner: 531 return b.z - a.z if a.z != b.z else a.core - b.core 532 533 # Corner cores always go after non-corner cores. 534 if a_corner != b_corner: 535 return -1 if b_corner else 1 536 537 # Both non-corner cores are on the same z-plane. Reverse odd z-planes. 538 if a.z == b.z: 539 return _cmp_2d(ia, ib) if a.z % 2 == 0 else -_cmp_2d(ia, ib) 540 541 # Both non-corner cores are on different z-planes. Smaller z goes first. 542 return a.z - b.z 543 544 # If all cores are on the same z-plane, order as usual. Otherwise, order 545 # neighbor z-planes in opposite orders. Stack all z-planes along the z axis 546 # and connect them in one corner. 547 if same_z: 548 permutation.sort(key=functools.cmp_to_key(_cmp_2d)) 549 else: 550 permutation.sort(key=functools.cmp_to_key(_cmp_3d)) 551 logging.vlog(2, "Permutation out: %s", permutation) 552 return permutation 553 554 555def _build_orthogonal_rings( 556 core_locations: List[_CoreLocation], ring_size: int, 557 rotate_ring_across_rings: bool) -> List[_CoreLocation]: 558 """Build two all-reduce rings orthogonal to each other. 559 560 One ring includes every `ring_size` consecutive core locations. It is usually 561 applied to the model-parallel dimension of a mesh to achieve best 1D 562 all-reduce performance. The other ring includes core locations separated by 563 a stride of `ring_size`. It is usually applied to the data-parallel dimension 564 of a mesh to get predictable strided all-reduce performance. 565 566 Args: 567 core_locations: A list of core locations expressed as [x, y, z, core]. 568 ring_size: The number of core locations in the consecutive ring. 569 rotate_ring_across_rings: Build column-major secondary rings. 570 571 Returns: 572 A permutation of the input list forming the described rings. 573 """ 574 # Build a ring for the first `ring_size` cores, and apply that permutation to 575 # every group of `ring_size` cores. 576 num_cores = len(core_locations) 577 permutation = _build_all_reduce_ring(core_locations[:ring_size]) 578 for r in range(0, num_cores, ring_size): 579 core_locations[r:r + ring_size] = [ 580 core_locations[r + permutation[i]] for i in range(ring_size) 581 ] 582 logging.vlog(1, "Permutated core locations: %s", core_locations) 583 584 # Build a "ring" for the collection of devices consisting of the 0th device 585 # from every group, and apply that permutation to every i-th device group. 586 # This is achieved by transposing the list and back. 587 transposed = [] 588 for i in range(ring_size): 589 transposed += [ 590 core_locations[g + i] for g in range(0, num_cores, ring_size) 591 ] 592 593 num_rings = int(num_cores / ring_size) 594 permutation = _build_all_reduce_ring( 595 transposed[:num_rings], rotate=rotate_ring_across_rings) 596 for r in range(0, num_cores, num_rings): 597 transposed[r:r + num_rings] = [ 598 transposed[r + permutation[i]] for i in range(num_rings) 599 ] 600 601 untransposed = [] 602 for i in range(num_rings): 603 untransposed += [transposed[g + i] for g in range(0, num_cores, num_rings)] 604 logging.vlog(1, "Stride-permutated core locations: %s", untransposed) 605 606 return untransposed 607 608 609def create_tpu_mesh(mesh_dim_names: List[str], 610 mesh_shape: List[int], 611 mesh_name: str, 612 ring_dims: Optional[int] = None, 613 ring_axes: Optional[List[str]] = None, 614 ring_bounds: Optional[List[int]] = None, 615 can_split_host_across_rings: bool = True, 616 build_ring_across_rings: bool = False, 617 rotate_ring_across_rings: bool = False) -> layout_lib.Mesh: 618 """Returns a TPU mesh optimized for AllReduce ring reductions. 619 620 Only as many as leading axes specified by `ring_axes` as necessary will be 621 used to build rings, as long as the subslice formed by these axes have enough 622 cores to contain a ring of the required size. The leftover axes in `ring_axes` 623 won't affect results. 624 625 Args: 626 mesh_dim_names: List of mesh dimension names. 627 mesh_shape: Shape of the mesh. 628 mesh_name: A unique name for the mesh. If empty, internally generate one. 629 ring_dims: Optional; The number of leading (ring_dims > 0) or trailing 630 (ring_dims < 0) mesh dimensions to build rings for. If unspecified, build 631 rings for all but the first dimension. 632 ring_axes: Optional; A permutation of ["x", "y", "z", "core"], specifying 633 the order of TPU topology axes to build rings in. If unspecified, default 634 to ["core", "x", "y", "z"]. 635 ring_bounds: Optional; The maximum number of devices on each axis, in the x, 636 y, z, core order. If unspecified, default to physical topology limits. 637 can_split_host_across_rings: Optional; If true, devices attached to the same 638 host (i.e., DTensor client) may get assigned to different rings. Setting 639 it to false may cause some combinations of arguments to be infeasible; see 640 DeviceAssignmentTest.testCreateMesh[No]SplittingHosts* for examples. 641 build_ring_across_rings: Optional; If true, also build a data-parallel ring 642 across model-parallel rings. This ring could be strided. 643 rotate_ring_across_rings: Optional; If true, build the data-parallel ring in 644 column-major instead of row-major order. 645 """ 646 647 logging.info("Building a TPU mesh %s of shape %s", mesh_name, mesh_shape) 648 logging.info("Requested ring_dims: %s", ring_dims) 649 logging.info("Requested ring_axes: %s", ring_axes) 650 logging.info("Requested ring_bounds: %s", ring_bounds) 651 logging.info("Requested can_split_host_across_rings: %s", 652 can_split_host_across_rings) 653 if not mesh_name: 654 mesh_name = "mesh_%f" % time.time() 655 logging.info("Requested mesh_name: %s", mesh_name) 656 657 # By default, build rings for all but the first (usually batch) dimension. 658 if ring_dims is None: 659 ring_dims = 1 - len(mesh_shape) 660 elif ring_dims < -len(mesh_shape) or ring_dims > len(mesh_shape): 661 raise ValueError("Invalid ring_dims value: %d" % ring_dims) 662 logging.info("Actual ring_dims: %s", ring_dims) 663 664 # By default, vary axes in the core -> x -> y -> z order. 665 if ring_axes is None: 666 ring_axes = ["core", "x", "y", "z"] 667 elif len(ring_axes) != 4: 668 raise ValueError("Expected 4 elements in ring_axes, got %s" % ring_axes) 669 elif sorted(ring_axes) != ["core", "x", "y", "z"]: 670 raise ValueError("Invalid ring_axes value: %s" % ring_axes) 671 logging.info("Actual ring_axes: %s", ring_axes) 672 673 # Validate ring_bounds values. 674 if _tpu_topology is None: 675 raise ValueError( 676 "Invalid TPU topology, run dtensor.initialize_tpu_system() first") 677 topology_shape = list(_tpu_topology.mesh_shape) 678 if ring_bounds is None: 679 ring_bounds = topology_shape 680 elif len(ring_bounds) != 4: 681 raise ValueError("Expected 4 elements in ring_bounds, got %s" % ring_bounds) 682 elif ring_bounds > topology_shape: 683 raise ValueError("ring_bounds %s should be <= topology sizes %s" % 684 (ring_bounds, topology_shape)) 685 logging.info("Actual ring_bounds: %s", ring_bounds) 686 687 # Compute ring_size, the number of cores in a ring. 688 if ring_dims > 0: 689 ring_size = np.prod(mesh_shape[:ring_dims]) 690 elif ring_dims < 0: 691 ring_size = np.prod(mesh_shape[ring_dims:]) 692 else: 693 ring_size = 1 # single-core rings 694 logging.info("Actual ring_size: %d", ring_size) 695 696 # Rearrange all cores according to the axis iteration order. 697 global_core_locations = _enumerate_core_locations( 698 topology_shape, ring_bounds, ring_axes, can_split_host_across_rings, 699 ring_size) 700 logging.vlog(1, "Enumerated core locations: %s", global_core_locations) 701 num_cores = len(global_core_locations) 702 703 # The mesh to be created must use all TPU cores in the system. 704 mesh_size = np.prod(mesh_shape) 705 if mesh_size != num_cores: 706 raise ValueError( 707 "Invalid mesh size: mesh shape %s cannot 1:1 map to %d TPU cores" % 708 (mesh_shape, num_cores)) 709 710 # Build a ring for the `ring_size` dimension and, if required, a strided ring 711 # for the orthogonal dimension. 712 if build_ring_across_rings: 713 global_core_locations = _build_orthogonal_rings(global_core_locations, 714 ring_size, 715 rotate_ring_across_rings) 716 else: 717 permutation = _build_all_reduce_ring(global_core_locations[:ring_size]) 718 for r in range(0, num_cores, ring_size): 719 global_core_locations[r:r + ring_size] = [ 720 global_core_locations[r + permutation[i]] for i in range(ring_size) 721 ] 722 logging.vlog(1, "Permutated core locations: %s", global_core_locations) 723 724 # For this point on, change from List[CoreLocation] to List[List[int]] for 725 # easier interaction with the C++ API. 726 global_core_locations = [l.to_list() for l in global_core_locations] 727 if _dtensor_device is None: 728 raise ValueError( 729 "Invalid system device, run dtensor.initialize_tpu_system() first") 730 global_core_ids = _dtensor_device.tpu_core_locations_to_ids( 731 global_core_locations) 732 733 # Store a per-mesh mapping in the runtime. 734 _dtensor_device.set_tpu_core_ids(mesh_name, global_core_ids) 735 736 # Create the mesh by manually specifying local_device_ids. 737 local_core_locations = _tpu_topology.device_coordinates[api.client_id()] 738 indexes = [ 739 global_core_locations.index(list(local_core_location)) 740 for local_core_location in local_core_locations 741 ] 742 global_device_ids, local_device_ids, local_device_list = _create_device_array( 743 mesh_shape, _TPU_DEVICE_TYPE, None, local_device_ids=indexes) 744 return layout_lib.Mesh(mesh_dim_names, global_device_ids, local_device_ids, 745 local_device_list, mesh_name) 746 747 748def get_device_ids(mesh: layout_lib.Mesh, 749 client_id: Optional[int] = None) -> List[int]: 750 """Returns the device IDs of all TPU cores local to the given client. 751 752 A device ID is a non-negative integer that uniquely identifies a device in the 753 mesh. For example, for a 2x2 mesh ('x', 'y'), this function returns a 754 permutation of [0, 1, 2, 3]. 755 756 Note that device IDs and device locations are equivalent. The former is a 757 linearization of the latter along mesh dimensions. 758 759 Args: 760 mesh: A TPU mesh. 761 client_id: Optional; A DTensor client ID. If empty, query this client. 762 """ 763 764 if mesh.device_type() != _TPU_DEVICE_TYPE: 765 raise ValueError("The mesh must be a TPU mesh") 766 767 if client_id is None or client_id == api.client_id(): 768 return mesh.local_device_ids() 769 770 # It's not clear we should ever allow a client to query other clients for 771 # their device IDs. 772 raise NotImplementedError( 773 "Looking up other clients' device IDs is not supported") 774 775 776def get_device_locations( 777 mesh: layout_lib.Mesh, 778 client_id: Optional[int] = None) -> List[Dict[str, int]]: 779 """Returns the device locations of all TPU cores local to the given client. 780 781 A device location is a dictionary from dimension names to indices on those 782 dimensions. For example, for a 2x2 mesh ('x', 'y'), this function returns a 783 permutation of this list: 784 785 [{'x': 0, 'y': 0}, 786 {'x': 0, 'y': 1}, 787 {'x': 1, 'y': 0}, 788 {'x': 1, 'y': 1}]. 789 790 Note that device IDs and device locations are equivalent. The former is a 791 linearization of the latter along mesh dimensions. 792 793 Args: 794 mesh: A TPU mesh. 795 client_id: Optional; A DTensor client ID. If empty, query this client. 796 """ 797 798 if mesh.device_type() != _TPU_DEVICE_TYPE: 799 raise ValueError("The mesh must be a TPU mesh") 800 801 if client_id is None or client_id == api.client_id(): 802 return mesh.local_device_locations() 803 804 # It's not clear we should ever allow a client to query other clients for 805 # their device locations. 806 raise NotImplementedError( 807 "Looking up other clients' device locations is not supported") 808 809 810def _configure_tpu_runtime(): 811 was_enabled = context.is_tfrt_enabled() 812 if ("tpu_use_tfrt" in flags.FLAGS and flags.FLAGS["tpu_use_tfrt"].value): 813 tfrt_utils.set_tfrt_enabled(True) 814 if not was_enabled: 815 context._reset_context() # pylint:disable=protected-access 816