• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Ops for operating with sets. They are not checked in
17 // to TensorFlow because we would first like to demonstrate successful
18 // end-to-end use of these ops in eval and polish the api a bit like taking two
19 // SparseTensor rather than on edense and one sparse.
20 
21 #define EIGEN_USE_THREADS
22 
23 #include <algorithm>
24 #include <numeric>
25 // TODO(ptucker): Consider switching back to hash_set - I had trouble getting it
26 // to work with string values.
27 #include <set>
28 #include <string>
29 
30 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
31 #include "tensorflow/core/framework/op_kernel.h"
32 #include "tensorflow/core/framework/register_types.h"
33 #include "tensorflow/core/framework/tensor.h"
34 #include "tensorflow/core/framework/tensor_util.h"
35 #include "tensorflow/core/framework/types.h"
36 #include "tensorflow/core/lib/core/status.h"
37 #include "tensorflow/core/platform/env.h"
38 #include "tensorflow/core/util/sparse/sparse_tensor.h"
39 
40 namespace tensorflow {
41 
42 using ShapeArray = sparse::SparseTensor::ShapeArray;
43 using VarDimArray = sparse::SparseTensor::VarDimArray;
44 
45 // Validate rank >= 2.
CheckRankAtLeast2(OpKernelContext * ctx,const TensorShape & shape)46 void CheckRankAtLeast2(OpKernelContext* ctx, const TensorShape& shape) {
47   const auto rank = shape.dims();
48   OP_REQUIRES(ctx, rank >= 2,
49               errors::InvalidArgument("Invalid rank ", rank, "."));
50 }
51 
52 // Return group shape, which is the 1st n-1 dimensions of shape.
GroupShape(const VarDimArray & input_shape,ShapeArray * grouped_shape)53 Status GroupShape(const VarDimArray& input_shape, ShapeArray* grouped_shape) {
54   if (input_shape.size() < 2) {
55     // TODO(irving): Why can't 2 be 1 here?
56     return errors::InvalidArgument("Shape [", absl::StrJoin(input_shape, ","),
57                                    "] has rank ", input_shape.size(), " < 2");
58   }
59   // grouped_shape is input_shape[:-1]
60   *grouped_shape = ShapeArray(input_shape.begin(), input_shape.end() - 1);
61   return Status::OK();
62 }
63 
64 // Build `SparseTensor` from indices, values, and shape in inputs
65 // [base_index, base_index + 3), and validate its rank and indices.
SparseTensorFromContext(OpKernelContext * ctx,const int32_t base_index,bool validate_indices,sparse::SparseTensor * tensor)66 Status SparseTensorFromContext(OpKernelContext* ctx, const int32_t base_index,
67                                bool validate_indices,
68                                sparse::SparseTensor* tensor) {
69   // Assume row-major order.
70   const TensorShape shape =
71       TensorShape(ctx->input(base_index + 2).vec<int64>());
72   CheckRankAtLeast2(ctx, shape);
73   std::vector<int64> order(shape.dims());
74   std::iota(order.begin(), order.end(), 0);
75 
76   return sparse::SparseTensor::Create(
77       ctx->input(base_index), ctx->input(base_index + 1), shape, order, tensor);
78 }
79 
80 // TODO(ptucker): CheckGroup is just a sanity check on the result of
81 // SparseTensor.group, consider removing.
82 // `sparse_tensor_shape` is the shape of the `SparseTensor` from which group
83 // was created, and is used to sanity check the indices in `group'.
84 template <typename T>
CheckGroup(OpKernelContext * ctx,const sparse::Group & group,const VarDimArray & sparse_tensor_shape)85 void CheckGroup(OpKernelContext* ctx, const sparse::Group& group,
86                 const VarDimArray& sparse_tensor_shape) {
87   const auto& indices = group.indices();
88   const auto& values = group.values<T>();
89 
90   // Sanity check: group is non-empty, and indices and values are same size.
91   const auto num_values = values.dimension(0);
92   OP_REQUIRES(ctx, indices.size() > 0, errors::Internal("Empty group."));
93   OP_REQUIRES(
94       ctx, indices.dimension(0) == num_values,
95       errors::Internal("shape[0] of group indices ", indices.dimension(0),
96                        " != values ", num_values, "."));
97 
98   // Sanity check: valid indices.
99   const auto group_rank = indices.dimension(1);
100   const auto expected_rank = sparse_tensor_shape.size();
101   OP_REQUIRES(ctx, expected_rank == group_rank,
102               errors::Internal("Rank expected ", expected_rank, ", got ",
103                                group_rank, "."));
104   for (int32_t j = 0; j < expected_rank; ++j) {
105     const auto dim_size = sparse_tensor_shape[j];
106     OP_REQUIRES(
107         ctx, dim_size > 0,
108         errors::Internal("Invalid dim_size[", j, "] = ", dim_size, "."));
109     for (int64_t i = 0; i < num_values; ++i) {
110       const auto index = indices(i, j);
111       OP_REQUIRES(ctx, dim_size > index,
112                   errors::Internal("indices[", i, ", ", j, "] expected < ",
113                                    dim_size, ", got ", index, "."));
114     }
115   }
116 }
117 
118 // This lets us calculate the row-major index into flattened output.
Strides(const VarDimArray & shape)119 const ShapeArray Strides(const VarDimArray& shape) {
120   ShapeArray result(shape.size());
121   int64_t product = 1;
122   for (int i = shape.size() - 1; i >= 0; --i) {
123     result[i] = product;
124     product *= shape[i];
125   }
126   return result;
127 }
128 
129 // TODO(ptucker): If memory becomes an issue, consider a 2-pass approach to
130 // eliminate the intermediate `values` data structure - iterate once to
131 // determine `num_values`, allocate output tensors, then write results directly
132 // to output tensors.
133 
134 // TODO(ptucker): Consider sharding work across multiple threads. See
135 // SparseCrossOp for an example.
136 
137 // Output `SparseTensor` of shape `output_shape`. `sets` contains a map of
138 // group indices (i.e., values for all but the last dimension of `output_shape`)
139 // to set values, each of which will occupy the last dimension of
140 // `output_shape`.
141 template <typename T>
OutputSparseTensor(OpKernelContext * ctx,const TensorShape & output_shape,const int64_t num_values,const std::map<std::vector<int64>,std::set<T>> & sets)142 void OutputSparseTensor(OpKernelContext* ctx, const TensorShape& output_shape,
143                         const int64_t num_values,
144                         const std::map<std::vector<int64>, std::set<T>>& sets) {
145   // Allocate 3 output tensors for sparse data.
146   Tensor *out_indices_t, *out_values_t, *out_shape_t;
147   OP_REQUIRES_OK(ctx, ctx->allocate_output(
148                           0, TensorShape({num_values, output_shape.dims()}),
149                           &out_indices_t));
150   OP_REQUIRES_OK(
151       ctx, ctx->allocate_output(1, TensorShape({num_values}), &out_values_t));
152   OP_REQUIRES_OK(ctx, ctx->allocate_output(
153                           2, TensorShape({output_shape.dims()}), &out_shape_t));
154   auto out_indices_mat = out_indices_t->matrix<int64>();
155   auto out_values_flat = out_values_t->vec<T>();
156 
157   // For each set, write its indices and values to output tensors.
158   int64_t value_index = 0;
159   for (auto it = sets.begin(); it != sets.end(); ++it) {
160     const auto& group_indices = it->first;
161     OP_REQUIRES(
162         ctx, group_indices.size() == output_shape.dims() - 1,
163         errors::Internal("Invalid number of indices ", group_indices.size(),
164                          ", expected ", output_shape.dims() - 1, "."));
165     const auto& set = it->second;
166 
167     // For each set item, write its indices and value to output tensors.
168     int64_t group_value_index = 0;
169     for (auto value = set.begin(); value != set.end();
170          ++value, ++value_index, ++group_value_index) {
171       // First n-1 dimensions are the group, last dimension is the position in
172       // the set.
173       for (int32_t i = 0; i < group_indices.size(); ++i) {
174         out_indices_mat(value_index, i) = group_indices[i];
175       }
176       out_indices_mat(value_index, group_indices.size()) = group_value_index;
177 
178       out_values_flat(value_index) = *value;
179     }
180   }
181 
182   // Write output shape.
183   auto out_shape_flat = out_shape_t->vec<int64>();
184   for (int32_t i = 0; i < output_shape.dims(); ++i) {
185     out_shape_flat(i) = output_shape.dim_size(i);
186   }
187 }
188 
ValidateIndicesFromContext(OpKernelConstruction * ctx)189 bool ValidateIndicesFromContext(OpKernelConstruction* ctx) {
190   bool result;
191   if (ctx->GetAttr("validate_indices", &result).ok()) {
192     return result;
193   }
194   return true;
195 }
196 
197 // Populate `result` set from group in `tensor`. "Group" is defined by
198 // `group_indices`, which are values for the first n-1 dimensions of
199 // `input_tensor`. `input_strides` is provided to avoid recalculating it
200 // multiple times, and is used to calculate the flat index into `input_tensor`
201 // values.
202 template <typename T>
PopulateFromDenseGroup(OpKernelContext * ctx,const Tensor & input_tensor,const VarDimArray & input_strides,const std::vector<int64> & group_indices,std::set<T> * result)203 void PopulateFromDenseGroup(OpKernelContext* ctx, const Tensor& input_tensor,
204                             const VarDimArray& input_strides,
205                             const std::vector<int64>& group_indices,
206                             std::set<T>* result) {
207   OP_REQUIRES(ctx, group_indices.size() == input_strides.size() - 1,
208               errors::Internal("group_indices.size ", group_indices.size(),
209                                ", !=  input_strides.size-1 ",
210                                input_strides.size() - 1, "."));
211   result->clear();
212   auto input_flat = input_tensor.flat<T>();
213   const auto start = std::inner_product(
214       group_indices.begin(), group_indices.end(), input_strides.begin(), 0LL);
215   const TensorShape& input_shape = input_tensor.shape();
216   const auto end = start + input_shape.dim_size(input_shape.dims() - 1);
217   for (int64_t i = start; i < end; ++i) {
218     result->insert(input_flat(i));
219   }
220 }
221 
222 // Populate `result` set from `group`. `sparse_tensor_shape` is the shape of the
223 // `SparseTensor` from which group was created, and is used to sanity check the
224 // indices in `group'.
225 template <typename T>
PopulateFromSparseGroup(OpKernelContext * ctx,const sparse::Group & group,const VarDimArray & sparse_tensor_shape,std::set<T> * result)226 void PopulateFromSparseGroup(OpKernelContext* ctx, const sparse::Group& group,
227                              const VarDimArray& sparse_tensor_shape,
228                              std::set<T>* result) {
229   CheckGroup<T>(ctx, group, sparse_tensor_shape);
230   result->clear();
231   const auto& group_values = group.values<T>();
232   for (int64_t i = 0; i < group_values.size(); ++i) {
233     result->insert(group_values(i));
234   }
235 }
236 
237 template <typename T>
238 class SetSizeOp : public OpKernel {
239  public:
SetSizeOp(OpKernelConstruction * ctx)240   explicit SetSizeOp(OpKernelConstruction* ctx)
241       : OpKernel(ctx), validate_indices_(ValidateIndicesFromContext(ctx)) {}
242 
243   void Compute(OpKernelContext* ctx) override;
244 
245  private:
246   const bool validate_indices_;
247 };
248 
249 template <typename T>
Compute(OpKernelContext * ctx)250 void SetSizeOp<T>::Compute(OpKernelContext* ctx) {
251   sparse::SparseTensor set_st;
252   OP_REQUIRES_OK(ctx,
253                  SparseTensorFromContext(ctx, 0, validate_indices_, &set_st));
254   OP_REQUIRES_OK(ctx, set_st.IndicesValid());
255 
256   // Output shape is same as input except for last dimension, which reduces
257   // to the set size of values along that dimension.
258   ShapeArray output_shape;
259   OP_REQUIRES_OK(ctx, GroupShape(set_st.shape(), &output_shape));
260   const auto output_strides = Strides(output_shape);
261 
262   TensorShape output_shape_ts;
263   OP_REQUIRES_OK(ctx,
264                  TensorShapeUtils::MakeShape(output_shape, &output_shape_ts));
265   Tensor* out_t;
266   OP_REQUIRES_OK(ctx, ctx->allocate_output(0, output_shape_ts, &out_t));
267   auto out = out_t->flat<int32>();
268   out.device(ctx->eigen_cpu_device()) = out.constant(static_cast<int32>(0.0));
269 
270   // Group by all but last dimension, create a set of group values, and add set
271   // size to output.
272   VarDimArray group_ix = set_st.order().subspan(0, set_st.order().size() - 1);
273   std::set<T> group_set;
274   for (const auto& group : set_st.group(group_ix)) {
275     PopulateFromSparseGroup<T>(ctx, group, set_st.shape(), &group_set);
276 
277     const auto group_key = group.group();
278     const auto output_index = std::inner_product(
279         group_key.begin(), group_key.end(), output_strides.begin(), 0LL);
280     out(output_index) = group_set.size();
281   }
282 }
283 
284 #define _SET_SIZE_REGISTER_KERNEL_BUILDER(T)                     \
285   REGISTER_KERNEL_BUILDER(                                       \
286       Name("SetSize").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
287       SetSizeOp<T>);
288 _SET_SIZE_REGISTER_KERNEL_BUILDER(int8);
289 _SET_SIZE_REGISTER_KERNEL_BUILDER(int16);
290 _SET_SIZE_REGISTER_KERNEL_BUILDER(int32);
291 _SET_SIZE_REGISTER_KERNEL_BUILDER(int64);
292 _SET_SIZE_REGISTER_KERNEL_BUILDER(uint8);
293 _SET_SIZE_REGISTER_KERNEL_BUILDER(uint16);
294 _SET_SIZE_REGISTER_KERNEL_BUILDER(tstring);
295 #undef _SET_SIZE_REGISTER_KERNEL_BUILDER
296 
297 enum InputTypes {
298   DENSE_DENSE = 0,
299   DENSE_SPARSE = 1,
300   SPARSE_SPARSE = 2,
301 };
302 
303 enum SetOperation { A_MINUS_B = 0, B_MINUS_A = 1, INTERSECTION = 2, UNION = 3 };
304 
SetOperationFromContext(OpKernelConstruction * ctx)305 SetOperation SetOperationFromContext(OpKernelConstruction* ctx) {
306   string set_operation_str;
307   if (!ctx->GetAttr("set_operation", &set_operation_str).ok()) {
308     ctx->CtxFailure(errors::InvalidArgument("Missing set_operation."));
309   } else {
310     std::transform(set_operation_str.begin(), set_operation_str.end(),
311                    set_operation_str.begin(), ::tolower);
312     if ("a-b" == set_operation_str) {
313       return A_MINUS_B;
314     }
315     if ("b-a" == set_operation_str) {
316       return B_MINUS_A;
317     }
318     if ("intersection" == set_operation_str) {
319       return INTERSECTION;
320     }
321     if ("union" != set_operation_str) {
322       ctx->CtxFailure(errors::InvalidArgument("Invalid set_operation ",
323                                               set_operation_str, "."));
324     }
325   }
326   // NOTE: This is not the default, this function fails if no 'set_operation'
327   // attribute is provided.
328   return UNION;
329 }
330 
331 // Abstract base class for performing set operations across the last dimension
332 // of 2 input tensors.
333 template <typename T>
334 class SetOperationOp : public OpKernel {
335  public:
SetOperationOp(OpKernelConstruction * ctx,InputTypes input_types)336   SetOperationOp(OpKernelConstruction* ctx, InputTypes input_types)
337       : OpKernel(ctx),
338         set_operation_(SetOperationFromContext(ctx)),
339         validate_indices_(ValidateIndicesFromContext(ctx)),
340         input_types_(input_types) {}
341 
342   void Compute(OpKernelContext* ctx) override;
343 
344  private:
345   void ApplySetOperation(const std::set<T>& set1, const std::set<T>& set2,
346                          std::set<T>* result) const;
347   void ComputeDenseToDense(OpKernelContext* ctx) const;
348   void ComputeDenseToSparse(OpKernelContext* ctx) const;
349   void ComputeSparseToSparse(OpKernelContext* ctx) const;
350   const SetOperation set_operation_;
351   const bool validate_indices_;
352   const InputTypes input_types_;
353 };
354 
355 template <typename T>
ApplySetOperation(const std::set<T> & set1,const std::set<T> & set2,std::set<T> * result) const356 void SetOperationOp<T>::ApplySetOperation(const std::set<T>& set1,
357                                           const std::set<T>& set2,
358                                           std::set<T>* result) const {
359   switch (set_operation_) {
360     case A_MINUS_B:
361       std::set_difference(set1.begin(), set1.end(), set2.begin(), set2.end(),
362                           std::inserter(*result, result->begin()));
363       break;
364     case B_MINUS_A:
365       std::set_difference(set2.begin(), set2.end(), set1.begin(), set1.end(),
366                           std::inserter(*result, result->begin()));
367       break;
368     case INTERSECTION:
369       std::set_intersection(set1.begin(), set1.end(), set2.begin(), set2.end(),
370                             std::inserter(*result, result->begin()));
371       break;
372     case UNION:
373       std::set_union(set1.begin(), set1.end(), set2.begin(), set2.end(),
374                      std::inserter(*result, result->begin()));
375       break;
376   }
377 }
378 
379 // Validate shapes have the same dimensions.
CheckShapesMatch(VarDimArray shape1,VarDimArray shape2)380 Status CheckShapesMatch(VarDimArray shape1, VarDimArray shape2) {
381   if (shape1 != shape2) {
382     return errors::InvalidArgument("Mismatched shapes [",
383                                    absl::StrJoin(shape1, ","), "] vs [",
384                                    absl::StrJoin(shape2, ","), "]");
385   }
386   return Status::OK();
387 }
388 
389 // Validate ranks are the same, and all but last dimension are the same.
390 // Return GroupShape.
GroupShapeFromInputs(VarDimArray shape1,VarDimArray shape2,ShapeArray * group_shape)391 Status GroupShapeFromInputs(VarDimArray shape1, VarDimArray shape2,
392                             ShapeArray* group_shape) {
393   ShapeArray group_shape_1;
394   TF_RETURN_IF_ERROR(GroupShape(shape1, &group_shape_1));
395   ShapeArray group_shape_2;
396   TF_RETURN_IF_ERROR(GroupShape(shape2, &group_shape_2));
397   TF_RETURN_IF_ERROR(CheckShapesMatch(group_shape_1, group_shape_2));
398   *group_shape = group_shape_1;
399   return Status::OK();
400 }
401 
402 // Split `flat_group_index` into separate dimensions based on `group_shape`.
PopulateGroupIndices(const int64_t flat_group_index,VarDimArray group_shape,std::vector<int64> * group_indices)403 void PopulateGroupIndices(const int64_t flat_group_index,
404                           VarDimArray group_shape,
405                           std::vector<int64>* group_indices) {
406   group_indices->clear();
407   int64_t running_flat_group_index = flat_group_index;
408   for (int group_dim_index = group_shape.size() - 1; group_dim_index >= 0;
409        --group_dim_index) {
410     const auto group_dim = group_shape[group_dim_index];
411     group_indices->insert(group_indices->begin(),
412                           running_flat_group_index % group_dim);
413     running_flat_group_index /= group_dim;
414   }
415 }
416 
TensorShapeToArray(const TensorShape & t)417 ShapeArray TensorShapeToArray(const TensorShape& t) {
418   ShapeArray vec(t.dims());
419   for (int i = 0; i < t.dims(); ++i) vec[i] = t.dim_size(i);
420   return vec;
421 };
422 
423 // `ctx` contains set1 and set2 dense tensors.
424 // Iterate over groups in set1 and set2, applying `ApplySetOperation` to each,
425 // and outputing the result `SparseTensor`. A "group" is a collection of values
426 // with the same first n-1 dimensions in set1 and set2.
427 template <typename T>
ComputeDenseToDense(OpKernelContext * ctx) const428 void SetOperationOp<T>::ComputeDenseToDense(OpKernelContext* ctx) const {
429   const Tensor& set1_t = ctx->input(0);
430   const Tensor& set2_t = ctx->input(1);
431   // The following should stay in sync with `_dense_to_dense_shape` shape
432   // assertions in python/ops/set_ops.py, and `SetShapeFn` for
433   // `DenseToDenseSetOperation` in ops/set_ops.cc.
434   ShapeArray group_shape;
435   const auto shape1 = TensorShapeToArray(set1_t.shape());
436   const auto shape2 = TensorShapeToArray(set2_t.shape());
437   OP_REQUIRES_OK(ctx, GroupShapeFromInputs(shape1, shape2, &group_shape));
438 
439   const auto set1_strides = Strides(shape1);
440   const auto set2_strides = Strides(shape2);
441 
442   std::map<std::vector<int64>, std::set<T>> group_sets;
443   int64_t num_result_values = 0;
444   int64_t max_set_size = 0;
445 
446   std::set<T> set1_group_set;
447   std::set<T> set2_group_set;
448   std::vector<int64> group_indices;
449   int64_t num_elements;
450   OP_REQUIRES_OK(ctx,
451                  TensorShapeUtils::NumElements(group_shape, &num_elements));
452   for (int64_t flat_group_index = 0; flat_group_index < num_elements;
453        ++flat_group_index) {
454     PopulateGroupIndices(flat_group_index, group_shape, &group_indices);
455     PopulateFromDenseGroup<T>(ctx, set1_t, set1_strides, group_indices,
456                               &set1_group_set);
457     PopulateFromDenseGroup<T>(ctx, set2_t, set2_strides, group_indices,
458                               &set2_group_set);
459 
460     std::set<T> group_set;
461     ApplySetOperation(set1_group_set, set2_group_set, &group_set);
462     if (!group_set.empty()) {
463       group_sets[group_indices] = group_set;
464       const auto set_size = group_set.size();
465       if (set_size > max_set_size) {
466         max_set_size = set_size;
467       }
468       num_result_values += set_size;
469     }
470   }
471 
472   TensorShape output_shape;
473   OP_REQUIRES_OK(ctx, TensorShapeUtils::MakeShape(group_shape, &output_shape));
474   output_shape.AddDim(max_set_size);
475   OutputSparseTensor<T>(ctx, output_shape, num_result_values, group_sets);
476 }
477 
478 // `ctx` contains dense set1 and sparse set2 tensors.
479 // Iterate over groups in set1 and set2, applying `ApplySetOperation` to each,
480 // and outputing the result `SparseTensor`. A "group" is a collection of values
481 // with the same first n-1 dimensions in set1 and set2.
482 template <typename T>
ComputeDenseToSparse(OpKernelContext * ctx) const483 void SetOperationOp<T>::ComputeDenseToSparse(OpKernelContext* ctx) const {
484   const Tensor& set1_t = ctx->input(0);
485   sparse::SparseTensor set2_st;
486   OP_REQUIRES_OK(ctx,
487                  SparseTensorFromContext(ctx, 1, validate_indices_, &set2_st));
488   OP_REQUIRES_OK(ctx, set2_st.IndicesValid());
489   // The following should stay in sync with `_dense_to_sparse_shape` shape
490   // assertions in python/ops/set_ops.py, and `SetShapeFn` for
491   // `DenseToSparseSetOperation` in ops/set_ops.cc.
492   ShapeArray group_shape;
493   OP_REQUIRES_OK(ctx, GroupShapeFromInputs(TensorShapeToArray(set1_t.shape()),
494                                            set2_st.shape(), &group_shape));
495 
496   const ShapeArray set1_strides = Strides(TensorShapeToArray(set1_t.shape()));
497 
498   std::map<std::vector<int64>, std::set<T>> group_sets;
499   int64_t num_result_values = 0;
500   int64_t max_set_size = 0;
501 
502   std::set<T> set1_group_set;
503   std::set<T> set2_group_set;
504   auto set2_grouper =
505       set2_st.group(set2_st.order().subspan(0, set2_st.order().size() - 1));
506   auto set2_group_it = set2_grouper.begin();
507   std::vector<int64> group_indices;
508   int64_t num_elements;
509   OP_REQUIRES_OK(ctx,
510                  TensorShapeUtils::NumElements(group_shape, &num_elements));
511   for (int64_t flat_group_index = 0; flat_group_index < num_elements;
512        ++flat_group_index) {
513     PopulateGroupIndices(flat_group_index, group_shape, &group_indices);
514 
515     // Get values from set1.
516     PopulateFromDenseGroup<T>(ctx, set1_t, set1_strides, group_indices,
517                               &set1_group_set);
518 
519     // Get values from set2, if applicable.
520     set2_group_set.clear();
521     if (set2_group_it != set2_grouper.end()) {
522       const auto& group = *set2_group_it;
523       const auto set2_group_indices = group.group();
524       OP_REQUIRES(
525           ctx, set2_group_indices.size() == group_indices.size(),
526           errors::InvalidArgument("Invalid number of group indices ",
527                                   set2_group_indices.size(), ", expected ",
528                                   group_indices.size(), "."));
529       bool group_match = true;
530       for (int32_t i = 0; group_match && (i < set2_group_indices.size()); ++i) {
531         if (set2_group_indices[i] != group_indices[i]) {
532           group_match = false;
533         }
534       }
535       if (group_match) {
536         PopulateFromSparseGroup<T>(ctx, group, set2_st.shape(),
537                                    &set2_group_set);
538         ++set2_group_it;
539       }
540     }
541 
542     std::set<T> group_set;
543     ApplySetOperation(set1_group_set, set2_group_set, &group_set);
544     if (!group_set.empty()) {
545       group_sets[group_indices] = group_set;
546       const auto set_size = group_set.size();
547       if (set_size > max_set_size) {
548         max_set_size = set_size;
549       }
550       num_result_values += set_size;
551     }
552   }
553 
554   TensorShape output_shape;
555   OP_REQUIRES_OK(ctx, TensorShapeUtils::MakeShape(group_shape, &output_shape));
556   output_shape.AddDim(max_set_size);
557   OutputSparseTensor<T>(ctx, output_shape, num_result_values, group_sets);
558 }
559 
560 // This is used to determine which group iterator is less than the other, based
561 // on row-major ordering of indices.
562 // An empty index list indicates end of iteration, which is interpreted as "max"
563 // for the purposes of comparison; i.e., non-empty < empty.
564 // Return 0 if both groups are empty, or both non-empty with the same values.
565 // Return <0 if set1 <= set2, or set2 is empty.
566 // Return >0 if set2 <= set1, or set1 is empty.
CompareGroups(OpKernelContext * ctx,const std::vector<int64> & set1_group_indices,const std::vector<int64> & set2_group_indices,int64 * result)567 void CompareGroups(OpKernelContext* ctx,
568                    const std::vector<int64>& set1_group_indices,
569                    const std::vector<int64>& set2_group_indices,
570                    int64* result) {
571   if (set1_group_indices.empty()) {
572     *result = set2_group_indices.empty() ? 0 : 1;
573     return;
574   }
575   if (set2_group_indices.empty()) {
576     *result = set1_group_indices.empty() ? 0 : -1;
577     return;
578   }
579   OP_REQUIRES(ctx, set1_group_indices.size() == set2_group_indices.size(),
580               errors::InvalidArgument("Mismatched group dims ",
581                                       set1_group_indices.size(), " vs ",
582                                       set2_group_indices.size(), "."));
583   for (int32_t i = 0; i < set1_group_indices.size(); ++i) {
584     *result = set1_group_indices[i] - set2_group_indices[i];
585     if (*result != 0) {
586       return;
587     }
588   }
589 }
590 
591 // Empty indices vector represents iteration end in `CompareGroups`.
592 const std::vector<int64> GROUP_ITER_END;
593 
594 // `ctx` contains set1 and set2 sparse tensors.
595 // Iterate over groups in set1 and set2, applying `ApplySetOperation` to each,
596 // and outputing the result `SparseTensor`. A "group" is a collection of values
597 // with the same first n-1 dimensions in set1 and set2.
598 template <typename T>
ComputeSparseToSparse(OpKernelContext * ctx) const599 void SetOperationOp<T>::ComputeSparseToSparse(OpKernelContext* ctx) const {
600   sparse::SparseTensor set1_st;
601   OP_REQUIRES_OK(ctx,
602                  SparseTensorFromContext(ctx, 0, validate_indices_, &set1_st));
603   OP_REQUIRES_OK(ctx, set1_st.IndicesValid());
604 
605   sparse::SparseTensor set2_st;
606   OP_REQUIRES_OK(ctx,
607                  SparseTensorFromContext(ctx, 3, validate_indices_, &set2_st));
608 
609   // The following should stay in sync with `_sparse_to_sparse_shape` shape
610   // assertions in python/ops/set_ops.py, and `SetShapeFn` for
611   // `SparseToSparseSetOperation` in ops/set_ops.cc.
612   ShapeArray group_shape;
613   OP_REQUIRES_OK(ctx, GroupShapeFromInputs(set1_st.shape(), set2_st.shape(),
614                                            &group_shape));
615 
616   const ShapeArray set1_strides = Strides(set1_st.shape());
617   const ShapeArray set2_strides = Strides(set2_st.shape());
618 
619   std::map<std::vector<int64>, std::set<T>> group_sets;
620   int64_t num_result_values = 0;
621   int64_t max_set_size = 0;
622 
623   std::set<T> set1_group_set;
624   std::set<T> set2_group_set;
625   auto set1_grouper =
626       set1_st.group(set1_st.order().subspan(0, set1_st.order().size() - 1));
627   auto set1_group_it = set1_grouper.begin();
628   auto set2_grouper =
629       set2_st.group(set2_st.order().subspan(0, set2_st.order().size() - 1));
630   auto set2_group_it = set2_grouper.begin();
631 
632   // Group by rows, and iterate over rows of both sets in parallel, creating a
633   // set for each row.
634   while ((set1_group_it != set1_grouper.end()) ||
635          (set2_group_it != set2_grouper.end())) {
636     const std::vector<int64>& set1_group_indices =
637         (set1_group_it == set1_grouper.end()) ? GROUP_ITER_END
638                                               : (*set1_group_it).group();
639     const std::vector<int64>& set2_group_indices =
640         (set2_group_it == set2_grouper.end()) ? GROUP_ITER_END
641                                               : (*set2_group_it).group();
642 
643     int64_t compare_groups;
644     CompareGroups(ctx, set1_group_indices, set2_group_indices, &compare_groups);
645     const std::vector<int64>* group_indices = nullptr;
646 
647     // Get values from set1, if applicable.
648     set1_group_set.clear();
649     if (compare_groups <= 0) {
650       PopulateFromSparseGroup<T>(ctx, *set1_group_it, set1_st.shape(),
651                                  &set1_group_set);
652       ++set1_group_it;
653       group_indices = &set1_group_indices;
654     }
655 
656     // Get values from set2, if applicable.
657     set2_group_set.clear();
658     if (compare_groups >= 0) {
659       PopulateFromSparseGroup<T>(ctx, *set2_group_it, set2_st.shape(),
660                                  &set2_group_set);
661       ++set2_group_it;
662       group_indices = &set2_group_indices;
663     }
664 
665     std::set<T> group_set;
666     ApplySetOperation(set1_group_set, set2_group_set, &group_set);
667     if (!group_set.empty()) {
668       group_sets[*group_indices] = group_set;
669       const auto set_size = group_set.size();
670       if (set_size > max_set_size) {
671         max_set_size = set_size;
672       }
673       num_result_values += set_size;
674     }
675   }
676 
677   TensorShape output_shape;
678   OP_REQUIRES_OK(ctx, TensorShapeUtils::MakeShape(group_shape, &output_shape));
679   output_shape.AddDim(max_set_size);
680   OutputSparseTensor<T>(ctx, output_shape, num_result_values, group_sets);
681 }
682 
683 // Given set1 of shape [b, n1] and data_2 of shape [b, n2], populate result
684 // sparse tensor with [b, n3] values, where each row `i` contains the result of
685 // the set operation on elements from set1[i] and set2[i]. `n3` is the number
686 // of elements in that result row.
687 template <typename T>
Compute(OpKernelContext * ctx)688 void SetOperationOp<T>::Compute(OpKernelContext* ctx) {
689   switch (input_types_) {
690     case DENSE_DENSE:
691       ComputeDenseToDense(ctx);
692       break;
693     case DENSE_SPARSE:
694       ComputeDenseToSparse(ctx);
695       break;
696     case SPARSE_SPARSE:
697       ComputeSparseToSparse(ctx);
698       break;
699   }
700 }
701 
702 template <typename T>
703 class DenseToDenseSetOperationOp : public SetOperationOp<T> {
704  public:
DenseToDenseSetOperationOp(OpKernelConstruction * ctx)705   explicit DenseToDenseSetOperationOp(OpKernelConstruction* ctx)
706       : SetOperationOp<T>(ctx, DENSE_DENSE) {}
707 };
708 
709 #define _DENSE_TO_DENSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(T) \
710   REGISTER_KERNEL_BUILDER(Name("DenseToDenseSetOperation")       \
711                               .Device(DEVICE_CPU)                \
712                               .TypeConstraint<T>("T"),           \
713                           DenseToDenseSetOperationOp<T>);
714 _DENSE_TO_DENSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(int8);
715 _DENSE_TO_DENSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(int16);
716 _DENSE_TO_DENSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(int32);
717 _DENSE_TO_DENSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(int64);
718 _DENSE_TO_DENSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(uint8);
719 _DENSE_TO_DENSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(uint16);
720 _DENSE_TO_DENSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(tstring);
721 #undef _DENSE_TO_DENSE_SET_OPERATION_REGISTER_KERNEL_BUILDER
722 
723 template <typename T>
724 class DenseToSparseSetOperationOp : public SetOperationOp<T> {
725  public:
DenseToSparseSetOperationOp(OpKernelConstruction * ctx)726   explicit DenseToSparseSetOperationOp(OpKernelConstruction* ctx)
727       : SetOperationOp<T>(ctx, DENSE_SPARSE) {}
728 };
729 
730 #define _DENSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(T) \
731   REGISTER_KERNEL_BUILDER(Name("DenseToSparseSetOperation")       \
732                               .Device(DEVICE_CPU)                 \
733                               .TypeConstraint<T>("T"),            \
734                           DenseToSparseSetOperationOp<T>);
735 _DENSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(int8);
736 _DENSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(int16);
737 _DENSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(int32);
738 _DENSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(int64);
739 _DENSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(uint8);
740 _DENSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(uint16);
741 _DENSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(tstring);
742 #undef _DENSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER
743 
744 template <typename T>
745 class SparseToSparseSetOperationOp : public SetOperationOp<T> {
746  public:
SparseToSparseSetOperationOp(OpKernelConstruction * ctx)747   explicit SparseToSparseSetOperationOp(OpKernelConstruction* ctx)
748       : SetOperationOp<T>(ctx, SPARSE_SPARSE) {}
749 };
750 
751 #define _SPARSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(T) \
752   REGISTER_KERNEL_BUILDER(Name("SparseToSparseSetOperation")       \
753                               .Device(DEVICE_CPU)                  \
754                               .TypeConstraint<T>("T"),             \
755                           SparseToSparseSetOperationOp<T>);
756 _SPARSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(int8);
757 _SPARSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(int16);
758 _SPARSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(int32);
759 _SPARSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(int64);
760 _SPARSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(uint8);
761 _SPARSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(uint16);
762 _SPARSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(tstring);
763 #undef _SPARSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER
764 
765 }  // namespace tensorflow
766