• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ======================================
15"""Defines the `Topology` class, that describes a TPU fabric topology."""
16
17import numpy as np
18
19from tensorflow.core.protobuf.tpu import topology_pb2
20from tensorflow.python.util.tf_export import tf_export
21
22
23def _tpu_device_name(job, task, device):
24  """Returns the device name for the TPU `device` on `task` of `job`."""
25  if job is None:
26    return "/task:%d/device:TPU:%d" % (task, device)
27  else:
28    return "/job:%s/task:%d/device:TPU:%d" % (job, task, device)
29
30
31def _tpu_host_device_name(job, task):
32  """Returns the device name for the CPU device on `task` of `job`."""
33  if job is None:
34    return "/task:%d/device:CPU:0" % task
35  else:
36    return "/job:%s/task:%d/device:CPU:0" % (job, task)
37
38
39@tf_export("tpu.experimental.Topology")
40class Topology(object):
41  """Describes a set of TPU devices.
42
43  Represents both the shape of the physical mesh, and the mapping between
44  TensorFlow TPU devices to physical mesh coordinates.
45  """
46
47  def __init__(self, serialized=None, mesh_shape=None, device_coordinates=None):
48    """Builds a Topology object.
49
50    If `serialized` is not `None`, the topology is parsed from `serialized` and
51    the other arguments are ignored. Otherwise, the topology is computed from
52    `mesh_shape` and `device_coordinates`.
53
54    Args:
55      serialized: A serialized `TopologyProto`, or `None`. If not `None`, the
56        serialized proto is parsed to discover the topology.
57      mesh_shape: A sequence of 4 positive integers, or `None`. If not `None`,
58        the shape of the TPU topology, in number of cores. Ignored if
59        `serialized` is not `None`.
60      device_coordinates: A rank 3 numpy array that describes the mapping from
61        TensorFlow TPU devices to TPU fabric coordinates, or `None`. If
62        specified, array is a rank 3 int32 array with shape
63        `[tasks, devices, axis]`.  `tasks` is the number of tasks in the TPU
64        cluster, `devices` is the number of TPU devices per task, and `axis` is
65        the number of axes in the TPU cluster topology. Each entry gives the
66        `axis`-th coordinate in the topology of a task/device pair. TPU
67        topologies are 4-dimensional, with dimensions `(x, y, z, core number)`.
68        This arg is ignored if `serialized is not `None`.
69
70    Raises:
71      ValueError: If `serialized` does not describe a well-formed topology.
72      ValueError: If `serialized` is `None` and `mesh_shape` is not a sequence
73        of 4 positive integers.
74      ValueError: If `serialized` is `None` and `device_coordinates` is not a
75        rank 3 numpy int32 array that describes a valid coordinate mapping.
76    """
77
78    self._serialized = serialized
79
80    if serialized:
81      self._parse_topology(serialized)
82    else:
83      self._mesh_shape = np.asarray(mesh_shape, dtype=np.int32)
84      self._device_coordinates = np.asarray(device_coordinates, np.int32)
85      if len(self._mesh_shape) != 4 or any(self._mesh_shape < 1):
86        raise ValueError("`mesh_shape` must be a sequence of 4 positive "
87                         f"entries; got `mesh_shape={self._mesh_shape}`")
88
89      if (len(self._device_coordinates.shape) != 3 or
90          self._device_coordinates.shape[2] != len(self._mesh_shape)):
91        raise ValueError(
92            "`device_coordinates` must be a rank 3 int32 array "
93            "with minor dimension equal to the `mesh_shape` rank"
94            "got device_coordinates.shape={} len(device_coordinates.shape)={} device_coordinates.shape[2]={} mesh_shape={}, len(mesh_shape)={}"
95            .format(self._device_coordinates.shape,
96                    len(self._device_coordinates.shape),
97                    self._device_coordinates.shape[2], self._mesh_shape,
98                    len(self._mesh_shape)))
99
100    self._topology_tasks, self._topology_devices = self._invert_topology()
101
102    # Coordinates of devices that are missing
103    self._missing_devices = np.argwhere(self._topology_tasks < 0)
104
105  def _parse_topology(self, serialized):
106    """Parses a serialized `TopologyProto` into `self`."""
107    proto = topology_pb2.TopologyProto()
108    proto.ParseFromString(serialized)
109
110    self._mesh_shape = np.array(proto.mesh_shape, dtype=np.int32)
111    if len(self._mesh_shape) != 4 or any(self._mesh_shape < 1):
112      raise ValueError("`mesh_shape` must be a vector of size 4 with positive "
113                       "entries; got {}".format(self._mesh_shape))
114
115    if proto.num_tasks < 0:
116      raise ValueError("`num_tasks` must be >= 0; got {}".format(
117          proto.num_tasks))
118    if proto.num_tpu_devices_per_task < 0:
119      raise ValueError("`num_tpu_devices_per_task` must be >= 0; got {}".format(
120          proto.num_tpu_devices_per_task))
121
122    expected_coordinates_size = (
123        proto.num_tasks * proto.num_tpu_devices_per_task * len(
124            proto.mesh_shape))
125    if len(proto.device_coordinates) != expected_coordinates_size:
126      raise ValueError("`device_coordinates` must have shape num_tasks ({}) * "
127                       "num_tpu_devices_per_task ({}) * len(mesh_shape) ({}); "
128                       "got shape {}".format(proto.num_tasks,
129                                             proto.num_tpu_devices_per_task,
130                                             proto.mesh_shape,
131                                             len(proto.device_coordinates)))
132
133    coords = np.array(proto.device_coordinates, dtype=np.int32)
134    if any(coords < 0):
135      raise ValueError(
136          "All values in `device_coordinates` must be >= 0, got {}"
137          .format(coords))
138    coords = coords.reshape((proto.num_tasks, proto.num_tpu_devices_per_task,
139                             len(proto.mesh_shape)))
140    self._device_coordinates = coords
141
142  def _invert_topology(self):
143    """Inverts a [task,device,axis] topology to [x,y,z] -> task/device maps."""
144    tasks = np.full(list(self.mesh_shape), -1, dtype=np.int32)
145    devices = np.full(list(self.mesh_shape), -1, dtype=np.int32)
146    for task in range(self.device_coordinates.shape[0]):
147      for device in range(self.device_coordinates.shape[1]):
148        x, y, z, core = self.device_coordinates[task, device, :]
149        tasks[x, y, z, core] = task
150        devices[x, y, z, core] = device
151    return tasks, devices
152
153  @property
154  def mesh_shape(self):
155    """A rank 1 int32 array describing the shape of the TPU topology."""
156    return self._mesh_shape
157
158  @property
159  def mesh_rank(self):
160    """Returns the number of dimensions in the mesh."""
161    return len(self._mesh_shape)
162
163  @property
164  def device_coordinates(self):
165    """Describes the mapping from TPU devices to topology coordinates.
166
167    Returns:
168      A rank 3 int32 array with shape `[tasks, devices, axis]`.
169      `tasks` is the number of tasks in the TPU cluster, `devices` is the number
170      of TPU devices per task, and `axis` is the number of axes in the TPU
171      cluster topology. Each entry gives the `axis`-th coordinate in the
172      topology of a task/device pair. TPU topologies are 4-dimensional, with
173      dimensions `(x, y, z, core number)`.
174    """
175    return self._device_coordinates
176
177  @property
178  def missing_devices(self):
179    """Array of indices of missing devices."""
180    return self._missing_devices
181
182  def task_ordinal_at_coordinates(self, device_coordinates):
183    """Returns the TensorFlow task number attached to `device_coordinates`.
184
185    Args:
186      device_coordinates: An integer sequence describing a device's physical
187        coordinates in the TPU fabric.
188
189    Returns:
190      Returns the TensorFlow task number that contains the TPU device with those
191      physical coordinates.
192    """
193    return self._topology_tasks[tuple(device_coordinates)]
194
195  def tpu_device_ordinal_at_coordinates(self, device_coordinates):
196    """Returns the TensorFlow device number at `device_coordinates`.
197
198    Args:
199      device_coordinates: An integer sequence describing a device's physical
200        coordinates in the TPU fabric.
201
202    Returns:
203      Returns the TensorFlow device number within the task corresponding to
204      attached to the device with those physical coordinates.
205    """
206    return self._topology_devices[tuple(device_coordinates)]
207
208  def cpu_device_name_at_coordinates(self, device_coordinates, job=None):
209    """Returns the CPU device attached to a logical core."""
210    return _tpu_host_device_name(
211        job, self._topology_tasks[tuple(device_coordinates)])
212
213  def tpu_device_name_at_coordinates(self, device_coordinates, job=None):
214    """Returns the name of the TPU device assigned to a logical core."""
215    return _tpu_device_name(job,
216                            self._topology_tasks[tuple(device_coordinates)],
217                            self._topology_devices[tuple(device_coordinates)])
218
219  @property
220  def num_tasks(self):
221    """Returns the number of TensorFlow tasks in the TPU slice."""
222    return self._device_coordinates.shape[0]
223
224  @property
225  def num_tpus_per_task(self):
226    """Returns the number of TPU devices per task in the TPU slice."""
227    return self._device_coordinates.shape[1]
228
229  def serialized(self):
230    """Returns the serialized form of the topology."""
231    if self._serialized is None:
232      proto = topology_pb2.TopologyProto()
233      proto.mesh_shape[:] = list(self._mesh_shape)
234      proto.num_tasks = self._device_coordinates.shape[0]
235      proto.num_tpu_devices_per_task = self._device_coordinates.shape[1]
236      proto.device_coordinates.extend(list(self._device_coordinates.flatten()))
237      self._serialized = proto.SerializeToString()
238
239    return self._serialized
240