• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""ShardedVariable class."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import copy
21import math
22from typing import Sequence
23import numpy as np
24
25from tensorflow.python.framework import composite_tensor
26from tensorflow.python.framework import constant_op
27from tensorflow.python.framework import dtypes
28from tensorflow.python.framework import ops
29from tensorflow.python.framework import tensor_shape
30from tensorflow.python.framework import type_spec
31from tensorflow.python.ops import array_ops
32from tensorflow.python.ops import data_flow_ops
33from tensorflow.python.ops import embedding_ops
34from tensorflow.python.ops import math_ops
35from tensorflow.python.ops import partitioned_variables
36from tensorflow.python.ops import resource_variable_ops
37from tensorflow.python.ops import variables as variables_lib
38from tensorflow.python.saved_model import revived_types
39from tensorflow.python.saved_model import save_context
40from tensorflow.python.training.saving import saveable_object_util
41from tensorflow.python.training.tracking import base as trackable
42from tensorflow.python.util import dispatch
43from tensorflow.python.util.tf_export import tf_export
44
45
46@tf_export('distribute.experimental.partitioners.Partitioner', v1=[])
47class Partitioner(object):
48  """Partitioner base class: all partitiners inherit from this class.
49
50  Partitioners should implement a `__call__` method with the following
51  signature:
52
53  ```python
54  def __call__(self, shape, dtype, axis=0):
55    # Partitions the given `shape` and returns the partition results.
56    # See docstring of `__call__` method for the format of partition results.
57  ```
58  """
59
60  def __call__(self, shape, dtype, axis=0):
61    """Partitions the given `shape` and returns the partition results.
62
63    Examples of a partitioner that allocates a fixed number of shards:
64
65    ```python
66    partitioner = FixedShardsPartitioner(num_shards=2)
67    partitions = partitioner(tf.TensorShape([10, 3], tf.float32), axis=0)
68    print(partitions) # [2, 0]
69    ```
70
71    Args:
72      shape: a `tf.TensorShape`, the shape to partition.
73      dtype: a `tf.dtypes.Dtype` indicating the type of the partition value.
74      axis: The axis to partition along.  Default: outermost axis.
75
76    Returns:
77      A list of integers representing the number of partitions on each axis,
78      where i-th value correponds to i-th axis.
79    """
80    raise NotImplementedError
81
82
83@tf_export('distribute.experimental.partitioners.FixedShardsPartitioner', v1=[])
84class FixedShardsPartitioner(Partitioner):
85  """Partitioner that allocates a fixed number of shards.
86
87  Examples:
88
89  >>> # standalone usage:
90  >>> partitioner = FixedShardsPartitioner(num_shards=2)
91  >>> partitions = partitioner(tf.TensorShape([10, 3]), tf.float32)
92  >>> [2, 1]
93  >>>
94  >>> # use in ParameterServerStrategy
95  >>> # strategy = tf.distribute.experimental.ParameterServerStrategy(
96  >>> #   cluster_resolver=cluster_resolver, variable_partitioner=partitioner)
97
98  """
99
100  def __init__(self, num_shards):
101    """Creates a new `FixedShardsPartitioner`.
102
103    Args:
104      num_shards: `int`, number of shards to partition.
105    """
106    self._num_shards = num_shards
107
108  def __call__(self, shape, dtype, axis=0):
109    del dtype
110    result = [1] * len(shape)
111    result[axis] = min(self._num_shards, shape.dims[axis].value)
112    return result
113
114
115@tf_export('distribute.experimental.partitioners.MinSizePartitioner', v1=[])
116class MinSizePartitioner(Partitioner):
117  """Partitioner that allocates a minimum size per shard.
118
119  This partitioner ensures each shard has at least `min_shard_bytes`, and tries
120  to allocate as many shards as possible, i.e., keeping shard size as small as
121  possible. The maximum number of such shards (upper bound) is given by
122  `max_shards`.
123
124  Examples:
125
126  >>> partitioner = MinSizePartitioner(min_shard_bytes=4, max_shards=2)
127  >>> partitions = partitioner(tf.TensorShape([6, 1]), tf.float32)
128  >>> [2, 1]
129  >>> partitioner = MinSizePartitioner(min_shard_bytes=4, max_shards=10)
130  >>> partitions = partitioner(tf.TensorShape([6, 1]), tf.float32)
131  >>> [6, 1]
132  >>>
133  >>> # use in ParameterServerStrategy
134  >>> # strategy = tf.distribute.experimental.ParameterServerStrategy(
135  >>> #   cluster_resolver=cluster_resolver, variable_partitioner=partitioner)
136  """
137
138  def __init__(self,
139               min_shard_bytes=256 << 10,
140               max_shards=1,
141               bytes_per_string=16):
142    """Creates a new `MinSizePartitioner`.
143
144    Args:
145      min_shard_bytes: Minimum bytes of each shard. Defaults to 256K.
146      max_shards: Upper bound on the number of shards. Defaults to 1.
147      bytes_per_string: If the partition value is of type string, this provides
148        an estimate of how large each string is.
149    """
150    if min_shard_bytes < 1:
151      raise ValueError('Argument `min_shard_bytes` must be positive. '
152                       f'Received: {min_shard_bytes}')
153    if max_shards < 1:
154      raise ValueError('Argument `max_shards` must be positive. '
155                       f'Received: {max_shards}')
156    if bytes_per_string < 1:
157      raise ValueError('Argument `bytes_per_string` must be positive. '
158                       f'Received: {bytes_per_string}')
159    self._min_shard_bytes = min_shard_bytes
160    self._max_shards = max_shards
161    self._bytes_per_string = bytes_per_string
162
163  def __call__(self, shape, dtype, axis=0):
164    return partitioned_variables.min_max_variable_partitioner(
165        max_partitions=self._max_shards,
166        axis=axis,
167        min_slice_size=self._min_shard_bytes,
168        bytes_per_string_element=self._bytes_per_string)(shape, dtype)
169
170
171@tf_export('distribute.experimental.partitioners.MaxSizePartitioner', v1=[])
172class MaxSizePartitioner(Partitioner):
173  """Partitioner that keeps shards below `max_shard_bytes`.
174
175  This partitioner ensures each shard has at most `max_shard_bytes`, and tries
176  to allocate as few shards as possible, i.e., keeping shard size as large
177  as possible.
178
179  If the partitioner hits the `max_shards` limit, then each shard may end up
180  larger than `max_shard_bytes`. By default `max_shards` equals `None` and no
181  limit on the number of shards is enforced.
182
183  Examples:
184
185  >>> partitioner = MaxSizePartitioner(max_shard_bytes=4)
186  >>> partitions = partitioner(tf.TensorShape([6, 1]), tf.float32)
187  >>> [6, 1]
188  >>> partitioner = MaxSizePartitioner(max_shard_bytes=4, max_shards=2)
189  >>> partitions = partitioner(tf.TensorShape([6, 1]), tf.float32)
190  >>> [2, 1]
191  >>> partitioner = MaxSizePartitioner(max_shard_bytes=1024)
192  >>> partitions = partitioner(tf.TensorShape([6, 1]), tf.float32)
193  >>> [1, 1]
194  >>>
195  >>> # use in ParameterServerStrategy
196  >>> # strategy = tf.distribute.experimental.ParameterServerStrategy(
197  >>> #   cluster_resolver=cluster_resolver, variable_partitioner=partitioner)
198  """
199
200  def __init__(self, max_shard_bytes, max_shards=None, bytes_per_string=16):
201    """Creates a new `MaxSizePartitioner`.
202
203    Args:
204      max_shard_bytes: The maximum size any given shard is allowed to be.
205      max_shards: The maximum number of shards in `int` created taking
206        precedence over `max_shard_bytes`.
207      bytes_per_string: If the partition value is of type string, this provides
208        an estimate of how large each string is.
209    """
210    if max_shard_bytes < 1:
211      raise ValueError('Argument `max_shard_bytes` must be positive. '
212                       f'Received {max_shard_bytes}')
213    if max_shards and max_shards < 1:
214      raise ValueError('Argument `max_shards` must be positive. '
215                       f'Received {max_shards}')
216    if bytes_per_string < 1:
217      raise ValueError('Argument `bytes_per_string` must be positive. '
218                       f'Received: {bytes_per_string}')
219
220    self._max_shard_bytes = max_shard_bytes
221    self._max_shards = max_shards
222    self._bytes_per_string = bytes_per_string
223
224  def __call__(self, shape, dtype, axis=0):
225    return partitioned_variables.variable_axis_size_partitioner(
226        max_shard_bytes=self._max_shard_bytes,
227        max_shards=self._max_shards,
228        bytes_per_string_element=self._bytes_per_string,
229        axis=axis)(shape, dtype)
230
231
232class ShardedVariableSpec(type_spec.TypeSpec):
233  """Type specification for a `ShardedVariable`."""
234
235  __slots__ = ['_variable_specs']
236
237  value_type = property(lambda self: ShardedVariable)
238
239  def __init__(self, *variable_specs):
240    self._variable_specs = tuple(variable_specs)
241
242  def _serialize(self):
243    return self._variable_specs
244
245  @property
246  def _component_specs(self):
247    return self._variable_specs
248
249  def _to_components(self, value):
250    return value.variables
251
252  def _from_components(self, variables):
253    return ShardedVariable(variables)
254
255
256class ShardedVariableMixin(trackable.Trackable):
257  """Mixin for ShardedVariable."""
258
259  # TODO(b/170877138): Remove this mixin once fixed. This mixin is required
260  # since TPUShardedVariable can't be a CompositeTensor.
261
262  def __init__(self,
263               variables: Sequence[variables_lib.Variable],
264               name='ShardedVariable'):
265    """Treats `variables` as shards of a larger Variable.
266
267
268    Example:
269
270    ```
271    variables = [
272      tf.Variable(..., shape=(10, 100), dtype=tf.float32),
273      tf.Variable(..., shape=(15, 100), dtype=tf.float32),
274      tf.Variable(..., shape=(5, 100), dtype=tf.float32)
275    ]
276    sharded_variable = ShardedVariableMixin(variables)
277    assert sharded_variable.shape.as_list() == [30, 100]
278    ```
279
280    Args:
281      variables: A list of `ResourceVariable`s that comprise this sharded
282        variable. Variables should not be shared between different
283        `ShardedVariableMixin` objects.
284      name: String. Name of this container. Defaults to "ShardedVariable".
285    """
286    super(ShardedVariableMixin, self).__init__()
287    self._variables = variables
288    self._name = name
289
290    if not isinstance(variables, Sequence) or not variables or any(
291        not isinstance(v, variables_lib.Variable) for v in variables):
292      raise TypeError('Argument `variables` should be a non-empty list of '
293                      f'`variables.Variable`s. Received {variables}')
294
295    var_dtypes = {v.dtype for v in variables}
296    if len(var_dtypes) > 1:
297      raise ValueError(
298          'All elements in argument `variables` must have the same dtype. '
299          f'Received dtypes: {[v.dtype for v in variables]}')
300
301    first_var = variables[0]
302    self._dtype = first_var.dtype
303
304    # All variables must have the same shape for axes > 0.
305    higher_dim_shapes = {tuple(v.shape.as_list()[1:]) for v in variables}
306    if len(higher_dim_shapes) > 1:
307      raise ValueError(
308          'All elements in argument `variables` must have the same shapes '
309          'except for the first axis. '
310          f'Received shapes: {[v.shape for v in variables]}')
311    first_dim = sum(int(v.shape.as_list()[0]) for v in variables)
312    self._shape = tensor_shape.TensorShape([first_dim] +
313                                           first_var.shape.as_list()[1:])
314    self._var_offsets = [
315        [0 for _ in range(len(first_var.shape))] for _ in range(len(variables))
316    ]
317    for i in range(1, len(variables)):
318      # Always partition on the first axis. Offsets on other axes are 0.
319      self._var_offsets[i][0] += (
320          self._var_offsets[i - 1][0] + variables[i - 1].shape.as_list()[0])
321
322    save_slice_info = [v._get_save_slice_info() for v in variables]  # pylint: disable=protected-access
323    if any(slice_info is not None for slice_info in save_slice_info):
324      raise ValueError(
325          '`SaveSliceInfo` should not be set for all elements in argument '
326          '`variables`. `ShardedVariable` will infer `SaveSliceInfo` according '
327          'to the order of the elements `variables`. '
328          f'Received save slice info {save_slice_info}')
329
330    # We create an uninitialized saving_variable with the full shape, which can
331    # be later captured in signatures so that the signatures can treat this
332    # ShardedVariable as one single variable.
333    self._saving_variable = resource_variable_ops.UninitializedVariable(
334        shape=self._shape, dtype=self._dtype, name=self._name)
335
336  def __iter__(self):
337    """Return an iterable for accessing the underlying sharded variables."""
338    return iter(self._variables)
339
340  def __getitem__(self, slice_spec):
341    """Extracts the specified region as a Tensor from the sharded variable.
342
343    The API contract is identical to `Tensor.__getitem__`. Assignment to the
344    sliced range is not yet supported.
345
346    Args:
347      slice_spec: The arguments to __getitem__, specifying the global slicing of
348        the sharded variable.
349
350    Returns:
351      The appropriate slice of tensor based on `slice_spec`.
352
353    Raises:
354      IndexError: If a slice index is out of bound.
355      TypeError: If `spec_spec` contains Tensor.
356    """
357
358    # TODO(b/177482728): Support tensor input.
359    # TODO(b/177482728): Support slice assign, similar to variable slice assign.
360
361    if (isinstance(slice_spec, bool) or (isinstance(slice_spec, ops.Tensor) and
362                                         slice_spec.dtype == dtypes.bool) or
363        (isinstance(slice_spec, np.ndarray) and slice_spec.dtype == bool)):
364      tensor = _var_to_tensor(self)
365      return array_ops.boolean_mask(tensor=tensor, mask=slice_spec)
366
367    if not isinstance(slice_spec, (list, tuple)):
368      slice_spec = (slice_spec,)
369
370    s = slice_spec[0]
371    if isinstance(s, slice):
372      first_dim_slice_specs = self._decompose_slice_spec(s)
373      values = []
374      for i, var in enumerate(self._variables):
375        if first_dim_slice_specs[i] is not None:
376          all_dim_slice_spec = (first_dim_slice_specs[i],) + slice_spec[1:]
377          values.append(var[all_dim_slice_spec])
378      if s.step is not None and s.step < 0:
379        values.reverse()
380      if not values:
381        return constant_op.constant([],
382                                    dtype=self._dtype,
383                                    shape=((0,) + self._shape[1:]))
384      return array_ops.concat(values, axis=0)
385    elif s is Ellipsis:
386      return array_ops.concat([var[slice_spec] for var in self._variables],
387                              axis=0)
388    elif s is array_ops.newaxis:
389      return array_ops.concat([var[slice_spec[1:]] for var in self._variables],
390                              axis=0)[array_ops.newaxis]
391    else:
392      if isinstance(s, ops.Tensor):
393        raise TypeError(
394            'ShardedVariable: using Tensor for indexing is not allowed.')
395      if s < 0:
396        s += self._shape[0]
397      if s < 0 or s >= self._shape[0]:
398        raise IndexError(
399            f'ShardedVariable: slice index {s} of dimension 0 out of bounds.')
400      for i in range(len(self._variables)):
401        if i == len(self._variables) - 1 or (s > self._var_offsets[i][0] and
402                                             s < self._var_offsets[i + 1][0]):
403          return self._variables[i][(s - self._var_offsets[i][0],) +
404                                    slice_spec[1:]]
405
406  def _decompose_slice_spec(self, slice_spec):
407    """Decompose a global slice_spec into a list of per-variable slice_spec.
408
409    `ShardedVariable` only supports first dimension partitioning, thus
410    `slice_spec` must be for first dimension.
411
412    Args:
413      slice_spec: A python `slice` object that specifies the global slicing.
414
415    Returns:
416      A list of python `slice` objects or None specifying the local slicing for
417      each component variable. None means no slicing.
418
419    For example, given component variables:
420      v0 = [0, 1, 2]
421      v1 = [3, 4, 5]
422      v2 = [6, 7, 8, 9]
423
424    If `slice_spec` is slice(start=None, stop=None, step=None), we will have:
425      v0[returned[0]] = [0, 1, 2]
426      v1[returned[1]] = [3, 4, 5]
427      v2[returned[2]] = [6, 7, 8, 9]
428    If `slice_spec` is slice(start=2, stop=8, step=3), we will have:
429      v0[returned[0]] = [2]
430      v1[returned[1]] = [5]
431      returned[2] == None
432    If `slice_spec` is slice(start=9, stop=3, step=-2), we will have:
433      returned[0] == None
434      v1[returned[1]] = [5]
435      v2[returned[2]] = [9, 7]
436    """
437    if isinstance(slice_spec.start, ops.Tensor) or isinstance(
438        slice_spec.stop, ops.Tensor) or isinstance(slice_spec.step, ops.Tensor):
439      raise TypeError(
440          'ShardedVariable: using Tensor in slice_spec is not allowed. Please '
441          'file a feature request with the TensorFlow team.')
442
443    result = []
444    # Normalize start, end and stop.
445    slice_step = slice_spec.step if slice_spec.step is not None else 1
446    if slice_step == 0:
447      raise ValueError('slice step cannot be zero')
448    slice_start = slice_spec.start
449    if slice_start is None:
450      slice_start = 0 if slice_step > 0 else self._shape[0] - 1
451    elif slice_start < 0:
452      slice_start += self._shape[0]
453    slice_end = slice_spec.stop
454    if slice_end is None:
455      # After the normalization, we no longer interpret negative index, thus
456      # "-1" conceptually refers to the element before the first one, which
457      # doesn't exist. This is to ease the decomposition code.
458      slice_end = self._shape[0] if slice_step > 0 else -1
459    elif slice_end < 0:
460      slice_end += self._shape[0]
461
462    # To find the local slice_spec of each component variable, we start from
463    # the start of the global slice, and iterate through each variable.
464    # When iterating on a variable, we move the cursor (`cur`) to the first
465    # index that falls into the variable's range, which becomes the start of
466    # the variable's local slice_spec. The end of the local_spec is determined
467    # by using whatever is smaller between global slice end and variable range
468    # end.
469    cur = slice_start
470    if slice_step > 0:
471      for i in range(len(self._var_offsets)):
472        var_start = self._var_offsets[i][0]
473        var_end = (
474            self._var_offsets[i + 1][0]
475            if i < len(self._var_offsets) - 1 else self._shape[0])
476        if cur < var_start:
477          cur += slice_step * int(math.ceil((var_start - cur) / slice_step))
478        if cur >= var_end or cur >= slice_end:
479          result.append(None)
480        else:
481          start = cur - var_start
482          end = min(slice_end, var_end) - var_start
483          result.append(slice(start, end, slice_step))
484    else:  # slice_step < 0
485      for i in range(len(self._var_offsets) - 1, -1, -1):
486        var_start = self._var_offsets[i][0]
487        var_end = (
488            self._var_offsets[i + 1][0]
489            if i < len(self._var_offsets) - 1 else self._shape[0])
490        if cur >= var_end:
491          cur += slice_step * int(math.ceil((var_end - cur - 1) / slice_step))
492        if cur < var_start or cur <= slice_end:
493          result.append(None)
494        else:
495          start = cur - var_start
496          if slice_end >= var_start:
497            end = slice_end - var_start
498          else:
499            end = None  # no explicit end: slice until hitting the boundary.
500          result.append(slice(start, end, slice_step))
501
502      result.reverse()
503
504    return result
505
506  @property
507  def _type_spec(self):
508    return ShardedVariableSpec(
509        *(resource_variable_ops.VariableSpec(v.shape, v.dtype)
510          for v in self._variables))
511
512  @property
513  def variables(self):
514    """The list of `Variable`s that make up the shards of this object."""
515    if save_context.in_save_context():
516      return [self._saving_variable]
517    return self._variables
518
519  @property
520  def name(self):
521    """The name of this object. Used for checkpointing."""
522    return self._name
523
524  @property
525  def dtype(self):
526    """The dtype of all `Variable`s in this object."""
527    return self._dtype
528
529  @property
530  def shape(self):
531    """The overall shape, combining all shards along axis `0`."""
532    return self._shape
533
534  def assign(self, value, use_locking=None, name=None, read_value=True):
535    for i, v in enumerate(self._variables):
536      v.assign(array_ops.slice(value, self._var_offsets[i], v.shape.as_list()))
537    return self
538
539  def assign_add(self, delta, use_locking=False, name=None, read_value=True):
540    for i, v in enumerate(self._variables):
541      v.assign_add(
542          array_ops.slice(delta, self._var_offsets[i], v.shape.as_list()))
543    return self
544
545  def assign_sub(self, delta, use_locking=False, name=None, read_value=True):
546    for i, v in enumerate(self._variables):
547      v.assign_sub(
548          array_ops.slice(delta, self._var_offsets[i], v.shape.as_list()))
549    return self
550
551  def _decompose_indices(self, indices):
552    """Decompose a global 1D indices into a list of per-variable indices."""
553    if indices.shape.rank != 1:
554      raise ValueError(
555          'ShardedVariable: indices must be 1D Tensor for sparse operations. '
556          f'Received shape: {indices.shape}')
557
558    base = self._shape[0] // len(self._variables)
559    extra = self._shape[0] % len(self._variables)
560
561    # Assert that sharding conforms to "div" sharding
562    expect_first_dim = [base] * len(self._variables)
563    for i in range(extra):
564      expect_first_dim[i] = expect_first_dim[i] + 1
565    actual_first_dim = [v.shape.as_list()[0] for v in self._variables]
566    if expect_first_dim != actual_first_dim:
567      raise NotImplementedError(
568          'scater_xxx ops are not supported in ShardedVariale that does not '
569          'conform to "div" sharding')
570
571    # For index that falls into the partition that has extra 1, assignment is
572    # `index // (base + 1)` (no less than `(indices - extra) // base`)
573    # For index that falls into the partition that doesn't has extra 1,
574    # assignment is `(indices - extra) // base` (no less than
575    # `indices // (base + 1)`)
576    #
577    # Example:
578    #   base = 10, extra = 2, partitions: [0, 11), [11, 22), [22, 32)
579    #   index = 10 -> partition_assigment = 0
580    #   index = 22 -> partition_assiment = 2
581    partition_assignments = math_ops.maximum(indices // (base + 1),
582                                             (indices - extra) // base)
583    local_indices = array_ops.where(partition_assignments < extra,
584                                    indices % (base + 1),
585                                    (indices - extra) % base)
586    # For whatever reason `dynamic_partition` only supports int32
587    partition_assignments = math_ops.cast(partition_assignments, dtypes.int32)
588    per_var_indices = data_flow_ops.dynamic_partition(local_indices,
589                                                      partition_assignments,
590                                                      len(self._variables))
591
592    return per_var_indices, partition_assignments
593
594  def _decompose_indexed_slices(self, indexed_slices):
595    """Decompose a global `IndexedSlices` into a list of per-variable ones."""
596    per_var_indices, partition_assignments = self._decompose_indices(
597        indexed_slices.indices)
598    per_var_values = data_flow_ops.dynamic_partition(indexed_slices.values,
599                                                     partition_assignments,
600                                                     len(self._variables))
601
602    return [
603        ops.IndexedSlices(values=per_var_values[i], indices=per_var_indices[i])
604        for i in range(len(self._variables))
605    ]
606
607  # ==================== scatter ops implementations ======================== #
608
609  def scatter_add(self, sparse_delta, use_locking=False, name=None):
610    """Implements tf.Variable.scatter_add."""
611    per_var_sparse_delta = self._decompose_indexed_slices(sparse_delta)
612    for i, v in enumerate(self._variables):
613      new_name = None
614      if name is not None:
615        new_name = '{}/part_{}'.format(name, i)
616      v.scatter_add(per_var_sparse_delta[i], name=new_name)
617    return self
618
619  def scatter_div(self, sparse_delta, use_locking=False, name=None):
620    """Implements tf.Variable.scatter_div."""
621    per_var_sparse_delta = self._decompose_indexed_slices(sparse_delta)
622    for i, v in enumerate(self._variables):
623      new_name = None
624      if name is not None:
625        new_name = '{}/part_{}'.format(name, i)
626      v.scatter_div(per_var_sparse_delta[i], name=new_name)
627    return self
628
629  def scatter_max(self, sparse_delta, use_locking=False, name=None):
630    """Implements tf.Variable.scatter_max."""
631    per_var_sparse_delta = self._decompose_indexed_slices(sparse_delta)
632    for i, v in enumerate(self._variables):
633      new_name = None
634      if name is not None:
635        new_name = '{}/part_{}'.format(name, i)
636      v.scatter_max(per_var_sparse_delta[i], name=new_name)
637    return self
638
639  def scatter_min(self, sparse_delta, use_locking=False, name=None):
640    """Implements tf.Variable.scatter_min."""
641    per_var_sparse_delta = self._decompose_indexed_slices(sparse_delta)
642    for i, v in enumerate(self._variables):
643      new_name = None
644      if name is not None:
645        new_name = '{}/part_{}'.format(name, i)
646      v.scatter_min(per_var_sparse_delta[i], name=new_name)
647    return self
648
649  def scatter_mul(self, sparse_delta, use_locking=False, name=None):
650    """Implements tf.Variable.scatter_mul."""
651    per_var_sparse_delta = self._decompose_indexed_slices(sparse_delta)
652    for i, v in enumerate(self._variables):
653      new_name = None
654      if name is not None:
655        new_name = '{}/part_{}'.format(name, i)
656      v.scatter_mul(per_var_sparse_delta[i], name=new_name)
657    return self
658
659  def scatter_sub(self, sparse_delta, use_locking=False, name=None):
660    """Implements tf.Variable.scatter_sub."""
661    per_var_sparse_delta = self._decompose_indexed_slices(sparse_delta)
662    for i, v in enumerate(self._variables):
663      new_name = None
664      if name is not None:
665        new_name = '{}/part_{}'.format(name, i)
666      v.scatter_sub(per_var_sparse_delta[i], name=new_name)
667    return self
668
669  def scatter_update(self, sparse_delta, use_locking=False, name=None):
670    """Implements tf.Variable.scatter_update."""
671    per_var_sparse_delta = self._decompose_indexed_slices(sparse_delta)
672    for i, v in enumerate(self._variables):
673      new_name = None
674      if name is not None:
675        new_name = '{}/part_{}'.format(name, i)
676      v.scatter_update(per_var_sparse_delta[i], name=new_name)
677    return self
678
679  def batch_scatter_update(self, sparse_delta, use_locking=False, name=None):
680    """Implements tf.Variable.batch_scatter_update."""
681    per_var_sparse_delta = self._decompose_indexed_slices(sparse_delta)
682    for i, v in enumerate(self._variables):
683      new_name = None
684      if name is not None:
685        new_name = '{}/part_{}'.format(name, i)
686      v.batch_scatter_update(per_var_sparse_delta[i], name=new_name)
687    return self
688
689  # ================== scatter ops implementations END ====================== #
690
691  def sparse_read(self, indices, name=None):
692    """Implements tf.Variable.sparse_read."""
693    per_var_indices, _ = self._decompose_indices(indices)
694    result = []
695    for i, v in enumerate(self._variables):
696      new_name = None
697      if name is not None:
698        new_name = '{}/part_{}'.format(name, i)
699      result.append(v.sparse_read(per_var_indices[i], name=new_name))
700    return array_ops.concat(result, axis=0)
701
702  def _gather_saveables_for_checkpoint(self):
703    """Return a `Saveable` for each shard. See `Trackable`."""
704
705    def _saveable_factory(name=self.name):
706      """Creates `SaveableObject`s for this `ShardedVariable`."""
707      saveables = []
708      dims = len(self._variables[0].shape)
709      var_offset = [0 for _ in range(dims)]
710      for v in self._variables:
711        save_slice_info = variables_lib.Variable.SaveSliceInfo(
712            full_name=self.name,
713            full_shape=self.shape.as_list(),
714            var_offset=copy.copy(var_offset),
715            var_shape=v.shape.as_list())
716        saveables.append(
717            saveable_object_util.ResourceVariableSaveable(
718                v, save_slice_info.spec, name))
719        var_offset[0] += int(v.shape[0])
720      return saveables
721
722    return {trackable.VARIABLE_VALUE_KEY: _saveable_factory}
723
724  def _map_resources(self, save_options):
725    """For implementing `Trackable`."""
726    obj_map, resource_map = {}, {}
727    for v in self._variables + [self._saving_variable]:
728      v_obj_map, v_resource_map = v._map_resources(save_options)  # pylint:disable=protected-access
729      obj_map.update(v_obj_map)
730      resource_map.update(v_resource_map)
731    obj_map[self] = ShardedVariable([obj_map[self._saving_variable]],
732                                    name=self.name)
733
734    return obj_map, resource_map
735
736
737class ShardedVariable(ShardedVariableMixin, composite_tensor.CompositeTensor):
738  """A container for `Variables` that should be treated as shards.
739
740  Variables that are too large to fit on a single device (e.g., large
741  embeddings)
742  may need to be sharded over multiple devices. This class maintains a list of
743  smaller variables that can be independently stored on separate devices (eg,
744  multiple parameter servers), and saves and restores those variables as if they
745  were a single larger variable.
746
747  Objects of this class can be saved with a given number of shards and then
748  restored from a checkpoint into a different number of shards.
749
750  Objects of this class can be saved to SavedModel format using
751  `tf.saved_model.save`. The SavedModel can be used by programs like TF serving
752  APIs. It is not yet supported to load the SavedModel with
753  `tf.saved_model.load`.
754
755  Since `ShardedVariable` can be saved and then restored to different number of
756  shards depending on the restore environments, for example, TF serving APIs
757  would restore to one shard for serving efficiency, when using
758  `ShardedVariable` in a tf.function, one should generally not assume it has the
759  same number of shards across save and load.
760
761  Sharding is only supported along the first dimension.
762
763  >>> class Model(tf.Module):
764  ...   def __init__(self):
765  ...     self.sharded_variable = ShardedVariable([
766  ...       tf.Variable([3.0], dtype=tf.float32),
767  ...       tf.Variable([2.0], dtype=tf.float32)
768  ...     ])
769  ...
770  ...   @tf.function(input_signature=[tf.TensorSpec([], dtype=tf.int32)])
771  ...   def fn(self, x):
772  ...     return tf.nn.embedding_lookup(self.sharded_variable.variables, x)
773  ...
774  ...   @tf.function(input_signature=[tf.TensorSpec([], dtype=tf.int32)])
775  ...   def serve_fn(self, x):
776  ...     return tf.nn.embedding_lookup(self.sharded_variable.variables, x)
777  >>>
778  >>> model = Model()
779  >>> model.fn(1).numpy()
780  2.0
781  >>> tf.saved_model.save(model, export_dir='/tmp/saved_model',
782  ...   signatures=model.serve_fn)
783  """
784
785  @property
786  def _type_spec(self):
787    return ShardedVariableSpec(
788        *(resource_variable_ops.VariableSpec(v.shape, v.dtype)
789          for v in self._variables))
790
791  @classmethod
792  def _overload_all_operators(cls):
793    """Register overloads for all operators."""
794    for operator in ops.Tensor.OVERLOADABLE_OPERATORS:
795      if operator == '__getitem__':
796        continue
797
798      cls._overload_operator(operator)
799
800  @classmethod
801  def _overload_operator(cls, operator):
802    """Delegate an operator overload to `ops.Tensor`."""
803    tensor_operator = getattr(ops.Tensor, operator)
804
805    def _operator(v, *args, **kwargs):
806      return tensor_operator(_var_to_tensor(v), *args, **kwargs)
807
808    setattr(cls, operator, _operator)
809
810
811def _var_to_tensor(var, dtype=None, name=None, as_ref=False):
812  """Converts a `ShardedVariable` to a `Tensor`."""
813  del name
814  if dtype is not None and not dtype.is_compatible_with(var.dtype):
815    raise ValueError(
816        'Incompatible type conversion requested to type {!r} for variable '
817        'of type {!r}'.format(dtype.name, var.dtype.name))
818  if as_ref:
819    raise NotImplementedError(
820        "ShardedVariable doesn't support being used as a reference.")
821  # We use op dispatch mechanism to override embedding_lookup ops when called
822  # with ShardedVariable. This requires embedding_lookup ops to raise TypeError
823  # when called with ShardedVariable. However since ShardedVariable can be
824  # converted to a tensor via concat, embedding_lookup ops would silently
825  # do the convertion and never raise a TypeError. To be able to properly
826  # raise a TypeError, namescope is used to detect if this method is called
827  # within a embedding_lookup op.
828  # NOTE: This doesn't work in eager mode since op namescope is always cleared
829  # in eager. This also breaks if user sets the name of embedding_lookup op
830  # with something that doesn't contain str "embedding_lookup".
831  #
832  # TODO(chenkai): Find a more robust way to do this, which should not rely
833  # on namescope.
834  if 'embedding_lookup' in ops.get_name_scope():
835    raise TypeError('Converting ShardedVariable to tensor in embedding lookup'
836                    ' ops is disallowed.')
837  return array_ops.concat(var.variables, axis=0)
838
839
840# Register a conversion function which reads the value of the variable,
841# allowing instances of the class to be used as tensors.
842ops.register_tensor_conversion_function(ShardedVariable, _var_to_tensor)
843
844ShardedVariable._overload_all_operators()  # pylint: disable=protected-access
845
846
847# Override the behavior of embedding_lookup(sharded_variable, ...)
848@dispatch.dispatch_for_types(embedding_ops.embedding_lookup, ShardedVariable)
849def embedding_lookup(params,
850                     ids,
851                     partition_strategy='mod',
852                     name=None,
853                     validate_indices=True,
854                     max_norm=None):
855  if isinstance(params, list):
856    params = params[0]
857  return embedding_ops.embedding_lookup(params.variables, ids,
858                                        partition_strategy, name,
859                                        validate_indices, max_norm)
860
861
862def _raise_when_load(_):
863  # We don't have serialization and deserialization mechanisms for
864  # `ShardedVariable` in 2.x style save/load yet.
865  raise ValueError(
866      'Loading a saved_model containing ShardedVariable via '
867      '`tf.saved_model.load` is not supported. If the model is built using '
868      'Keras, please use `tf.keras.models.load_model` instead.')
869
870
871revived_types.register_revived_type(
872    '_tf_distribute_sharded_variable',
873    lambda obj: isinstance(obj, ShardedVariable),
874    versions=[
875        revived_types.VersionedTypeRegistration(
876            object_factory=_raise_when_load,
877            version=0,
878            min_producer_version=0,
879            min_consumer_version=0)
880    ])
881