• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ======================================
15"""Library of TPU helper functions."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import math
22
23import numpy as np
24from six.moves import xrange  # pylint: disable=redefined-builtin
25
26from tensorflow.python.platform import tf_logging as logging
27from tensorflow.python.tpu.topology import Topology
28from tensorflow.python.util.tf_export import tf_export
29
30
31SINGLE_CORE_ASSIGNMENT = [[[0, 0, 0]]]
32
33
34def _compute_task_and_cores_to_replicas(core_assignment, topology):
35  """Computes a nested dict which maps task and logical core to replicas."""
36  task_and_cores_to_replicas = {}
37  for replica in xrange(core_assignment.shape[0]):
38    for logical_core in xrange(core_assignment.shape[1]):
39      coordinates = core_assignment[replica, logical_core, :]
40      task_id = topology.task_ordinal_at_coordinates(coordinates)
41      if task_id not in task_and_cores_to_replicas:
42        task_and_cores_to_replicas[task_id] = {}
43      if logical_core not in task_and_cores_to_replicas[task_id]:
44        task_and_cores_to_replicas[task_id][logical_core] = set()
45
46      task_and_cores_to_replicas[task_id][logical_core].add(replica)
47
48  task_to_sorted_replica_id = {}
49
50  for task, core_to_replicas in task_and_cores_to_replicas.items():
51    core_to_sorted_replicas = {}
52    for core, replicas in core_to_replicas.items():
53      core_to_sorted_replicas[core] = sorted(replicas)
54
55    task_to_sorted_replica_id[task] = core_to_sorted_replicas
56  return task_to_sorted_replica_id
57
58
59@tf_export("tpu.experimental.DeviceAssignment")
60class DeviceAssignment(object):
61  """Mapping from logical cores in a computation to the physical TPU topology.
62
63  Prefer to use the `DeviceAssignment.build()` helper to construct a
64  `DeviceAssignment`; it is easier if less flexible than constructing a
65  `DeviceAssignment` directly.
66  """
67
68  def __init__(self, topology, core_assignment):
69    """Constructs a `DeviceAssignment` object.
70
71    Args:
72      topology: A `Topology` object that describes the physical TPU topology.
73      core_assignment: A logical to physical core mapping, represented as a
74        rank 3 numpy array. See the description of the `core_assignment`
75        property for more details.
76
77    Raises:
78      ValueError: If `topology` is not `Topology` object.
79      ValueError: If `core_assignment` is not a rank 3 numpy array.
80    """
81    if not isinstance(topology, Topology):
82      raise ValueError("topology must be a Topology object, got {}".format(
83          type(topology)))
84    core_assignment = np.asarray(core_assignment, dtype=np.int32)
85
86    self._topology = topology
87
88    if core_assignment.ndim != 3:
89      raise ValueError("core_assignment must be a rank 3 numpy array, "
90                       "got shape {}".format(core_assignment.shape))
91
92    self._num_replicas = core_assignment.shape[0]
93    self._num_cores_per_replica = core_assignment.shape[1]
94
95    if core_assignment.shape[-1] != topology.mesh_rank:
96      raise ValueError(
97          "minor dimension of core_assignment must have size equal to topology "
98          "rank ({}), got shape {}".format(topology.mesh_rank,
99                                           core_assignment.shape))
100
101    self._core_assignment = core_assignment
102    self._task_and_cores_to_replicas = _compute_task_and_cores_to_replicas(
103        self._core_assignment, topology)
104
105  @property
106  def topology(self):
107    """A `Topology` that describes the TPU topology."""
108    return self._topology
109
110  @property
111  def num_cores_per_replica(self):
112    """The number of cores per replica."""
113    return self._num_cores_per_replica
114
115  @property
116  def num_replicas(self):
117    """The number of replicas of the computation."""
118    return self._num_replicas
119
120  @property
121  def core_assignment(self):
122    """The logical to physical core mapping.
123
124    Returns:
125      An integer numpy array of rank 3, with shape
126      `[num_replicas, num_cores_per_replica, topology_rank]`. Maps
127      (replica, logical core) pairs to physical topology coordinates.
128    """
129    return self._core_assignment
130
131  def coordinates(self, replica, logical_core):
132    """Returns the physical topology coordinates of a logical core."""
133    return tuple(self.core_assignment[replica, logical_core, :])
134
135  def lookup_replicas(self, task_id, logical_core):
136    """Lookup replica ids by task number and logical core.
137
138    Args:
139      task_id: TensorFlow task number.
140      logical_core: An integer, identifying a logical core.
141    Returns:
142      A sorted list of the replicas that are attached to that task and
143      logical_core.
144    Raises:
145      ValueError: If no replica exists in the task which contains the logical
146      core.
147    """
148    try:
149      return self._task_and_cores_to_replicas[task_id][logical_core]
150    except KeyError:
151      raise ValueError(
152          "Can not find any replica in task: {} contains logical_core: {} ".
153          format(task_id, logical_core))
154
155  def tpu_ordinal(self, replica=0, logical_core=0):
156    """Returns the ordinal of the TPU device assigned to a logical core."""
157    coordinates = self.coordinates(replica, logical_core)
158    return self._topology.tpu_device_ordinal_at_coordinates(coordinates)
159
160  def host_device(self, replica=0, logical_core=0, job=None):
161    """Returns the CPU device attached to a logical core."""
162    coordinates = self.coordinates(replica, logical_core)
163    return self._topology.cpu_device_name_at_coordinates(coordinates, job=job)
164
165  def tpu_device(self, replica=0, logical_core=0, job=None):
166    """Returns the name of the TPU device assigned to a logical core."""
167    coordinates = self.coordinates(replica, logical_core)
168    return self._topology.tpu_device_name_at_coordinates(coordinates, job=job)
169
170  @staticmethod
171  def build(topology,
172            computation_shape=None,
173            computation_stride=None,
174            num_replicas=1):
175    return device_assignment(topology, computation_shape, computation_stride,
176                             num_replicas)
177
178
179def _ring_2d(height, width):
180  """Ring-order of a height x width mesh.
181
182  For example, in a 4x4 mesh, this returns the following order.
183    0 -- 1 -- 2 -- 3
184    |    |    |    |
185    15-- 6 -- 5 -- 4
186    |    |    |    |
187    14-- 7 -- 8 -- 9
188    |    |    |    |
189    13-- 12-- 11-- 10
190
191  Args:
192    height: An integer represents the height.
193    width: An integer represents the width.
194
195  Returns:
196    A list of [y, x] pairs with ring order.
197  """
198  if height == 1:
199    return [(0, i) for i in range(width)]
200  if width == 1:
201    return [(i, 0) for i in range(height)]
202  if height % 2 != 0:
203    logging.warning("Odd dimension")
204    return [(i % height, i // height) for i in range(width * height)]
205  ret = [(0, 0)]
206  for i in range(height // 2):
207    for j in range(1, width):
208      ret.append((2 * i, j))
209    for j in range(width - 1, 0, -1):
210      ret.append((2 * i + 1, j))
211  for i in range(height - 1, 0, -1):
212    ret.append((i, 0))
213  return ret
214
215
216def device_assignment(topology,
217                      computation_shape=None,
218                      computation_stride=None,
219                      num_replicas=1):
220  """Computes a device_assignment of a computation across a TPU topology.
221
222  Attempts to choose a compact grid of cores for locality.
223
224  Returns a `DeviceAssignment` that describes the cores in the topology assigned
225  to each core of each replica.
226
227  `computation_shape` and `computation_stride` values should be powers of 2 for
228  optimal packing.
229
230  Args:
231    topology: A `Topology` object that describes the TPU cluster topology.
232      To obtain a TPU topology, evaluate the `Tensor` returned by
233      `initialize_system` using `Session.run`. Either a serialized
234      `TopologyProto` or a `Topology` object may be passed. Note: you must
235      evaluate the `Tensor` first; you cannot pass an unevaluated `Tensor` here.
236    computation_shape: A rank 1 int32 numpy array with size equal to the
237      topology rank, describing the shape of the computation's block of cores.
238      If None, the `computation_shape` is `[1] * topology_rank`.
239    computation_stride: A rank 1 int32 numpy array of size `topology_rank`,
240      describing the inter-core spacing of the `computation_shape` cores in the
241      TPU topology. If None, the `computation_stride` is `[1] * topology_rank`.
242    num_replicas: The number of computation replicas to run. The replicas will
243      be packed into the free spaces of the topology.
244
245  Returns:
246    A DeviceAssignment object, which describes the mapping between the logical
247    cores in each computation replica and the physical cores in the TPU
248    topology.
249
250  Raises:
251    ValueError: If `topology` is not a valid `Topology` object.
252    ValueError: If `computation_shape` or `computation_stride` are not 1D int32
253      numpy arrays with shape [3] where all values are positive.
254    ValueError: If computation's replicas cannot fit into the TPU topology.
255  """
256  # Deserialize the Topology proto, if it is a string.
257  if isinstance(topology, bytes):
258    topology = Topology(serialized=topology)
259
260  if not isinstance(topology, Topology):
261    raise ValueError("`topology` is not a Topology object; got {}".format(
262        type(topology)))
263
264  topology_rank = len(topology.mesh_shape)
265  mesh_shape = topology.mesh_shape
266  if computation_shape is None:
267    computation_shape = np.array([1] * topology_rank, dtype=np.int32)
268  else:
269    computation_shape = np.asarray(computation_shape, dtype=np.int32)
270
271  if computation_stride is None:
272    computation_stride = np.array([1] * topology_rank, dtype=np.int32)
273  else:
274    computation_stride = np.asarray(computation_stride, dtype=np.int32)
275
276  if computation_shape.shape != (topology_rank,):
277    raise ValueError("computation_shape must have shape [{}]; got {}".format(
278        topology_rank, computation_shape.shape))
279  if computation_stride.shape != (topology_rank,):
280    raise ValueError("computation_stride must have shape [{}]; got {}".format(
281        topology_rank, computation_stride.shape))
282
283  if any(computation_shape < 1):
284    raise ValueError(
285        "computation_shape must be positive; got computation_shape={}".format(
286            computation_shape))
287  if any(computation_stride < 1):
288    raise ValueError(
289        "computation_stride must be positive; got computation_stride={}".format(
290            computation_stride))
291
292  # Computes the physical size of one computation instance.
293  computation_footprint = computation_shape * computation_stride
294  if any(computation_footprint > mesh_shape):
295    raise ValueError(
296        "computation footprint {} does not fit in TPU topology shape {}".format(
297            computation_footprint, mesh_shape))
298
299  # Computes how many copies of the computation footprint fit in the mesh.
300  block_counts = mesh_shape // computation_footprint
301
302  replica_counts = block_counts * computation_stride
303  max_replicas = np.prod(replica_counts)
304  if num_replicas > max_replicas:
305    raise ValueError(
306        "requested {} replicas but only {} replicas with shape {} and "
307        "computation_stride {} fit in a TPU mesh of shape {}".format(
308            num_replicas, max_replicas, computation_shape, computation_stride,
309            mesh_shape))
310
311  def ceil_of_ratio(n, m):
312    return (n + m - 1) // m
313
314  if topology.missing_devices.size == 0:
315    replica_shape = [0] * topology_rank
316    if num_replicas > 0:
317      remaining_replicas = num_replicas
318      remaining_dims = topology_rank
319
320      # Choose dimensions as close to an equal cube as possible,
321      # in order of increasing dimension size. By visiting dimensions
322      # in increasing size, we assign the most constrained dimension
323      # first, so we won't make infeasible choices.
324      #
325      # As a secondary sort order, visit the dimensions in reverse
326      # order. This means we try to use both cores on the same chip
327      # in preference to two cores on different chips.
328
329      for x, ni in sorted(((x, -i) for (i, x) in enumerate(replica_counts))):
330        i = -ni
331        target_size = int(math.ceil(remaining_replicas**(1.0 / remaining_dims)))
332        replica_shape[i] = min(target_size, x)
333        remaining_replicas = ceil_of_ratio(remaining_replicas, replica_shape[i])
334        remaining_dims -= 1
335
336      assert remaining_replicas == 1 and remaining_dims == 0
337
338    # Assigns an offset to each replica such that no two replicas overlap.
339    replica_offsets = np.full([num_replicas, topology_rank], -1, dtype=np.int32)
340
341    # TODO(ylc): Revisit here when topology_rank > 3.
342    enable_2d_tiling = (
343        topology_rank == 3 and
344        computation_shape[-1] == 2  # Only handle 2D case.
345        and np.prod(computation_stride) == 1  # Ensure no stride.
346        and num_replicas == max_replicas)  # Full replication.
347    logging.info("enable_2d_tiling: {}".format(enable_2d_tiling))
348    if enable_2d_tiling:
349      assignment = []
350      inner_ring = _ring_2d(computation_shape[0], computation_shape[1])
351      outer_ring = _ring_2d(replica_shape[0], replica_shape[1])
352
353      for replica in xrange(num_replicas):
354        outer_x, outer_y = outer_ring[replica]
355        per_replica_assignment = []
356        for index in xrange(np.prod(computation_shape)):
357          inner_x, inner_y = inner_ring[index // 2]
358          px = outer_x * computation_shape[0] + inner_x
359          py = outer_y * computation_shape[1] + inner_y
360          pz = index % 2
361          per_replica_assignment.append([px, py, pz])
362        assignment.append(per_replica_assignment)
363    else:
364      for replica in xrange(num_replicas):
365        # Chooses a replica number in each axis.
366        t = replica
367        pos = []
368        for dim in replica_shape[::-1]:
369          pos.append(t % dim)
370          t //= dim
371        replica_pos = np.array(pos[::-1], dtype=np.int32)
372
373        # Determines where that replica starts in each axis.
374        outer = replica_pos // computation_stride
375        inner = replica_pos % computation_stride
376        replica_offsets[replica, :] = outer * computation_footprint + inner
377
378      # Computes a logical core -> physical core mapping for each replica.
379      indices = [
380          np.arange(0, computation_shape[i] * computation_stride[i],
381                    computation_stride[i]) for i in xrange(topology_rank)
382      ]
383      indices = np.concatenate(
384          [i[..., np.newaxis] for i in np.meshgrid(*indices, indexing="ij")],
385          axis=-1)
386      indices = indices.reshape((-1, topology_rank))
387      assignment = indices + replica_offsets[:, np.newaxis, :]
388  else:
389    # We have a slice with missing chips. We define a simple assignment by
390    # ignoring computation stride. This assignment should enable a consistent
391    # and correct device assignment on degraded slices. It is optimal when
392    # weights are not sharded. But this device assignment may be sub-optimal for
393    # other model parallelism scenarios.
394    assert np.prod(computation_stride) == 1
395    # Next, we check if we have sufficient devices.
396    assert num_replicas * np.prod(
397        computation_shape) <= topology.num_tasks * topology.num_tpus_per_task
398    # Map replicas to physical devices in task order.
399    device_coordinates = topology.device_coordinates
400    assignment = []
401    devices_per_replica = np.prod(computation_shape)
402    for rindex in xrange(num_replicas):
403      replica_assignment = []
404      for index in xrange(devices_per_replica):
405        logical_id = rindex * devices_per_replica + index
406        # Pick logical cores in task order
407        task = logical_id // topology.num_tpus_per_task
408        device = logical_id % topology.num_tpus_per_task
409        # Append physical cores to the replica assignment
410        replica_assignment.append(device_coordinates[task, device, :])
411      assignment.append(replica_assignment)
412
413  return DeviceAssignment(topology, core_assignment=assignment)
414