• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Ops for operating with sets. They are not checked in
17 // to TensorFlow because we would first like to demonstrate successful
18 // end-to-end use of these ops in eval and polish the api a bit like taking two
19 // SparseTensor rather than on edense and one sparse.
20 
21 #define EIGEN_USE_THREADS
22 
23 #include <algorithm>
24 #include <numeric>
25 #include <string>
26 #include <utility>
27 #include <vector>
28 
29 #include "absl/container/btree_set.h"
30 #include "absl/container/flat_hash_set.h"
31 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
32 #include "tensorflow/core/framework/op_kernel.h"
33 #include "tensorflow/core/framework/register_types.h"
34 #include "tensorflow/core/framework/tensor.h"
35 #include "tensorflow/core/framework/tensor_util.h"
36 #include "tensorflow/core/framework/types.h"
37 #include "tensorflow/core/lib/core/status.h"
38 #include "tensorflow/core/platform/env.h"
39 #include "tensorflow/core/platform/errors.h"
40 #include "tensorflow/core/util/sparse/sparse_tensor.h"
41 
42 namespace tensorflow {
43 
44 using ShapeArray = sparse::SparseTensor::ShapeArray;
45 using VarDimArray = sparse::SparseTensor::VarDimArray;
46 
47 // Validate rank >= 2.
CheckRankAtLeast2(OpKernelContext * ctx,const TensorShape & shape)48 void CheckRankAtLeast2(OpKernelContext* ctx, const TensorShape& shape) {
49   const auto rank = shape.dims();
50   OP_REQUIRES(ctx, rank >= 2,
51               errors::InvalidArgument("Invalid rank ", rank, "."));
52 }
53 
54 // Return group shape, which is the 1st n-1 dimensions of shape.
GroupShape(const VarDimArray & input_shape,ShapeArray * grouped_shape)55 Status GroupShape(const VarDimArray& input_shape, ShapeArray* grouped_shape) {
56   if (input_shape.size() < 2) {
57     // TODO(irving): Why can't 2 be 1 here?
58     return errors::InvalidArgument("Shape [", absl::StrJoin(input_shape, ","),
59                                    "] has rank ", input_shape.size(), " < 2");
60   }
61   // grouped_shape is input_shape[:-1]
62   *grouped_shape = ShapeArray(input_shape.begin(), input_shape.end() - 1);
63   return OkStatus();
64 }
65 
66 // Build `SparseTensor` from indices, values, and shape in inputs
67 // [base_index, base_index + 3), and validate its rank and indices.
SparseTensorFromContext(OpKernelContext * ctx,const int32_t base_index,const bool validate_indices,sparse::SparseTensor * tensor)68 Status SparseTensorFromContext(OpKernelContext* ctx, const int32_t base_index,
69                                const bool validate_indices,
70                                sparse::SparseTensor* tensor) {
71   // Assume row-major order.
72   TensorShape shape;
73   const Tensor& shape_tensor = ctx->input(base_index + 2);
74   if (shape_tensor.dims() != 1) {
75     return errors::InvalidArgument("Shape must be a 1D tensor.");
76   }
77   TF_RETURN_IF_ERROR(
78       TensorShape::BuildTensorShape(shape_tensor.vec<int64_t>(), &shape));
79   CheckRankAtLeast2(ctx, shape);
80   std::vector<int64_t> order(shape.dims());
81   std::iota(order.begin(), order.end(), 0);
82 
83   Status status = sparse::SparseTensor::Create(
84       ctx->input(base_index), ctx->input(base_index + 1), shape, order, tensor);
85 
86   if (!validate_indices || !status.ok()) return status;
87   return tensor->IndicesValid();
88 }
89 
90 // TODO(ptucker): CheckGroup is just a sanity check on the result of
91 // SparseTensor.group, consider removing.
92 // `sparse_tensor_shape` is the shape of the `SparseTensor` from which group
93 // was created, and is used to sanity check the indices in `group'.
94 template <typename T>
CheckGroup(OpKernelContext * ctx,const sparse::Group & group,const VarDimArray & sparse_tensor_shape)95 void CheckGroup(OpKernelContext* ctx, const sparse::Group& group,
96                 const VarDimArray& sparse_tensor_shape) {
97   const auto& indices = group.indices();
98   const auto& values = group.values<T>();
99 
100   // Sanity check: group is non-empty, and indices and values are same size.
101   const auto num_values = values.dimension(0);
102   OP_REQUIRES(ctx, indices.size() > 0, errors::Internal("Empty group."));
103   OP_REQUIRES(
104       ctx, indices.dimension(0) == num_values,
105       errors::Internal("shape[0] of group indices ", indices.dimension(0),
106                        " != values ", num_values, "."));
107 
108   // Sanity check: valid indices.
109   const auto group_rank = indices.dimension(1);
110   const auto expected_rank = sparse_tensor_shape.size();
111   OP_REQUIRES(ctx, expected_rank == group_rank,
112               errors::Internal("Rank expected ", expected_rank, ", got ",
113                                group_rank, "."));
114   for (int32_t j = 0; j < expected_rank; ++j) {
115     const auto dim_size = sparse_tensor_shape[j];
116     OP_REQUIRES(
117         ctx, dim_size > 0,
118         errors::Internal("Invalid dim_size[", j, "] = ", dim_size, "."));
119     for (int64_t i = 0; i < num_values; ++i) {
120       const auto index = indices(i, j);
121       OP_REQUIRES(ctx, dim_size > index,
122                   errors::Internal("indices[", i, ", ", j, "] expected < ",
123                                    dim_size, ", got ", index, "."));
124     }
125   }
126 }
127 
128 // This lets us calculate the row-major index into flattened output.
Strides(const VarDimArray & shape)129 const ShapeArray Strides(const VarDimArray& shape) {
130   ShapeArray result(shape.size());
131   int64_t product = 1;
132   for (int i = shape.size() - 1; i >= 0; --i) {
133     result[i] = product;
134     product *= shape[i];
135   }
136   return result;
137 }
138 
139 // TODO(ptucker): If memory becomes an issue, consider a 2-pass approach to
140 // eliminate the intermediate `values` data structure - iterate once to
141 // determine `num_values`, allocate output tensors, then write results directly
142 // to output tensors.
143 
144 // TODO(ptucker): Consider sharding work across multiple threads. See
145 // SparseCrossOp for an example.
146 
147 // Output `SparseTensor` of shape `output_shape`. `sets` contains pairs of
148 // group indices (i.e., values for all but the last dimension of `output_shape`)
149 // and set values, each of which will occupy the last dimension of
150 // `output_shape`. `sets` should be sorted in ascending order by group indices.
151 template <typename T>
OutputSparseTensor(OpKernelContext * ctx,const TensorShape & output_shape,const int64_t num_values,const std::vector<std::pair<std::vector<int64_t>,absl::btree_set<T>>> & sets)152 void OutputSparseTensor(
153     OpKernelContext* ctx, const TensorShape& output_shape,
154     const int64_t num_values,
155     const std::vector<std::pair<std::vector<int64_t>, absl::btree_set<T>>>&
156         sets) {
157   // Allocate 3 output tensors for sparse data.
158   Tensor *out_indices_t, *out_values_t, *out_shape_t;
159   OP_REQUIRES_OK(ctx, ctx->allocate_output(
160                           0, TensorShape({num_values, output_shape.dims()}),
161                           &out_indices_t));
162   OP_REQUIRES_OK(
163       ctx, ctx->allocate_output(1, TensorShape({num_values}), &out_values_t));
164   OP_REQUIRES_OK(ctx, ctx->allocate_output(
165                           2, TensorShape({output_shape.dims()}), &out_shape_t));
166   auto out_indices_mat = out_indices_t->matrix<int64_t>();
167   auto out_values_flat = out_values_t->vec<T>();
168 
169   // For each set, write its indices and values to output tensors.
170   int64_t value_index = 0;
171   for (auto it = sets.begin(); it != sets.end(); ++it) {
172     const auto& group_indices = it->first;
173     OP_REQUIRES(
174         ctx, group_indices.size() == output_shape.dims() - 1,
175         errors::Internal("Invalid number of indices ", group_indices.size(),
176                          ", expected ", output_shape.dims() - 1, "."));
177     const auto& set = it->second;
178 
179     // For each set item, write its indices and value to output tensors.
180     int64_t group_value_index = 0;
181     for (auto value = set.begin(); value != set.end();
182          ++value, ++value_index, ++group_value_index) {
183       // First n-1 dimensions are the group, last dimension is the position in
184       // the set.
185       for (int32_t i = 0; i < group_indices.size(); ++i) {
186         out_indices_mat(value_index, i) = group_indices[i];
187       }
188       out_indices_mat(value_index, group_indices.size()) = group_value_index;
189 
190       out_values_flat(value_index) = *value;
191     }
192   }
193 
194   // Write output shape.
195   auto out_shape_flat = out_shape_t->vec<int64_t>();
196   for (int32_t i = 0; i < output_shape.dims(); ++i) {
197     out_shape_flat(i) = output_shape.dim_size(i);
198   }
199 }
200 
ValidateIndicesFromContext(OpKernelConstruction * ctx)201 bool ValidateIndicesFromContext(OpKernelConstruction* ctx) {
202   bool result;
203   if (ctx->GetAttr("validate_indices", &result).ok()) {
204     return result;
205   }
206   return true;
207 }
208 
209 // Populate `result` set from group in `tensor`. "Group" is defined by
210 // `group_indices`, which are values for the first n-1 dimensions of
211 // `input_tensor`. `input_strides` is provided to avoid recalculating it
212 // multiple times, and is used to calculate the flat index into `input_tensor`
213 // values.
214 template <typename T>
PopulateFromDenseGroup(OpKernelContext * ctx,const Tensor & input_tensor,const VarDimArray & input_strides,const std::vector<int64_t> & group_indices,absl::flat_hash_set<T> * result)215 void PopulateFromDenseGroup(OpKernelContext* ctx, const Tensor& input_tensor,
216                             const VarDimArray& input_strides,
217                             const std::vector<int64_t>& group_indices,
218                             absl::flat_hash_set<T>* result) {
219   OP_REQUIRES(ctx, group_indices.size() == input_strides.size() - 1,
220               errors::Internal("group_indices.size ", group_indices.size(),
221                                ", !=  input_strides.size-1 ",
222                                input_strides.size() - 1, "."));
223   result->clear();
224   auto input_flat = input_tensor.flat<T>();
225   const auto start = std::inner_product(
226       group_indices.begin(), group_indices.end(), input_strides.begin(), 0LL);
227   const TensorShape& input_shape = input_tensor.shape();
228   const auto end = start + input_shape.dim_size(input_shape.dims() - 1);
229   for (int64_t i = start; i < end; ++i) {
230     result->insert(input_flat(i));
231   }
232 }
233 
234 // Populate `result` set from `group`. `sparse_tensor_shape` is the shape of the
235 // `SparseTensor` from which group was created, and is used to sanity check the
236 // indices in `group'.
237 template <typename T>
PopulateFromSparseGroup(OpKernelContext * ctx,const sparse::Group & group,const VarDimArray & sparse_tensor_shape,absl::flat_hash_set<T> * result)238 void PopulateFromSparseGroup(OpKernelContext* ctx, const sparse::Group& group,
239                              const VarDimArray& sparse_tensor_shape,
240                              absl::flat_hash_set<T>* result) {
241   CheckGroup<T>(ctx, group, sparse_tensor_shape);
242   result->clear();
243   const auto& group_values = group.values<T>();
244   for (int64_t i = 0; i < group_values.size(); ++i) {
245     result->insert(group_values(i));
246   }
247 }
248 
249 template <typename T>
250 class SetSizeOp : public OpKernel {
251  public:
SetSizeOp(OpKernelConstruction * ctx)252   explicit SetSizeOp(OpKernelConstruction* ctx)
253       : OpKernel(ctx), validate_indices_(ValidateIndicesFromContext(ctx)) {}
254 
255   void Compute(OpKernelContext* ctx) override;
256 
257  private:
258   const bool validate_indices_;
259 };
260 
261 template <typename T>
Compute(OpKernelContext * ctx)262 void SetSizeOp<T>::Compute(OpKernelContext* ctx) {
263   sparse::SparseTensor set_st;
264   OP_REQUIRES_OK(ctx,
265                  SparseTensorFromContext(ctx, 0, validate_indices_, &set_st));
266 
267   // Output shape is same as input except for last dimension, which reduces
268   // to the set size of values along that dimension.
269   ShapeArray output_shape;
270   OP_REQUIRES_OK(ctx, GroupShape(set_st.shape(), &output_shape));
271   const auto output_strides = Strides(output_shape);
272 
273   TensorShape output_shape_ts;
274   OP_REQUIRES_OK(ctx,
275                  TensorShapeUtils::MakeShape(output_shape, &output_shape_ts));
276   Tensor* out_t;
277   OP_REQUIRES_OK(ctx, ctx->allocate_output(0, output_shape_ts, &out_t));
278   auto out = out_t->flat<int32>();
279   out.device(ctx->eigen_cpu_device()) = out.constant(static_cast<int32>(0.0));
280 
281   // Group by all but last dimension, create a set of group values, and add set
282   // size to output.
283   VarDimArray group_ix = set_st.order().subspan(0, set_st.order().size() - 1);
284   absl::flat_hash_set<T> group_set;
285   for (const auto& group : set_st.group(group_ix)) {
286     PopulateFromSparseGroup<T>(ctx, group, set_st.shape(), &group_set);
287 
288     const auto group_key = group.group();
289     const auto output_index = std::inner_product(
290         group_key.begin(), group_key.end(), output_strides.begin(), 0LL);
291     out(output_index) = group_set.size();
292   }
293 }
294 
295 #define _SET_SIZE_REGISTER_KERNEL_BUILDER(T)                     \
296   REGISTER_KERNEL_BUILDER(                                       \
297       Name("SetSize").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
298       SetSizeOp<T>);
299 _SET_SIZE_REGISTER_KERNEL_BUILDER(int8);
300 _SET_SIZE_REGISTER_KERNEL_BUILDER(int16);
301 _SET_SIZE_REGISTER_KERNEL_BUILDER(int32);
302 _SET_SIZE_REGISTER_KERNEL_BUILDER(int64_t);
303 _SET_SIZE_REGISTER_KERNEL_BUILDER(uint8);
304 _SET_SIZE_REGISTER_KERNEL_BUILDER(uint16);
305 _SET_SIZE_REGISTER_KERNEL_BUILDER(tstring);
306 #undef _SET_SIZE_REGISTER_KERNEL_BUILDER
307 
308 enum InputTypes {
309   DENSE_DENSE = 0,
310   DENSE_SPARSE = 1,
311   SPARSE_SPARSE = 2,
312 };
313 
314 enum SetOperation { A_MINUS_B = 0, B_MINUS_A = 1, INTERSECTION = 2, UNION = 3 };
315 
SetOperationFromContext(OpKernelConstruction * ctx)316 SetOperation SetOperationFromContext(OpKernelConstruction* ctx) {
317   string set_operation_str;
318   if (!ctx->GetAttr("set_operation", &set_operation_str).ok()) {
319     ctx->CtxFailure(errors::InvalidArgument("Missing set_operation."));
320   } else {
321     std::transform(set_operation_str.begin(), set_operation_str.end(),
322                    set_operation_str.begin(), ::tolower);
323     if ("a-b" == set_operation_str) {
324       return A_MINUS_B;
325     }
326     if ("b-a" == set_operation_str) {
327       return B_MINUS_A;
328     }
329     if ("intersection" == set_operation_str) {
330       return INTERSECTION;
331     }
332     if ("union" != set_operation_str) {
333       ctx->CtxFailure(errors::InvalidArgument("Invalid set_operation ",
334                                               set_operation_str, "."));
335     }
336   }
337   // NOTE: This is not the default, this function fails if no 'set_operation'
338   // attribute is provided.
339   return UNION;
340 }
341 
342 // Abstract base class for performing set operations across the last dimension
343 // of 2 input tensors.
344 template <typename T>
345 class SetOperationOp : public OpKernel {
346  public:
SetOperationOp(OpKernelConstruction * ctx,InputTypes input_types)347   SetOperationOp(OpKernelConstruction* ctx, InputTypes input_types)
348       : OpKernel(ctx),
349         set_operation_(SetOperationFromContext(ctx)),
350         validate_indices_(ValidateIndicesFromContext(ctx)),
351         input_types_(input_types) {}
352 
353   void Compute(OpKernelContext* ctx) override;
354 
355  private:
356   void ApplySetOperation(const absl::flat_hash_set<T>& set1,
357                          const absl::flat_hash_set<T>& set2,
358                          absl::btree_set<T>* result) const;
359   void ComputeDenseToDense(OpKernelContext* ctx) const;
360   void ComputeDenseToSparse(OpKernelContext* ctx) const;
361   void ComputeSparseToSparse(OpKernelContext* ctx) const;
362   const SetOperation set_operation_;
363   const bool validate_indices_;
364   const InputTypes input_types_;
365 };
366 
367 template <typename T>
SetDifference(const absl::flat_hash_set<T> & set1,const absl::flat_hash_set<T> & set2,absl::btree_set<T> * result)368 void SetDifference(const absl::flat_hash_set<T>& set1,
369                    const absl::flat_hash_set<T>& set2,
370                    absl::btree_set<T>* result) {
371   for (const T& elem : set1) {
372     if (!set2.contains(elem)) result->insert(elem);
373   }
374 }
375 
376 template <typename T>
SetIntersection(const absl::flat_hash_set<T> & set1,const absl::flat_hash_set<T> & set2,absl::btree_set<T> * result)377 void SetIntersection(const absl::flat_hash_set<T>& set1,
378                      const absl::flat_hash_set<T>& set2,
379                      absl::btree_set<T>* result) {
380   if (set1.size() <= set2.size()) {
381     for (const T& elem : set1) {
382       if (set2.contains(elem)) result->insert(elem);
383     }
384   } else {
385     for (const T& elem : set2) {
386       if (set1.contains(elem)) result->insert(elem);
387     }
388   }
389 }
390 
391 template <typename T>
SetUnion(const absl::flat_hash_set<T> & set1,const absl::flat_hash_set<T> & set2,absl::btree_set<T> * result)392 void SetUnion(const absl::flat_hash_set<T>& set1,
393               const absl::flat_hash_set<T>& set2, absl::btree_set<T>* result) {
394   result->insert(set1.begin(), set1.end());
395   result->insert(set2.begin(), set2.end());
396 }
397 
398 template <typename T>
ApplySetOperation(const absl::flat_hash_set<T> & set1,const absl::flat_hash_set<T> & set2,absl::btree_set<T> * result) const399 void SetOperationOp<T>::ApplySetOperation(const absl::flat_hash_set<T>& set1,
400                                           const absl::flat_hash_set<T>& set2,
401                                           absl::btree_set<T>* result) const {
402   switch (set_operation_) {
403     case A_MINUS_B:
404       SetDifference<T>(set1, set2, result);
405       break;
406     case B_MINUS_A:
407       SetDifference<T>(set2, set1, result);
408       break;
409     case INTERSECTION:
410       SetIntersection<T>(set1, set2, result);
411       break;
412     case UNION:
413       SetUnion<T>(set1, set2, result);
414       break;
415   }
416 }
417 
418 // Validate shapes have the same dimensions.
CheckShapesMatch(VarDimArray shape1,VarDimArray shape2)419 Status CheckShapesMatch(VarDimArray shape1, VarDimArray shape2) {
420   if (shape1 != shape2) {
421     return errors::InvalidArgument("Mismatched shapes [",
422                                    absl::StrJoin(shape1, ","), "] vs [",
423                                    absl::StrJoin(shape2, ","), "]");
424   }
425   return OkStatus();
426 }
427 
428 // Validate ranks are the same, and all but last dimension are the same.
429 // Return GroupShape.
GroupShapeFromInputs(VarDimArray shape1,VarDimArray shape2,ShapeArray * group_shape)430 Status GroupShapeFromInputs(VarDimArray shape1, VarDimArray shape2,
431                             ShapeArray* group_shape) {
432   ShapeArray group_shape_1;
433   TF_RETURN_IF_ERROR(GroupShape(shape1, &group_shape_1));
434   ShapeArray group_shape_2;
435   TF_RETURN_IF_ERROR(GroupShape(shape2, &group_shape_2));
436   TF_RETURN_IF_ERROR(CheckShapesMatch(group_shape_1, group_shape_2));
437   *group_shape = group_shape_1;
438   return OkStatus();
439 }
440 
441 // Split `flat_group_index` into separate dimensions based on `group_shape`.
PopulateGroupIndices(const int64_t flat_group_index,VarDimArray group_shape,std::vector<int64_t> * group_indices)442 void PopulateGroupIndices(const int64_t flat_group_index,
443                           VarDimArray group_shape,
444                           std::vector<int64_t>* group_indices) {
445   group_indices->clear();
446   int64_t running_flat_group_index = flat_group_index;
447   for (int group_dim_index = group_shape.size() - 1; group_dim_index >= 0;
448        --group_dim_index) {
449     const auto group_dim = group_shape[group_dim_index];
450     group_indices->insert(group_indices->begin(),
451                           running_flat_group_index % group_dim);
452     running_flat_group_index /= group_dim;
453   }
454 }
455 
TensorShapeToArray(const TensorShape & t)456 ShapeArray TensorShapeToArray(const TensorShape& t) {
457   ShapeArray vec(t.dims());
458   for (int i = 0; i < t.dims(); ++i) vec[i] = t.dim_size(i);
459   return vec;
460 }
461 
462 // `ctx` contains set1 and set2 dense tensors.
463 // Iterate over groups in set1 and set2, applying `ApplySetOperation` to each,
464 // and outputting the result `SparseTensor`. A "group" is a collection of values
465 // with the same first n-1 dimensions in set1 and set2.
466 template <typename T>
ComputeDenseToDense(OpKernelContext * ctx) const467 void SetOperationOp<T>::ComputeDenseToDense(OpKernelContext* ctx) const {
468   const Tensor& set1_t = ctx->input(0);
469   const Tensor& set2_t = ctx->input(1);
470   // The following should stay in sync with `_dense_to_dense_shape` shape
471   // assertions in python/ops/set_ops.py, and `SetShapeFn` for
472   // `DenseToDenseSetOperation` in ops/set_ops.cc.
473   ShapeArray group_shape;
474   const auto shape1 = TensorShapeToArray(set1_t.shape());
475   const auto shape2 = TensorShapeToArray(set2_t.shape());
476   OP_REQUIRES_OK(ctx, GroupShapeFromInputs(shape1, shape2, &group_shape));
477 
478   const auto set1_strides = Strides(shape1);
479   const auto set2_strides = Strides(shape2);
480 
481   std::vector<std::pair<std::vector<int64_t>, absl::btree_set<T>>> group_sets;
482   int64_t num_result_values = 0;
483   int64_t max_set_size = 0;
484 
485   absl::flat_hash_set<T> set1_group_set;
486   absl::flat_hash_set<T> set2_group_set;
487   std::vector<int64_t> group_indices;
488   int64_t num_elements;
489   OP_REQUIRES_OK(ctx,
490                  TensorShapeUtils::NumElements(group_shape, &num_elements));
491   for (int64_t flat_group_index = 0; flat_group_index < num_elements;
492        ++flat_group_index) {
493     PopulateGroupIndices(flat_group_index, group_shape, &group_indices);
494     PopulateFromDenseGroup<T>(ctx, set1_t, set1_strides, group_indices,
495                               &set1_group_set);
496     PopulateFromDenseGroup<T>(ctx, set2_t, set2_strides, group_indices,
497                               &set2_group_set);
498 
499     absl::btree_set<T> group_set;
500     ApplySetOperation(set1_group_set, set2_group_set, &group_set);
501     if (!group_set.empty()) {
502       const auto set_size = group_set.size();
503       if (set_size > max_set_size) {
504         max_set_size = set_size;
505       }
506       num_result_values += set_size;
507       group_sets.push_back({group_indices, std::move(group_set)});
508     }
509   }
510 
511   TensorShape output_shape;
512   OP_REQUIRES_OK(ctx, TensorShapeUtils::MakeShape(group_shape, &output_shape));
513   output_shape.AddDim(max_set_size);
514   OutputSparseTensor<T>(ctx, output_shape, num_result_values, group_sets);
515 }
516 
517 // `ctx` contains dense set1 and sparse set2 tensors.
518 // Iterate over groups in set1 and set2, applying `ApplySetOperation` to each,
519 // and outputing the result `SparseTensor`. A "group" is a collection of values
520 // with the same first n-1 dimensions in set1 and set2.
521 template <typename T>
ComputeDenseToSparse(OpKernelContext * ctx) const522 void SetOperationOp<T>::ComputeDenseToSparse(OpKernelContext* ctx) const {
523   const Tensor& set1_t = ctx->input(0);
524   sparse::SparseTensor set2_st;
525   OP_REQUIRES_OK(ctx,
526                  SparseTensorFromContext(ctx, 1, validate_indices_, &set2_st));
527   // The following should stay in sync with `_dense_to_sparse_shape` shape
528   // assertions in python/ops/set_ops.py, and `SetShapeFn` for
529   // `DenseToSparseSetOperation` in ops/set_ops.cc.
530   ShapeArray group_shape;
531   OP_REQUIRES_OK(ctx, GroupShapeFromInputs(TensorShapeToArray(set1_t.shape()),
532                                            set2_st.shape(), &group_shape));
533 
534   const ShapeArray set1_strides = Strides(TensorShapeToArray(set1_t.shape()));
535 
536   std::vector<std::pair<std::vector<int64_t>, absl::btree_set<T>>> group_sets;
537   int64_t num_result_values = 0;
538   int64_t max_set_size = 0;
539 
540   absl::flat_hash_set<T> set1_group_set;
541   absl::flat_hash_set<T> set2_group_set;
542   auto set2_grouper =
543       set2_st.group(set2_st.order().subspan(0, set2_st.order().size() - 1));
544   auto set2_group_it = set2_grouper.begin();
545   std::vector<int64_t> group_indices;
546   int64_t num_elements;
547   OP_REQUIRES_OK(ctx,
548                  TensorShapeUtils::NumElements(group_shape, &num_elements));
549   for (int64_t flat_group_index = 0; flat_group_index < num_elements;
550        ++flat_group_index) {
551     PopulateGroupIndices(flat_group_index, group_shape, &group_indices);
552 
553     // Get values from set1.
554     PopulateFromDenseGroup<T>(ctx, set1_t, set1_strides, group_indices,
555                               &set1_group_set);
556 
557     // Get values from set2, if applicable.
558     set2_group_set.clear();
559     if (set2_group_it != set2_grouper.end()) {
560       const auto& group = *set2_group_it;
561       const auto set2_group_indices = group.group();
562       OP_REQUIRES(
563           ctx, set2_group_indices.size() == group_indices.size(),
564           errors::InvalidArgument("Invalid number of group indices ",
565                                   set2_group_indices.size(), ", expected ",
566                                   group_indices.size(), "."));
567       bool group_match = true;
568       for (int32_t i = 0; group_match && (i < set2_group_indices.size()); ++i) {
569         if (set2_group_indices[i] != group_indices[i]) {
570           group_match = false;
571         }
572       }
573       if (group_match) {
574         PopulateFromSparseGroup<T>(ctx, group, set2_st.shape(),
575                                    &set2_group_set);
576         ++set2_group_it;
577       }
578     }
579 
580     absl::btree_set<T> group_set;
581     ApplySetOperation(set1_group_set, set2_group_set, &group_set);
582     if (!group_set.empty()) {
583       const auto set_size = group_set.size();
584       if (set_size > max_set_size) {
585         max_set_size = set_size;
586       }
587       num_result_values += set_size;
588       group_sets.push_back({group_indices, std::move(group_set)});
589     }
590   }
591 
592   TensorShape output_shape;
593   OP_REQUIRES_OK(ctx, TensorShapeUtils::MakeShape(group_shape, &output_shape));
594   output_shape.AddDim(max_set_size);
595   OutputSparseTensor<T>(ctx, output_shape, num_result_values, group_sets);
596 }
597 
598 // This is used to determine which group iterator is less than the other, based
599 // on row-major ordering of indices.
600 // An empty index list indicates end of iteration, which is interpreted as "max"
601 // for the purposes of comparison; i.e., non-empty < empty.
602 // Return 0 if both groups are empty, or both non-empty with the same values.
603 // Return <0 if set1 <= set2, or set2 is empty.
604 // Return >0 if set2 <= set1, or set1 is empty.
CompareGroups(OpKernelContext * ctx,const std::vector<int64_t> & set1_group_indices,const std::vector<int64_t> & set2_group_indices,int64_t * result)605 void CompareGroups(OpKernelContext* ctx,
606                    const std::vector<int64_t>& set1_group_indices,
607                    const std::vector<int64_t>& set2_group_indices,
608                    int64_t* result) {
609   if (set1_group_indices.empty()) {
610     *result = set2_group_indices.empty() ? 0 : 1;
611     return;
612   }
613   if (set2_group_indices.empty()) {
614     *result = set1_group_indices.empty() ? 0 : -1;
615     return;
616   }
617   OP_REQUIRES(ctx, set1_group_indices.size() == set2_group_indices.size(),
618               errors::InvalidArgument("Mismatched group dims ",
619                                       set1_group_indices.size(), " vs ",
620                                       set2_group_indices.size(), "."));
621   for (int32_t i = 0; i < set1_group_indices.size(); ++i) {
622     *result = set1_group_indices[i] - set2_group_indices[i];
623     if (*result != 0) {
624       return;
625     }
626   }
627 }
628 
629 // `ctx` contains set1 and set2 sparse tensors.
630 // Iterate over groups in set1 and set2, applying `ApplySetOperation` to each,
631 // and outputing the result `SparseTensor`. A "group" is a collection of values
632 // with the same first n-1 dimensions in set1 and set2.
633 template <typename T>
ComputeSparseToSparse(OpKernelContext * ctx) const634 void SetOperationOp<T>::ComputeSparseToSparse(OpKernelContext* ctx) const {
635   sparse::SparseTensor set1_st;
636   OP_REQUIRES_OK(ctx,
637                  SparseTensorFromContext(ctx, 0, validate_indices_, &set1_st));
638 
639   sparse::SparseTensor set2_st;
640   OP_REQUIRES_OK(ctx,
641                  SparseTensorFromContext(ctx, 3, validate_indices_, &set2_st));
642 
643   // The following should stay in sync with `_sparse_to_sparse_shape` shape
644   // assertions in python/ops/set_ops.py, and `SetShapeFn` for
645   // `SparseToSparseSetOperation` in ops/set_ops.cc.
646   ShapeArray group_shape;
647   OP_REQUIRES_OK(ctx, GroupShapeFromInputs(set1_st.shape(), set2_st.shape(),
648                                            &group_shape));
649 
650   std::vector<std::pair<std::vector<int64_t>, absl::btree_set<T>>> group_sets;
651   int64_t num_result_values = 0;
652   int64_t max_set_size = 0;
653 
654   absl::flat_hash_set<T> set1_group_set;
655   absl::flat_hash_set<T> set2_group_set;
656   auto set1_grouper =
657       set1_st.group(set1_st.order().subspan(0, set1_st.order().size() - 1));
658   auto set1_group_it = set1_grouper.begin();
659   auto set2_grouper =
660       set2_st.group(set2_st.order().subspan(0, set2_st.order().size() - 1));
661   auto set2_group_it = set2_grouper.begin();
662 
663   // Empty indices vector represents iteration end in `CompareGroups`.
664   const std::vector<int64_t> group_iter_end;
665   // Group by rows, and iterate over rows of both sets in parallel, creating a
666   // set for each row.
667   while ((set1_group_it != set1_grouper.end()) ||
668          (set2_group_it != set2_grouper.end())) {
669     const std::vector<int64_t>& set1_group_indices =
670         (set1_group_it == set1_grouper.end()) ? group_iter_end
671                                               : (*set1_group_it).group();
672     const std::vector<int64_t>& set2_group_indices =
673         (set2_group_it == set2_grouper.end()) ? group_iter_end
674                                               : (*set2_group_it).group();
675 
676     int64_t compare_groups;
677     CompareGroups(ctx, set1_group_indices, set2_group_indices, &compare_groups);
678     const std::vector<int64_t>* group_indices = nullptr;
679 
680     // Get values from set1, if applicable.
681     set1_group_set.clear();
682     if (compare_groups <= 0) {
683       PopulateFromSparseGroup<T>(ctx, *set1_group_it, set1_st.shape(),
684                                  &set1_group_set);
685       ++set1_group_it;
686       group_indices = &set1_group_indices;
687     }
688 
689     // Get values from set2, if applicable.
690     set2_group_set.clear();
691     if (compare_groups >= 0) {
692       PopulateFromSparseGroup<T>(ctx, *set2_group_it, set2_st.shape(),
693                                  &set2_group_set);
694       ++set2_group_it;
695       group_indices = &set2_group_indices;
696     }
697 
698     absl::btree_set<T> group_set;
699     ApplySetOperation(set1_group_set, set2_group_set, &group_set);
700     if (!group_set.empty()) {
701       const auto set_size = group_set.size();
702       if (set_size > max_set_size) {
703         max_set_size = set_size;
704       }
705       num_result_values += set_size;
706       group_sets.push_back({*group_indices, std::move(group_set)});
707     }
708   }
709 
710   TensorShape output_shape;
711   OP_REQUIRES_OK(ctx, TensorShapeUtils::MakeShape(group_shape, &output_shape));
712   output_shape.AddDim(max_set_size);
713   OutputSparseTensor<T>(ctx, output_shape, num_result_values, group_sets);
714 }
715 
716 // Given set1 of shape [b, n1] and data_2 of shape [b, n2], populate result
717 // sparse tensor with [b, n3] values, where each row `i` contains the result of
718 // the set operation on elements from set1[i] and set2[i]. `n3` is the number
719 // of elements in that result row.
720 template <typename T>
Compute(OpKernelContext * ctx)721 void SetOperationOp<T>::Compute(OpKernelContext* ctx) {
722   switch (input_types_) {
723     case DENSE_DENSE:
724       ComputeDenseToDense(ctx);
725       break;
726     case DENSE_SPARSE:
727       ComputeDenseToSparse(ctx);
728       break;
729     case SPARSE_SPARSE:
730       ComputeSparseToSparse(ctx);
731       break;
732   }
733 }
734 
735 template <typename T>
736 class DenseToDenseSetOperationOp : public SetOperationOp<T> {
737  public:
DenseToDenseSetOperationOp(OpKernelConstruction * ctx)738   explicit DenseToDenseSetOperationOp(OpKernelConstruction* ctx)
739       : SetOperationOp<T>(ctx, DENSE_DENSE) {}
740 };
741 
742 #define _DENSE_TO_DENSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(T) \
743   REGISTER_KERNEL_BUILDER(Name("DenseToDenseSetOperation")       \
744                               .Device(DEVICE_CPU)                \
745                               .TypeConstraint<T>("T"),           \
746                           DenseToDenseSetOperationOp<T>);
747 _DENSE_TO_DENSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(int8);
748 _DENSE_TO_DENSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(int16);
749 _DENSE_TO_DENSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(int32);
750 _DENSE_TO_DENSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(int64_t);
751 _DENSE_TO_DENSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(uint8);
752 _DENSE_TO_DENSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(uint16);
753 _DENSE_TO_DENSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(tstring);
754 #undef _DENSE_TO_DENSE_SET_OPERATION_REGISTER_KERNEL_BUILDER
755 
756 template <typename T>
757 class DenseToSparseSetOperationOp : public SetOperationOp<T> {
758  public:
DenseToSparseSetOperationOp(OpKernelConstruction * ctx)759   explicit DenseToSparseSetOperationOp(OpKernelConstruction* ctx)
760       : SetOperationOp<T>(ctx, DENSE_SPARSE) {}
761 };
762 
763 #define _DENSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(T) \
764   REGISTER_KERNEL_BUILDER(Name("DenseToSparseSetOperation")       \
765                               .Device(DEVICE_CPU)                 \
766                               .TypeConstraint<T>("T"),            \
767                           DenseToSparseSetOperationOp<T>);
768 _DENSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(int8);
769 _DENSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(int16);
770 _DENSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(int32);
771 _DENSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(int64_t);
772 _DENSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(uint8);
773 _DENSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(uint16);
774 _DENSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(tstring);
775 #undef _DENSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER
776 
777 template <typename T>
778 class SparseToSparseSetOperationOp : public SetOperationOp<T> {
779  public:
SparseToSparseSetOperationOp(OpKernelConstruction * ctx)780   explicit SparseToSparseSetOperationOp(OpKernelConstruction* ctx)
781       : SetOperationOp<T>(ctx, SPARSE_SPARSE) {}
782 };
783 
784 #define _SPARSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(T) \
785   REGISTER_KERNEL_BUILDER(Name("SparseToSparseSetOperation")       \
786                               .Device(DEVICE_CPU)                  \
787                               .TypeConstraint<T>("T"),             \
788                           SparseToSparseSetOperationOp<T>);
789 _SPARSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(int8);
790 _SPARSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(int16);
791 _SPARSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(int32);
792 _SPARSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(int64_t);
793 _SPARSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(uint8);
794 _SPARSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(uint16);
795 _SPARSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(tstring);
796 #undef _SPARSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER
797 
798 }  // namespace tensorflow
799