• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ======================================
15"""Defines the `Topology` class, that describes a TPU fabric topology."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import numpy as np
22from six.moves import xrange  # pylint: disable=redefined-builtin
23
24from tensorflow.core.protobuf.tpu import topology_pb2
25from tensorflow.python.util.tf_export import tf_export
26
27
28def _tpu_device_name(job, task, device):
29  """Returns the device name for the TPU `device` on `task` of `job`."""
30  if job is None:
31    return "/task:%d/device:TPU:%d" % (task, device)
32  else:
33    return "/job:%s/task:%d/device:TPU:%d" % (job, task, device)
34
35
36def _tpu_host_device_name(job, task):
37  """Returns the device name for the CPU device on `task` of `job`."""
38  if job is None:
39    return "/task:%d/device:CPU:0" % task
40  else:
41    return "/job:%s/task:%d/device:CPU:0" % (job, task)
42
43
44@tf_export("tpu.experimental.Topology")
45class Topology(object):
46  """Describes a set of TPU devices.
47
48  Represents both the shape of the physical mesh, and the mapping between
49  TensorFlow TPU devices to physical mesh coordinates.
50  """
51
52  def __init__(self, serialized=None, mesh_shape=None, device_coordinates=None):
53    """Builds a Topology object.
54
55    If `serialized` is not `None`, the topology is parsed from `serialized` and
56    the other arguments are ignored. Otherwise, the topology is computed from
57    `mesh_shape` and `device_coordinates`.
58
59    Args:
60      serialized: A serialized `TopologyProto`, or `None`. If not `None`, the
61        serialized proto is parsed to discover the topology.
62      mesh_shape: A sequence of 4 positive integers, or `None`. If not `None`,
63        the shape of the TPU topology, in number of cores. Ignored if
64        `serialized` is not `None`.
65      device_coordinates: A rank 4 numpy array that describes the mapping from
66        TensorFlow TPU devices to TPU fabric coordinates, or `None`. Ignored
67        if `serialized is not `None`.
68
69    Raises:
70      ValueError: If `serialized` does not describe a well-formed topology.
71      ValueError: If `serialized` is `None` and `mesh_shape` is not a sequence
72        of 4 positive integers.
73      ValueError: If `serialized` is `None` and `device_coordinates` is not a
74        rank 4 numpy int32 array that describes a valid coordinate mapping.
75    """
76
77    self._serialized = serialized
78
79    if serialized:
80      self._parse_topology(serialized)
81    else:
82      self._mesh_shape = np.asarray(mesh_shape, dtype=np.int32)
83      self._device_coordinates = np.asarray(device_coordinates, np.int32)
84      if len(self._mesh_shape) != 4 or any(self._mesh_shape < 1):
85        raise ValueError("`mesh_shape` must be a sequence of 4 positive "
86                         "entries; got {}".format(self._mesh_shape))
87
88      if (len(self._device_coordinates.shape) != 3 or
89          self._device_coordinates.shape[2] != len(self._mesh_shape)):
90        raise ValueError("`device_coordinates` must be a rank 3 int32 array "
91                         "with minor dimension equal to the mesh shape rank"
92                         "got {} {} {} mesh_shape={},len {}".format(
93                             self._device_coordinates.shape,
94                             len(self._device_coordinates.shape),
95                             self._device_coordinates.shape[2],
96                             self._mesh_shape, len(self._mesh_shape)))
97
98    self._topology_tasks, self._topology_devices = self._invert_topology()
99
100    # Coordinates of devices that are missing
101    self._missing_devices = np.argwhere(self._topology_tasks < 0)
102
103  def _parse_topology(self, serialized):
104    """Parses a serialized `TopologyProto` into `self`."""
105    proto = topology_pb2.TopologyProto()
106    proto.ParseFromString(serialized)
107
108    self._mesh_shape = np.array(proto.mesh_shape, dtype=np.int32)
109    if len(self._mesh_shape) != 4 or any(self._mesh_shape < 1):
110      raise ValueError("`mesh_shape` must be a vector of size 4 with positive "
111                       "entries; got {}".format(self._mesh_shape))
112
113    if proto.num_tasks < 0:
114      raise ValueError("`num_tasks` must be >= 0; got {}".format(
115          proto.num_tasks))
116    if proto.num_tpu_devices_per_task < 0:
117      raise ValueError("`num_tpu_devices_per_task` must be >= 0; got {}".format(
118          proto.num_tpu_devices_per_task))
119
120    expected_coordinates_size = (
121        proto.num_tasks * proto.num_tpu_devices_per_task * len(
122            proto.mesh_shape))
123    if len(proto.device_coordinates) != expected_coordinates_size:
124      raise ValueError("`device_coordinates` must have shape num_tasks ({}) * "
125                       "num_tpu_devices_per_task ({}) * len(mesh_shape) ({}); "
126                       "got shape {}".format(proto.num_tasks,
127                                             proto.num_tpu_devices_per_task,
128                                             proto.mesh_shape,
129                                             len(proto.device_coordinates)))
130
131    coords = np.array(proto.device_coordinates, dtype=np.int32)
132    if any(coords < 0):
133      raise ValueError("`device_coordinates` must be >= 0")
134    coords = coords.reshape((proto.num_tasks, proto.num_tpu_devices_per_task,
135                             len(proto.mesh_shape)))
136    self._device_coordinates = coords
137
138  def _invert_topology(self):
139    """Inverts a [task,device,axis] topology to [x,y,z] -> task/device maps."""
140    tasks = np.full(list(self.mesh_shape), -1, dtype=np.int32)
141    devices = np.full(list(self.mesh_shape), -1, dtype=np.int32)
142    for task in xrange(self.device_coordinates.shape[0]):
143      for device in xrange(self.device_coordinates.shape[1]):
144        x, y, z, core = self.device_coordinates[task, device, :]
145        tasks[x, y, z, core] = task
146        devices[x, y, z, core] = device
147    return tasks, devices
148
149  @property
150  def mesh_shape(self):
151    """A rank 1 int32 array describing the shape of the TPU topology."""
152    return self._mesh_shape
153
154  @property
155  def mesh_rank(self):
156    """Returns the number of dimensions in the mesh."""
157    return len(self._mesh_shape)
158
159  @property
160  def device_coordinates(self):
161    """Describes the mapping from TPU devices to topology coordinates.
162
163    Returns:
164      A rank 3 int32 array with shape `[tasks, devices, axis]`.
165      `tasks` is the number of tasks in the TPU cluster, `devices` is the number
166      of TPU devices per task, and `axis` is the number of axes in the TPU
167      cluster topology. Each entry gives the `axis`-th coordinate in the
168      topology of a task/device pair. TPU topologies are 4-dimensional, with
169      dimensions `(x, y, z, core number)`.
170    """
171    return self._device_coordinates
172
173  @property
174  def missing_devices(self):
175    """Array of indices of missing devices."""
176    return self._missing_devices
177
178  def task_ordinal_at_coordinates(self, device_coordinates):
179    """Returns the TensorFlow task number attached to `device_coordinates`.
180
181    Args:
182      device_coordinates: An integer sequence describing a device's physical
183        coordinates in the TPU fabric.
184
185    Returns:
186      Returns the TensorFlow task number that contains the TPU device with those
187      physical coordinates.
188    """
189    return self._topology_tasks[tuple(device_coordinates)]
190
191  def tpu_device_ordinal_at_coordinates(self, device_coordinates):
192    """Returns the TensorFlow device number at `device_coordinates`.
193
194    Args:
195      device_coordinates: An integer sequence describing a device's physical
196        coordinates in the TPU fabric.
197
198    Returns:
199      Returns the TensorFlow device number within the task corresponding to
200      attached to the device with those physical coordinates.
201    """
202    return self._topology_devices[tuple(device_coordinates)]
203
204  def cpu_device_name_at_coordinates(self, device_coordinates, job=None):
205    """Returns the CPU device attached to a logical core."""
206    return _tpu_host_device_name(
207        job, self._topology_tasks[tuple(device_coordinates)])
208
209  def tpu_device_name_at_coordinates(self, device_coordinates, job=None):
210    """Returns the name of the TPU device assigned to a logical core."""
211    return _tpu_device_name(job,
212                            self._topology_tasks[tuple(device_coordinates)],
213                            self._topology_devices[tuple(device_coordinates)])
214
215  @property
216  def num_tasks(self):
217    """Returns the number of TensorFlow tasks in the TPU slice."""
218    return self._device_coordinates.shape[0]
219
220  @property
221  def num_tpus_per_task(self):
222    """Returns the number of TPU devices per task in the TPU slice."""
223    return self._device_coordinates.shape[1]
224
225  def serialized(self):
226    """Returns the serialized form of the topology."""
227    if self._serialized is None:
228      proto = topology_pb2.TopologyProto()
229      proto.mesh_shape[:] = list(self._mesh_shape)
230      proto.num_tasks = self._device_coordinates.shape[0]
231      proto.num_tpu_devices_per_task = self._device_coordinates.shape[1]
232      proto.device_coordinates.extend(list(self._device_coordinates.flatten()))
233      self._serialized = proto.SerializeToString()
234
235    return self._serialized
236