• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Implementation of tf.sets."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21
22from tensorflow.python.framework import dtypes
23from tensorflow.python.framework import ops
24from tensorflow.python.framework import sparse_tensor
25from tensorflow.python.ops import gen_set_ops
26from tensorflow.python.util import dispatch
27from tensorflow.python.util.tf_export import tf_export
28
29
30_VALID_DTYPES = set([
31    dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64,
32    dtypes.uint8, dtypes.uint16, dtypes.string])
33
34
35@tf_export("sets.size", v1=["sets.size", "sets.set_size"])
36@dispatch.add_dispatch_support
37def set_size(a, validate_indices=True):
38  """Compute number of unique elements along last dimension of `a`.
39
40  Args:
41    a: `SparseTensor`, with indices sorted in row-major order.
42    validate_indices: Whether to validate the order and range of sparse indices
43       in `a`.
44
45  Returns:
46    `int32` `Tensor` of set sizes. For `a` ranked `n`, this is a `Tensor` with
47    rank `n-1`, and the same 1st `n-1` dimensions as `a`. Each value is the
48    number of unique elements in the corresponding `[0...n-1]` dimension of `a`.
49
50  Raises:
51    TypeError: If `a` is an invalid types.
52  """
53  a = sparse_tensor.convert_to_tensor_or_sparse_tensor(a, name="a")
54  if not isinstance(a, sparse_tensor.SparseTensor):
55    raise TypeError("Expected `SparseTensor`, got %s." % a)
56  if a.values.dtype.base_dtype not in _VALID_DTYPES:
57    raise TypeError("Invalid dtype %s." % a.values.dtype)
58  # pylint: disable=protected-access
59  return gen_set_ops.set_size(
60      a.indices, a.values, a.dense_shape, validate_indices)
61
62ops.NotDifferentiable("SetSize")
63
64
65ops.NotDifferentiable("DenseToDenseSetOperation")
66ops.NotDifferentiable("DenseToSparseSetOperation")
67ops.NotDifferentiable("SparseToSparseSetOperation")
68
69
70def _convert_to_tensors_or_sparse_tensors(a, b):
71  """Convert to tensor types, and flip order if necessary.
72
73  Args:
74    a: `Tensor` or `SparseTensor` of the same type as `b`.
75    b: `Tensor` or `SparseTensor` of the same type as `a`.
76
77  Returns:
78    Tuple of `(a, b, flipped)`, where `a` and `b` have been converted to
79    `Tensor` or `SparseTensor`, and `flipped` indicates whether the order has
80    been flipped to make it dense,sparse instead of sparse,dense (since the set
81    ops do not support the latter).
82  """
83  a = sparse_tensor.convert_to_tensor_or_sparse_tensor(a, name="a")
84  if a.dtype.base_dtype not in _VALID_DTYPES:
85    raise TypeError("'a' invalid dtype %s." % a.dtype)
86  b = sparse_tensor.convert_to_tensor_or_sparse_tensor(b, name="b")
87  if b.dtype.base_dtype != a.dtype.base_dtype:
88    raise TypeError("Types don't match, %s vs %s." % (a.dtype, b.dtype))
89  if (isinstance(a, sparse_tensor.SparseTensor) and
90      not isinstance(b, sparse_tensor.SparseTensor)):
91    return b, a, True
92  return a, b, False
93
94
95def _set_operation(a, b, set_operation, validate_indices=True):
96  """Compute set operation of elements in last dimension of `a` and `b`.
97
98  All but the last dimension of `a` and `b` must match.
99
100  Args:
101    a: `Tensor` or `SparseTensor` of the same type as `b`. If sparse, indices
102        must be sorted in row-major order.
103    b: `Tensor` or `SparseTensor` of the same type as `a`. Must be
104        `SparseTensor` if `a` is `SparseTensor`. If sparse, indices must be
105        sorted in row-major order.
106    set_operation: String indicating set operation. See
107        SetOperationOp::SetOperationFromContext for valid values.
108    validate_indices: Whether to validate the order and range of sparse indices
109       in `a` and `b`.
110
111  Returns:
112    A `SparseTensor` with the same rank as `a` and `b`, and all but the last
113    dimension the same. Elements along the last dimension contain the results
114    of the set operation.
115
116  Raises:
117    TypeError: If inputs are invalid types.
118    ValueError: If `a` is sparse and `b` is dense.
119  """
120  if isinstance(a, sparse_tensor.SparseTensor):
121    if isinstance(b, sparse_tensor.SparseTensor):
122      indices, values, shape = gen_set_ops.sparse_to_sparse_set_operation(
123          a.indices, a.values, a.dense_shape,
124          b.indices, b.values, b.dense_shape,
125          set_operation, validate_indices)
126    else:
127      raise ValueError("Sparse,Dense is not supported, but Dense,Sparse is. "
128                       "Please flip the order of your inputs.")
129  elif isinstance(b, sparse_tensor.SparseTensor):
130    indices, values, shape = gen_set_ops.dense_to_sparse_set_operation(
131        a, b.indices, b.values, b.dense_shape, set_operation, validate_indices)
132  else:
133    indices, values, shape = gen_set_ops.dense_to_dense_set_operation(
134        a, b, set_operation, validate_indices)
135  return sparse_tensor.SparseTensor(indices, values, shape)
136
137
138@tf_export(
139    "sets.intersection", v1=["sets.intersection", "sets.set_intersection"])
140@dispatch.add_dispatch_support
141def set_intersection(a, b, validate_indices=True):
142  """Compute set intersection of elements in last dimension of `a` and `b`.
143
144  All but the last dimension of `a` and `b` must match.
145
146  Example:
147
148  ```python
149    import tensorflow as tf
150    import collections
151
152    # Represent the following array of sets as a sparse tensor:
153    # a = np.array([[{1, 2}, {3}], [{4}, {5, 6}]])
154    a = collections.OrderedDict([
155        ((0, 0, 0), 1),
156        ((0, 0, 1), 2),
157        ((0, 1, 0), 3),
158        ((1, 0, 0), 4),
159        ((1, 1, 0), 5),
160        ((1, 1, 1), 6),
161    ])
162    a = tf.sparse.SparseTensor(list(a.keys()), list(a.values()),
163                               dense_shape=[2,2,2])
164
165    # b = np.array([[{1}, {}], [{4}, {5, 6, 7, 8}]])
166    b = collections.OrderedDict([
167        ((0, 0, 0), 1),
168        ((1, 0, 0), 4),
169        ((1, 1, 0), 5),
170        ((1, 1, 1), 6),
171        ((1, 1, 2), 7),
172        ((1, 1, 3), 8),
173    ])
174    b = tf.sparse.SparseTensor(list(b.keys()), list(b.values()),
175                               dense_shape=[2, 2, 4])
176
177    # `tf.sets.intersection` is applied to each aligned pair of sets.
178    tf.sets.intersection(a, b)
179
180    # The result will be equivalent to either of:
181    #
182    # np.array([[{1}, {}], [{4}, {5, 6}]])
183    #
184    # collections.OrderedDict([
185    #     ((0, 0, 0), 1),
186    #     ((1, 0, 0), 4),
187    #     ((1, 1, 0), 5),
188    #     ((1, 1, 1), 6),
189    # ])
190  ```
191
192  Args:
193    a: `Tensor` or `SparseTensor` of the same type as `b`. If sparse, indices
194        must be sorted in row-major order.
195    b: `Tensor` or `SparseTensor` of the same type as `a`. If sparse, indices
196        must be sorted in row-major order.
197    validate_indices: Whether to validate the order and range of sparse indices
198       in `a` and `b`.
199
200  Returns:
201    A `SparseTensor` whose shape is the same rank as `a` and `b`, and all but
202    the last dimension the same. Elements along the last dimension contain the
203    intersections.
204  """
205  a, b, _ = _convert_to_tensors_or_sparse_tensors(a, b)
206  return _set_operation(a, b, "intersection", validate_indices)
207
208
209@tf_export(
210    "sets.difference", v1=["sets.difference", "sets.set_difference"])
211@dispatch.add_dispatch_support
212def set_difference(a, b, aminusb=True, validate_indices=True):
213  """Compute set difference of elements in last dimension of `a` and `b`.
214
215  All but the last dimension of `a` and `b` must match.
216
217  Example:
218
219  ```python
220    import tensorflow as tf
221    import collections
222
223    # Represent the following array of sets as a sparse tensor:
224    # a = np.array([[{1, 2}, {3}], [{4}, {5, 6}]])
225    a = collections.OrderedDict([
226        ((0, 0, 0), 1),
227        ((0, 0, 1), 2),
228        ((0, 1, 0), 3),
229        ((1, 0, 0), 4),
230        ((1, 1, 0), 5),
231        ((1, 1, 1), 6),
232    ])
233    a = tf.sparse.SparseTensor(list(a.keys()), list(a.values()),
234                               dense_shape=[2, 2, 2])
235
236    # np.array([[{1, 3}, {2}], [{4, 5}, {5, 6, 7, 8}]])
237    b = collections.OrderedDict([
238        ((0, 0, 0), 1),
239        ((0, 0, 1), 3),
240        ((0, 1, 0), 2),
241        ((1, 0, 0), 4),
242        ((1, 0, 1), 5),
243        ((1, 1, 0), 5),
244        ((1, 1, 1), 6),
245        ((1, 1, 2), 7),
246        ((1, 1, 3), 8),
247    ])
248    b = tf.sparse.SparseTensor(list(b.keys()), list(b.values()),
249                               dense_shape=[2, 2, 4])
250
251    # `set_difference` is applied to each aligned pair of sets.
252    tf.sets.difference(a, b)
253
254    # The result will be equivalent to either of:
255    #
256    # np.array([[{2}, {3}], [{}, {}]])
257    #
258    # collections.OrderedDict([
259    #     ((0, 0, 0), 2),
260    #     ((0, 1, 0), 3),
261    # ])
262  ```
263
264  Args:
265    a: `Tensor` or `SparseTensor` of the same type as `b`. If sparse, indices
266        must be sorted in row-major order.
267    b: `Tensor` or `SparseTensor` of the same type as `a`. If sparse, indices
268        must be sorted in row-major order.
269    aminusb: Whether to subtract `b` from `a`, vs vice versa.
270    validate_indices: Whether to validate the order and range of sparse indices
271       in `a` and `b`.
272
273  Returns:
274    A `SparseTensor` whose shape is the same rank as `a` and `b`, and all but
275    the last dimension the same. Elements along the last dimension contain the
276    differences.
277
278  Raises:
279    TypeError: If inputs are invalid types, or if `a` and `b` have
280        different types.
281    ValueError: If `a` is sparse and `b` is dense.
282    errors_impl.InvalidArgumentError: If the shapes of `a` and `b` do not
283        match in any dimension other than the last dimension.
284  """
285  a, b, flipped = _convert_to_tensors_or_sparse_tensors(a, b)
286  if flipped:
287    aminusb = not aminusb
288  return _set_operation(a, b, "a-b" if aminusb else "b-a", validate_indices)
289
290
291@tf_export(
292    "sets.union", v1=["sets.union", "sets.set_union"])
293@dispatch.add_dispatch_support
294def set_union(a, b, validate_indices=True):
295  """Compute set union of elements in last dimension of `a` and `b`.
296
297  All but the last dimension of `a` and `b` must match.
298
299  Example:
300
301  ```python
302    import tensorflow as tf
303    import collections
304
305    # [[{1, 2}, {3}], [{4}, {5, 6}]]
306    a = collections.OrderedDict([
307        ((0, 0, 0), 1),
308        ((0, 0, 1), 2),
309        ((0, 1, 0), 3),
310        ((1, 0, 0), 4),
311        ((1, 1, 0), 5),
312        ((1, 1, 1), 6),
313    ])
314    a = tf.sparse.SparseTensor(list(a.keys()), list(a.values()),
315                               dense_shape=[2, 2, 2])
316
317    # [[{1, 3}, {2}], [{4, 5}, {5, 6, 7, 8}]]
318    b = collections.OrderedDict([
319        ((0, 0, 0), 1),
320        ((0, 0, 1), 3),
321        ((0, 1, 0), 2),
322        ((1, 0, 0), 4),
323        ((1, 0, 1), 5),
324        ((1, 1, 0), 5),
325        ((1, 1, 1), 6),
326        ((1, 1, 2), 7),
327        ((1, 1, 3), 8),
328    ])
329    b = tf.sparse.SparseTensor(list(b.keys()), list(b.values()),
330                               dense_shape=[2, 2, 4])
331
332    # `set_union` is applied to each aligned pair of sets.
333    tf.sets.union(a, b)
334
335    # The result will be a equivalent to either of:
336    #
337    # np.array([[{1, 2, 3}, {2, 3}], [{4, 5}, {5, 6, 7, 8}]])
338    #
339    # collections.OrderedDict([
340    #     ((0, 0, 0), 1),
341    #     ((0, 0, 1), 2),
342    #     ((0, 0, 2), 3),
343    #     ((0, 1, 0), 2),
344    #     ((0, 1, 1), 3),
345    #     ((1, 0, 0), 4),
346    #     ((1, 0, 1), 5),
347    #     ((1, 1, 0), 5),
348    #     ((1, 1, 1), 6),
349    #     ((1, 1, 2), 7),
350    #     ((1, 1, 3), 8),
351    # ])
352  ```
353
354  Args:
355    a: `Tensor` or `SparseTensor` of the same type as `b`. If sparse, indices
356        must be sorted in row-major order.
357    b: `Tensor` or `SparseTensor` of the same type as `a`. If sparse, indices
358        must be sorted in row-major order.
359    validate_indices: Whether to validate the order and range of sparse indices
360       in `a` and `b`.
361
362  Returns:
363    A `SparseTensor` whose shape is the same rank as `a` and `b`, and all but
364    the last dimension the same. Elements along the last dimension contain the
365    unions.
366  """
367  a, b, _ = _convert_to_tensors_or_sparse_tensors(a, b)
368  return _set_operation(a, b, "union", validate_indices)
369