• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
15
16"""Helper library for sharding during TPU compilation."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22from six.moves import xrange  # pylint: disable=redefined-builtin
23
24from tensorflow.python.framework import tensor_shape
25
26_DEFAULT_NUMBER_OF_SHARDS = 1
27_DEFAULT_SHARD_DIMENSION = 0
28
29
30# TODO(b/36777903) change other parts of tpu.py to use this class.
31class ShardingPolicy(object):
32  """An object use to hold the sharding policy for a Tensor.
33  """
34
35  def __init__(self):
36    self._number_of_shards = None
37    self._shard_dimension = None
38    self._frozen = False
39
40  def __str__(self):
41    if self.number_of_shards is None or self.shard_dimension is None:
42      return "ShardingPolicy(unset)"
43    else:
44      return ("ShardingPolicy(%d shards dimension %d)" %
45              (self.number_of_shards, self.shard_dimension))
46
47  def _fill_default_values(self):
48    if self._number_of_shards is None:
49      self._number_of_shards = _DEFAULT_NUMBER_OF_SHARDS
50    if self._shard_dimension is None:
51      self._shard_dimension = tensor_shape.as_dimension(
52          _DEFAULT_SHARD_DIMENSION)
53
54  def freeze(self):
55    """Prevents further modification to the sharding policy.
56
57    Any values that have not been set when freeze is called are set to
58    defaults. If the ShardingPolicy is already frozen, this is a NoOp.
59    """
60    if not self._frozen:
61      self._fill_default_values()
62      self._frozen = True
63
64  @property
65  def number_of_shards(self):
66    """Returns the number of shards in the policy or None if unspecified."""
67    return self._number_of_shards
68
69  def set_number_of_shards(self, number_of_shards):
70    """Sets the number of shards for the current policy.
71
72    If the policy has been frozen then number_of_shards must match the
73    existing setting.
74
75    Args:
76      number_of_shards: The number of shards to use in the policy.
77
78    Raises:
79      ValueError: If the policy has been frozen and number_of_shards
80        differs from the frozen value; or number_of_shards <= 0.
81    """
82    if self._frozen:
83      if self._number_of_shards != number_of_shards:
84        raise ValueError(
85            "Can't set sharding policy to use %d shards since it has been "
86            "frozen to use %d." % (number_of_shards, self._number_of_shards))
87    else:
88      if number_of_shards > 0:
89        self._number_of_shards = number_of_shards
90      else:
91        raise ValueError(
92            "Can't set sharding policy to use %s shards; value must be >0" %
93            str(number_of_shards))
94
95  @property
96  def shard_dimension(self):
97    """Returns the shard dimension of the policy or None if unspecified."""
98    return self._shard_dimension
99
100  def set_shard_dimension(self, shard_dimension):
101    """Sets the shard dimension for the current policy.
102
103    If the policy has been frozen then shard_dimension must match the
104    existing setting.
105
106    Args:
107      shard_dimension: The shard dimension to use in the policy.
108
109    Raises:
110      ValueError: If the policy has been frozen and shard_dimension
111        differs from the frozen value, or shard_dimension can't be
112        interpreted as a Dimension.
113    """
114    if self._frozen:
115      if self._shard_dimension != shard_dimension:
116        raise ValueError(
117            "Can't set shard dimension to %d since it has been frozen to "
118            "use %d." % (shard_dimension, self._shard_dimension))
119    else:
120      self._shard_dimension = tensor_shape.as_dimension(shard_dimension)
121
122  def merge(self, other):
123    """Merges the policy of another policy into the current policy.
124
125    Args:
126      other: The policy to merge into this one.
127
128    Raises:
129      ValueError: If this policy has been frozen and the merge conflicts with
130      the frozen policy.
131    """
132    if other.number_of_shards is not None:
133      self.set_number_of_shards(other.number_of_shards)
134    if other.shard_dimension is not None:
135      self.set_shard_dimension(other.shard_dimension)
136
137  def get_sharded_shape(self, shape, shard_index=None):
138    """Returns the shape of a shard of a full Tensor.
139
140    When given the shape of a 'full-size' Tensor, returns the shape of
141    the sub-Tensor after it has been sharded. Freezes the policy if it
142    has not yet been frozen.
143
144    Args:
145      shape: The shape of the full-size Tensor to be sharded.
146      shard_index: The index of the shard whose shape should be returned.
147        shard_index can be None for sharding policies that use the same
148        shape for every shard.
149
150    Returns:
151      The shape of the sharded version of the Tensor.
152
153    Raises:
154      ValueError: If shard_index is None when shards are of different
155        shapes; or shard_index is not None and
156        !(0<=shard_index<number_of_shards); or shape does not have at
157        least self.shard_dimension+1 dimensions; or the value of
158        shape's shard dimension is not a multiple of
159        self.number_of_shards
160    """
161    if self._shard_dimension is None or self._number_of_shards is None:
162      # Don't raise an error if the config is unset.
163      return None
164    if shard_index is not None:
165      if shard_index < 0 or shard_index >= self.number_of_shards:
166        raise ValueError("shard_index %d, but must be in [0,%d)." %
167                         (shard_index, self._number_of_shards))
168    shape = tensor_shape.as_shape(shape)
169    if self._number_of_shards == 1:
170      # Don't do anything when there's only one shard.
171      return shape
172    ndims = shape.ndims
173    if ndims is None:
174      raise ValueError("shape must be a specified shape not Unknown")
175    if ndims <= self._shard_dimension:
176      raise ValueError("shape %s does not contain shard_dimension %d" %
177                       (shape.as_list(), self._shard_dimension))
178    dims = shape.as_list()
179    if dims[self._shard_dimension] is None:
180      raise ValueError("shape %s must have a fixed size for dimension %d "
181                       "that is known at graph construction time." %
182                       (shape.as_list(), self._shard_dimension))
183    if (dims[self._shard_dimension] % self._number_of_shards) != 0:
184      raise ValueError("shape %s cannot be sharded %d ways along dimension %d" %
185                       (shape.as_list(), self._number_of_shards,
186                        self._shard_dimension))
187    dims[self._shard_dimension] //= self._number_of_shards
188    return tensor_shape.as_shape(dims)
189
190  def _unshard_shape(self, shape):
191    """Return the unsharded shape that would generate a given sharded shape.
192
193    Args:
194      shape: the sharded shape to unshard
195
196    Returns:
197      The unsharded shape.
198
199    Raises:
200      ValueError: if shape is unknown or does not contain
201        self.shard_dimension
202      TypeError: if shape is not convertible to a TensorShape
203    """
204    shape = tensor_shape.as_shape(shape)
205    if self._number_of_shards == 1:
206      # Don't do anything when there's only one shard.
207      return shape
208    ndims = shape.ndims
209    if ndims is None:
210      raise ValueError("shape must be a specified shape not Unknown")
211    if ndims <= self._shard_dimension:
212      raise ValueError("shape %s does not contain shard_dimension %d" %
213                       (shape.as_list(), self._shard_dimension))
214    dims = shape.as_list()
215    dims[self._shard_dimension] *= self._number_of_shards
216    return tensor_shape.as_shape(dims)
217
218  def get_unsharded_shape(self, shapes):
219    """Returns the shape of an unsharded Tensor given a list of shards.
220
221    When given a list of shapes of shards, returns the shape of the
222    unsharded Tensor that would generate the shards. Sets defaults for the
223    policy if number_of_shards or shard_dimension is None.
224
225    Args:
226      shapes: The shapes of the Tensor shards to be combined.
227
228    Returns:
229      The shape of the unsharded version of the Tensor.
230
231    Raises:
232      ValueError: if shapes is not a list of length
233        self.number_of_shards; or any element of shapes is not a valid
234        shape consistent with the sharding policy; or the list of
235        shapes is not a valid sharding of a full shape.
236      TypeError: if an element of shapes is not convertible to a
237        TensorShape
238    """
239    self._fill_default_values()
240    if len(shapes) != self.number_of_shards:
241      raise ValueError(
242          "shapes is %s but must be a list of length number_of_shards=%d" % (
243              str(shapes), self.number_of_shards))
244    unsharded_shapes = [self._unshard_shape(s) for s in shapes]
245    for i in xrange(self.number_of_shards - 1):
246      if not unsharded_shapes[i].is_compatible_with(
247          unsharded_shapes[self.number_of_shards - 1]):
248        raise ValueError(
249            "sharded shapes %s are not consistent shards of a full shape "
250            "sharded %d ways along dimension %d" % (
251                str(shapes), self.number_of_shards, self.shard_dimension))
252    return unsharded_shapes[0]
253