• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
15
16"""Helper library for sharding during TPU compilation."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22from six.moves import xrange  # pylint: disable=redefined-builtin
23
24from tensorflow.python.framework import tensor_shape
25
26_DEFAULT_NUMBER_OF_SHARDS = 1
27_DEFAULT_SHARD_DIMENSION = 0
28
29
30# TODO(b/36777903) change other parts of tpu.py to use this class.
31class ShardingPolicy(object):
32  """An object use to hold the sharding policy for a Tensor.
33  """
34
35  def __init__(self):
36    self._number_of_shards = None
37    self._shard_dimension = None
38    self._frozen = False
39
40  def __str__(self):
41    if self.number_of_shards is None or self.shard_dimension is None:
42      return "ShardingPolicy(unset)"
43    else:
44      return ("ShardingPolicy(%d shards dimension %d)" %
45              (self.number_of_shards, self.shard_dimension))
46
47  def _fill_default_values(self):
48    if self._number_of_shards is None:
49      self._number_of_shards = _DEFAULT_NUMBER_OF_SHARDS
50    if self._shard_dimension is None:
51      self._shard_dimension = tensor_shape.as_dimension(
52          _DEFAULT_SHARD_DIMENSION)
53
54  def freeze(self):
55    """Prevents further modification to the sharding policy.
56
57    Any values that have not been set when freeze is called are set to
58    defaults. If the ShardingPolicy is already frozen, this is a NoOp.
59    """
60    if not self._frozen:
61      self._fill_default_values()
62      self._frozen = True
63
64  @property
65  def number_of_shards(self):
66    """Returns the number of shards in the policy or None if unspecified."""
67    return self._number_of_shards
68
69  def set_number_of_shards(self, number_of_shards):
70    """Sets the number of shards for the current policy.
71
72    If the policy has been frozen then number_of_shards must match the
73    existing setting.
74
75    Args:
76      number_of_shards: The number of shards to use in the policy.
77
78    Raises:
79      ValueError: If the policy has been frozen and number_of_shards
80        differs from the frozen value; or number_of_shards <= 0.
81    """
82    if self._frozen:
83      if self._number_of_shards != number_of_shards:
84        raise ValueError(
85            "Can't set sharding policy to use %d shards since it has been "
86            "frozen to use %d." % (number_of_shards, self._number_of_shards))
87    else:
88      if number_of_shards > 0:
89        self._number_of_shards = number_of_shards
90      else:
91        raise ValueError(
92            "Can't set sharding policy to use %s shards; value must be >0",
93            str(number_of_shards))
94
95  @property
96  def shard_dimension(self):
97    """Returns the shard dimension of the policy or None if unspecified."""
98    return self._shard_dimension
99
100  def set_shard_dimension(self, shard_dimension):
101    """Sets the shard dimension for the current policy.
102
103    If the policy has been frozen then shard_dimension must match the
104    existing setting.
105
106    Args:
107      shard_dimension: The shard dimension to use in the policy.
108
109    Raises:
110      ValueError: If the policy has been frozen and shard_dimension
111        differs from the frozen value, or shard_dimension can't be
112        interpreted as a Dimension.
113    """
114    if self._frozen:
115      if self._shard_dimension != shard_dimension:
116        raise ValueError(
117            "Can't set shard dimension to %d since it has been frozen to "
118            "use %d." % (shard_dimension, self._shard_dimension))
119    else:
120      self._shard_dimension = tensor_shape.as_dimension(shard_dimension)
121
122  def merge(self, other):
123    """Merges the policy of another policy into the current policy.
124
125    Args:
126      other: The policy to merge into this one.
127
128    Raises:
129      ValueError: If this policy has been frozen and the merge conflicts with
130      the frozen policy.
131    """
132    if other.number_of_shards is not None:
133      self.set_number_of_shards(other.number_of_shards)
134    if other.shard_dimension is not None:
135      self.set_shard_dimension(other.shard_dimension)
136
137  def get_sharded_shape(self, shape, shard_index=None):
138    """Returns the shape of a shard of a full Tensor.
139
140    When given the shape of a 'full-size' Tensor, returns the shape of
141    the sub-Tensor after it has been sharded. Freezes the policy if it
142    has not yet been frozen.
143
144    Args:
145      shape: The shape of the full-size Tensor to be sharded.
146      shard_index: The index of the shard whose shape should be returned.
147        shard_index can be None for sharding policies that use the same
148        shape for every shard.
149      freeze_config:
150
151    Returns:
152      The shape of the sharded version of the Tensor.
153
154    Raises:
155      ValueError: If shard_index is None when shards are of different
156        shapes; or shard_index is not None and
157        !(0<=shard_index<number_of_shards); or shape does not have at
158        least self.shard_dimension+1 dimensions; or the value of
159        shape's shard dimension is not a multiple of
160        self.number_of_shards
161    """
162    if self._shard_dimension is None or self._number_of_shards is None:
163      # Don't raise an error if the config is unset.
164      return None
165    if shard_index is not None:
166      if shard_index < 0 or shard_index >= self.number_of_shards:
167        raise ValueError("shard_index %d, but must be in [0,%d)." %
168                         (shard_index, self._number_of_shards))
169    shape = tensor_shape.as_shape(shape)
170    if self._number_of_shards == 1:
171      # Don't do anything when there's only one shard.
172      return shape
173    ndims = shape.ndims
174    if ndims is None:
175      raise ValueError("shape must be a specified shape not Unknown")
176    if ndims <= self._shard_dimension:
177      raise ValueError("shape %s does not contain shard_dimension %d" %
178                       (shape.as_list(), self._shard_dimension))
179    dims = shape.as_list()
180    if dims[self._shard_dimension] is None:
181      raise ValueError("shape %s must have a fixed size for dimension %d "
182                       "that is known at graph construction time." %
183                       (shape.as_list(), self._shard_dimension))
184    if (dims[self._shard_dimension] % self._number_of_shards) != 0:
185      raise ValueError("shape %s cannot be sharded %d ways along dimension %d" %
186                       (shape.as_list(), self._number_of_shards,
187                        self._shard_dimension))
188    dims[self._shard_dimension] /= self._number_of_shards
189    return tensor_shape.as_shape(dims)
190
191  def _unshard_shape(self, shape):
192    """Return the unsharded shape that would generate a given sharded shape.
193
194    Args:
195      shape: the sharded shape to unshard
196
197    Returns:
198      The unsharded shape.
199
200    Raises:
201      ValueError: if shape is unknown or does not contain
202        self.shard_dimension
203      TypeError: if shape is not convertible to a TensorShape
204    """
205    shape = tensor_shape.as_shape(shape)
206    if self._number_of_shards == 1:
207      # Don't do anything when there's only one shard.
208      return shape
209    ndims = shape.ndims
210    if ndims is None:
211      raise ValueError("shape must be a specified shape not Unknown")
212    if ndims <= self._shard_dimension:
213      raise ValueError("shape %s does not contain shard_dimension %d" %
214                       (shape.as_list(), self._shard_dimension))
215    dims = shape.as_list()
216    dims[self._shard_dimension] *= self._number_of_shards
217    return tensor_shape.as_shape(dims)
218
219  def get_unsharded_shape(self, shapes):
220    """Returns the shape of an unsharded Tensor given a list of shards.
221
222    When given a list of shapes of shards, returns the shape of the
223    unsharded Tensor that would generate the shards. Sets defaults for the
224    policy if number_of_shards or shard_dimension is None.
225
226    Args:
227      shapes: The shapes of the Tensor shards to be combined.
228
229    Returns:
230      The shape of the unsharded version of the Tensor.
231
232    Raises:
233      ValueError: if shapes is not a list of length
234        self.number_of_shards; or any element of shapes is not a valid
235        shape consistent with the sharding policy; or the list of
236        shapes is not a valid sharding of a full shape.
237      TypeError: if an element of shapes is not convertible to a
238        TensorShape
239    """
240    self._fill_default_values()
241    if len(shapes) != self.number_of_shards:
242      raise ValueError(
243          "shapes is %s but must be a list of length number_of_shards=%d" % (
244              str(shapes), self.number_of_shards))
245    unsharded_shapes = [self._unshard_shape(s) for s in shapes]
246    for i in xrange(self.number_of_shards - 1):
247      if not unsharded_shapes[i].is_compatible_with(
248          unsharded_shapes[self.number_of_shards - 1]):
249        raise ValueError(
250            "sharded shapes %s are not consistent shards of a full shape "
251            "sharded %d ways along dimension %d" % (
252                str(shapes), self.number_of_shards, self.shard_dimension))
253    return unsharded_shapes[0]
254