• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ======================================
15"""Defines the `Topology` class, that describes a TPU fabric topology."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import numpy as np
22from six.moves import xrange  # pylint: disable=redefined-builtin
23
24from tensorflow.core.protobuf.tpu import topology_pb2
25
26
27def _tpu_device_name(job, task, device):
28  """Returns the device name for the TPU `device` on `task` of `job`."""
29  if job is None:
30    return "/task:%d/device:TPU:%d" % (task, device)
31  else:
32    return "/job:%s/task:%d/device:TPU:%d" % (job, task, device)
33
34
35def _tpu_host_device_name(job, task):
36  """Returns the device name for the CPU device on `task` of `job`."""
37  if job is None:
38    return "/task:%d/device:CPU:0" % task
39  else:
40    return "/job:%s/task:%d/device:CPU:0" % (job, task)
41
42
43class Topology(object):
44  """Describes a set of TPU devices.
45
46  Represents both the shape of the physical mesh, and the mapping between
47  TensorFlow TPU devices to physical mesh coordinates.
48  """
49
50  def __init__(self, serialized=None, mesh_shape=None, device_coordinates=None):
51    """Builds a Topology object.
52
53    If `serialized` is not `None`, the topology is parsed from `serialized` and
54    the other arguments are ignored. Otherwise, the topology is computed from
55    `mesh_shape` and `device_coordinates`.
56
57    Args:
58      serialized: A serialized `TopologyProto`, or `None`. If not `None`, the
59        serialized proto is parsed to discover the topology.
60      mesh_shape: A sequence of 3 positive integers, or `None`. If not `None`,
61        the shape of the TPU topology, in number of cores. Ignored if
62        `serialized` is not `None`.
63      device_coordinates: A rank 3 numpy array that describes the mapping from
64        TensorFlow TPU devices to TPU fabric coordinates, or `None`. Ignored
65        if `serialized is not `None`.
66
67    Raises:
68      ValueError: If `serialized` does not describe a well-formed topology.
69      ValueError: If `serialized` is `None` and `mesh_shape` is not a sequence
70        of 3 positive integers.
71      ValueError: If `serialized` is `None` and `device_coordinates` is not a
72        rank 3 numpy int32 array that describes a valid coordinate mapping.
73    """
74
75    self._serialized = serialized
76
77    if serialized:
78      self._parse_topology(serialized)
79    else:
80      self._mesh_shape = np.asarray(mesh_shape, dtype=np.int32)
81      self._device_coordinates = np.asarray(device_coordinates, np.int32)
82      if len(self._mesh_shape) != 3 or any(self._mesh_shape < 1):
83        raise ValueError("`mesh_shape` must be a sequence of 3 positive "
84                         "entries; got {}".format(self._mesh_shape))
85
86      if (len(self._device_coordinates.shape) != 3 or
87          self._device_coordinates.shape[2] != len(self._mesh_shape)):
88        raise ValueError("`device_coordinates` must be a rank 3 int32 array "
89                         "with minor dimension equal to the mesh shape rank")
90
91    self._topology_tasks, self._topology_devices = self._invert_topology()
92
93    # Coordinates of devices that are missing
94    self._missing_devices = np.argwhere(self._topology_tasks < 0)
95
96  def _parse_topology(self, serialized):
97    """Parses a serialized `TopologyProto` into `self`."""
98    proto = topology_pb2.TopologyProto()
99    proto.ParseFromString(serialized)
100
101    self._mesh_shape = np.array(proto.mesh_shape, dtype=np.int32)
102    if len(self._mesh_shape) != 3 or any(self._mesh_shape < 1):
103      raise ValueError("`mesh_shape` must be a vector of size 3 with positive "
104                       "entries; got {}".format(self._mesh_shape))
105
106    if proto.num_tasks < 0:
107      raise ValueError("`num_tasks` must be >= 0; got {}".format(
108          proto.num_tasks))
109    if proto.num_tpu_devices_per_task < 0:
110      raise ValueError("`num_tpu_devices_per_task` must be >= 0; got {}".format(
111          proto.num_tpu_devices_per_task))
112
113    expected_coordinates_size = (
114        proto.num_tasks * proto.num_tpu_devices_per_task * len(
115            proto.mesh_shape))
116    if len(proto.device_coordinates) != expected_coordinates_size:
117      raise ValueError("`device_coordinates` must have shape num_tasks ({}) * "
118                       "num_tpu_devices_per_task ({}) * len(mesh_shape) ({}); "
119                       "got shape {}".format(proto.num_tasks,
120                                             proto.num_tpu_devices_per_task,
121                                             proto.mesh_shape,
122                                             len(proto.device_coordinates)))
123
124    coords = np.array(proto.device_coordinates, dtype=np.int32)
125    if any(coords < 0):
126      raise ValueError("`device_coordinates` must be >= 0")
127    coords = coords.reshape((proto.num_tasks, proto.num_tpu_devices_per_task,
128                             len(proto.mesh_shape)))
129    self._device_coordinates = coords
130
131  def _invert_topology(self):
132    """Inverts a [task,device,axis] topology to [x,y,z] -> task/device maps."""
133    tasks = np.full(list(self.mesh_shape), -1, dtype=np.int32)
134    devices = np.full(list(self.mesh_shape), -1, dtype=np.int32)
135    for task in xrange(self.device_coordinates.shape[0]):
136      for device in xrange(self.device_coordinates.shape[1]):
137        x, y, z = self.device_coordinates[task, device, :]
138        tasks[x, y, z] = task
139        devices[x, y, z] = device
140    return tasks, devices
141
142  @property
143  def mesh_shape(self):
144    """A rank 1 int32 array describing the shape of the TPU topology."""
145    return self._mesh_shape
146
147  @property
148  def mesh_rank(self):
149    """Returns the number of dimensions in the mesh."""
150    return len(self._mesh_shape)
151
152  @property
153  def device_coordinates(self):
154    """Describes the mapping from TPU devices to topology coordinates.
155
156    Returns:
157      A rank 3 int32 array with shape `[tasks, devices, axis]`.
158      `tasks` is the number of tasks in the TPU cluster, `devices` is the number
159      of TPU devices per task, and `axis` is the number of axes in the TPU
160      cluster topology. Each entry gives the `axis`-th coordinate in the
161      topology of a task/device pair. TPU topologies are 3-dimensional, with
162      dimensions `(x, y, core number)`.
163    """
164    return self._device_coordinates
165
166  @property
167  def missing_devices(self):
168    """Array of indices of missing devices."""
169    return self._missing_devices
170
171  def task_ordinal_at_coordinates(self, device_coordinates):
172    """Returns the TensorFlow task number attached to `device_coordinates`.
173
174    Args:
175      device_coordinates: An integer sequence describing a device's physical
176        coordinates in the TPU fabric.
177
178    Returns:
179      Returns the TensorFlow task number that contains the TPU device with those
180      physical coordinates.
181    """
182    return self._topology_tasks[tuple(device_coordinates)]
183
184  def tpu_device_ordinal_at_coordinates(self, device_coordinates):
185    """Returns the TensorFlow device number at `device_coordinates`.
186
187    Args:
188      device_coordinates: An integer sequence describing a device's physical
189        coordinates in the TPU fabric.
190
191    Returns:
192      Returns the TensorFlow device number within the task corresponding to
193      attached to the device with those physical coordinates.
194    """
195    return self._topology_devices[tuple(device_coordinates)]
196
197  def cpu_device_name_at_coordinates(self, device_coordinates, job=None):
198    """Returns the CPU device attached to a logical core."""
199    return _tpu_host_device_name(
200        job, self._topology_tasks[tuple(device_coordinates)])
201
202  def tpu_device_name_at_coordinates(self, device_coordinates, job=None):
203    """Returns the name of the TPU device assigned to a logical core."""
204    return _tpu_device_name(job,
205                            self._topology_tasks[tuple(device_coordinates)],
206                            self._topology_devices[tuple(device_coordinates)])
207
208  @property
209  def num_tasks(self):
210    """Returns the number of TensorFlow tasks in the TPU slice."""
211    return self._device_coordinates.shape[0]
212
213  @property
214  def num_tpus_per_task(self):
215    """Returns the number of TPU devices per task in the TPU slice."""
216    return self._device_coordinates.shape[1]
217
218  def serialized(self):
219    """Returns the serialized form of the topology."""
220    if self._serialized is None:
221      proto = topology_pb2.TopologyProto()
222      proto.mesh_shape[:] = list(self._mesh_shape)
223      proto.num_tasks = self._device_coordinates.shape[0]
224      proto.num_tpu_devices_per_task = self._device_coordinates.shape[1]
225      proto.device_coordinates.extend(list(self._device_coordinates.flatten()))
226      self._serialized = proto.SerializeToString()
227
228    return self._serialized
229