1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Implementation of tf.sets.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21 22from tensorflow.python.framework import dtypes 23from tensorflow.python.framework import ops 24from tensorflow.python.framework import sparse_tensor 25from tensorflow.python.ops import gen_set_ops 26from tensorflow.python.util.tf_export import tf_export 27 28 29_VALID_DTYPES = set([ 30 dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64, 31 dtypes.uint8, dtypes.uint16, dtypes.string]) 32 33 34@tf_export("sets.size", v1=["sets.size", "sets.set_size"]) 35def set_size(a, validate_indices=True): 36 """Compute number of unique elements along last dimension of `a`. 37 38 Args: 39 a: `SparseTensor`, with indices sorted in row-major order. 40 validate_indices: Whether to validate the order and range of sparse indices 41 in `a`. 42 43 Returns: 44 `int32` `Tensor` of set sizes. For `a` ranked `n`, this is a `Tensor` with 45 rank `n-1`, and the same 1st `n-1` dimensions as `a`. Each value is the 46 number of unique elements in the corresponding `[0...n-1]` dimension of `a`. 47 48 Raises: 49 TypeError: If `a` is an invalid types. 50 """ 51 a = sparse_tensor.convert_to_tensor_or_sparse_tensor(a, name="a") 52 if not isinstance(a, sparse_tensor.SparseTensor): 53 raise TypeError("Expected `SparseTensor`, got %s." % a) 54 if a.values.dtype.base_dtype not in _VALID_DTYPES: 55 raise TypeError("Invalid dtype %s." % a.values.dtype) 56 # pylint: disable=protected-access 57 return gen_set_ops.set_size( 58 a.indices, a.values, a.dense_shape, validate_indices) 59 60ops.NotDifferentiable("SetSize") 61 62 63ops.NotDifferentiable("DenseToDenseSetOperation") 64ops.NotDifferentiable("DenseToSparseSetOperation") 65ops.NotDifferentiable("SparseToSparseSetOperation") 66 67 68def _convert_to_tensors_or_sparse_tensors(a, b): 69 """Convert to tensor types, and flip order if necessary. 70 71 Args: 72 a: `Tensor` or `SparseTensor` of the same type as `b`. 73 b: `Tensor` or `SparseTensor` of the same type as `a`. 74 75 Returns: 76 Tuple of `(a, b, flipped)`, where `a` and `b` have been converted to 77 `Tensor` or `SparseTensor`, and `flipped` indicates whether the order has 78 been flipped to make it dense,sparse instead of sparse,dense (since the set 79 ops do not support the latter). 80 """ 81 a = sparse_tensor.convert_to_tensor_or_sparse_tensor(a, name="a") 82 if a.dtype.base_dtype not in _VALID_DTYPES: 83 raise TypeError("'a' invalid dtype %s." % a.dtype) 84 b = sparse_tensor.convert_to_tensor_or_sparse_tensor(b, name="b") 85 if b.dtype.base_dtype != a.dtype.base_dtype: 86 raise TypeError("Types don't match, %s vs %s." % (a.dtype, b.dtype)) 87 if (isinstance(a, sparse_tensor.SparseTensor) and 88 not isinstance(b, sparse_tensor.SparseTensor)): 89 return b, a, True 90 return a, b, False 91 92 93def _set_operation(a, b, set_operation, validate_indices=True): 94 """Compute set operation of elements in last dimension of `a` and `b`. 95 96 All but the last dimension of `a` and `b` must match. 97 98 Args: 99 a: `Tensor` or `SparseTensor` of the same type as `b`. If sparse, indices 100 must be sorted in row-major order. 101 b: `Tensor` or `SparseTensor` of the same type as `a`. Must be 102 `SparseTensor` if `a` is `SparseTensor`. If sparse, indices must be 103 sorted in row-major order. 104 set_operation: String indicating set operation. See 105 SetOperationOp::SetOperationFromContext for valid values. 106 validate_indices: Whether to validate the order and range of sparse indices 107 in `a` and `b`. 108 109 Returns: 110 A `SparseTensor` with the same rank as `a` and `b`, and all but the last 111 dimension the same. Elements along the last dimension contain the results 112 of the set operation. 113 114 Raises: 115 TypeError: If inputs are invalid types. 116 ValueError: If `a` is sparse and `b` is dense. 117 """ 118 if isinstance(a, sparse_tensor.SparseTensor): 119 if isinstance(b, sparse_tensor.SparseTensor): 120 indices, values, shape = gen_set_ops.sparse_to_sparse_set_operation( 121 a.indices, a.values, a.dense_shape, 122 b.indices, b.values, b.dense_shape, 123 set_operation, validate_indices) 124 else: 125 raise ValueError("Sparse,Dense is not supported, but Dense,Sparse is. " 126 "Please flip the order of your inputs.") 127 elif isinstance(b, sparse_tensor.SparseTensor): 128 indices, values, shape = gen_set_ops.dense_to_sparse_set_operation( 129 a, b.indices, b.values, b.dense_shape, set_operation, validate_indices) 130 else: 131 indices, values, shape = gen_set_ops.dense_to_dense_set_operation( 132 a, b, set_operation, validate_indices) 133 return sparse_tensor.SparseTensor(indices, values, shape) 134 135 136@tf_export( 137 "sets.intersection", v1=["sets.intersection", "sets.set_intersection"]) 138def set_intersection(a, b, validate_indices=True): 139 """Compute set intersection of elements in last dimension of `a` and `b`. 140 141 All but the last dimension of `a` and `b` must match. 142 143 Example: 144 145 ```python 146 import tensorflow as tf 147 import collections 148 149 # Represent the following array of sets as a sparse tensor: 150 # a = np.array([[{1, 2}, {3}], [{4}, {5, 6}]]) 151 a = collections.OrderedDict([ 152 ((0, 0, 0), 1), 153 ((0, 0, 1), 2), 154 ((0, 1, 0), 3), 155 ((1, 0, 0), 4), 156 ((1, 1, 0), 5), 157 ((1, 1, 1), 6), 158 ]) 159 a = tf.SparseTensor(list(a.keys()), list(a.values()), dense_shape=[2,2,2]) 160 161 # b = np.array([[{1}, {}], [{4}, {5, 6, 7, 8}]]) 162 b = collections.OrderedDict([ 163 ((0, 0, 0), 1), 164 ((1, 0, 0), 4), 165 ((1, 1, 0), 5), 166 ((1, 1, 1), 6), 167 ((1, 1, 2), 7), 168 ((1, 1, 3), 8), 169 ]) 170 b = tf.SparseTensor(list(b.keys()), list(b.values()), dense_shape=[2, 2, 4]) 171 172 # `tf.sets.set_intersection` is applied to each aligned pair of sets. 173 tf.sets.set_intersection(a, b) 174 175 # The result will be equivalent to either of: 176 # 177 # np.array([[{1}, {}], [{4}, {5, 6}]]) 178 # 179 # collections.OrderedDict([ 180 # ((0, 0, 0), 1), 181 # ((1, 0, 0), 4), 182 # ((1, 1, 0), 5), 183 # ((1, 1, 1), 6), 184 # ]) 185 ``` 186 187 Args: 188 a: `Tensor` or `SparseTensor` of the same type as `b`. If sparse, indices 189 must be sorted in row-major order. 190 b: `Tensor` or `SparseTensor` of the same type as `a`. If sparse, indices 191 must be sorted in row-major order. 192 validate_indices: Whether to validate the order and range of sparse indices 193 in `a` and `b`. 194 195 Returns: 196 A `SparseTensor` whose shape is the same rank as `a` and `b`, and all but 197 the last dimension the same. Elements along the last dimension contain the 198 intersections. 199 """ 200 a, b, _ = _convert_to_tensors_or_sparse_tensors(a, b) 201 return _set_operation(a, b, "intersection", validate_indices) 202 203 204@tf_export( 205 "sets.difference", v1=["sets.difference", "sets.set_difference"]) 206def set_difference(a, b, aminusb=True, validate_indices=True): 207 """Compute set difference of elements in last dimension of `a` and `b`. 208 209 All but the last dimension of `a` and `b` must match. 210 211 Example: 212 213 ```python 214 import tensorflow as tf 215 import collections 216 217 # Represent the following array of sets as a sparse tensor: 218 # a = np.array([[{1, 2}, {3}], [{4}, {5, 6}]]) 219 a = collections.OrderedDict([ 220 ((0, 0, 0), 1), 221 ((0, 0, 1), 2), 222 ((0, 1, 0), 3), 223 ((1, 0, 0), 4), 224 ((1, 1, 0), 5), 225 ((1, 1, 1), 6), 226 ]) 227 a = tf.SparseTensor(list(a.keys()), list(a.values()), dense_shape=[2, 2, 2]) 228 229 # np.array([[{1, 3}, {2}], [{4, 5}, {5, 6, 7, 8}]]) 230 b = collections.OrderedDict([ 231 ((0, 0, 0), 1), 232 ((0, 0, 1), 3), 233 ((0, 1, 0), 2), 234 ((1, 0, 0), 4), 235 ((1, 0, 1), 5), 236 ((1, 1, 0), 5), 237 ((1, 1, 1), 6), 238 ((1, 1, 2), 7), 239 ((1, 1, 3), 8), 240 ]) 241 b = tf.SparseTensor(list(b.keys()), list(b.values()), dense_shape=[2, 2, 4]) 242 243 # `set_difference` is applied to each aligned pair of sets. 244 tf.sets.set_difference(a, b) 245 246 # The result will be equivalent to either of: 247 # 248 # np.array([[{2}, {3}], [{}, {}]]) 249 # 250 # collections.OrderedDict([ 251 # ((0, 0, 0), 2), 252 # ((0, 1, 0), 3), 253 # ]) 254 ``` 255 256 Args: 257 a: `Tensor` or `SparseTensor` of the same type as `b`. If sparse, indices 258 must be sorted in row-major order. 259 b: `Tensor` or `SparseTensor` of the same type as `a`. If sparse, indices 260 must be sorted in row-major order. 261 aminusb: Whether to subtract `b` from `a`, vs vice versa. 262 validate_indices: Whether to validate the order and range of sparse indices 263 in `a` and `b`. 264 265 Returns: 266 A `SparseTensor` whose shape is the same rank as `a` and `b`, and all but 267 the last dimension the same. Elements along the last dimension contain the 268 differences. 269 """ 270 a, b, flipped = _convert_to_tensors_or_sparse_tensors(a, b) 271 if flipped: 272 aminusb = not aminusb 273 return _set_operation(a, b, "a-b" if aminusb else "b-a", validate_indices) 274 275 276@tf_export( 277 "sets.union", v1=["sets.union", "sets.set_union"]) 278def set_union(a, b, validate_indices=True): 279 """Compute set union of elements in last dimension of `a` and `b`. 280 281 All but the last dimension of `a` and `b` must match. 282 283 Example: 284 285 ```python 286 import tensorflow as tf 287 import collections 288 289 # [[{1, 2}, {3}], [{4}, {5, 6}]] 290 a = collections.OrderedDict([ 291 ((0, 0, 0), 1), 292 ((0, 0, 1), 2), 293 ((0, 1, 0), 3), 294 ((1, 0, 0), 4), 295 ((1, 1, 0), 5), 296 ((1, 1, 1), 6), 297 ]) 298 a = tf.SparseTensor(list(a.keys()), list(a.values()), dense_shape=[2, 2, 2]) 299 300 # [[{1, 3}, {2}], [{4, 5}, {5, 6, 7, 8}]] 301 b = collections.OrderedDict([ 302 ((0, 0, 0), 1), 303 ((0, 0, 1), 3), 304 ((0, 1, 0), 2), 305 ((1, 0, 0), 4), 306 ((1, 0, 1), 5), 307 ((1, 1, 0), 5), 308 ((1, 1, 1), 6), 309 ((1, 1, 2), 7), 310 ((1, 1, 3), 8), 311 ]) 312 b = tf.SparseTensor(list(b.keys()), list(b.values()), dense_shape=[2, 2, 4]) 313 314 # `set_union` is applied to each aligned pair of sets. 315 tf.sets.set_union(a, b) 316 317 # The result will be a equivalent to either of: 318 # 319 # np.array([[{1, 2, 3}, {2, 3}], [{4, 5}, {5, 6, 7, 8}]]) 320 # 321 # collections.OrderedDict([ 322 # ((0, 0, 0), 1), 323 # ((0, 0, 1), 2), 324 # ((0, 0, 2), 3), 325 # ((0, 1, 0), 2), 326 # ((0, 1, 1), 3), 327 # ((1, 0, 0), 4), 328 # ((1, 0, 1), 5), 329 # ((1, 1, 0), 5), 330 # ((1, 1, 1), 6), 331 # ((1, 1, 2), 7), 332 # ((1, 1, 3), 8), 333 # ]) 334 ``` 335 336 Args: 337 a: `Tensor` or `SparseTensor` of the same type as `b`. If sparse, indices 338 must be sorted in row-major order. 339 b: `Tensor` or `SparseTensor` of the same type as `a`. If sparse, indices 340 must be sorted in row-major order. 341 validate_indices: Whether to validate the order and range of sparse indices 342 in `a` and `b`. 343 344 Returns: 345 A `SparseTensor` whose shape is the same rank as `a` and `b`, and all but 346 the last dimension the same. Elements along the last dimension contain the 347 unions. 348 """ 349 a, b, _ = _convert_to_tensors_or_sparse_tensors(a, b) 350 return _set_operation(a, b, "union", validate_indices) 351