1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================= 15 16"""Helper library for sharding during TPU compilation.""" 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22from six.moves import xrange # pylint: disable=redefined-builtin 23 24from tensorflow.python.framework import tensor_shape 25 26_DEFAULT_NUMBER_OF_SHARDS = 1 27_DEFAULT_SHARD_DIMENSION = 0 28 29 30# TODO(b/36777903) change other parts of tpu.py to use this class. 31class ShardingPolicy(object): 32 """An object use to hold the sharding policy for a Tensor. 33 """ 34 35 def __init__(self): 36 self._number_of_shards = None 37 self._shard_dimension = None 38 self._frozen = False 39 40 def __str__(self): 41 if self.number_of_shards is None or self.shard_dimension is None: 42 return "ShardingPolicy(unset)" 43 else: 44 return ("ShardingPolicy(%d shards dimension %d)" % 45 (self.number_of_shards, self.shard_dimension)) 46 47 def _fill_default_values(self): 48 if self._number_of_shards is None: 49 self._number_of_shards = _DEFAULT_NUMBER_OF_SHARDS 50 if self._shard_dimension is None: 51 self._shard_dimension = tensor_shape.as_dimension( 52 _DEFAULT_SHARD_DIMENSION) 53 54 def freeze(self): 55 """Prevents further modification to the sharding policy. 56 57 Any values that have not been set when freeze is called are set to 58 defaults. If the ShardingPolicy is already frozen, this is a NoOp. 59 """ 60 if not self._frozen: 61 self._fill_default_values() 62 self._frozen = True 63 64 @property 65 def number_of_shards(self): 66 """Returns the number of shards in the policy or None if unspecified.""" 67 return self._number_of_shards 68 69 def set_number_of_shards(self, number_of_shards): 70 """Sets the number of shards for the current policy. 71 72 If the policy has been frozen then number_of_shards must match the 73 existing setting. 74 75 Args: 76 number_of_shards: The number of shards to use in the policy. 77 78 Raises: 79 ValueError: If the policy has been frozen and number_of_shards 80 differs from the frozen value; or number_of_shards <= 0. 81 """ 82 if self._frozen: 83 if self._number_of_shards != number_of_shards: 84 raise ValueError( 85 "Can't set sharding policy to use %d shards since it has been " 86 "frozen to use %d." % (number_of_shards, self._number_of_shards)) 87 else: 88 if number_of_shards > 0: 89 self._number_of_shards = number_of_shards 90 else: 91 raise ValueError( 92 "Can't set sharding policy to use %s shards; value must be >0", 93 str(number_of_shards)) 94 95 @property 96 def shard_dimension(self): 97 """Returns the shard dimension of the policy or None if unspecified.""" 98 return self._shard_dimension 99 100 def set_shard_dimension(self, shard_dimension): 101 """Sets the shard dimension for the current policy. 102 103 If the policy has been frozen then shard_dimension must match the 104 existing setting. 105 106 Args: 107 shard_dimension: The shard dimension to use in the policy. 108 109 Raises: 110 ValueError: If the policy has been frozen and shard_dimension 111 differs from the frozen value, or shard_dimension can't be 112 interpreted as a Dimension. 113 """ 114 if self._frozen: 115 if self._shard_dimension != shard_dimension: 116 raise ValueError( 117 "Can't set shard dimension to %d since it has been frozen to " 118 "use %d." % (shard_dimension, self._shard_dimension)) 119 else: 120 self._shard_dimension = tensor_shape.as_dimension(shard_dimension) 121 122 def merge(self, other): 123 """Merges the policy of another policy into the current policy. 124 125 Args: 126 other: The policy to merge into this one. 127 128 Raises: 129 ValueError: If this policy has been frozen and the merge conflicts with 130 the frozen policy. 131 """ 132 if other.number_of_shards is not None: 133 self.set_number_of_shards(other.number_of_shards) 134 if other.shard_dimension is not None: 135 self.set_shard_dimension(other.shard_dimension) 136 137 def get_sharded_shape(self, shape, shard_index=None): 138 """Returns the shape of a shard of a full Tensor. 139 140 When given the shape of a 'full-size' Tensor, returns the shape of 141 the sub-Tensor after it has been sharded. Freezes the policy if it 142 has not yet been frozen. 143 144 Args: 145 shape: The shape of the full-size Tensor to be sharded. 146 shard_index: The index of the shard whose shape should be returned. 147 shard_index can be None for sharding policies that use the same 148 shape for every shard. 149 freeze_config: 150 151 Returns: 152 The shape of the sharded version of the Tensor. 153 154 Raises: 155 ValueError: If shard_index is None when shards are of different 156 shapes; or shard_index is not None and 157 !(0<=shard_index<number_of_shards); or shape does not have at 158 least self.shard_dimension+1 dimensions; or the value of 159 shape's shard dimension is not a multiple of 160 self.number_of_shards 161 """ 162 if self._shard_dimension is None or self._number_of_shards is None: 163 # Don't raise an error if the config is unset. 164 return None 165 if shard_index is not None: 166 if shard_index < 0 or shard_index >= self.number_of_shards: 167 raise ValueError("shard_index %d, but must be in [0,%d)." % 168 (shard_index, self._number_of_shards)) 169 shape = tensor_shape.as_shape(shape) 170 if self._number_of_shards == 1: 171 # Don't do anything when there's only one shard. 172 return shape 173 ndims = shape.ndims 174 if ndims is None: 175 raise ValueError("shape must be a specified shape not Unknown") 176 if ndims <= self._shard_dimension: 177 raise ValueError("shape %s does not contain shard_dimension %d" % 178 (shape.as_list(), self._shard_dimension)) 179 dims = shape.as_list() 180 if dims[self._shard_dimension] is None: 181 raise ValueError("shape %s must have a fixed size for dimension %d " 182 "that is known at graph construction time." % 183 (shape.as_list(), self._shard_dimension)) 184 if (dims[self._shard_dimension] % self._number_of_shards) != 0: 185 raise ValueError("shape %s cannot be sharded %d ways along dimension %d" % 186 (shape.as_list(), self._number_of_shards, 187 self._shard_dimension)) 188 dims[self._shard_dimension] /= self._number_of_shards 189 return tensor_shape.as_shape(dims) 190 191 def _unshard_shape(self, shape): 192 """Return the unsharded shape that would generate a given sharded shape. 193 194 Args: 195 shape: the sharded shape to unshard 196 197 Returns: 198 The unsharded shape. 199 200 Raises: 201 ValueError: if shape is unknown or does not contain 202 self.shard_dimension 203 TypeError: if shape is not convertible to a TensorShape 204 """ 205 shape = tensor_shape.as_shape(shape) 206 if self._number_of_shards == 1: 207 # Don't do anything when there's only one shard. 208 return shape 209 ndims = shape.ndims 210 if ndims is None: 211 raise ValueError("shape must be a specified shape not Unknown") 212 if ndims <= self._shard_dimension: 213 raise ValueError("shape %s does not contain shard_dimension %d" % 214 (shape.as_list(), self._shard_dimension)) 215 dims = shape.as_list() 216 dims[self._shard_dimension] *= self._number_of_shards 217 return tensor_shape.as_shape(dims) 218 219 def get_unsharded_shape(self, shapes): 220 """Returns the shape of an unsharded Tensor given a list of shards. 221 222 When given a list of shapes of shards, returns the shape of the 223 unsharded Tensor that would generate the shards. Sets defaults for the 224 policy if number_of_shards or shard_dimension is None. 225 226 Args: 227 shapes: The shapes of the Tensor shards to be combined. 228 229 Returns: 230 The shape of the unsharded version of the Tensor. 231 232 Raises: 233 ValueError: if shapes is not a list of length 234 self.number_of_shards; or any element of shapes is not a valid 235 shape consistent with the sharding policy; or the list of 236 shapes is not a valid sharding of a full shape. 237 TypeError: if an element of shapes is not convertible to a 238 TensorShape 239 """ 240 self._fill_default_values() 241 if len(shapes) != self.number_of_shards: 242 raise ValueError( 243 "shapes is %s but must be a list of length number_of_shards=%d" % ( 244 str(shapes), self.number_of_shards)) 245 unsharded_shapes = [self._unshard_shape(s) for s in shapes] 246 for i in xrange(self.number_of_shards - 1): 247 if not unsharded_shapes[i].is_compatible_with( 248 unsharded_shapes[self.number_of_shards - 1]): 249 raise ValueError( 250 "sharded shapes %s are not consistent shards of a full shape " 251 "sharded %d ways along dimension %d" % ( 252 str(shapes), self.number_of_shards, self.shard_dimension)) 253 return unsharded_shapes[0] 254