• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Sharded mutable dense hash table (deprecated).
16
17This module and all its submodules are deprecated. To UPDATE or USE linear
18optimizers, please check its latest version in core:
19tensorflow_estimator/python/estimator/canned/linear_optimizer/.
20"""
21
22from __future__ import absolute_import
23from __future__ import division
24from __future__ import print_function
25
26from six.moves import range
27
28from tensorflow.contrib import lookup
29from tensorflow.python.framework import dtypes
30from tensorflow.python.framework import ops
31from tensorflow.python.framework import tensor_shape
32from tensorflow.python.ops import array_ops
33from tensorflow.python.ops import control_flow_ops
34from tensorflow.python.ops import data_flow_ops
35from tensorflow.python.ops import math_ops
36from tensorflow.python.util import deprecation
37
38
39# TODO(rohanj): This should subclass Trackable and implement
40# _gather_saveables_for_checkpoint.
41class ShardedMutableDenseHashTable(object):
42  """A sharded version of MutableDenseHashTable.
43
44  It is designed to be interface compatible with LookupInterface and
45  MutableDenseHashTable, with the exception of the export method, which is
46  replaced by an export_sharded method.
47
48  The _ShardedMutableDenseHashTable keeps `num_shards` MutableDenseHashTable
49  internally. The shard is computed via the modulo operation on the key.
50  """
51
52  # TODO(andreasst): consider moving this to lookup module
53
54  @deprecation.deprecated(
55      None, 'This class is deprecated. To UPDATE or USE linear optimizers, '
56      'please check its latest version in core: '
57      'tensorflow_estimator/python/estimator/canned/linear_optimizer/.')
58  def __init__(self,
59               key_dtype,
60               value_dtype,
61               default_value,
62               empty_key,
63               deleted_key,
64               num_shards=1,
65               checkpoint=True,
66               name='ShardedMutableHashTable'):
67    self._key_dtype = key_dtype
68    self._value_dtype = value_dtype
69    with ops.name_scope(name, 'sharded_mutable_hash_table') as scope:
70      self._table_name = scope
71      table_shards = []
72      for i in range(num_shards):
73        table_shards.append(
74            lookup.MutableDenseHashTable(
75                key_dtype=key_dtype,
76                value_dtype=value_dtype,
77                default_value=default_value,
78                empty_key=empty_key,
79                deleted_key=deleted_key,
80                checkpoint=checkpoint,
81                name='%s-%d-of-%d' % (name, i + 1, num_shards)))
82      self._table_shards = table_shards
83      # TODO(andreasst): add a value_shape() method to LookupInterface
84      # pylint: disable=protected-access
85      self._value_shape = self._table_shards[0]._value_shape
86      # pylint: enable=protected-access
87
88  @property
89  def name(self):
90    return self._table_name
91
92  @property
93  def _num_shards(self):
94    return len(self._table_shards)
95
96  @property
97  def table_shards(self):
98    return self._table_shards
99
100  def size(self, name=None):
101    with ops.name_scope(name, 'sharded_mutable_hash_table_size'):
102      sizes = [
103          self._table_shards[i].size() for i in range(self._num_shards)
104      ]
105      return math_ops.add_n(sizes)
106
107  def _shard_indices(self, keys):
108    key_shape = keys.get_shape()
109    if key_shape.ndims > 1:
110      # If keys are a matrix (i.e. a single key is a vector), we use the first
111      # element of each key vector to determine the shard.
112      keys = array_ops.slice(keys, [0, 0], [key_shape.dims[0].value, 1])
113      keys = array_ops.reshape(keys, [-1])
114    indices = math_ops.mod(math_ops.abs(keys), self._num_shards)
115    return math_ops.cast(indices, dtypes.int32)
116
117  def _check_keys(self, keys):
118    if not keys.get_shape().is_fully_defined():
119      raise ValueError('Key shape must be fully defined, got %s.' %
120                       keys.get_shape())
121    if keys.get_shape().ndims != 1 and keys.get_shape().ndims != 2:
122      raise ValueError('Expected a vector or matrix for keys, got %s.' %
123                       keys.get_shape())
124
125  def lookup(self, keys, name=None):
126    """Looks up `keys` in a table, outputs the corresponding values."""
127    if keys.dtype.base_dtype != self._key_dtype:
128      raise TypeError('Signature mismatch. Keys must be dtype %s, got %s.' %
129                      (self._key_dtype, keys.dtype))
130    self._check_keys(keys)
131    num_shards = self._num_shards
132    if num_shards == 1:
133      return self._table_shards[0].lookup(keys, name=name)
134
135    shard_indices = self._shard_indices(keys)
136    # TODO(andreasst): support 'keys' that are not vectors
137    key_shards = data_flow_ops.dynamic_partition(keys, shard_indices,
138                                                 num_shards)
139    value_shards = [
140        self._table_shards[i].lookup(key_shards[i], name=name)
141        for i in range(num_shards)
142    ]
143
144    num_keys = keys.get_shape().dims[0]
145    original_indices = math_ops.range(num_keys)
146    partitioned_indices = data_flow_ops.dynamic_partition(original_indices,
147                                                          shard_indices,
148                                                          num_shards)
149    result = data_flow_ops.dynamic_stitch(partitioned_indices, value_shards)
150    result.set_shape(
151        tensor_shape.TensorShape([num_keys]).concatenate(self._value_shape))
152    return result
153
154  def insert(self, keys, values, name=None):
155    """Inserts `keys` in a table."""
156    self._check_keys(keys)
157    num_shards = self._num_shards
158    if num_shards == 1:
159      return self._table_shards[0].insert(keys, values, name=name)
160
161    shard_indices = self._shard_indices(keys)
162    # TODO(andreasst): support 'keys' that are not vectors
163    key_shards = data_flow_ops.dynamic_partition(keys, shard_indices,
164                                                 num_shards)
165    value_shards = data_flow_ops.dynamic_partition(values, shard_indices,
166                                                   num_shards)
167    return_values = [
168        self._table_shards[i].insert(key_shards[i], value_shards[i], name=name)
169        for i in range(num_shards)
170    ]
171
172    return control_flow_ops.group(*return_values)
173
174  def export_sharded(self, name=None):
175    """Returns lists of the keys and values tensors in the sharded table.
176
177    Args:
178      name: name of the table.
179
180    Returns:
181      A pair of lists with the first list containing the key tensors and the
182        second list containing the value tensors from each shard.
183    """
184    keys_list = []
185    values_list = []
186    for table_shard in self._table_shards:
187      exported_keys, exported_values = table_shard.export(name=name)
188      keys_list.append(exported_keys)
189      values_list.append(exported_values)
190    return keys_list, values_list
191