• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
15"""Helper library for sharding during TPU compilation."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from six.moves import xrange  # pylint: disable=redefined-builtin
22
23from tensorflow.python.framework import tensor_shape
24
25_DEFAULT_NUMBER_OF_SHARDS = 1
26_DEFAULT_SHARD_DIMENSION = 0
27
28
29# TODO(b/36777903) change other parts of tpu.py to use this class.
30class ShardingPolicy(object):
31  """An object use to hold the sharding policy for a Tensor."""
32
33  def __init__(self):
34    self._number_of_shards = None
35    self._number_of_partitions = 1
36    self._shard_dimension = None
37    self._frozen = False
38
39  def __str__(self):
40    if self.number_of_shards is None or self.shard_dimension is None:
41      return "ShardingPolicy(unset)"
42    else:
43      return ("ShardingPolicy(%d shards dimension %d)" %
44              (self.number_of_shards, self.shard_dimension))
45
46  def _fill_default_values(self):
47    if self._number_of_shards is None:
48      self._number_of_shards = _DEFAULT_NUMBER_OF_SHARDS
49    if self._shard_dimension is None:
50      self._shard_dimension = tensor_shape.as_dimension(
51          _DEFAULT_SHARD_DIMENSION)
52
53  def freeze(self):
54    """Prevents further modification to the sharding policy.
55
56    Any values that have not been set when freeze is called are set to
57    defaults. If the ShardingPolicy is already frozen, this is a NoOp.
58    """
59    if not self._frozen:
60      self._fill_default_values()
61      self._frozen = True
62
63  @property
64  def number_of_shards(self):
65    """Returns the number of shards in the policy or None if unspecified."""
66    return self._number_of_shards
67
68  def set_number_of_shards(self, number_of_shards):
69    """Sets the number of shards for the current policy.
70
71    If the policy has been frozen then number_of_shards must match the
72    existing setting.
73
74    Args:
75      number_of_shards: The number of shards to use in the policy.
76
77    Raises:
78      ValueError: If the policy has been frozen and number_of_shards
79        differs from the frozen value; or number_of_shards <= 0.
80    """
81    if self._frozen:
82      if self._number_of_shards != number_of_shards:
83        raise ValueError(
84            f"Can't set sharding policy to use {number_of_shards} shards since "
85            f"it has been frozen to use {self._number_of_shards}")
86    else:
87      if number_of_shards > 0:
88        self._number_of_shards = number_of_shards
89      else:
90        raise ValueError(
91            "Can't set sharding policy to use {number_of_shards} shards; "
92            "value must be > 0")
93
94  @property
95  def number_of_partitions(self):
96    """Returns the number of partitions of the policy or None if unspecified."""
97    return self._number_of_partitions
98
99  def set_number_of_partitions(self, number_of_partitions):
100    """Sets the number of partitions for the current policy.
101
102    If the policy has been frozen then shard_dimension must match the
103    existing setting.
104
105    Args:
106      number_of_partitions: The number of partitions to use in the policy.
107
108    Raises:
109      ValueError: If the policy has been frozen and shard_dimension
110        differs from the frozen value.
111    """
112    if self._frozen:
113      if self._number_of_partitions != number_of_partitions:
114        raise ValueError(
115            f"Can't set number_of_partitions to {number_of_partitions} since "
116            f"it has been frozen to use {self._number_of_partitions}.")
117    else:
118      self._number_of_partitions = number_of_partitions
119
120  @property
121  def shard_dimension(self):
122    """Returns the shard dimension of the policy or None if unspecified."""
123    return self._shard_dimension
124
125  def set_shard_dimension(self, shard_dimension):
126    """Sets the shard dimension for the current policy.
127
128    If the policy has been frozen then shard_dimension must match the
129    existing setting.
130
131    Args:
132      shard_dimension: The shard dimension to use in the policy.
133
134    Raises:
135      ValueError: If the policy has been frozen and shard_dimension
136        differs from the frozen value, or shard_dimension can't be
137        interpreted as a Dimension.
138    """
139    if self._frozen:
140      if self._shard_dimension != shard_dimension:
141        raise ValueError(
142            "Can't set shard dimension to %d since it has been frozen to "
143            "use %d." % (shard_dimension, self._shard_dimension))
144    else:
145      self._shard_dimension = tensor_shape.as_dimension(shard_dimension)
146
147  def merge(self, other):
148    """Merges the policy of another policy into the current policy.
149
150    Args:
151      other: The policy to merge into this one.
152
153    Raises:
154      ValueError: If this policy has been frozen and the merge conflicts with
155      the frozen policy.
156    """
157    if other.number_of_shards is not None:
158      self.set_number_of_shards(other.number_of_shards)
159    if other.shard_dimension is not None:
160      self.set_shard_dimension(other.shard_dimension)
161
162  def get_unpartitioned_shape(self, shape):
163    """Returns the shape of an unpartitioned Tensor.
164
165    When given the shape of a 'sharded-size' Tensor, returns the shape
166    of the full shape of its unpartitioned Tensor.
167
168    Args:
169      shape: The shape of the sharded Tensor.
170
171    Returns:
172      The shape of the unpartitioned version of the Tensor.
173
174    Raises:
175      ValueError: if shape has unknown sharded dimension
176    """
177    shape = tensor_shape.as_shape(shape)
178    dims = shape.as_list()
179    if (self._shard_dimension is None or self._number_of_partitions is None or
180        not dims):
181      return None
182    if dims[self._shard_dimension] is None:
183      raise ValueError(f"Shape {shape.as_list()} must have a fixed size for "
184                       f"dimension {self._shard_dimension} that is known. ")
185    if self._number_of_partitions > 1:
186      dims[self._shard_dimension] *= self._number_of_partitions
187    return tensor_shape.as_shape(dims)
188
189  def get_sharded_shape(self, shape, shard_index=None):
190    """Returns the shape of a shard of a full Tensor.
191
192    When given the shape of a 'full-size' Tensor, returns the shape of
193    the sub-Tensor after it has been sharded. Freezes the policy if it
194    has not yet been frozen.
195
196    Args:
197      shape: The shape of the full-size Tensor to be sharded.
198      shard_index: The index of the shard whose shape should be returned.
199        shard_index can be None for sharding policies that use the same shape
200        for every shard.
201
202    Returns:
203      The shape of the sharded version of the Tensor.
204
205    Raises:
206      ValueError: If shard_index is None when shards are of different
207        shapes; or shard_index is not None and
208        !(0<=shard_index<number_of_shards); or shape does not have at
209        least self.shard_dimension+1 dimensions; or the value of
210        shape's shard dimension is not a multiple of
211        self.number_of_shards
212    """
213    if self._shard_dimension is None or self._number_of_shards is None:
214      # Don't raise an error if the config is unset.
215      return None
216    if shard_index is not None:
217      if shard_index < 0 or shard_index >= self.number_of_shards:
218        raise ValueError(
219            f"Requested shard_index {shard_index}, but shard_index must be in "
220            f"[0,{self._number_of_shards}).")
221    shape = tensor_shape.as_shape(shape)
222    if self._number_of_shards == 1:
223      # Don't do anything when there's only one shard.
224      return shape
225    ndims = shape.ndims
226    if ndims is None:
227      raise ValueError(f"Shape {shape} must be a known shape.")
228    if ndims <= self._shard_dimension:
229      raise ValueError(
230          f"Shape {shape.as_list()} does not contain shard_dimension "
231          f"{self._shard_dimension}")
232    dims = shape.as_list()
233    if dims[self._shard_dimension] is None:
234      raise ValueError(
235          f"Shape {shape.as_list()} must have a fixed size for dimension "
236          f"{self._shard_dimension} that is known at construction time.")
237    if (dims[self._shard_dimension] % self._number_of_shards) != 0:
238      raise ValueError(
239          f"Shape {shape.as_list()} cannot be sharded {self._number_of_shards} "
240          f"ways along dimension {self._shard_dimension}")
241    dims[self._shard_dimension] //= self._number_of_shards
242    return tensor_shape.TensorShape(dims)
243
244  def _unshard_shape(self, shape):
245    """Return the unsharded shape that would generate a given sharded shape.
246
247    Args:
248      shape: the sharded shape to unshard
249
250    Returns:
251      The unsharded shape.
252
253    Raises:
254      ValueError: if shape is unknown or does not contain
255        self.shard_dimension
256      TypeError: if shape is not convertible to a TensorShape
257    """
258    shape = tensor_shape.as_shape(shape)
259    if self._number_of_shards == 1:
260      # Don't do anything when there's only one shard.
261      return shape
262    ndims = shape.ndims
263    if ndims is None:
264      raise ValueError(f"Shape {shape} must be statically known.")
265    if ndims <= self._shard_dimension:
266      raise ValueError(f"Shape {shape.as_list()} does not contain "
267                       f"shard_dimension {self._shard_dimension}. "
268                       f"Rank is too small.")
269    dims = shape.as_list()
270    dims[self._shard_dimension] *= self._number_of_shards
271    return tensor_shape.TensorShape(dims)
272
273  def get_unsharded_shape(self, shapes):
274    """Returns the shape of an unsharded Tensor given a list of shards.
275
276    When given a list of shapes of shards, returns the shape of the
277    unsharded Tensor that would generate the shards. Sets defaults for the
278    policy if number_of_shards or shard_dimension is None.
279
280    Args:
281      shapes: The shapes of the Tensor shards to be combined.
282
283    Returns:
284      The shape of the unsharded version of the Tensor.
285
286    Raises:
287      ValueError: if shapes is not a list of length
288        self.number_of_shards; or any element of shapes is not a valid
289        shape consistent with the sharding policy; or the list of
290        shapes is not a valid sharding of a full shape.
291      TypeError: if an element of shapes is not convertible to a
292        TensorShape
293    """
294    self._fill_default_values()
295    if len(shapes) != self.number_of_shards:
296      raise ValueError(
297          f"Shapes {shapes} is length {len(shapes)} but must be a list of "
298          f"length number_of_shards={self.number_of_shards}")
299    unsharded_shapes = [self._unshard_shape(s) for s in shapes]
300    for i in xrange(self.number_of_shards - 1):
301      if not unsharded_shapes[i].is_compatible_with(
302          unsharded_shapes[self.number_of_shards - 1]):
303        raise ValueError(
304            f"Sharded shapes {shapes} are not consistent shards of a full shape "
305            f"sharded {self.number_of_shards} ways along "
306            f"dimension {self.shard_dimension}.")
307    return unsharded_shapes[0]
308