• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""TPU-specific utilities for DTensor."""
16
17import functools
18import time
19from typing import List, Optional, Dict
20
21from absl import flags
22import numpy as np
23
24from tensorflow.dtensor.python import api
25from tensorflow.dtensor.python import dtensor_device
26from tensorflow.dtensor.python import gen_dtensor_ops
27from tensorflow.dtensor.python import heartbeat
28from tensorflow.dtensor.python import layout as layout_lib
29from tensorflow.dtensor.python import multi_client_util
30from tensorflow.python.eager import context
31from tensorflow.python.eager import def_function
32from tensorflow.python.eager import function
33from tensorflow.python.framework import constant_op
34from tensorflow.python.framework import device as tf_device
35from tensorflow.python.framework import errors
36from tensorflow.python.framework import ops
37from tensorflow.python.framework import tfrt_utils
38from tensorflow.python.ops import array_ops
39from tensorflow.python.ops import math_ops
40from tensorflow.python.platform import tf_logging as logging
41from tensorflow.python.tpu import topology
42from tensorflow.python.util.tf_export import tf_export
43
44_INITIALIZED_TPU_SYSTEMS = {}
45_MESH_DIM_X = "x"
46_TPU_DEVICE_TYPE = "TPU"
47
48# A dedicated, hidden device used to make C++ API calls.
49_dtensor_device = None
50
51# `_topology._mesh_shape` contains the TPU hardware slice size.
52# `_topology.device_coordinates` maps TF task-device ordinals to TPU core IDs.
53_tpu_topology = None
54
55# Cache core ID <-> location mappings so we need not make repeated C++ calls.
56# Both are indexed by TF task-device ordinals.
57_all_core_ids = None
58_all_core_locations = None
59
60
61class _CoreLocation:
62  """Represents a TPU core's location in the mesh."""
63
64  def __init__(self, x: int = 0, y: int = 0, z: int = 0, core: int = 0):
65    self.x = x
66    self.y = y
67    self.z = z
68    self.core = core
69
70  def __eq__(self, other):
71    if not isinstance(other, _CoreLocation):
72      return False
73    return self.x == other.x and self.y == other.y and self.z == other.z and self.core == other.core
74
75  def __ne__(self, other):
76    if not isinstance(other, _CoreLocation):
77      return True
78    return not self == other
79
80  def __hash__(self):
81    return hash((self.x, self.y, self.z, self.core))
82
83  def __repr__(self):
84    return f"{type(self).__name__}(x={self.x}, y={self.y}, z={self.z}, core={self.core})"
85
86  def to_list(self):
87    return [self.x, self.y, self.z, self.core]
88
89
90def _create_device_array(shape, device_type, host_id, local_device_ids=None):
91  """Returns ID and device lists that can be used to create a mesh."""
92  num_global_devices = api.num_global_devices(device_type)
93  global_device_ids = np.arange(num_global_devices).reshape(shape)
94  local_device_list = api.local_devices(device_type)
95
96  # User can specify local_device_ids or use default list for multi host.
97  num_local_devices = len(local_device_list)
98  local_device_ids = [
99      x + host_id * num_local_devices for x in range(num_local_devices)
100  ] if not local_device_ids else local_device_ids
101
102  return global_device_ids, local_device_ids, local_device_list
103
104
105def _create_tpu_topology(core_locations: List[_CoreLocation], num_tasks: int,
106                         num_devices_per_task: int) -> topology.Topology:
107  """Returns a Topology object build from a _CoreLocation list.
108
109  Args:
110    core_locations: A list of _CoreLocation objects sorted first by TF task ID
111      and then by per-task device ordinals.
112    num_tasks: The number of TF tasks in the cluster.
113    num_devices_per_task: The number of TPU devices local to each task.
114  """
115
116  assert min([l.x for l in core_locations]) == 0
117  assert min([l.y for l in core_locations]) == 0
118  assert min([l.z for l in core_locations]) == 0
119  assert min([l.core for l in core_locations]) == 0
120  x_max = max([l.x for l in core_locations])
121  y_max = max([l.y for l in core_locations])
122  z_max = max([l.z for l in core_locations])
123  core_max = max([l.core for l in core_locations])
124  mesh_shape = [x_max + 1, y_max + 1, z_max + 1, core_max + 1]
125
126  device_coordinates = [[l.x, l.y, l.z, l.core] for l in core_locations]
127  device_coordinates = np.asarray(device_coordinates).reshape(
128      num_tasks, num_devices_per_task, 4)
129
130  return topology.Topology(
131      mesh_shape=mesh_shape, device_coordinates=device_coordinates)
132
133
134@tf_export("experimental.dtensor.shutdown_tpu_system", v1=[])
135def dtensor_shutdown_tpu_system():
136  """Shutdown TPU system."""
137
138  @def_function.function
139  def _shutdown_tpu_system():
140    return gen_dtensor_ops.shutdown_tpu_system()
141
142  success = _shutdown_tpu_system() if context.is_tfrt_enabled() else True
143  if success:
144    logging.info("TPU system shut down.")
145  else:
146    logging.warning("TPU system fails to shut down.")
147
148
149@tf_export("experimental.dtensor.initialize_tpu_system", v1=[])
150def dtensor_initialize_tpu_system(enable_coordination_service=False):
151  """Initialize the TPU devices.
152
153  This functions performs additional TPU related initialization after
154  calling `dtensor.initialize_multi_client` to initialize multi-client DTensor.
155  Refer to `dtensor.initialize_multi_client` for relevant environment
156  variables that controls the initialization of multi-client DTensor.
157
158  Args:
159    enable_coordination_service: If true, enable distributed coordination
160      service to make sure that workers know the devices on each other, a
161      prerequisite for data transfer through cross-worker rendezvous.
162
163  Raises:
164    RuntimeError: If running inside a tf.function.
165    NotFoundError: If no TPU devices found in eager mode.
166  """
167
168  assert context.executing_eagerly()
169
170  # Reconfigure TensorFlow to use TFRT TPU runtime if requested.
171  _configure_tpu_runtime()
172
173  # Collective GRPC servers are only necessary in mutli-client setup.
174  # Single clients can use local mode of collectives.
175  if api.num_clients() > 1 and not multi_client_util.is_initialized():
176    multi_client_util.initialize_multi_client_cluster(
177        job_name=api.job_name(),
178        dtensor_jobs=api.jobs(),
179        client_id=api.client_id(),
180        collective_leader=api.full_job_name(task_id=0),
181        enable_coordination_service=enable_coordination_service)
182
183  # Make sure the server change is fully propagated before attempting to run
184  # the core ID merging logic below.
185  context.ensure_initialized()
186  context.async_wait()
187  context.context()._clear_caches()  # pylint: disable=protected-access
188
189  @function.defun
190  def _tpu_init_fn():
191    return gen_dtensor_ops.configure_and_initialize_global_tpu()
192
193  @def_function.function
194  def _set_global_tpu_array_fn(topology_proto):
195    gen_dtensor_ops.d_tensor_set_global_tpu_array(topology_proto)
196
197  try:
198    with ops.device("/job:" + api.full_job_name() + "/device:TPU_SYSTEM:0"):  # pylint: disable=protected-access
199      my_core_ids = _tpu_init_fn()
200    logging.info("TPU core IDs: %s", my_core_ids)
201    context.initialize_logical_devices()
202
203    # Configure virtual CPUs that is 1:1 mapped to TPU cores.
204    context.context().set_logical_cpu_devices(
205        len(api.local_devices(_TPU_DEVICE_TYPE)),
206        tf_device.DeviceSpec(
207            job=api.job_name(), replica=0, task=api.client_id()).to_string())
208
209    # `my_core_ids` contains the IDs of TPU cores attached to this host.
210    #
211    # To generate correct and efficient XLA AllReduce group assignment, we must
212    # merge these arrays from all hosts and broadcast the result back to all
213    # hosts, so all hosts can use these mappings in their MLIR passes.
214    #
215    # This is essentially doing what WaitForDistributedTpuOp and
216    # SetGlobalTPUArrayOp do, in our multi-client environment.
217    task_id = api.client_id()
218    num_tasks = api.num_clients()
219    num_devices = api.num_global_devices(_TPU_DEVICE_TYPE)
220    num_devices_per_task = int(num_devices / num_tasks)
221
222    # Create a one-time use mesh and layout just for merging core IDs.
223    mesh = layout_lib.Mesh([_MESH_DIM_X],
224                           *_create_device_array((num_devices,),
225                                                 _TPU_DEVICE_TYPE,
226                                                 api.client_id()))
227    layout = layout_lib.Layout([_MESH_DIM_X, layout_lib.UNSHARDED], mesh)
228    device = dtensor_device.DTensorDevice(meshes=[mesh])
229    logging.info("TPU core locations: %s",
230                 device.tpu_core_ids_to_locations(my_core_ids))
231
232    # At this point, we don't know which cores are attached to other hosts.
233    # The core ID mappings in the runtime haven't been set yet.
234    #
235    # The core ID merging AllReduce below is carefully written so it works
236    # without needing correct core mappings to be set in the runtime. We will
237    # use this AllReduce's result to set the core ID mappings, and all future
238    # user-initiated AllReduces will use the mappings.
239    #
240    # The runtime is hard-coded to ignore core ID mappings on this AllReduce.
241    all_core_ids = np.zeros([num_devices], dtype=np.int32)
242    for i in range(len(my_core_ids)):
243      all_core_ids[task_id * num_devices_per_task + i] = my_core_ids[i]
244
245    # Only one local device gets valid input: 8 local core IDs among
246    # (num_tasks - 1) * 8 zeros. The 8 core IDs are set using task ID as offset.
247    # The other 7 local devices get zero inputs. All devices on all host
248    # participate in one AllReduce, whose result will be core IDs arranged by
249    # task-device ordinals.
250    all_core_ids = constant_op.constant([all_core_ids])
251    zeros = array_ops.zeros_like(all_core_ids)
252    all_core_ids = [all_core_ids] + [zeros] * (num_devices_per_task - 1)
253
254    with ops.device(device.name):
255      all_core_ids = device.pack(all_core_ids, layout)
256      all_core_ids = math_ops.reduce_sum(all_core_ids, axis=[0])
257      unpacked_all_tpu_ids = device.unpack(all_core_ids)
258
259    all_core_ids = list(unpacked_all_tpu_ids[0].numpy())
260    logging.info("All TPU core IDs: %s", all_core_ids)
261
262    # Set the default core ID mappings in the runtime for legacy code and tests.
263    #
264    # Legacy code and tests create TPU meshes directly without using the
265    # `create_tpu_mesh` function below. Those meshes have global device IDs
266    # equal to TF task-device ordinals. The `all_core_ids` array happens to
267    # arrange core IDs by TF task-device ordinals. Using this array on those
268    # meshes guarantee correct although inefficient results.
269    device.set_tpu_core_ids("", all_core_ids)
270
271    # Remember enough global, immutable information to be able to build any ring
272    # we want prescribed by `create_tpu_mesh` in the future.
273    global _all_core_ids
274    _all_core_ids = all_core_ids
275
276    all_core_locations = device.tpu_core_ids_to_locations(all_core_ids)
277    all_core_locations = [
278        _CoreLocation(l[0], l[1], l[2], l[3]) for l in all_core_locations
279    ]
280    global _all_core_locations
281    _all_core_locations = all_core_locations
282    logging.info("All TPU core locations: %s", all_core_locations)
283
284    tpu_topology = _create_tpu_topology(all_core_locations, num_tasks,
285                                        num_devices_per_task)
286
287    _set_global_tpu_array_fn(tpu_topology.serialized())
288    global _tpu_topology
289    _tpu_topology = tpu_topology
290    logging.vlog(1, "TPU Topology: %s, %s", tpu_topology.mesh_shape,
291                 tpu_topology.device_coordinates)
292
293    global _dtensor_device
294    _dtensor_device = device
295
296    context.async_wait()
297
298  except errors.InvalidArgumentError as e:
299    raise errors.NotFoundError(
300        None, None,
301        "Initialization failed, no valid TPUs found. " + str(e)) from e
302
303  except errors.InternalError as e:
304    logging.error("Hit internal error during TPU system initialization. "
305                  + "It is likely hareware failure. \nPlease check the error "
306                  + "messages above to see whether that's the case. \nIf so, "
307                  + "consider to restart the job or try another machine.")
308    raise e
309
310  # Optionally exchange heartbeats between workers every minute.
311  if api.num_clients() > 1 and api.heartbeat_enabled():
312    logging.info(
313        "Starting DTensor heartbeat service exchanging signals every 10 minutes"
314    )
315    heartbeat.start(period=180)
316
317  # Clear out the eager context caches since the memory is invalid now.
318  logging.info("Clearing out eager caches")
319  context.context()._clear_caches()  # pylint: disable=protected-access
320
321
322def _enumerate_cores(bounds: List[int], ring_bounds: List[int],
323                     ring_sizes: List[int], host_bounds: List[int],
324                     host_sizes: List[int]) -> List[List[int]]:
325  """Enumerates cores within `bounds` from fatest to slowest varying axes.
326
327  Args:
328    bounds: Upper bounds of axes, from fastest to slowest varying.
329    ring_bounds: Upper bounds of ring size per axis in the same axis order.
330    ring_sizes: Number consecutive cores in the ring built so far, cumulatively.
331    host_bounds: Number of axis values per host in the same axis order.
332    host_sizes: Number consecutive cores on one host, cumulatively.
333
334  Returns:
335    Cores represented as a list of 4 integers in the same axis order.
336  """
337  if not bounds:
338    return [[]]
339
340  # Recursively enumerate cores under all but the slowest varying axis.
341  partials = _enumerate_cores(bounds[:-1], ring_bounds[:-1], ring_sizes[:-1],
342                              host_bounds[:-1], host_sizes[:-1])
343
344  # Append the slowest varying axis to the end of all partial results.
345  # From ring_i|j to host_i|j to core_i|j, use progressively smaller or equal
346  # iteration groupings until every one of the bounds[-1] * len(partials)
347  # combinations is iterated on.
348  # Despite the six levels of nested loops below, the total time complexity for
349  # this invocation is O(N), where N is the number of cores in the topology.
350  results = []
351  for ring_i in range(0, bounds[-1], ring_bounds[-1]):
352    for ring_j in range(0, len(partials), ring_sizes[-1]):
353      for host_i in range(ring_i, ring_i + ring_bounds[-1], host_bounds[-1]):
354        for host_j in range(ring_j, ring_j + ring_sizes[-1], host_sizes[-1]):
355          for i in range(host_i, host_i + host_bounds[-1]):
356            for j in range(host_j, host_j + host_sizes[-1]):
357              results.append(partials[j] + [i])
358  return results
359
360
361def _enumerate_core_locations(bounds: List[int], ring_bounds: List[int],
362                              axes: List[str],
363                              can_split_host_across_rings: bool,
364                              ring_size: int) -> List[_CoreLocation]:
365  """Enumerates all possible core locations under the axis iteration order.
366
367  Args:
368    bounds: A list of 4 positive integers, upper bound values for x, y, z, core.
369    ring_bounds: A list of 4 positive integers, upper bound values for ring size
370      in x, y, z, core axes.
371    axes: A permutation of ["x", "y", "z", "core"], the axis iteration order.
372    can_split_host_across_rings: If true, devices attached to the same host may
373      get assigned to different rings.
374    ring_size: Number of devices in a ring, only for argument validation.
375
376  Returns:
377    A list of all CoreLocation objects defined in a TPU slice of shape `bounds`,
378    sorted by axis iteration order specified by `axes`.
379
380    For example, given bounds=[2, 2, 1, 2] and axes=["core", "z", "y", "x"],
381    return 8 core locations expressed in (x, y, z, core) format but iterated in
382    core -> z -> y -> x order (fatest to slowest varying):
383
384    [_CoreLocation(0, 0, 0, 0),
385     _CoreLocation(0, 0, 0, 1),
386     _CoreLocation(0, 1, 0, 0),
387     _CoreLocation(0, 1, 0, 1),
388     _CoreLocation(1, 0, 0, 0),
389     _CoreLocation(1, 0, 0, 1),
390     _CoreLocation(1, 1, 0, 0),
391     _CoreLocation(1, 1, 0, 1)]
392
393  Raises:
394    ValueError: If ring_size cannot be fulfilled without splitting hosts.
395  """
396
397  num_cores_per_chip = bounds[3]
398  if num_cores_per_chip != 1 and num_cores_per_chip != 2:
399    raise ValueError("Unsupported TPU slice size: %s" % bounds)
400
401  # Translate `axes` from string to integer format.
402  axes = [{"x": 0, "y": 1, "z": 2, "core": 3}[axis] for axis in axes]
403  # Reorder bounds from fastest to slowest varying axes.
404  bounds = [bounds[i] for i in axes]
405
406  # Set and validate host_bounds.
407  if can_split_host_across_rings:
408    # If we can split hosts, shrink every host to effectively contain 1 device.
409    host_bounds = [1, 1, 1, 1]
410  elif np.prod(bounds) <= 2:
411    # We must be running on 1x1 or 1x1x1 Forge.
412    host_bounds = [[1, 1, 1, num_cores_per_chip][i] for i in axes]
413  else:
414    # Other cases including 2x2 Forge and Borg must use a full donut.
415    host_bounds = [[2, 2, 1, num_cores_per_chip][i] for i in axes]
416  # host_sizes is the cumulative products of host_bounts.
417  host_sizes = [1]
418  for host_bound in host_bounds:
419    host_sizes.append(host_sizes[-1] * host_bound)
420  host_size = host_sizes.pop()
421  # When can_split_host_across_rings is false, a ring must contain at least as
422  # many devices as a host has.
423  if ring_size < host_size:
424    assert not can_split_host_across_rings
425    raise ValueError(
426        "Rings too small for can_split_host_across_rings = False: %d" %
427        ring_size)
428
429  # Reorder ring_bounds and validate it's element-wise >= host_bounds.
430  ring_bounds = [ring_bounds[i] for i in axes]
431  if ring_bounds < host_bounds:
432    raise ValueError("ring_bounds %s should be >= host_bounds %s" %
433                     (ring_bounds, host_bounds))
434  ring_sizes = [1]
435  # ring_sizes is the cumulative products of ring_bounds.
436  for ring_bound in ring_bounds:
437    ring_sizes.append(ring_sizes[-1] * ring_bound)
438  ring_sizes.pop()
439
440  # Enumerate cores in the given iteration order. Each core is represented as a
441  # list of int, which are offsets from fatest to slowest varying axes.
442  cores = _enumerate_cores(bounds, ring_bounds, ring_sizes, host_bounds,
443                           host_sizes)
444  # Reorder offsets of each core back to the x, y, z, core order.
445  core_locations = []
446  for core in cores:
447    core = [core[axes.index(i)] for i in range(4)]
448    core_locations.append(_CoreLocation(core[0], core[1], core[2], core[3]))
449  return core_locations
450
451
452def _build_all_reduce_ring(core_locations: List[_CoreLocation],
453                           rotate: bool = False) -> List[int]:
454  """Reorders a list of TPU cores to optimize for AllReduce performance.
455
456  This is ported from the C++ tensorflow::BuildAllReduceRing function,
457  mixed with some logic from TF TPU's device_assignment._ring_3d.
458
459  Args:
460    core_locations: A list of core locations expressed as [x, y, z, core].
461    rotate: If true, scan the cores in a column-major order. False by default.
462
463  Returns:
464    A permutation of the input list such that neighbors in the sequence are
465    nearby in the TPU topology.
466  """
467
468  permutation = list(range(len(core_locations)))
469  if not permutation:
470    return permutation
471  logging.vlog(2, "Core locations in: %s", core_locations)
472
473  first_column = min([l.x for l in core_locations])
474  first_row = min([l.y for l in core_locations])
475  same_z = (len(set([l.z for l in core_locations])) == 1)
476  logging.vlog(2, "first_column: %d", first_column)
477  logging.vlog(2, "first_row: %d", first_row)
478  logging.vlog(2, "same_z: %s", same_z)
479
480  def _cmp_2d(ia: int, ib: int) -> int:
481    if not rotate:
482      a = core_locations[ia]
483      b = core_locations[ib]
484
485      # Order the first column last in the sequence, except for the first row.
486      a_first = (a.x == first_column and a.y != first_row)
487      b_first = (b.x == first_column and b.y != first_row)
488      if a_first != b_first:
489        return -1 if b_first else 1
490
491      # Order rows in increasing order, unless in the first column.
492      if a.y != b.y:
493        return b.y - a.y if a_first else a.y - b.y
494
495      # Order even rows left to right, odd rows right to left.
496      if a.x != b.x:
497        return a.x - b.x if a.y % 2 == 0 else b.x - a.x
498
499      # Order cores in increasing order.
500      return a.core - b.core
501    else:
502      a = core_locations[ia]
503      b = core_locations[ib]
504
505      # Order the first row last in the sequence, except for the first column.
506      a_first = (a.y == first_row and a.x != first_column)
507      b_first = (b.y == first_row and b.x != first_column)
508      if a_first != b_first:
509        return -1 if b_first else 1
510
511      # Order columns in increasing order, unless in the first row.
512      if a.x != b.x:
513        return b.x - a.x if a_first else a.x - b.x
514
515      # Order even columns top down, odd columns bottom up.
516      if a.y != b.y:
517        return a.y - b.y if a.x % 2 == 0 else b.y - a.y
518
519      # Order cores in increasing order.
520      return a.core - b.core
521
522  def _cmp_3d(ia: int, ib: int) -> int:
523    a = core_locations[ia]
524    b = core_locations[ib]
525
526    a_corner = (a.x == first_column and a.y == first_row)
527    b_corner = (b.x == first_column and b.y == first_row)
528
529    # If both are in the corner, order in reverse z then core order.
530    if a_corner and b_corner:
531      return b.z - a.z if a.z != b.z else a.core - b.core
532
533    # Corner cores always go after non-corner cores.
534    if a_corner != b_corner:
535      return -1 if b_corner else 1
536
537    # Both non-corner cores are on the same z-plane. Reverse odd z-planes.
538    if a.z == b.z:
539      return _cmp_2d(ia, ib) if a.z % 2 == 0 else -_cmp_2d(ia, ib)
540
541    # Both non-corner cores are on different z-planes. Smaller z goes first.
542    return a.z - b.z
543
544  # If all cores are on the same z-plane, order as usual. Otherwise, order
545  # neighbor z-planes in opposite orders. Stack all z-planes along the z axis
546  # and connect them in one corner.
547  if same_z:
548    permutation.sort(key=functools.cmp_to_key(_cmp_2d))
549  else:
550    permutation.sort(key=functools.cmp_to_key(_cmp_3d))
551  logging.vlog(2, "Permutation out: %s", permutation)
552  return permutation
553
554
555def _build_orthogonal_rings(
556    core_locations: List[_CoreLocation], ring_size: int,
557    rotate_ring_across_rings: bool) -> List[_CoreLocation]:
558  """Build two all-reduce rings orthogonal to each other.
559
560  One ring includes every `ring_size` consecutive core locations. It is usually
561  applied to the model-parallel dimension of a mesh to achieve best 1D
562  all-reduce performance. The other ring includes core locations separated by
563  a stride of `ring_size`. It is usually applied to the data-parallel dimension
564  of a mesh to get predictable strided all-reduce performance.
565
566  Args:
567    core_locations: A list of core locations expressed as [x, y, z, core].
568    ring_size: The number of core locations in the consecutive ring.
569    rotate_ring_across_rings: Build column-major secondary rings.
570
571  Returns:
572    A permutation of the input list forming the described rings.
573  """
574  # Build a ring for the first `ring_size` cores, and apply that permutation to
575  # every group of `ring_size` cores.
576  num_cores = len(core_locations)
577  permutation = _build_all_reduce_ring(core_locations[:ring_size])
578  for r in range(0, num_cores, ring_size):
579    core_locations[r:r + ring_size] = [
580        core_locations[r + permutation[i]] for i in range(ring_size)
581    ]
582  logging.vlog(1, "Permutated core locations: %s", core_locations)
583
584  # Build a "ring" for the collection of devices consisting of the 0th device
585  # from every group, and apply that permutation to every i-th device group.
586  # This is achieved by transposing the list and back.
587  transposed = []
588  for i in range(ring_size):
589    transposed += [
590        core_locations[g + i] for g in range(0, num_cores, ring_size)
591    ]
592
593  num_rings = int(num_cores / ring_size)
594  permutation = _build_all_reduce_ring(
595      transposed[:num_rings], rotate=rotate_ring_across_rings)
596  for r in range(0, num_cores, num_rings):
597    transposed[r:r + num_rings] = [
598        transposed[r + permutation[i]] for i in range(num_rings)
599    ]
600
601  untransposed = []
602  for i in range(num_rings):
603    untransposed += [transposed[g + i] for g in range(0, num_cores, num_rings)]
604  logging.vlog(1, "Stride-permutated core locations: %s", untransposed)
605
606  return untransposed
607
608
609def create_tpu_mesh(mesh_dim_names: List[str],
610                    mesh_shape: List[int],
611                    mesh_name: str,
612                    ring_dims: Optional[int] = None,
613                    ring_axes: Optional[List[str]] = None,
614                    ring_bounds: Optional[List[int]] = None,
615                    can_split_host_across_rings: bool = True,
616                    build_ring_across_rings: bool = False,
617                    rotate_ring_across_rings: bool = False) -> layout_lib.Mesh:
618  """Returns a TPU mesh optimized for AllReduce ring reductions.
619
620  Only as many as leading axes specified by `ring_axes` as necessary will be
621  used to build rings, as long as the subslice formed by these axes have enough
622  cores to contain a ring of the required size. The leftover axes in `ring_axes`
623  won't affect results.
624
625  Args:
626    mesh_dim_names: List of mesh dimension names.
627    mesh_shape: Shape of the mesh.
628    mesh_name: A unique name for the mesh. If empty, internally generate one.
629    ring_dims: Optional; The number of leading (ring_dims > 0) or trailing
630      (ring_dims < 0) mesh dimensions to build rings for. If unspecified, build
631      rings for all but the first dimension.
632    ring_axes: Optional; A permutation of ["x", "y", "z", "core"], specifying
633      the order of TPU topology axes to build rings in. If unspecified, default
634      to ["core", "x", "y", "z"].
635    ring_bounds: Optional; The maximum number of devices on each axis, in the x,
636      y, z, core order. If unspecified, default to physical topology limits.
637    can_split_host_across_rings: Optional; If true, devices attached to the same
638      host (i.e., DTensor client) may get assigned to different rings. Setting
639      it to false may cause some combinations of arguments to be infeasible; see
640      DeviceAssignmentTest.testCreateMesh[No]SplittingHosts* for examples.
641    build_ring_across_rings: Optional; If true, also build a data-parallel ring
642      across model-parallel rings. This ring could be strided.
643    rotate_ring_across_rings: Optional; If true, build the data-parallel ring in
644      column-major instead of row-major order.
645  """
646
647  logging.info("Building a TPU mesh %s of shape %s", mesh_name, mesh_shape)
648  logging.info("Requested ring_dims: %s", ring_dims)
649  logging.info("Requested ring_axes: %s", ring_axes)
650  logging.info("Requested ring_bounds: %s", ring_bounds)
651  logging.info("Requested can_split_host_across_rings: %s",
652               can_split_host_across_rings)
653  if not mesh_name:
654    mesh_name = "mesh_%f" % time.time()
655  logging.info("Requested mesh_name: %s", mesh_name)
656
657  # By default, build rings for all but the first (usually batch) dimension.
658  if ring_dims is None:
659    ring_dims = 1 - len(mesh_shape)
660  elif ring_dims < -len(mesh_shape) or ring_dims > len(mesh_shape):
661    raise ValueError("Invalid ring_dims value: %d" % ring_dims)
662  logging.info("Actual ring_dims: %s", ring_dims)
663
664  # By default, vary axes in the core -> x -> y -> z order.
665  if ring_axes is None:
666    ring_axes = ["core", "x", "y", "z"]
667  elif len(ring_axes) != 4:
668    raise ValueError("Expected 4 elements in ring_axes, got %s" % ring_axes)
669  elif sorted(ring_axes) != ["core", "x", "y", "z"]:
670    raise ValueError("Invalid ring_axes value: %s" % ring_axes)
671  logging.info("Actual ring_axes: %s", ring_axes)
672
673  # Validate ring_bounds values.
674  if _tpu_topology is None:
675    raise ValueError(
676        "Invalid TPU topology, run dtensor.initialize_tpu_system() first")
677  topology_shape = list(_tpu_topology.mesh_shape)
678  if ring_bounds is None:
679    ring_bounds = topology_shape
680  elif len(ring_bounds) != 4:
681    raise ValueError("Expected 4 elements in ring_bounds, got %s" % ring_bounds)
682  elif ring_bounds > topology_shape:
683    raise ValueError("ring_bounds %s should be <= topology sizes %s" %
684                     (ring_bounds, topology_shape))
685  logging.info("Actual ring_bounds: %s", ring_bounds)
686
687  # Compute ring_size, the number of cores in a ring.
688  if ring_dims > 0:
689    ring_size = np.prod(mesh_shape[:ring_dims])
690  elif ring_dims < 0:
691    ring_size = np.prod(mesh_shape[ring_dims:])
692  else:
693    ring_size = 1  # single-core rings
694  logging.info("Actual ring_size: %d", ring_size)
695
696  # Rearrange all cores according to the axis iteration order.
697  global_core_locations = _enumerate_core_locations(
698      topology_shape, ring_bounds, ring_axes, can_split_host_across_rings,
699      ring_size)
700  logging.vlog(1, "Enumerated core locations: %s", global_core_locations)
701  num_cores = len(global_core_locations)
702
703  # The mesh to be created must use all TPU cores in the system.
704  mesh_size = np.prod(mesh_shape)
705  if mesh_size != num_cores:
706    raise ValueError(
707        "Invalid mesh size: mesh shape %s cannot 1:1 map to %d TPU cores" %
708        (mesh_shape, num_cores))
709
710  # Build a ring for the `ring_size` dimension and, if required, a strided ring
711  # for the orthogonal dimension.
712  if build_ring_across_rings:
713    global_core_locations = _build_orthogonal_rings(global_core_locations,
714                                                    ring_size,
715                                                    rotate_ring_across_rings)
716  else:
717    permutation = _build_all_reduce_ring(global_core_locations[:ring_size])
718    for r in range(0, num_cores, ring_size):
719      global_core_locations[r:r + ring_size] = [
720          global_core_locations[r + permutation[i]] for i in range(ring_size)
721      ]
722    logging.vlog(1, "Permutated core locations: %s", global_core_locations)
723
724  # For this point on, change from List[CoreLocation] to List[List[int]] for
725  # easier interaction with the C++ API.
726  global_core_locations = [l.to_list() for l in global_core_locations]
727  if _dtensor_device is None:
728    raise ValueError(
729        "Invalid system device, run dtensor.initialize_tpu_system() first")
730  global_core_ids = _dtensor_device.tpu_core_locations_to_ids(
731      global_core_locations)
732
733  # Store a per-mesh mapping in the runtime.
734  _dtensor_device.set_tpu_core_ids(mesh_name, global_core_ids)
735
736  # Create the mesh by manually specifying local_device_ids.
737  local_core_locations = _tpu_topology.device_coordinates[api.client_id()]
738  indexes = [
739      global_core_locations.index(list(local_core_location))
740      for local_core_location in local_core_locations
741  ]
742  global_device_ids, local_device_ids, local_device_list = _create_device_array(
743      mesh_shape, _TPU_DEVICE_TYPE, None, local_device_ids=indexes)
744  return layout_lib.Mesh(mesh_dim_names, global_device_ids, local_device_ids,
745                         local_device_list, mesh_name)
746
747
748def get_device_ids(mesh: layout_lib.Mesh,
749                   client_id: Optional[int] = None) -> List[int]:
750  """Returns the device IDs of all TPU cores local to the given client.
751
752  A device ID is a non-negative integer that uniquely identifies a device in the
753  mesh. For example, for a 2x2 mesh ('x', 'y'), this function returns a
754  permutation of [0, 1, 2, 3].
755
756  Note that device IDs and device locations are equivalent. The former is a
757  linearization of the latter along mesh dimensions.
758
759  Args:
760    mesh: A TPU mesh.
761    client_id: Optional; A DTensor client ID. If empty, query this client.
762  """
763
764  if mesh.device_type() != _TPU_DEVICE_TYPE:
765    raise ValueError("The mesh must be a TPU mesh")
766
767  if client_id is None or client_id == api.client_id():
768    return mesh.local_device_ids()
769
770  # It's not clear we should ever allow a client to query other clients for
771  # their device IDs.
772  raise NotImplementedError(
773      "Looking up other clients' device IDs is not supported")
774
775
776def get_device_locations(
777    mesh: layout_lib.Mesh,
778    client_id: Optional[int] = None) -> List[Dict[str, int]]:
779  """Returns the device locations of all TPU cores local to the given client.
780
781  A device location is a dictionary from dimension names to indices on those
782  dimensions. For example, for a 2x2 mesh ('x', 'y'), this function returns a
783  permutation of this list:
784
785    [{'x': 0, 'y': 0},
786     {'x': 0, 'y': 1},
787     {'x': 1, 'y': 0},
788     {'x': 1, 'y': 1}].
789
790  Note that device IDs and device locations are equivalent. The former is a
791  linearization of the latter along mesh dimensions.
792
793  Args:
794    mesh: A TPU mesh.
795    client_id: Optional; A DTensor client ID. If empty, query this client.
796  """
797
798  if mesh.device_type() != _TPU_DEVICE_TYPE:
799    raise ValueError("The mesh must be a TPU mesh")
800
801  if client_id is None or client_id == api.client_id():
802    return mesh.local_device_locations()
803
804  # It's not clear we should ever allow a client to query other clients for
805  # their device locations.
806  raise NotImplementedError(
807      "Looking up other clients' device locations is not supported")
808
809
810def _configure_tpu_runtime():
811  was_enabled = context.is_tfrt_enabled()
812  if ("tpu_use_tfrt" in flags.FLAGS and flags.FLAGS["tpu_use_tfrt"].value):
813    tfrt_utils.set_tfrt_enabled(True)
814  if not was_enabled:
815    context._reset_context()  # pylint:disable=protected-access
816