1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================= 15"""Helper library for sharding during TPU compilation.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from six.moves import xrange # pylint: disable=redefined-builtin 22 23from tensorflow.python.framework import tensor_shape 24 25_DEFAULT_NUMBER_OF_SHARDS = 1 26_DEFAULT_SHARD_DIMENSION = 0 27 28 29# TODO(b/36777903) change other parts of tpu.py to use this class. 30class ShardingPolicy(object): 31 """An object use to hold the sharding policy for a Tensor.""" 32 33 def __init__(self): 34 self._number_of_shards = None 35 self._number_of_partitions = 1 36 self._shard_dimension = None 37 self._frozen = False 38 39 def __str__(self): 40 if self.number_of_shards is None or self.shard_dimension is None: 41 return "ShardingPolicy(unset)" 42 else: 43 return ("ShardingPolicy(%d shards dimension %d)" % 44 (self.number_of_shards, self.shard_dimension)) 45 46 def _fill_default_values(self): 47 if self._number_of_shards is None: 48 self._number_of_shards = _DEFAULT_NUMBER_OF_SHARDS 49 if self._shard_dimension is None: 50 self._shard_dimension = tensor_shape.as_dimension( 51 _DEFAULT_SHARD_DIMENSION) 52 53 def freeze(self): 54 """Prevents further modification to the sharding policy. 55 56 Any values that have not been set when freeze is called are set to 57 defaults. If the ShardingPolicy is already frozen, this is a NoOp. 58 """ 59 if not self._frozen: 60 self._fill_default_values() 61 self._frozen = True 62 63 @property 64 def number_of_shards(self): 65 """Returns the number of shards in the policy or None if unspecified.""" 66 return self._number_of_shards 67 68 def set_number_of_shards(self, number_of_shards): 69 """Sets the number of shards for the current policy. 70 71 If the policy has been frozen then number_of_shards must match the 72 existing setting. 73 74 Args: 75 number_of_shards: The number of shards to use in the policy. 76 77 Raises: 78 ValueError: If the policy has been frozen and number_of_shards 79 differs from the frozen value; or number_of_shards <= 0. 80 """ 81 if self._frozen: 82 if self._number_of_shards != number_of_shards: 83 raise ValueError( 84 f"Can't set sharding policy to use {number_of_shards} shards since " 85 f"it has been frozen to use {self._number_of_shards}") 86 else: 87 if number_of_shards > 0: 88 self._number_of_shards = number_of_shards 89 else: 90 raise ValueError( 91 "Can't set sharding policy to use {number_of_shards} shards; " 92 "value must be > 0") 93 94 @property 95 def number_of_partitions(self): 96 """Returns the number of partitions of the policy or None if unspecified.""" 97 return self._number_of_partitions 98 99 def set_number_of_partitions(self, number_of_partitions): 100 """Sets the number of partitions for the current policy. 101 102 If the policy has been frozen then shard_dimension must match the 103 existing setting. 104 105 Args: 106 number_of_partitions: The number of partitions to use in the policy. 107 108 Raises: 109 ValueError: If the policy has been frozen and shard_dimension 110 differs from the frozen value. 111 """ 112 if self._frozen: 113 if self._number_of_partitions != number_of_partitions: 114 raise ValueError( 115 f"Can't set number_of_partitions to {number_of_partitions} since " 116 f"it has been frozen to use {self._number_of_partitions}.") 117 else: 118 self._number_of_partitions = number_of_partitions 119 120 @property 121 def shard_dimension(self): 122 """Returns the shard dimension of the policy or None if unspecified.""" 123 return self._shard_dimension 124 125 def set_shard_dimension(self, shard_dimension): 126 """Sets the shard dimension for the current policy. 127 128 If the policy has been frozen then shard_dimension must match the 129 existing setting. 130 131 Args: 132 shard_dimension: The shard dimension to use in the policy. 133 134 Raises: 135 ValueError: If the policy has been frozen and shard_dimension 136 differs from the frozen value, or shard_dimension can't be 137 interpreted as a Dimension. 138 """ 139 if self._frozen: 140 if self._shard_dimension != shard_dimension: 141 raise ValueError( 142 "Can't set shard dimension to %d since it has been frozen to " 143 "use %d." % (shard_dimension, self._shard_dimension)) 144 else: 145 self._shard_dimension = tensor_shape.as_dimension(shard_dimension) 146 147 def merge(self, other): 148 """Merges the policy of another policy into the current policy. 149 150 Args: 151 other: The policy to merge into this one. 152 153 Raises: 154 ValueError: If this policy has been frozen and the merge conflicts with 155 the frozen policy. 156 """ 157 if other.number_of_shards is not None: 158 self.set_number_of_shards(other.number_of_shards) 159 if other.shard_dimension is not None: 160 self.set_shard_dimension(other.shard_dimension) 161 162 def get_unpartitioned_shape(self, shape): 163 """Returns the shape of an unpartitioned Tensor. 164 165 When given the shape of a 'sharded-size' Tensor, returns the shape 166 of the full shape of its unpartitioned Tensor. 167 168 Args: 169 shape: The shape of the sharded Tensor. 170 171 Returns: 172 The shape of the unpartitioned version of the Tensor. 173 174 Raises: 175 ValueError: if shape has unknown sharded dimension 176 """ 177 shape = tensor_shape.as_shape(shape) 178 dims = shape.as_list() 179 if (self._shard_dimension is None or self._number_of_partitions is None or 180 not dims): 181 return None 182 if dims[self._shard_dimension] is None: 183 raise ValueError(f"Shape {shape.as_list()} must have a fixed size for " 184 f"dimension {self._shard_dimension} that is known. ") 185 if self._number_of_partitions > 1: 186 dims[self._shard_dimension] *= self._number_of_partitions 187 return tensor_shape.as_shape(dims) 188 189 def get_sharded_shape(self, shape, shard_index=None): 190 """Returns the shape of a shard of a full Tensor. 191 192 When given the shape of a 'full-size' Tensor, returns the shape of 193 the sub-Tensor after it has been sharded. Freezes the policy if it 194 has not yet been frozen. 195 196 Args: 197 shape: The shape of the full-size Tensor to be sharded. 198 shard_index: The index of the shard whose shape should be returned. 199 shard_index can be None for sharding policies that use the same shape 200 for every shard. 201 202 Returns: 203 The shape of the sharded version of the Tensor. 204 205 Raises: 206 ValueError: If shard_index is None when shards are of different 207 shapes; or shard_index is not None and 208 !(0<=shard_index<number_of_shards); or shape does not have at 209 least self.shard_dimension+1 dimensions; or the value of 210 shape's shard dimension is not a multiple of 211 self.number_of_shards 212 """ 213 if self._shard_dimension is None or self._number_of_shards is None: 214 # Don't raise an error if the config is unset. 215 return None 216 if shard_index is not None: 217 if shard_index < 0 or shard_index >= self.number_of_shards: 218 raise ValueError( 219 f"Requested shard_index {shard_index}, but shard_index must be in " 220 f"[0,{self._number_of_shards}).") 221 shape = tensor_shape.as_shape(shape) 222 if self._number_of_shards == 1: 223 # Don't do anything when there's only one shard. 224 return shape 225 ndims = shape.ndims 226 if ndims is None: 227 raise ValueError(f"Shape {shape} must be a known shape.") 228 if ndims <= self._shard_dimension: 229 raise ValueError( 230 f"Shape {shape.as_list()} does not contain shard_dimension " 231 f"{self._shard_dimension}") 232 dims = shape.as_list() 233 if dims[self._shard_dimension] is None: 234 raise ValueError( 235 f"Shape {shape.as_list()} must have a fixed size for dimension " 236 f"{self._shard_dimension} that is known at construction time.") 237 if (dims[self._shard_dimension] % self._number_of_shards) != 0: 238 raise ValueError( 239 f"Shape {shape.as_list()} cannot be sharded {self._number_of_shards} " 240 f"ways along dimension {self._shard_dimension}") 241 dims[self._shard_dimension] //= self._number_of_shards 242 return tensor_shape.TensorShape(dims) 243 244 def _unshard_shape(self, shape): 245 """Return the unsharded shape that would generate a given sharded shape. 246 247 Args: 248 shape: the sharded shape to unshard 249 250 Returns: 251 The unsharded shape. 252 253 Raises: 254 ValueError: if shape is unknown or does not contain 255 self.shard_dimension 256 TypeError: if shape is not convertible to a TensorShape 257 """ 258 shape = tensor_shape.as_shape(shape) 259 if self._number_of_shards == 1: 260 # Don't do anything when there's only one shard. 261 return shape 262 ndims = shape.ndims 263 if ndims is None: 264 raise ValueError(f"Shape {shape} must be statically known.") 265 if ndims <= self._shard_dimension: 266 raise ValueError(f"Shape {shape.as_list()} does not contain " 267 f"shard_dimension {self._shard_dimension}. " 268 f"Rank is too small.") 269 dims = shape.as_list() 270 dims[self._shard_dimension] *= self._number_of_shards 271 return tensor_shape.TensorShape(dims) 272 273 def get_unsharded_shape(self, shapes): 274 """Returns the shape of an unsharded Tensor given a list of shards. 275 276 When given a list of shapes of shards, returns the shape of the 277 unsharded Tensor that would generate the shards. Sets defaults for the 278 policy if number_of_shards or shard_dimension is None. 279 280 Args: 281 shapes: The shapes of the Tensor shards to be combined. 282 283 Returns: 284 The shape of the unsharded version of the Tensor. 285 286 Raises: 287 ValueError: if shapes is not a list of length 288 self.number_of_shards; or any element of shapes is not a valid 289 shape consistent with the sharding policy; or the list of 290 shapes is not a valid sharding of a full shape. 291 TypeError: if an element of shapes is not convertible to a 292 TensorShape 293 """ 294 self._fill_default_values() 295 if len(shapes) != self.number_of_shards: 296 raise ValueError( 297 f"Shapes {shapes} is length {len(shapes)} but must be a list of " 298 f"length number_of_shards={self.number_of_shards}") 299 unsharded_shapes = [self._unshard_shape(s) for s in shapes] 300 for i in xrange(self.number_of_shards - 1): 301 if not unsharded_shapes[i].is_compatible_with( 302 unsharded_shapes[self.number_of_shards - 1]): 303 raise ValueError( 304 f"Sharded shapes {shapes} are not consistent shards of a full shape " 305 f"sharded {self.number_of_shards} ways along " 306 f"dimension {self.shard_dimension}.") 307 return unsharded_shapes[0] 308