• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ======================================
15"""Library of TPU helper functions."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import enum
22import math
23from typing import List, Optional, Text, Tuple
24
25import numpy as np
26from six.moves import xrange  # pylint: disable=redefined-builtin
27
28from tensorflow.python.platform import tf_logging as logging
29from tensorflow.python.tpu.topology import Topology
30from tensorflow.python.util.tf_export import tf_export
31
32
33SINGLE_CORE_ASSIGNMENT = [[[0, 0, 0, 0]]]
34
35
36def _compute_task_and_cores_to_replicas(core_assignment, topology):
37  """Computes a nested dict which maps task and logical core to replicas."""
38  task_and_cores_to_replicas = {}
39  for replica in xrange(core_assignment.shape[0]):
40    for logical_core in xrange(core_assignment.shape[1]):
41      coordinates = core_assignment[replica, logical_core, :]
42      task_id = topology.task_ordinal_at_coordinates(coordinates)
43      if task_id not in task_and_cores_to_replicas:
44        task_and_cores_to_replicas[task_id] = {}
45      if logical_core not in task_and_cores_to_replicas[task_id]:
46        task_and_cores_to_replicas[task_id][logical_core] = set()
47
48      task_and_cores_to_replicas[task_id][logical_core].add(replica)
49
50  task_to_sorted_replica_id = {}
51
52  for task, core_to_replicas in task_and_cores_to_replicas.items():
53    core_to_sorted_replicas = {}
54    for core, replicas in core_to_replicas.items():
55      core_to_sorted_replicas[core] = sorted(replicas)
56
57    task_to_sorted_replica_id[task] = core_to_sorted_replicas
58  return task_to_sorted_replica_id
59
60
61@tf_export("tpu.experimental.DeviceAssignment")
62class DeviceAssignment(object):
63  """Mapping from logical cores in a computation to the physical TPU topology.
64
65  Prefer to use the `DeviceAssignment.build()` helper to construct a
66  `DeviceAssignment`; it is easier if less flexible than constructing a
67  `DeviceAssignment` directly.
68  """
69
70  def __init__(self, topology: Topology, core_assignment: np.ndarray):
71    """Constructs a `DeviceAssignment` object.
72
73    Args:
74      topology: A `Topology` object that describes the physical TPU topology.
75      core_assignment: A logical to physical core mapping, represented as a
76        rank 3 numpy array. See the description of the `core_assignment`
77        property for more details.
78
79    Raises:
80      ValueError: If `topology` is not `Topology` object.
81      ValueError: If `core_assignment` is not a rank 3 numpy array.
82    """
83    if not isinstance(topology, Topology):
84      raise ValueError("topology must be a Topology object, got {}".format(
85          type(topology)))
86    core_assignment = np.asarray(core_assignment, dtype=np.int32)
87
88    self._topology = topology
89
90    if core_assignment.ndim != 3:
91      raise ValueError("core_assignment must be a rank 3 numpy array, "
92                       "got shape {}".format(core_assignment.shape))
93
94    self._num_replicas = core_assignment.shape[0]
95    self._num_cores_per_replica = core_assignment.shape[1]
96
97    if core_assignment.shape[-1] != topology.mesh_rank:
98      raise ValueError(
99          "minor dimension of core_assignment must have size equal to topology "
100          "rank ({}), got shape {}".format(topology.mesh_rank,
101                                           core_assignment.shape))
102
103    self._core_assignment = core_assignment
104    self._task_and_cores_to_replicas = _compute_task_and_cores_to_replicas(
105        self._core_assignment, topology)
106
107  @property
108  def topology(self) -> Topology:
109    """A `Topology` that describes the TPU topology."""
110    return self._topology
111
112  @property
113  def num_cores_per_replica(self) -> int:
114    """The number of cores per replica."""
115    return self._num_cores_per_replica
116
117  @property
118  def num_replicas(self) -> int:
119    """The number of replicas of the computation."""
120    return self._num_replicas
121
122  @property
123  def core_assignment(self) -> np.ndarray:
124    """The logical to physical core mapping.
125
126    Returns:
127      An integer numpy array of rank 3, with shape
128      `[num_replicas, num_cores_per_replica, topology_rank]`. Maps
129      (replica, logical core) pairs to physical topology coordinates.
130    """
131    return self._core_assignment
132
133  def coordinates(self, replica: int, logical_core: int) -> Tuple:  # pylint:disable=g-bare-generic
134    """Returns the physical topology coordinates of a logical core."""
135    return tuple(self.core_assignment[replica, logical_core, :])
136
137  def lookup_replicas(self, task_id: int, logical_core: int) -> List[int]:
138    """Lookup replica ids by task number and logical core.
139
140    Args:
141      task_id: TensorFlow task number.
142      logical_core: An integer, identifying a logical core.
143    Returns:
144      A sorted list of the replicas that are attached to that task and
145      logical_core.
146    Raises:
147      ValueError: If no replica exists in the task which contains the logical
148      core.
149    """
150    try:
151      return self._task_and_cores_to_replicas[task_id][logical_core]
152    except KeyError:
153      raise ValueError(
154          "Can not find any replica in task: {} contains logical_core: {} ".
155          format(task_id, logical_core))
156
157  def tpu_ordinal(self, replica: int = 0, logical_core: int = 0) -> int:
158    """Returns the ordinal of the TPU device assigned to a logical core."""
159    coordinates = self.coordinates(replica, logical_core)
160    return self._topology.tpu_device_ordinal_at_coordinates(coordinates)
161
162  def host_device(self,
163                  replica: int = 0,
164                  logical_core: int = 0,
165                  job: Optional[Text] = None) -> Text:
166    """Returns the CPU device attached to a logical core."""
167    coordinates = self.coordinates(replica, logical_core)
168    return self._topology.cpu_device_name_at_coordinates(coordinates, job=job)
169
170  def tpu_device(self,
171                 replica: int = 0,
172                 logical_core: int = 0,
173                 job: Optional[Text] = None) -> Text:
174    """Returns the name of the TPU device assigned to a logical core."""
175    coordinates = self.coordinates(replica, logical_core)
176    return self._topology.tpu_device_name_at_coordinates(coordinates, job=job)
177
178  @staticmethod
179  def build(topology: Topology,
180            computation_shape: Optional[np.ndarray] = None,
181            computation_stride: Optional[np.ndarray] = None,
182            num_replicas: int = 1) -> "DeviceAssignment":
183    return device_assignment(topology, computation_shape, computation_stride,
184                             num_replicas)
185
186
187def _open_ring_2d(x_size: int, y_size: int,
188                  z_coord: int) -> List[Tuple[int, int, int]]:
189  """Ring-order of a X by Y mesh, with a fixed Z coordinate.
190
191  For example, in a 4x4 mesh, this returns the following order.
192    0 -- 1 -- 2 -- 3
193    |    |    |    |
194    15-- 6 -- 5 -- 4
195    |    |    |    |
196    14-- 7 -- 8 -- 9
197    |    |    |    |
198    13-- 12-- 11-- 10
199
200  Note that chip 0 is not included in the output.
201
202  Args:
203    x_size: An integer represents the mesh size in the x-dimension. Must be
204      larger than 1.
205    y_size: An integer represents the mesh size in the y-dimension. Must be
206      larger than 1.
207    z_coord: An integer represents the z-coordinate to use for the chips in the
208      ring.
209
210  Returns:
211    A list of (x,y,z) triples in ring order.
212  """
213  ret = []
214  for i in range(y_size // 2):
215    for j in range(1, x_size):
216      ret.append((j, 2 * i, z_coord))
217    for j in range(x_size - 1, 0, -1):
218      ret.append((j, 2 * i + 1, z_coord))
219  for i in range(y_size - 1, 0, -1):
220    ret.append((0, i, z_coord))
221  return ret
222
223
224def _ring_3d(x_size: int, y_size: int,
225             z_size: int) -> List[Tuple[int, int, int]]:
226  """Ring-order of a X by Y by Z mesh.
227
228  Constructs the 3d ring from 2d rings that are stacked in the Z dimension and
229  joined in one corner.
230
231  z == 0:
232    0 -- 1 -- 2 -- 3
233    |    |    |    |
234    15 - 6 -- 5 -- 4
235    |    |    |    |
236    14 - 7 -- 8 -- 9
237    |    |    |    |
238    13 - 12 - 11 - 10
239  z == 1:
240    63 - 30 - 29 - 28
241    |    |    |    |
242    16 - 25 - 26 - 27
243    |    |    |    |
244    17 - 24 - 23 - 22
245    |    |    |    |
246    18 - 19 - 20 - 21
247  z == 2:
248    62 - 31 - 32 - 33
249    |    |    |    |
250    45 - 36 - 35 - 34
251    |    |    |    |
252    44 - 37 - 38 - 39
253    |    |    |    |
254    43 - 42 - 41 - 40
255  z == 3:
256    61 - 60 - 59 - 58
257    |    |    |    |
258    46 - 55 - 56 - 57
259    |    |    |    |
260    47 - 54 - 53 - 52
261    |    |    |    |
262    48 - 49 - 50 - 51
263
264  Args:
265    x_size: An integer represents the mesh size in the x-dimension. Must be
266      larger than 1.
267    y_size: An integer represents the mesh size in the y-dimension. Must be
268      larger than 1.
269    z_size: An integer represents the mesh size in the z-dimension. Must be
270      larger than 1.  For example, in a 4x4x4 mesh, this returns the following
271      order.
272
273  Returns:
274    A list of (x,y,z) triples in ring order.
275  """
276
277  # Handle the case where 2 dimensions are size 1.
278  if x_size == 1 and y_size == 1:
279    return [(0, 0, i) for i in range(z_size)]
280  if x_size == 1 and z_size == 1:
281    return [(0, i, 0) for i in range(y_size)]
282  if y_size == 1 and z_size == 1:
283    return [(i, 0, 0) for i in range(x_size)]
284
285  # Handle odd mesh dimensions.  This never happens in practice, so we don't
286  # bother to try building something optimal.
287  if (x_size > 1 and x_size % 2 != 0) or (y_size > 1 and
288                                          y_size % 2 != 0) or (z_size > 1 and
289                                                               z_size % 2 != 0):
290    logging.warning("Odd dimension")
291    ret = []
292    for z in range(z_size):
293      for y in range(y_size):
294        ret.extend((x, y, z) for x in range(x_size))
295    return ret
296
297  # Always start with chip 0.
298  ret = [(0, 0, 0)]
299  # Handle the case where one dimension is size 1.  We just build a flat, 2d
300  # ring.
301  if z_size == 1:
302    ret.extend(_open_ring_2d(x_size, y_size, 0))
303    return ret
304  if y_size == 1:
305    ret = [(0, 0, 0)]
306    ret.extend((x, y, z) for (x, z, y) in _open_ring_2d(x_size, z_size, 0))
307    return ret
308  if x_size == 1:
309    ret = [(0, 0, 0)]
310    ret.extend((x, y, z) for (y, z, x) in _open_ring_2d(y_size, z_size, 0))
311    return ret
312
313  # Handle the case where all dimensions have size > 1 and even.
314  ret = [(0, 0, 0)]
315  for i in range(0, z_size):
316    r = _open_ring_2d(x_size, y_size, i)
317    if i % 2 == 0:
318      ret.extend(r)
319    else:
320      ret.extend(reversed(r))
321  for i in range(z_size - 1, 0, -1):
322    ret.append((0, 0, i))
323  return ret
324
325
326class DeviceOrderMode(enum.IntEnum):
327  """The way of determining device orders when computing device assignment."""
328  # By default the mode is set to AUTO, the library will choose to form rings
329  # when that is possible.
330  AUTO = 0
331  # Form rings for replicas and model-parallel cores.
332  RING = 1
333  # Form meshes for replicas and/or model-parallel cores.
334  MESH = 2
335
336
337def device_assignment(
338    topology: Topology,
339    computation_shape: Optional[np.ndarray] = None,
340    computation_stride: Optional[np.ndarray] = None,
341    num_replicas: int = 1,
342    device_order_mode: DeviceOrderMode = DeviceOrderMode.AUTO
343) -> DeviceAssignment:
344  """Computes a device_assignment of a computation across a TPU topology.
345
346  Attempts to choose a compact grid of cores for locality.
347
348  Returns a `DeviceAssignment` that describes the cores in the topology assigned
349  to each core of each replica.
350
351  `computation_shape` and `computation_stride` values should be powers of 2 for
352  optimal packing.
353
354  Args:
355    topology: A `Topology` object that describes the TPU cluster topology. To
356      obtain a TPU topology, evaluate the `Tensor` returned by
357      `initialize_system` using `Session.run`. Either a serialized
358      `TopologyProto` or a `Topology` object may be passed. Note: you must
359        evaluate the `Tensor` first; you cannot pass an unevaluated `Tensor`
360        here.
361    computation_shape: A rank 1 int32 numpy array with size equal to the
362      topology rank, describing the shape of the computation's block of cores.
363      If None, the `computation_shape` is `[1] * topology_rank`.
364    computation_stride: A rank 1 int32 numpy array of size `topology_rank`,
365      describing the inter-core spacing of the `computation_shape` cores in the
366      TPU topology. If None, the `computation_stride` is `[1] * topology_rank`.
367    num_replicas: The number of computation replicas to run. The replicas will
368      be packed into the free spaces of the topology.
369    device_order_mode: An enum of `DeviceOrderMode` class which indicates
370      whether to assign devices to form rings or meshes, or let the library to
371      choose.
372
373  Returns:
374    A DeviceAssignment object, which describes the mapping between the logical
375    cores in each computation replica and the physical cores in the TPU
376    topology.
377
378  Raises:
379    ValueError: If `topology` is not a valid `Topology` object.
380    ValueError: If `computation_shape` or `computation_stride` are not 1D int32
381      numpy arrays with shape [3] where all values are positive.
382    ValueError: If computation's replicas cannot fit into the TPU topology.
383  """
384  # Deserialize the Topology proto, if it is a string.
385  if isinstance(topology, bytes):
386    topology = Topology(serialized=topology)
387
388  if not isinstance(topology, Topology):
389    raise ValueError("`topology` is not a Topology object; got {}".format(
390        type(topology)))
391
392  topology_rank = len(topology.mesh_shape)
393  mesh_shape = topology.mesh_shape
394  if computation_shape is None:
395    computation_shape = np.array([1] * topology_rank, dtype=np.int32)
396  else:
397    computation_shape = np.asarray(computation_shape, dtype=np.int32)
398
399  if computation_stride is None:
400    computation_stride = np.array([1] * topology_rank, dtype=np.int32)
401  else:
402    computation_stride = np.asarray(computation_stride, dtype=np.int32)
403
404  if computation_shape.shape != (topology_rank,):
405    raise ValueError("computation_shape must have shape [{}]; got {}".format(
406        topology_rank, computation_shape.shape))
407  if computation_stride.shape != (topology_rank,):
408    raise ValueError("computation_stride must have shape [{}]; got {}".format(
409        topology_rank, computation_stride.shape))
410
411  if any(computation_shape < 1):
412    raise ValueError(
413        "computation_shape must be positive; got computation_shape={}".format(
414            computation_shape))
415  if any(computation_stride < 1):
416    raise ValueError(
417        "computation_stride must be positive; got computation_stride={}".format(
418            computation_stride))
419
420  # Computes the physical size of one computation instance.
421  computation_footprint = computation_shape * computation_stride
422  if any(computation_footprint > mesh_shape):
423    raise ValueError(
424        "computation footprint {} does not fit in TPU topology shape {}".format(
425            computation_footprint, mesh_shape))
426
427  # Computes how many copies of the computation footprint fit in the mesh.
428  block_counts = mesh_shape // computation_footprint
429
430  replica_counts = block_counts * computation_stride
431  max_replicas = np.prod(replica_counts)
432  if num_replicas > max_replicas:
433    raise ValueError(
434        "requested {} replicas but only {} replicas with shape {} and "
435        "computation_stride {} fit in a TPU mesh of shape {}".format(
436            num_replicas, max_replicas, computation_shape, computation_stride,
437            mesh_shape))
438
439  def ceil_of_ratio(n, m):
440    return (n + m - 1) // m
441
442  if topology.missing_devices.size == 0:
443    replica_shape = [0] * topology_rank
444    if num_replicas > 0:
445      remaining_replicas = num_replicas
446      remaining_dims = topology_rank
447
448      # Choose dimensions as close to an equal cube as possible,
449      # in order of increasing dimension size. By visiting dimensions
450      # in increasing size, we assign the most constrained dimension
451      # first, so we won't make infeasible choices.
452      #
453      # As a secondary sort order, visit the last dimension (core index) first,
454      # then the other dimensions in increasing order. This means we try to use
455      # both cores on the same chip in preference to two cores on different
456      # chips.  We visit the x dimension first, and the z dimension last, so
457      # that we prefer to arrange adjacent replicas on the same machine when
458      # possible.
459      #
460      # For example, if num_replicas == 4, we prefer to use a replica_shape of
461      # (2,1,1,2) over (1,1,2,2).
462
463      for x, ni in sorted(((x, ((i + 1) % topology_rank))
464                           for (i, x) in enumerate(replica_counts))):
465        i = (ni + topology_rank - 1) % topology_rank
466        target_size = int(math.ceil(remaining_replicas**(1.0 / remaining_dims)))
467        replica_shape[i] = min(target_size, x)
468        remaining_replicas = ceil_of_ratio(remaining_replicas, replica_shape[i])
469        remaining_dims -= 1
470
471      assert remaining_replicas == 1 and remaining_dims == 0
472
473    # Assigns an offset to each replica such that no two replicas overlap.
474    replica_offsets = np.full([num_replicas, topology_rank], -1, dtype=np.int32)
475
476    enable_3d_tiling = (
477        topology_rank == 4 and
478        computation_shape[-1] == 2  # Only handle 3D case.
479        and np.prod(computation_stride) == 1  # Ensure no stride.
480        and num_replicas == max_replicas)  # Full replication.
481
482    if device_order_mode != DeviceOrderMode.AUTO:
483      if device_order_mode == DeviceOrderMode.RING and not enable_3d_tiling:
484        raise ValueError("cannot assign ring order in the given topology")
485      enable_3d_tiling = device_order_mode == DeviceOrderMode.RING
486
487    if enable_3d_tiling:
488      assignment = []
489      inner_ring = _ring_3d(computation_shape[0], computation_shape[1],
490                            computation_shape[2])
491      outer_ring = _ring_3d(replica_shape[0], replica_shape[1],
492                            replica_shape[2])
493
494      for replica in xrange(num_replicas):
495        outer_x, outer_y, outer_z = outer_ring[replica]
496        per_replica_assignment = []
497        for index in xrange(np.prod(computation_shape)):
498          inner_x, inner_y, inner_z = inner_ring[index // 2]
499          px = outer_x * computation_shape[0] + inner_x
500          py = outer_y * computation_shape[1] + inner_y
501          pz = outer_z * computation_shape[2] + inner_z
502          pi = index % 2
503          per_replica_assignment.append([px, py, pz, pi])
504        assignment.append(per_replica_assignment)
505    else:
506      for replica in xrange(num_replicas):
507        # Chooses a replica number in each axis.
508        t = replica
509        pos = []
510        # Visit the core number first.
511        for dim in np.concatenate([[replica_shape[-1]], replica_shape[:-1]]):
512          pos.append(t % dim)
513          t //= dim
514        replica_pos = np.concatenate([pos[1:], [pos[0]]])
515
516        # Determines where that replica starts in each axis.
517        outer = replica_pos // computation_stride
518        inner = replica_pos % computation_stride
519        replica_offsets[replica, :] = outer * computation_footprint + inner
520
521      # Computes a logical core -> physical core mapping for each replica.
522      indices = [
523          np.arange(0, computation_shape[i] * computation_stride[i],
524                    computation_stride[i]) for i in range(topology_rank)
525      ]
526      indices = np.concatenate(
527          [i[..., np.newaxis] for i in np.meshgrid(*indices, indexing="ij")],
528          axis=-1)
529      indices = indices.reshape((-1, topology_rank))
530      assignment = indices + replica_offsets[:, np.newaxis, :]
531  else:
532    # We have a slice with missing chips. We define a simple assignment by
533    # ignoring computation stride. This assignment should enable a consistent
534    # and correct device assignment on degraded slices. It is optimal when
535    # weights are not sharded. But this device assignment may be sub-optimal for
536    # other model parallelism scenarios.
537    assert np.prod(computation_stride) == 1
538    # Next, we check if we have sufficient devices.
539    assert num_replicas * np.prod(
540        computation_shape) <= topology.num_tasks * topology.num_tpus_per_task
541    # Map replicas to physical devices in task order.
542    device_coordinates = topology.device_coordinates
543    assignment = []
544    devices_per_replica = np.prod(computation_shape)
545    for rindex in xrange(num_replicas):
546      replica_assignment = []
547      for index in xrange(devices_per_replica):
548        logical_id = rindex * devices_per_replica + index
549        # Pick logical cores in task order
550        task = logical_id // topology.num_tpus_per_task
551        device = logical_id % topology.num_tpus_per_task
552        # Append physical cores to the replica assignment
553        replica_assignment.append(device_coordinates[task, device, :])
554      assignment.append(replica_assignment)
555
556  return DeviceAssignment(topology, core_assignment=assignment)
557