1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Implementation of tf.sets.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21 22from tensorflow.python.framework import dtypes 23from tensorflow.python.framework import ops 24from tensorflow.python.framework import sparse_tensor 25from tensorflow.python.ops import gen_set_ops 26from tensorflow.python.util import dispatch 27from tensorflow.python.util.tf_export import tf_export 28 29 30_VALID_DTYPES = set([ 31 dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64, 32 dtypes.uint8, dtypes.uint16, dtypes.string]) 33 34 35@tf_export("sets.size", v1=["sets.size", "sets.set_size"]) 36@dispatch.add_dispatch_support 37def set_size(a, validate_indices=True): 38 """Compute number of unique elements along last dimension of `a`. 39 40 Args: 41 a: `SparseTensor`, with indices sorted in row-major order. 42 validate_indices: Whether to validate the order and range of sparse indices 43 in `a`. 44 45 Returns: 46 `int32` `Tensor` of set sizes. For `a` ranked `n`, this is a `Tensor` with 47 rank `n-1`, and the same 1st `n-1` dimensions as `a`. Each value is the 48 number of unique elements in the corresponding `[0...n-1]` dimension of `a`. 49 50 Raises: 51 TypeError: If `a` is an invalid types. 52 """ 53 a = sparse_tensor.convert_to_tensor_or_sparse_tensor(a, name="a") 54 if not isinstance(a, sparse_tensor.SparseTensor): 55 raise TypeError("Expected `SparseTensor`, got %s." % a) 56 if a.values.dtype.base_dtype not in _VALID_DTYPES: 57 raise TypeError("Invalid dtype %s." % a.values.dtype) 58 # pylint: disable=protected-access 59 return gen_set_ops.set_size( 60 a.indices, a.values, a.dense_shape, validate_indices) 61 62ops.NotDifferentiable("SetSize") 63 64 65ops.NotDifferentiable("DenseToDenseSetOperation") 66ops.NotDifferentiable("DenseToSparseSetOperation") 67ops.NotDifferentiable("SparseToSparseSetOperation") 68 69 70def _convert_to_tensors_or_sparse_tensors(a, b): 71 """Convert to tensor types, and flip order if necessary. 72 73 Args: 74 a: `Tensor` or `SparseTensor` of the same type as `b`. 75 b: `Tensor` or `SparseTensor` of the same type as `a`. 76 77 Returns: 78 Tuple of `(a, b, flipped)`, where `a` and `b` have been converted to 79 `Tensor` or `SparseTensor`, and `flipped` indicates whether the order has 80 been flipped to make it dense,sparse instead of sparse,dense (since the set 81 ops do not support the latter). 82 """ 83 a = sparse_tensor.convert_to_tensor_or_sparse_tensor(a, name="a") 84 if a.dtype.base_dtype not in _VALID_DTYPES: 85 raise TypeError("'a' invalid dtype %s." % a.dtype) 86 b = sparse_tensor.convert_to_tensor_or_sparse_tensor(b, name="b") 87 if b.dtype.base_dtype != a.dtype.base_dtype: 88 raise TypeError("Types don't match, %s vs %s." % (a.dtype, b.dtype)) 89 if (isinstance(a, sparse_tensor.SparseTensor) and 90 not isinstance(b, sparse_tensor.SparseTensor)): 91 return b, a, True 92 return a, b, False 93 94 95def _set_operation(a, b, set_operation, validate_indices=True): 96 """Compute set operation of elements in last dimension of `a` and `b`. 97 98 All but the last dimension of `a` and `b` must match. 99 100 Args: 101 a: `Tensor` or `SparseTensor` of the same type as `b`. If sparse, indices 102 must be sorted in row-major order. 103 b: `Tensor` or `SparseTensor` of the same type as `a`. Must be 104 `SparseTensor` if `a` is `SparseTensor`. If sparse, indices must be 105 sorted in row-major order. 106 set_operation: String indicating set operation. See 107 SetOperationOp::SetOperationFromContext for valid values. 108 validate_indices: Whether to validate the order and range of sparse indices 109 in `a` and `b`. 110 111 Returns: 112 A `SparseTensor` with the same rank as `a` and `b`, and all but the last 113 dimension the same. Elements along the last dimension contain the results 114 of the set operation. 115 116 Raises: 117 TypeError: If inputs are invalid types. 118 ValueError: If `a` is sparse and `b` is dense. 119 """ 120 if isinstance(a, sparse_tensor.SparseTensor): 121 if isinstance(b, sparse_tensor.SparseTensor): 122 indices, values, shape = gen_set_ops.sparse_to_sparse_set_operation( 123 a.indices, a.values, a.dense_shape, 124 b.indices, b.values, b.dense_shape, 125 set_operation, validate_indices) 126 else: 127 raise ValueError("Sparse,Dense is not supported, but Dense,Sparse is. " 128 "Please flip the order of your inputs.") 129 elif isinstance(b, sparse_tensor.SparseTensor): 130 indices, values, shape = gen_set_ops.dense_to_sparse_set_operation( 131 a, b.indices, b.values, b.dense_shape, set_operation, validate_indices) 132 else: 133 indices, values, shape = gen_set_ops.dense_to_dense_set_operation( 134 a, b, set_operation, validate_indices) 135 return sparse_tensor.SparseTensor(indices, values, shape) 136 137 138@tf_export( 139 "sets.intersection", v1=["sets.intersection", "sets.set_intersection"]) 140@dispatch.add_dispatch_support 141def set_intersection(a, b, validate_indices=True): 142 """Compute set intersection of elements in last dimension of `a` and `b`. 143 144 All but the last dimension of `a` and `b` must match. 145 146 Example: 147 148 ```python 149 import tensorflow as tf 150 import collections 151 152 # Represent the following array of sets as a sparse tensor: 153 # a = np.array([[{1, 2}, {3}], [{4}, {5, 6}]]) 154 a = collections.OrderedDict([ 155 ((0, 0, 0), 1), 156 ((0, 0, 1), 2), 157 ((0, 1, 0), 3), 158 ((1, 0, 0), 4), 159 ((1, 1, 0), 5), 160 ((1, 1, 1), 6), 161 ]) 162 a = tf.sparse.SparseTensor(list(a.keys()), list(a.values()), 163 dense_shape=[2,2,2]) 164 165 # b = np.array([[{1}, {}], [{4}, {5, 6, 7, 8}]]) 166 b = collections.OrderedDict([ 167 ((0, 0, 0), 1), 168 ((1, 0, 0), 4), 169 ((1, 1, 0), 5), 170 ((1, 1, 1), 6), 171 ((1, 1, 2), 7), 172 ((1, 1, 3), 8), 173 ]) 174 b = tf.sparse.SparseTensor(list(b.keys()), list(b.values()), 175 dense_shape=[2, 2, 4]) 176 177 # `tf.sets.intersection` is applied to each aligned pair of sets. 178 tf.sets.intersection(a, b) 179 180 # The result will be equivalent to either of: 181 # 182 # np.array([[{1}, {}], [{4}, {5, 6}]]) 183 # 184 # collections.OrderedDict([ 185 # ((0, 0, 0), 1), 186 # ((1, 0, 0), 4), 187 # ((1, 1, 0), 5), 188 # ((1, 1, 1), 6), 189 # ]) 190 ``` 191 192 Args: 193 a: `Tensor` or `SparseTensor` of the same type as `b`. If sparse, indices 194 must be sorted in row-major order. 195 b: `Tensor` or `SparseTensor` of the same type as `a`. If sparse, indices 196 must be sorted in row-major order. 197 validate_indices: Whether to validate the order and range of sparse indices 198 in `a` and `b`. 199 200 Returns: 201 A `SparseTensor` whose shape is the same rank as `a` and `b`, and all but 202 the last dimension the same. Elements along the last dimension contain the 203 intersections. 204 """ 205 a, b, _ = _convert_to_tensors_or_sparse_tensors(a, b) 206 return _set_operation(a, b, "intersection", validate_indices) 207 208 209@tf_export( 210 "sets.difference", v1=["sets.difference", "sets.set_difference"]) 211@dispatch.add_dispatch_support 212def set_difference(a, b, aminusb=True, validate_indices=True): 213 """Compute set difference of elements in last dimension of `a` and `b`. 214 215 All but the last dimension of `a` and `b` must match. 216 217 Example: 218 219 ```python 220 import tensorflow as tf 221 import collections 222 223 # Represent the following array of sets as a sparse tensor: 224 # a = np.array([[{1, 2}, {3}], [{4}, {5, 6}]]) 225 a = collections.OrderedDict([ 226 ((0, 0, 0), 1), 227 ((0, 0, 1), 2), 228 ((0, 1, 0), 3), 229 ((1, 0, 0), 4), 230 ((1, 1, 0), 5), 231 ((1, 1, 1), 6), 232 ]) 233 a = tf.sparse.SparseTensor(list(a.keys()), list(a.values()), 234 dense_shape=[2, 2, 2]) 235 236 # np.array([[{1, 3}, {2}], [{4, 5}, {5, 6, 7, 8}]]) 237 b = collections.OrderedDict([ 238 ((0, 0, 0), 1), 239 ((0, 0, 1), 3), 240 ((0, 1, 0), 2), 241 ((1, 0, 0), 4), 242 ((1, 0, 1), 5), 243 ((1, 1, 0), 5), 244 ((1, 1, 1), 6), 245 ((1, 1, 2), 7), 246 ((1, 1, 3), 8), 247 ]) 248 b = tf.sparse.SparseTensor(list(b.keys()), list(b.values()), 249 dense_shape=[2, 2, 4]) 250 251 # `set_difference` is applied to each aligned pair of sets. 252 tf.sets.difference(a, b) 253 254 # The result will be equivalent to either of: 255 # 256 # np.array([[{2}, {3}], [{}, {}]]) 257 # 258 # collections.OrderedDict([ 259 # ((0, 0, 0), 2), 260 # ((0, 1, 0), 3), 261 # ]) 262 ``` 263 264 Args: 265 a: `Tensor` or `SparseTensor` of the same type as `b`. If sparse, indices 266 must be sorted in row-major order. 267 b: `Tensor` or `SparseTensor` of the same type as `a`. If sparse, indices 268 must be sorted in row-major order. 269 aminusb: Whether to subtract `b` from `a`, vs vice versa. 270 validate_indices: Whether to validate the order and range of sparse indices 271 in `a` and `b`. 272 273 Returns: 274 A `SparseTensor` whose shape is the same rank as `a` and `b`, and all but 275 the last dimension the same. Elements along the last dimension contain the 276 differences. 277 278 Raises: 279 TypeError: If inputs are invalid types, or if `a` and `b` have 280 different types. 281 ValueError: If `a` is sparse and `b` is dense. 282 errors_impl.InvalidArgumentError: If the shapes of `a` and `b` do not 283 match in any dimension other than the last dimension. 284 """ 285 a, b, flipped = _convert_to_tensors_or_sparse_tensors(a, b) 286 if flipped: 287 aminusb = not aminusb 288 return _set_operation(a, b, "a-b" if aminusb else "b-a", validate_indices) 289 290 291@tf_export( 292 "sets.union", v1=["sets.union", "sets.set_union"]) 293@dispatch.add_dispatch_support 294def set_union(a, b, validate_indices=True): 295 """Compute set union of elements in last dimension of `a` and `b`. 296 297 All but the last dimension of `a` and `b` must match. 298 299 Example: 300 301 ```python 302 import tensorflow as tf 303 import collections 304 305 # [[{1, 2}, {3}], [{4}, {5, 6}]] 306 a = collections.OrderedDict([ 307 ((0, 0, 0), 1), 308 ((0, 0, 1), 2), 309 ((0, 1, 0), 3), 310 ((1, 0, 0), 4), 311 ((1, 1, 0), 5), 312 ((1, 1, 1), 6), 313 ]) 314 a = tf.sparse.SparseTensor(list(a.keys()), list(a.values()), 315 dense_shape=[2, 2, 2]) 316 317 # [[{1, 3}, {2}], [{4, 5}, {5, 6, 7, 8}]] 318 b = collections.OrderedDict([ 319 ((0, 0, 0), 1), 320 ((0, 0, 1), 3), 321 ((0, 1, 0), 2), 322 ((1, 0, 0), 4), 323 ((1, 0, 1), 5), 324 ((1, 1, 0), 5), 325 ((1, 1, 1), 6), 326 ((1, 1, 2), 7), 327 ((1, 1, 3), 8), 328 ]) 329 b = tf.sparse.SparseTensor(list(b.keys()), list(b.values()), 330 dense_shape=[2, 2, 4]) 331 332 # `set_union` is applied to each aligned pair of sets. 333 tf.sets.union(a, b) 334 335 # The result will be a equivalent to either of: 336 # 337 # np.array([[{1, 2, 3}, {2, 3}], [{4, 5}, {5, 6, 7, 8}]]) 338 # 339 # collections.OrderedDict([ 340 # ((0, 0, 0), 1), 341 # ((0, 0, 1), 2), 342 # ((0, 0, 2), 3), 343 # ((0, 1, 0), 2), 344 # ((0, 1, 1), 3), 345 # ((1, 0, 0), 4), 346 # ((1, 0, 1), 5), 347 # ((1, 1, 0), 5), 348 # ((1, 1, 1), 6), 349 # ((1, 1, 2), 7), 350 # ((1, 1, 3), 8), 351 # ]) 352 ``` 353 354 Args: 355 a: `Tensor` or `SparseTensor` of the same type as `b`. If sparse, indices 356 must be sorted in row-major order. 357 b: `Tensor` or `SparseTensor` of the same type as `a`. If sparse, indices 358 must be sorted in row-major order. 359 validate_indices: Whether to validate the order and range of sparse indices 360 in `a` and `b`. 361 362 Returns: 363 A `SparseTensor` whose shape is the same rank as `a` and `b`, and all but 364 the last dimension the same. Elements along the last dimension contain the 365 unions. 366 """ 367 a, b, _ = _convert_to_tensors_or_sparse_tensors(a, b) 368 return _set_operation(a, b, "union", validate_indices) 369