1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Sharded mutable dense hash table (deprecated). 16 17This module and all its submodules are deprecated. To UPDATE or USE linear 18optimizers, please check its latest version in core: 19tensorflow_estimator/python/estimator/canned/linear_optimizer/. 20""" 21 22from __future__ import absolute_import 23from __future__ import division 24from __future__ import print_function 25 26from six.moves import range 27 28from tensorflow.contrib import lookup 29from tensorflow.python.framework import dtypes 30from tensorflow.python.framework import ops 31from tensorflow.python.framework import tensor_shape 32from tensorflow.python.ops import array_ops 33from tensorflow.python.ops import control_flow_ops 34from tensorflow.python.ops import data_flow_ops 35from tensorflow.python.ops import math_ops 36from tensorflow.python.util import deprecation 37 38 39# TODO(rohanj): This should subclass Trackable and implement 40# _gather_saveables_for_checkpoint. 41class ShardedMutableDenseHashTable(object): 42 """A sharded version of MutableDenseHashTable. 43 44 It is designed to be interface compatible with LookupInterface and 45 MutableDenseHashTable, with the exception of the export method, which is 46 replaced by an export_sharded method. 47 48 The _ShardedMutableDenseHashTable keeps `num_shards` MutableDenseHashTable 49 internally. The shard is computed via the modulo operation on the key. 50 """ 51 52 # TODO(andreasst): consider moving this to lookup module 53 54 @deprecation.deprecated( 55 None, 'This class is deprecated. To UPDATE or USE linear optimizers, ' 56 'please check its latest version in core: ' 57 'tensorflow_estimator/python/estimator/canned/linear_optimizer/.') 58 def __init__(self, 59 key_dtype, 60 value_dtype, 61 default_value, 62 empty_key, 63 deleted_key, 64 num_shards=1, 65 checkpoint=True, 66 name='ShardedMutableHashTable'): 67 self._key_dtype = key_dtype 68 self._value_dtype = value_dtype 69 with ops.name_scope(name, 'sharded_mutable_hash_table') as scope: 70 self._table_name = scope 71 table_shards = [] 72 for i in range(num_shards): 73 table_shards.append( 74 lookup.MutableDenseHashTable( 75 key_dtype=key_dtype, 76 value_dtype=value_dtype, 77 default_value=default_value, 78 empty_key=empty_key, 79 deleted_key=deleted_key, 80 checkpoint=checkpoint, 81 name='%s-%d-of-%d' % (name, i + 1, num_shards))) 82 self._table_shards = table_shards 83 # TODO(andreasst): add a value_shape() method to LookupInterface 84 # pylint: disable=protected-access 85 self._value_shape = self._table_shards[0]._value_shape 86 # pylint: enable=protected-access 87 88 @property 89 def name(self): 90 return self._table_name 91 92 @property 93 def _num_shards(self): 94 return len(self._table_shards) 95 96 @property 97 def table_shards(self): 98 return self._table_shards 99 100 def size(self, name=None): 101 with ops.name_scope(name, 'sharded_mutable_hash_table_size'): 102 sizes = [ 103 self._table_shards[i].size() for i in range(self._num_shards) 104 ] 105 return math_ops.add_n(sizes) 106 107 def _shard_indices(self, keys): 108 key_shape = keys.get_shape() 109 if key_shape.ndims > 1: 110 # If keys are a matrix (i.e. a single key is a vector), we use the first 111 # element of each key vector to determine the shard. 112 keys = array_ops.slice(keys, [0, 0], [key_shape.dims[0].value, 1]) 113 keys = array_ops.reshape(keys, [-1]) 114 indices = math_ops.mod(math_ops.abs(keys), self._num_shards) 115 return math_ops.cast(indices, dtypes.int32) 116 117 def _check_keys(self, keys): 118 if not keys.get_shape().is_fully_defined(): 119 raise ValueError('Key shape must be fully defined, got %s.' % 120 keys.get_shape()) 121 if keys.get_shape().ndims != 1 and keys.get_shape().ndims != 2: 122 raise ValueError('Expected a vector or matrix for keys, got %s.' % 123 keys.get_shape()) 124 125 def lookup(self, keys, name=None): 126 """Looks up `keys` in a table, outputs the corresponding values.""" 127 if keys.dtype.base_dtype != self._key_dtype: 128 raise TypeError('Signature mismatch. Keys must be dtype %s, got %s.' % 129 (self._key_dtype, keys.dtype)) 130 self._check_keys(keys) 131 num_shards = self._num_shards 132 if num_shards == 1: 133 return self._table_shards[0].lookup(keys, name=name) 134 135 shard_indices = self._shard_indices(keys) 136 # TODO(andreasst): support 'keys' that are not vectors 137 key_shards = data_flow_ops.dynamic_partition(keys, shard_indices, 138 num_shards) 139 value_shards = [ 140 self._table_shards[i].lookup(key_shards[i], name=name) 141 for i in range(num_shards) 142 ] 143 144 num_keys = keys.get_shape().dims[0] 145 original_indices = math_ops.range(num_keys) 146 partitioned_indices = data_flow_ops.dynamic_partition(original_indices, 147 shard_indices, 148 num_shards) 149 result = data_flow_ops.dynamic_stitch(partitioned_indices, value_shards) 150 result.set_shape( 151 tensor_shape.TensorShape([num_keys]).concatenate(self._value_shape)) 152 return result 153 154 def insert(self, keys, values, name=None): 155 """Inserts `keys` in a table.""" 156 self._check_keys(keys) 157 num_shards = self._num_shards 158 if num_shards == 1: 159 return self._table_shards[0].insert(keys, values, name=name) 160 161 shard_indices = self._shard_indices(keys) 162 # TODO(andreasst): support 'keys' that are not vectors 163 key_shards = data_flow_ops.dynamic_partition(keys, shard_indices, 164 num_shards) 165 value_shards = data_flow_ops.dynamic_partition(values, shard_indices, 166 num_shards) 167 return_values = [ 168 self._table_shards[i].insert(key_shards[i], value_shards[i], name=name) 169 for i in range(num_shards) 170 ] 171 172 return control_flow_ops.group(*return_values) 173 174 def export_sharded(self, name=None): 175 """Returns lists of the keys and values tensors in the sharded table. 176 177 Args: 178 name: name of the table. 179 180 Returns: 181 A pair of lists with the first list containing the key tensors and the 182 second list containing the value tensors from each shard. 183 """ 184 keys_list = [] 185 values_list = [] 186 for table_shard in self._table_shards: 187 exported_keys, exported_values = table_shard.export(name=name) 188 keys_list.append(exported_keys) 189 values_list.append(exported_values) 190 return keys_list, values_list 191