1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""ShardedVariable class.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import copy 21import math 22from typing import Sequence 23import numpy as np 24 25from tensorflow.python.framework import composite_tensor 26from tensorflow.python.framework import constant_op 27from tensorflow.python.framework import dtypes 28from tensorflow.python.framework import ops 29from tensorflow.python.framework import tensor_shape 30from tensorflow.python.framework import type_spec 31from tensorflow.python.ops import array_ops 32from tensorflow.python.ops import data_flow_ops 33from tensorflow.python.ops import embedding_ops 34from tensorflow.python.ops import math_ops 35from tensorflow.python.ops import partitioned_variables 36from tensorflow.python.ops import resource_variable_ops 37from tensorflow.python.ops import variables as variables_lib 38from tensorflow.python.saved_model import revived_types 39from tensorflow.python.saved_model import save_context 40from tensorflow.python.training.saving import saveable_object_util 41from tensorflow.python.training.tracking import base as trackable 42from tensorflow.python.util import dispatch 43from tensorflow.python.util.tf_export import tf_export 44 45 46@tf_export('distribute.experimental.partitioners.Partitioner', v1=[]) 47class Partitioner(object): 48 """Partitioner base class: all partitiners inherit from this class. 49 50 Partitioners should implement a `__call__` method with the following 51 signature: 52 53 ```python 54 def __call__(self, shape, dtype, axis=0): 55 # Partitions the given `shape` and returns the partition results. 56 # See docstring of `__call__` method for the format of partition results. 57 ``` 58 """ 59 60 def __call__(self, shape, dtype, axis=0): 61 """Partitions the given `shape` and returns the partition results. 62 63 Examples of a partitioner that allocates a fixed number of shards: 64 65 ```python 66 partitioner = FixedShardsPartitioner(num_shards=2) 67 partitions = partitioner(tf.TensorShape([10, 3], tf.float32), axis=0) 68 print(partitions) # [2, 0] 69 ``` 70 71 Args: 72 shape: a `tf.TensorShape`, the shape to partition. 73 dtype: a `tf.dtypes.Dtype` indicating the type of the partition value. 74 axis: The axis to partition along. Default: outermost axis. 75 76 Returns: 77 A list of integers representing the number of partitions on each axis, 78 where i-th value correponds to i-th axis. 79 """ 80 raise NotImplementedError 81 82 83@tf_export('distribute.experimental.partitioners.FixedShardsPartitioner', v1=[]) 84class FixedShardsPartitioner(Partitioner): 85 """Partitioner that allocates a fixed number of shards. 86 87 Examples: 88 89 >>> # standalone usage: 90 >>> partitioner = FixedShardsPartitioner(num_shards=2) 91 >>> partitions = partitioner(tf.TensorShape([10, 3]), tf.float32) 92 >>> [2, 1] 93 >>> 94 >>> # use in ParameterServerStrategy 95 >>> # strategy = tf.distribute.experimental.ParameterServerStrategy( 96 >>> # cluster_resolver=cluster_resolver, variable_partitioner=partitioner) 97 98 """ 99 100 def __init__(self, num_shards): 101 """Creates a new `FixedShardsPartitioner`. 102 103 Args: 104 num_shards: `int`, number of shards to partition. 105 """ 106 self._num_shards = num_shards 107 108 def __call__(self, shape, dtype, axis=0): 109 del dtype 110 result = [1] * len(shape) 111 result[axis] = min(self._num_shards, shape.dims[axis].value) 112 return result 113 114 115@tf_export('distribute.experimental.partitioners.MinSizePartitioner', v1=[]) 116class MinSizePartitioner(Partitioner): 117 """Partitioner that allocates a minimum size per shard. 118 119 This partitioner ensures each shard has at least `min_shard_bytes`, and tries 120 to allocate as many shards as possible, i.e., keeping shard size as small as 121 possible. The maximum number of such shards (upper bound) is given by 122 `max_shards`. 123 124 Examples: 125 126 >>> partitioner = MinSizePartitioner(min_shard_bytes=4, max_shards=2) 127 >>> partitions = partitioner(tf.TensorShape([6, 1]), tf.float32) 128 >>> [2, 1] 129 >>> partitioner = MinSizePartitioner(min_shard_bytes=4, max_shards=10) 130 >>> partitions = partitioner(tf.TensorShape([6, 1]), tf.float32) 131 >>> [6, 1] 132 >>> 133 >>> # use in ParameterServerStrategy 134 >>> # strategy = tf.distribute.experimental.ParameterServerStrategy( 135 >>> # cluster_resolver=cluster_resolver, variable_partitioner=partitioner) 136 """ 137 138 def __init__(self, 139 min_shard_bytes=256 << 10, 140 max_shards=1, 141 bytes_per_string=16): 142 """Creates a new `MinSizePartitioner`. 143 144 Args: 145 min_shard_bytes: Minimum bytes of each shard. Defaults to 256K. 146 max_shards: Upper bound on the number of shards. Defaults to 1. 147 bytes_per_string: If the partition value is of type string, this provides 148 an estimate of how large each string is. 149 """ 150 if min_shard_bytes < 1: 151 raise ValueError('Argument `min_shard_bytes` must be positive. ' 152 f'Received: {min_shard_bytes}') 153 if max_shards < 1: 154 raise ValueError('Argument `max_shards` must be positive. ' 155 f'Received: {max_shards}') 156 if bytes_per_string < 1: 157 raise ValueError('Argument `bytes_per_string` must be positive. ' 158 f'Received: {bytes_per_string}') 159 self._min_shard_bytes = min_shard_bytes 160 self._max_shards = max_shards 161 self._bytes_per_string = bytes_per_string 162 163 def __call__(self, shape, dtype, axis=0): 164 return partitioned_variables.min_max_variable_partitioner( 165 max_partitions=self._max_shards, 166 axis=axis, 167 min_slice_size=self._min_shard_bytes, 168 bytes_per_string_element=self._bytes_per_string)(shape, dtype) 169 170 171@tf_export('distribute.experimental.partitioners.MaxSizePartitioner', v1=[]) 172class MaxSizePartitioner(Partitioner): 173 """Partitioner that keeps shards below `max_shard_bytes`. 174 175 This partitioner ensures each shard has at most `max_shard_bytes`, and tries 176 to allocate as few shards as possible, i.e., keeping shard size as large 177 as possible. 178 179 If the partitioner hits the `max_shards` limit, then each shard may end up 180 larger than `max_shard_bytes`. By default `max_shards` equals `None` and no 181 limit on the number of shards is enforced. 182 183 Examples: 184 185 >>> partitioner = MaxSizePartitioner(max_shard_bytes=4) 186 >>> partitions = partitioner(tf.TensorShape([6, 1]), tf.float32) 187 >>> [6, 1] 188 >>> partitioner = MaxSizePartitioner(max_shard_bytes=4, max_shards=2) 189 >>> partitions = partitioner(tf.TensorShape([6, 1]), tf.float32) 190 >>> [2, 1] 191 >>> partitioner = MaxSizePartitioner(max_shard_bytes=1024) 192 >>> partitions = partitioner(tf.TensorShape([6, 1]), tf.float32) 193 >>> [1, 1] 194 >>> 195 >>> # use in ParameterServerStrategy 196 >>> # strategy = tf.distribute.experimental.ParameterServerStrategy( 197 >>> # cluster_resolver=cluster_resolver, variable_partitioner=partitioner) 198 """ 199 200 def __init__(self, max_shard_bytes, max_shards=None, bytes_per_string=16): 201 """Creates a new `MaxSizePartitioner`. 202 203 Args: 204 max_shard_bytes: The maximum size any given shard is allowed to be. 205 max_shards: The maximum number of shards in `int` created taking 206 precedence over `max_shard_bytes`. 207 bytes_per_string: If the partition value is of type string, this provides 208 an estimate of how large each string is. 209 """ 210 if max_shard_bytes < 1: 211 raise ValueError('Argument `max_shard_bytes` must be positive. ' 212 f'Received {max_shard_bytes}') 213 if max_shards and max_shards < 1: 214 raise ValueError('Argument `max_shards` must be positive. ' 215 f'Received {max_shards}') 216 if bytes_per_string < 1: 217 raise ValueError('Argument `bytes_per_string` must be positive. ' 218 f'Received: {bytes_per_string}') 219 220 self._max_shard_bytes = max_shard_bytes 221 self._max_shards = max_shards 222 self._bytes_per_string = bytes_per_string 223 224 def __call__(self, shape, dtype, axis=0): 225 return partitioned_variables.variable_axis_size_partitioner( 226 max_shard_bytes=self._max_shard_bytes, 227 max_shards=self._max_shards, 228 bytes_per_string_element=self._bytes_per_string, 229 axis=axis)(shape, dtype) 230 231 232class ShardedVariableSpec(type_spec.TypeSpec): 233 """Type specification for a `ShardedVariable`.""" 234 235 __slots__ = ['_variable_specs'] 236 237 value_type = property(lambda self: ShardedVariable) 238 239 def __init__(self, *variable_specs): 240 self._variable_specs = tuple(variable_specs) 241 242 def _serialize(self): 243 return self._variable_specs 244 245 @property 246 def _component_specs(self): 247 return self._variable_specs 248 249 def _to_components(self, value): 250 return value.variables 251 252 def _from_components(self, variables): 253 return ShardedVariable(variables) 254 255 256class ShardedVariableMixin(trackable.Trackable): 257 """Mixin for ShardedVariable.""" 258 259 # TODO(b/170877138): Remove this mixin once fixed. This mixin is required 260 # since TPUShardedVariable can't be a CompositeTensor. 261 262 def __init__(self, 263 variables: Sequence[variables_lib.Variable], 264 name='ShardedVariable'): 265 """Treats `variables` as shards of a larger Variable. 266 267 268 Example: 269 270 ``` 271 variables = [ 272 tf.Variable(..., shape=(10, 100), dtype=tf.float32), 273 tf.Variable(..., shape=(15, 100), dtype=tf.float32), 274 tf.Variable(..., shape=(5, 100), dtype=tf.float32) 275 ] 276 sharded_variable = ShardedVariableMixin(variables) 277 assert sharded_variable.shape.as_list() == [30, 100] 278 ``` 279 280 Args: 281 variables: A list of `ResourceVariable`s that comprise this sharded 282 variable. Variables should not be shared between different 283 `ShardedVariableMixin` objects. 284 name: String. Name of this container. Defaults to "ShardedVariable". 285 """ 286 super(ShardedVariableMixin, self).__init__() 287 self._variables = variables 288 self._name = name 289 290 if not isinstance(variables, Sequence) or not variables or any( 291 not isinstance(v, variables_lib.Variable) for v in variables): 292 raise TypeError('Argument `variables` should be a non-empty list of ' 293 f'`variables.Variable`s. Received {variables}') 294 295 var_dtypes = {v.dtype for v in variables} 296 if len(var_dtypes) > 1: 297 raise ValueError( 298 'All elements in argument `variables` must have the same dtype. ' 299 f'Received dtypes: {[v.dtype for v in variables]}') 300 301 first_var = variables[0] 302 self._dtype = first_var.dtype 303 304 # All variables must have the same shape for axes > 0. 305 higher_dim_shapes = {tuple(v.shape.as_list()[1:]) for v in variables} 306 if len(higher_dim_shapes) > 1: 307 raise ValueError( 308 'All elements in argument `variables` must have the same shapes ' 309 'except for the first axis. ' 310 f'Received shapes: {[v.shape for v in variables]}') 311 first_dim = sum(int(v.shape.as_list()[0]) for v in variables) 312 self._shape = tensor_shape.TensorShape([first_dim] + 313 first_var.shape.as_list()[1:]) 314 self._var_offsets = [ 315 [0 for _ in range(len(first_var.shape))] for _ in range(len(variables)) 316 ] 317 for i in range(1, len(variables)): 318 # Always partition on the first axis. Offsets on other axes are 0. 319 self._var_offsets[i][0] += ( 320 self._var_offsets[i - 1][0] + variables[i - 1].shape.as_list()[0]) 321 322 save_slice_info = [v._get_save_slice_info() for v in variables] # pylint: disable=protected-access 323 if any(slice_info is not None for slice_info in save_slice_info): 324 raise ValueError( 325 '`SaveSliceInfo` should not be set for all elements in argument ' 326 '`variables`. `ShardedVariable` will infer `SaveSliceInfo` according ' 327 'to the order of the elements `variables`. ' 328 f'Received save slice info {save_slice_info}') 329 330 # We create an uninitialized saving_variable with the full shape, which can 331 # be later captured in signatures so that the signatures can treat this 332 # ShardedVariable as one single variable. 333 self._saving_variable = resource_variable_ops.UninitializedVariable( 334 shape=self._shape, dtype=self._dtype, name=self._name) 335 336 def __iter__(self): 337 """Return an iterable for accessing the underlying sharded variables.""" 338 return iter(self._variables) 339 340 def __getitem__(self, slice_spec): 341 """Extracts the specified region as a Tensor from the sharded variable. 342 343 The API contract is identical to `Tensor.__getitem__`. Assignment to the 344 sliced range is not yet supported. 345 346 Args: 347 slice_spec: The arguments to __getitem__, specifying the global slicing of 348 the sharded variable. 349 350 Returns: 351 The appropriate slice of tensor based on `slice_spec`. 352 353 Raises: 354 IndexError: If a slice index is out of bound. 355 TypeError: If `spec_spec` contains Tensor. 356 """ 357 358 # TODO(b/177482728): Support tensor input. 359 # TODO(b/177482728): Support slice assign, similar to variable slice assign. 360 361 if (isinstance(slice_spec, bool) or (isinstance(slice_spec, ops.Tensor) and 362 slice_spec.dtype == dtypes.bool) or 363 (isinstance(slice_spec, np.ndarray) and slice_spec.dtype == bool)): 364 tensor = _var_to_tensor(self) 365 return array_ops.boolean_mask(tensor=tensor, mask=slice_spec) 366 367 if not isinstance(slice_spec, (list, tuple)): 368 slice_spec = (slice_spec,) 369 370 s = slice_spec[0] 371 if isinstance(s, slice): 372 first_dim_slice_specs = self._decompose_slice_spec(s) 373 values = [] 374 for i, var in enumerate(self._variables): 375 if first_dim_slice_specs[i] is not None: 376 all_dim_slice_spec = (first_dim_slice_specs[i],) + slice_spec[1:] 377 values.append(var[all_dim_slice_spec]) 378 if s.step is not None and s.step < 0: 379 values.reverse() 380 if not values: 381 return constant_op.constant([], 382 dtype=self._dtype, 383 shape=((0,) + self._shape[1:])) 384 return array_ops.concat(values, axis=0) 385 elif s is Ellipsis: 386 return array_ops.concat([var[slice_spec] for var in self._variables], 387 axis=0) 388 elif s is array_ops.newaxis: 389 return array_ops.concat([var[slice_spec[1:]] for var in self._variables], 390 axis=0)[array_ops.newaxis] 391 else: 392 if isinstance(s, ops.Tensor): 393 raise TypeError( 394 'ShardedVariable: using Tensor for indexing is not allowed.') 395 if s < 0: 396 s += self._shape[0] 397 if s < 0 or s >= self._shape[0]: 398 raise IndexError( 399 f'ShardedVariable: slice index {s} of dimension 0 out of bounds.') 400 for i in range(len(self._variables)): 401 if i == len(self._variables) - 1 or (s > self._var_offsets[i][0] and 402 s < self._var_offsets[i + 1][0]): 403 return self._variables[i][(s - self._var_offsets[i][0],) + 404 slice_spec[1:]] 405 406 def _decompose_slice_spec(self, slice_spec): 407 """Decompose a global slice_spec into a list of per-variable slice_spec. 408 409 `ShardedVariable` only supports first dimension partitioning, thus 410 `slice_spec` must be for first dimension. 411 412 Args: 413 slice_spec: A python `slice` object that specifies the global slicing. 414 415 Returns: 416 A list of python `slice` objects or None specifying the local slicing for 417 each component variable. None means no slicing. 418 419 For example, given component variables: 420 v0 = [0, 1, 2] 421 v1 = [3, 4, 5] 422 v2 = [6, 7, 8, 9] 423 424 If `slice_spec` is slice(start=None, stop=None, step=None), we will have: 425 v0[returned[0]] = [0, 1, 2] 426 v1[returned[1]] = [3, 4, 5] 427 v2[returned[2]] = [6, 7, 8, 9] 428 If `slice_spec` is slice(start=2, stop=8, step=3), we will have: 429 v0[returned[0]] = [2] 430 v1[returned[1]] = [5] 431 returned[2] == None 432 If `slice_spec` is slice(start=9, stop=3, step=-2), we will have: 433 returned[0] == None 434 v1[returned[1]] = [5] 435 v2[returned[2]] = [9, 7] 436 """ 437 if isinstance(slice_spec.start, ops.Tensor) or isinstance( 438 slice_spec.stop, ops.Tensor) or isinstance(slice_spec.step, ops.Tensor): 439 raise TypeError( 440 'ShardedVariable: using Tensor in slice_spec is not allowed. Please ' 441 'file a feature request with the TensorFlow team.') 442 443 result = [] 444 # Normalize start, end and stop. 445 slice_step = slice_spec.step if slice_spec.step is not None else 1 446 if slice_step == 0: 447 raise ValueError('slice step cannot be zero') 448 slice_start = slice_spec.start 449 if slice_start is None: 450 slice_start = 0 if slice_step > 0 else self._shape[0] - 1 451 elif slice_start < 0: 452 slice_start += self._shape[0] 453 slice_end = slice_spec.stop 454 if slice_end is None: 455 # After the normalization, we no longer interpret negative index, thus 456 # "-1" conceptually refers to the element before the first one, which 457 # doesn't exist. This is to ease the decomposition code. 458 slice_end = self._shape[0] if slice_step > 0 else -1 459 elif slice_end < 0: 460 slice_end += self._shape[0] 461 462 # To find the local slice_spec of each component variable, we start from 463 # the start of the global slice, and iterate through each variable. 464 # When iterating on a variable, we move the cursor (`cur`) to the first 465 # index that falls into the variable's range, which becomes the start of 466 # the variable's local slice_spec. The end of the local_spec is determined 467 # by using whatever is smaller between global slice end and variable range 468 # end. 469 cur = slice_start 470 if slice_step > 0: 471 for i in range(len(self._var_offsets)): 472 var_start = self._var_offsets[i][0] 473 var_end = ( 474 self._var_offsets[i + 1][0] 475 if i < len(self._var_offsets) - 1 else self._shape[0]) 476 if cur < var_start: 477 cur += slice_step * int(math.ceil((var_start - cur) / slice_step)) 478 if cur >= var_end or cur >= slice_end: 479 result.append(None) 480 else: 481 start = cur - var_start 482 end = min(slice_end, var_end) - var_start 483 result.append(slice(start, end, slice_step)) 484 else: # slice_step < 0 485 for i in range(len(self._var_offsets) - 1, -1, -1): 486 var_start = self._var_offsets[i][0] 487 var_end = ( 488 self._var_offsets[i + 1][0] 489 if i < len(self._var_offsets) - 1 else self._shape[0]) 490 if cur >= var_end: 491 cur += slice_step * int(math.ceil((var_end - cur - 1) / slice_step)) 492 if cur < var_start or cur <= slice_end: 493 result.append(None) 494 else: 495 start = cur - var_start 496 if slice_end >= var_start: 497 end = slice_end - var_start 498 else: 499 end = None # no explicit end: slice until hitting the boundary. 500 result.append(slice(start, end, slice_step)) 501 502 result.reverse() 503 504 return result 505 506 @property 507 def _type_spec(self): 508 return ShardedVariableSpec( 509 *(resource_variable_ops.VariableSpec(v.shape, v.dtype) 510 for v in self._variables)) 511 512 @property 513 def variables(self): 514 """The list of `Variable`s that make up the shards of this object.""" 515 if save_context.in_save_context(): 516 return [self._saving_variable] 517 return self._variables 518 519 @property 520 def name(self): 521 """The name of this object. Used for checkpointing.""" 522 return self._name 523 524 @property 525 def dtype(self): 526 """The dtype of all `Variable`s in this object.""" 527 return self._dtype 528 529 @property 530 def shape(self): 531 """The overall shape, combining all shards along axis `0`.""" 532 return self._shape 533 534 def assign(self, value, use_locking=None, name=None, read_value=True): 535 for i, v in enumerate(self._variables): 536 v.assign(array_ops.slice(value, self._var_offsets[i], v.shape.as_list())) 537 return self 538 539 def assign_add(self, delta, use_locking=False, name=None, read_value=True): 540 for i, v in enumerate(self._variables): 541 v.assign_add( 542 array_ops.slice(delta, self._var_offsets[i], v.shape.as_list())) 543 return self 544 545 def assign_sub(self, delta, use_locking=False, name=None, read_value=True): 546 for i, v in enumerate(self._variables): 547 v.assign_sub( 548 array_ops.slice(delta, self._var_offsets[i], v.shape.as_list())) 549 return self 550 551 def _decompose_indices(self, indices): 552 """Decompose a global 1D indices into a list of per-variable indices.""" 553 if indices.shape.rank != 1: 554 raise ValueError( 555 'ShardedVariable: indices must be 1D Tensor for sparse operations. ' 556 f'Received shape: {indices.shape}') 557 558 base = self._shape[0] // len(self._variables) 559 extra = self._shape[0] % len(self._variables) 560 561 # Assert that sharding conforms to "div" sharding 562 expect_first_dim = [base] * len(self._variables) 563 for i in range(extra): 564 expect_first_dim[i] = expect_first_dim[i] + 1 565 actual_first_dim = [v.shape.as_list()[0] for v in self._variables] 566 if expect_first_dim != actual_first_dim: 567 raise NotImplementedError( 568 'scater_xxx ops are not supported in ShardedVariale that does not ' 569 'conform to "div" sharding') 570 571 # For index that falls into the partition that has extra 1, assignment is 572 # `index // (base + 1)` (no less than `(indices - extra) // base`) 573 # For index that falls into the partition that doesn't has extra 1, 574 # assignment is `(indices - extra) // base` (no less than 575 # `indices // (base + 1)`) 576 # 577 # Example: 578 # base = 10, extra = 2, partitions: [0, 11), [11, 22), [22, 32) 579 # index = 10 -> partition_assigment = 0 580 # index = 22 -> partition_assiment = 2 581 partition_assignments = math_ops.maximum(indices // (base + 1), 582 (indices - extra) // base) 583 local_indices = array_ops.where(partition_assignments < extra, 584 indices % (base + 1), 585 (indices - extra) % base) 586 # For whatever reason `dynamic_partition` only supports int32 587 partition_assignments = math_ops.cast(partition_assignments, dtypes.int32) 588 per_var_indices = data_flow_ops.dynamic_partition(local_indices, 589 partition_assignments, 590 len(self._variables)) 591 592 return per_var_indices, partition_assignments 593 594 def _decompose_indexed_slices(self, indexed_slices): 595 """Decompose a global `IndexedSlices` into a list of per-variable ones.""" 596 per_var_indices, partition_assignments = self._decompose_indices( 597 indexed_slices.indices) 598 per_var_values = data_flow_ops.dynamic_partition(indexed_slices.values, 599 partition_assignments, 600 len(self._variables)) 601 602 return [ 603 ops.IndexedSlices(values=per_var_values[i], indices=per_var_indices[i]) 604 for i in range(len(self._variables)) 605 ] 606 607 # ==================== scatter ops implementations ======================== # 608 609 def scatter_add(self, sparse_delta, use_locking=False, name=None): 610 """Implements tf.Variable.scatter_add.""" 611 per_var_sparse_delta = self._decompose_indexed_slices(sparse_delta) 612 for i, v in enumerate(self._variables): 613 new_name = None 614 if name is not None: 615 new_name = '{}/part_{}'.format(name, i) 616 v.scatter_add(per_var_sparse_delta[i], name=new_name) 617 return self 618 619 def scatter_div(self, sparse_delta, use_locking=False, name=None): 620 """Implements tf.Variable.scatter_div.""" 621 per_var_sparse_delta = self._decompose_indexed_slices(sparse_delta) 622 for i, v in enumerate(self._variables): 623 new_name = None 624 if name is not None: 625 new_name = '{}/part_{}'.format(name, i) 626 v.scatter_div(per_var_sparse_delta[i], name=new_name) 627 return self 628 629 def scatter_max(self, sparse_delta, use_locking=False, name=None): 630 """Implements tf.Variable.scatter_max.""" 631 per_var_sparse_delta = self._decompose_indexed_slices(sparse_delta) 632 for i, v in enumerate(self._variables): 633 new_name = None 634 if name is not None: 635 new_name = '{}/part_{}'.format(name, i) 636 v.scatter_max(per_var_sparse_delta[i], name=new_name) 637 return self 638 639 def scatter_min(self, sparse_delta, use_locking=False, name=None): 640 """Implements tf.Variable.scatter_min.""" 641 per_var_sparse_delta = self._decompose_indexed_slices(sparse_delta) 642 for i, v in enumerate(self._variables): 643 new_name = None 644 if name is not None: 645 new_name = '{}/part_{}'.format(name, i) 646 v.scatter_min(per_var_sparse_delta[i], name=new_name) 647 return self 648 649 def scatter_mul(self, sparse_delta, use_locking=False, name=None): 650 """Implements tf.Variable.scatter_mul.""" 651 per_var_sparse_delta = self._decompose_indexed_slices(sparse_delta) 652 for i, v in enumerate(self._variables): 653 new_name = None 654 if name is not None: 655 new_name = '{}/part_{}'.format(name, i) 656 v.scatter_mul(per_var_sparse_delta[i], name=new_name) 657 return self 658 659 def scatter_sub(self, sparse_delta, use_locking=False, name=None): 660 """Implements tf.Variable.scatter_sub.""" 661 per_var_sparse_delta = self._decompose_indexed_slices(sparse_delta) 662 for i, v in enumerate(self._variables): 663 new_name = None 664 if name is not None: 665 new_name = '{}/part_{}'.format(name, i) 666 v.scatter_sub(per_var_sparse_delta[i], name=new_name) 667 return self 668 669 def scatter_update(self, sparse_delta, use_locking=False, name=None): 670 """Implements tf.Variable.scatter_update.""" 671 per_var_sparse_delta = self._decompose_indexed_slices(sparse_delta) 672 for i, v in enumerate(self._variables): 673 new_name = None 674 if name is not None: 675 new_name = '{}/part_{}'.format(name, i) 676 v.scatter_update(per_var_sparse_delta[i], name=new_name) 677 return self 678 679 def batch_scatter_update(self, sparse_delta, use_locking=False, name=None): 680 """Implements tf.Variable.batch_scatter_update.""" 681 per_var_sparse_delta = self._decompose_indexed_slices(sparse_delta) 682 for i, v in enumerate(self._variables): 683 new_name = None 684 if name is not None: 685 new_name = '{}/part_{}'.format(name, i) 686 v.batch_scatter_update(per_var_sparse_delta[i], name=new_name) 687 return self 688 689 # ================== scatter ops implementations END ====================== # 690 691 def sparse_read(self, indices, name=None): 692 """Implements tf.Variable.sparse_read.""" 693 per_var_indices, _ = self._decompose_indices(indices) 694 result = [] 695 for i, v in enumerate(self._variables): 696 new_name = None 697 if name is not None: 698 new_name = '{}/part_{}'.format(name, i) 699 result.append(v.sparse_read(per_var_indices[i], name=new_name)) 700 return array_ops.concat(result, axis=0) 701 702 def _gather_saveables_for_checkpoint(self): 703 """Return a `Saveable` for each shard. See `Trackable`.""" 704 705 def _saveable_factory(name=self.name): 706 """Creates `SaveableObject`s for this `ShardedVariable`.""" 707 saveables = [] 708 dims = len(self._variables[0].shape) 709 var_offset = [0 for _ in range(dims)] 710 for v in self._variables: 711 save_slice_info = variables_lib.Variable.SaveSliceInfo( 712 full_name=self.name, 713 full_shape=self.shape.as_list(), 714 var_offset=copy.copy(var_offset), 715 var_shape=v.shape.as_list()) 716 saveables.append( 717 saveable_object_util.ResourceVariableSaveable( 718 v, save_slice_info.spec, name)) 719 var_offset[0] += int(v.shape[0]) 720 return saveables 721 722 return {trackable.VARIABLE_VALUE_KEY: _saveable_factory} 723 724 def _map_resources(self, save_options): 725 """For implementing `Trackable`.""" 726 obj_map, resource_map = {}, {} 727 for v in self._variables + [self._saving_variable]: 728 v_obj_map, v_resource_map = v._map_resources(save_options) # pylint:disable=protected-access 729 obj_map.update(v_obj_map) 730 resource_map.update(v_resource_map) 731 obj_map[self] = ShardedVariable([obj_map[self._saving_variable]], 732 name=self.name) 733 734 return obj_map, resource_map 735 736 737class ShardedVariable(ShardedVariableMixin, composite_tensor.CompositeTensor): 738 """A container for `Variables` that should be treated as shards. 739 740 Variables that are too large to fit on a single device (e.g., large 741 embeddings) 742 may need to be sharded over multiple devices. This class maintains a list of 743 smaller variables that can be independently stored on separate devices (eg, 744 multiple parameter servers), and saves and restores those variables as if they 745 were a single larger variable. 746 747 Objects of this class can be saved with a given number of shards and then 748 restored from a checkpoint into a different number of shards. 749 750 Objects of this class can be saved to SavedModel format using 751 `tf.saved_model.save`. The SavedModel can be used by programs like TF serving 752 APIs. It is not yet supported to load the SavedModel with 753 `tf.saved_model.load`. 754 755 Since `ShardedVariable` can be saved and then restored to different number of 756 shards depending on the restore environments, for example, TF serving APIs 757 would restore to one shard for serving efficiency, when using 758 `ShardedVariable` in a tf.function, one should generally not assume it has the 759 same number of shards across save and load. 760 761 Sharding is only supported along the first dimension. 762 763 >>> class Model(tf.Module): 764 ... def __init__(self): 765 ... self.sharded_variable = ShardedVariable([ 766 ... tf.Variable([3.0], dtype=tf.float32), 767 ... tf.Variable([2.0], dtype=tf.float32) 768 ... ]) 769 ... 770 ... @tf.function(input_signature=[tf.TensorSpec([], dtype=tf.int32)]) 771 ... def fn(self, x): 772 ... return tf.nn.embedding_lookup(self.sharded_variable.variables, x) 773 ... 774 ... @tf.function(input_signature=[tf.TensorSpec([], dtype=tf.int32)]) 775 ... def serve_fn(self, x): 776 ... return tf.nn.embedding_lookup(self.sharded_variable.variables, x) 777 >>> 778 >>> model = Model() 779 >>> model.fn(1).numpy() 780 2.0 781 >>> tf.saved_model.save(model, export_dir='/tmp/saved_model', 782 ... signatures=model.serve_fn) 783 """ 784 785 @property 786 def _type_spec(self): 787 return ShardedVariableSpec( 788 *(resource_variable_ops.VariableSpec(v.shape, v.dtype) 789 for v in self._variables)) 790 791 @classmethod 792 def _overload_all_operators(cls): 793 """Register overloads for all operators.""" 794 for operator in ops.Tensor.OVERLOADABLE_OPERATORS: 795 if operator == '__getitem__': 796 continue 797 798 cls._overload_operator(operator) 799 800 @classmethod 801 def _overload_operator(cls, operator): 802 """Delegate an operator overload to `ops.Tensor`.""" 803 tensor_operator = getattr(ops.Tensor, operator) 804 805 def _operator(v, *args, **kwargs): 806 return tensor_operator(_var_to_tensor(v), *args, **kwargs) 807 808 setattr(cls, operator, _operator) 809 810 811def _var_to_tensor(var, dtype=None, name=None, as_ref=False): 812 """Converts a `ShardedVariable` to a `Tensor`.""" 813 del name 814 if dtype is not None and not dtype.is_compatible_with(var.dtype): 815 raise ValueError( 816 'Incompatible type conversion requested to type {!r} for variable ' 817 'of type {!r}'.format(dtype.name, var.dtype.name)) 818 if as_ref: 819 raise NotImplementedError( 820 "ShardedVariable doesn't support being used as a reference.") 821 # We use op dispatch mechanism to override embedding_lookup ops when called 822 # with ShardedVariable. This requires embedding_lookup ops to raise TypeError 823 # when called with ShardedVariable. However since ShardedVariable can be 824 # converted to a tensor via concat, embedding_lookup ops would silently 825 # do the convertion and never raise a TypeError. To be able to properly 826 # raise a TypeError, namescope is used to detect if this method is called 827 # within a embedding_lookup op. 828 # NOTE: This doesn't work in eager mode since op namescope is always cleared 829 # in eager. This also breaks if user sets the name of embedding_lookup op 830 # with something that doesn't contain str "embedding_lookup". 831 # 832 # TODO(chenkai): Find a more robust way to do this, which should not rely 833 # on namescope. 834 if 'embedding_lookup' in ops.get_name_scope(): 835 raise TypeError('Converting ShardedVariable to tensor in embedding lookup' 836 ' ops is disallowed.') 837 return array_ops.concat(var.variables, axis=0) 838 839 840# Register a conversion function which reads the value of the variable, 841# allowing instances of the class to be used as tensors. 842ops.register_tensor_conversion_function(ShardedVariable, _var_to_tensor) 843 844ShardedVariable._overload_all_operators() # pylint: disable=protected-access 845 846 847# Override the behavior of embedding_lookup(sharded_variable, ...) 848@dispatch.dispatch_for_types(embedding_ops.embedding_lookup, ShardedVariable) 849def embedding_lookup(params, 850 ids, 851 partition_strategy='mod', 852 name=None, 853 validate_indices=True, 854 max_norm=None): 855 if isinstance(params, list): 856 params = params[0] 857 return embedding_ops.embedding_lookup(params.variables, ids, 858 partition_strategy, name, 859 validate_indices, max_norm) 860 861 862def _raise_when_load(_): 863 # We don't have serialization and deserialization mechanisms for 864 # `ShardedVariable` in 2.x style save/load yet. 865 raise ValueError( 866 'Loading a saved_model containing ShardedVariable via ' 867 '`tf.saved_model.load` is not supported. If the model is built using ' 868 'Keras, please use `tf.keras.models.load_model` instead.') 869 870 871revived_types.register_revived_type( 872 '_tf_distribute_sharded_variable', 873 lambda obj: isinstance(obj, ShardedVariable), 874 versions=[ 875 revived_types.VersionedTypeRegistration( 876 object_factory=_raise_when_load, 877 version=0, 878 min_producer_version=0, 879 min_consumer_version=0) 880 ]) 881