1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Ops for operating with sets. They are not checked in
17 // to TensorFlow because we would first like to demonstrate successful
18 // end-to-end use of these ops in eval and polish the api a bit like taking two
19 // SparseTensor rather than on edense and one sparse.
20
21 #define EIGEN_USE_THREADS
22
23 #include <algorithm>
24 #include <numeric>
25 // TODO(ptucker): Consider switching back to hash_set - I had trouble getting it
26 // to work with string values.
27 #include <set>
28 #include <string>
29
30 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
31 #include "tensorflow/core/framework/op_kernel.h"
32 #include "tensorflow/core/framework/register_types.h"
33 #include "tensorflow/core/framework/tensor.h"
34 #include "tensorflow/core/framework/tensor_util.h"
35 #include "tensorflow/core/framework/types.h"
36 #include "tensorflow/core/lib/core/status.h"
37 #include "tensorflow/core/platform/env.h"
38 #include "tensorflow/core/util/sparse/sparse_tensor.h"
39
40 namespace tensorflow {
41
42 using ShapeArray = sparse::SparseTensor::ShapeArray;
43 using VarDimArray = sparse::SparseTensor::VarDimArray;
44
45 // Validate rank >= 2.
CheckRankAtLeast2(OpKernelContext * ctx,const TensorShape & shape)46 void CheckRankAtLeast2(OpKernelContext* ctx, const TensorShape& shape) {
47 const auto rank = shape.dims();
48 OP_REQUIRES(ctx, rank >= 2,
49 errors::InvalidArgument("Invalid rank ", rank, "."));
50 }
51
52 // Return group shape, which is the 1st n-1 dimensions of shape.
GroupShape(const VarDimArray & input_shape,ShapeArray * grouped_shape)53 Status GroupShape(const VarDimArray& input_shape, ShapeArray* grouped_shape) {
54 if (input_shape.size() < 2) {
55 // TODO(irving): Why can't 2 be 1 here?
56 return errors::InvalidArgument("Shape [", absl::StrJoin(input_shape, ","),
57 "] has rank ", input_shape.size(), " < 2");
58 }
59 // grouped_shape is input_shape[:-1]
60 *grouped_shape = ShapeArray(input_shape.begin(), input_shape.end() - 1);
61 return Status::OK();
62 }
63
64 // Build `SparseTensor` from indices, values, and shape in inputs
65 // [base_index, base_index + 3), and validate its rank and indices.
SparseTensorFromContext(OpKernelContext * ctx,const int32_t base_index,bool validate_indices,sparse::SparseTensor * tensor)66 Status SparseTensorFromContext(OpKernelContext* ctx, const int32_t base_index,
67 bool validate_indices,
68 sparse::SparseTensor* tensor) {
69 // Assume row-major order.
70 const TensorShape shape =
71 TensorShape(ctx->input(base_index + 2).vec<int64>());
72 CheckRankAtLeast2(ctx, shape);
73 std::vector<int64> order(shape.dims());
74 std::iota(order.begin(), order.end(), 0);
75
76 return sparse::SparseTensor::Create(
77 ctx->input(base_index), ctx->input(base_index + 1), shape, order, tensor);
78 }
79
80 // TODO(ptucker): CheckGroup is just a sanity check on the result of
81 // SparseTensor.group, consider removing.
82 // `sparse_tensor_shape` is the shape of the `SparseTensor` from which group
83 // was created, and is used to sanity check the indices in `group'.
84 template <typename T>
CheckGroup(OpKernelContext * ctx,const sparse::Group & group,const VarDimArray & sparse_tensor_shape)85 void CheckGroup(OpKernelContext* ctx, const sparse::Group& group,
86 const VarDimArray& sparse_tensor_shape) {
87 const auto& indices = group.indices();
88 const auto& values = group.values<T>();
89
90 // Sanity check: group is non-empty, and indices and values are same size.
91 const auto num_values = values.dimension(0);
92 OP_REQUIRES(ctx, indices.size() > 0, errors::Internal("Empty group."));
93 OP_REQUIRES(
94 ctx, indices.dimension(0) == num_values,
95 errors::Internal("shape[0] of group indices ", indices.dimension(0),
96 " != values ", num_values, "."));
97
98 // Sanity check: valid indices.
99 const auto group_rank = indices.dimension(1);
100 const auto expected_rank = sparse_tensor_shape.size();
101 OP_REQUIRES(ctx, expected_rank == group_rank,
102 errors::Internal("Rank expected ", expected_rank, ", got ",
103 group_rank, "."));
104 for (int32_t j = 0; j < expected_rank; ++j) {
105 const auto dim_size = sparse_tensor_shape[j];
106 OP_REQUIRES(
107 ctx, dim_size > 0,
108 errors::Internal("Invalid dim_size[", j, "] = ", dim_size, "."));
109 for (int64_t i = 0; i < num_values; ++i) {
110 const auto index = indices(i, j);
111 OP_REQUIRES(ctx, dim_size > index,
112 errors::Internal("indices[", i, ", ", j, "] expected < ",
113 dim_size, ", got ", index, "."));
114 }
115 }
116 }
117
118 // This lets us calculate the row-major index into flattened output.
Strides(const VarDimArray & shape)119 const ShapeArray Strides(const VarDimArray& shape) {
120 ShapeArray result(shape.size());
121 int64_t product = 1;
122 for (int i = shape.size() - 1; i >= 0; --i) {
123 result[i] = product;
124 product *= shape[i];
125 }
126 return result;
127 }
128
129 // TODO(ptucker): If memory becomes an issue, consider a 2-pass approach to
130 // eliminate the intermediate `values` data structure - iterate once to
131 // determine `num_values`, allocate output tensors, then write results directly
132 // to output tensors.
133
134 // TODO(ptucker): Consider sharding work across multiple threads. See
135 // SparseCrossOp for an example.
136
137 // Output `SparseTensor` of shape `output_shape`. `sets` contains a map of
138 // group indices (i.e., values for all but the last dimension of `output_shape`)
139 // to set values, each of which will occupy the last dimension of
140 // `output_shape`.
141 template <typename T>
OutputSparseTensor(OpKernelContext * ctx,const TensorShape & output_shape,const int64_t num_values,const std::map<std::vector<int64>,std::set<T>> & sets)142 void OutputSparseTensor(OpKernelContext* ctx, const TensorShape& output_shape,
143 const int64_t num_values,
144 const std::map<std::vector<int64>, std::set<T>>& sets) {
145 // Allocate 3 output tensors for sparse data.
146 Tensor *out_indices_t, *out_values_t, *out_shape_t;
147 OP_REQUIRES_OK(ctx, ctx->allocate_output(
148 0, TensorShape({num_values, output_shape.dims()}),
149 &out_indices_t));
150 OP_REQUIRES_OK(
151 ctx, ctx->allocate_output(1, TensorShape({num_values}), &out_values_t));
152 OP_REQUIRES_OK(ctx, ctx->allocate_output(
153 2, TensorShape({output_shape.dims()}), &out_shape_t));
154 auto out_indices_mat = out_indices_t->matrix<int64>();
155 auto out_values_flat = out_values_t->vec<T>();
156
157 // For each set, write its indices and values to output tensors.
158 int64_t value_index = 0;
159 for (auto it = sets.begin(); it != sets.end(); ++it) {
160 const auto& group_indices = it->first;
161 OP_REQUIRES(
162 ctx, group_indices.size() == output_shape.dims() - 1,
163 errors::Internal("Invalid number of indices ", group_indices.size(),
164 ", expected ", output_shape.dims() - 1, "."));
165 const auto& set = it->second;
166
167 // For each set item, write its indices and value to output tensors.
168 int64_t group_value_index = 0;
169 for (auto value = set.begin(); value != set.end();
170 ++value, ++value_index, ++group_value_index) {
171 // First n-1 dimensions are the group, last dimension is the position in
172 // the set.
173 for (int32_t i = 0; i < group_indices.size(); ++i) {
174 out_indices_mat(value_index, i) = group_indices[i];
175 }
176 out_indices_mat(value_index, group_indices.size()) = group_value_index;
177
178 out_values_flat(value_index) = *value;
179 }
180 }
181
182 // Write output shape.
183 auto out_shape_flat = out_shape_t->vec<int64>();
184 for (int32_t i = 0; i < output_shape.dims(); ++i) {
185 out_shape_flat(i) = output_shape.dim_size(i);
186 }
187 }
188
ValidateIndicesFromContext(OpKernelConstruction * ctx)189 bool ValidateIndicesFromContext(OpKernelConstruction* ctx) {
190 bool result;
191 if (ctx->GetAttr("validate_indices", &result).ok()) {
192 return result;
193 }
194 return true;
195 }
196
197 // Populate `result` set from group in `tensor`. "Group" is defined by
198 // `group_indices`, which are values for the first n-1 dimensions of
199 // `input_tensor`. `input_strides` is provided to avoid recalculating it
200 // multiple times, and is used to calculate the flat index into `input_tensor`
201 // values.
202 template <typename T>
PopulateFromDenseGroup(OpKernelContext * ctx,const Tensor & input_tensor,const VarDimArray & input_strides,const std::vector<int64> & group_indices,std::set<T> * result)203 void PopulateFromDenseGroup(OpKernelContext* ctx, const Tensor& input_tensor,
204 const VarDimArray& input_strides,
205 const std::vector<int64>& group_indices,
206 std::set<T>* result) {
207 OP_REQUIRES(ctx, group_indices.size() == input_strides.size() - 1,
208 errors::Internal("group_indices.size ", group_indices.size(),
209 ", != input_strides.size-1 ",
210 input_strides.size() - 1, "."));
211 result->clear();
212 auto input_flat = input_tensor.flat<T>();
213 const auto start = std::inner_product(
214 group_indices.begin(), group_indices.end(), input_strides.begin(), 0LL);
215 const TensorShape& input_shape = input_tensor.shape();
216 const auto end = start + input_shape.dim_size(input_shape.dims() - 1);
217 for (int64_t i = start; i < end; ++i) {
218 result->insert(input_flat(i));
219 }
220 }
221
222 // Populate `result` set from `group`. `sparse_tensor_shape` is the shape of the
223 // `SparseTensor` from which group was created, and is used to sanity check the
224 // indices in `group'.
225 template <typename T>
PopulateFromSparseGroup(OpKernelContext * ctx,const sparse::Group & group,const VarDimArray & sparse_tensor_shape,std::set<T> * result)226 void PopulateFromSparseGroup(OpKernelContext* ctx, const sparse::Group& group,
227 const VarDimArray& sparse_tensor_shape,
228 std::set<T>* result) {
229 CheckGroup<T>(ctx, group, sparse_tensor_shape);
230 result->clear();
231 const auto& group_values = group.values<T>();
232 for (int64_t i = 0; i < group_values.size(); ++i) {
233 result->insert(group_values(i));
234 }
235 }
236
237 template <typename T>
238 class SetSizeOp : public OpKernel {
239 public:
SetSizeOp(OpKernelConstruction * ctx)240 explicit SetSizeOp(OpKernelConstruction* ctx)
241 : OpKernel(ctx), validate_indices_(ValidateIndicesFromContext(ctx)) {}
242
243 void Compute(OpKernelContext* ctx) override;
244
245 private:
246 const bool validate_indices_;
247 };
248
249 template <typename T>
Compute(OpKernelContext * ctx)250 void SetSizeOp<T>::Compute(OpKernelContext* ctx) {
251 sparse::SparseTensor set_st;
252 OP_REQUIRES_OK(ctx,
253 SparseTensorFromContext(ctx, 0, validate_indices_, &set_st));
254 OP_REQUIRES_OK(ctx, set_st.IndicesValid());
255
256 // Output shape is same as input except for last dimension, which reduces
257 // to the set size of values along that dimension.
258 ShapeArray output_shape;
259 OP_REQUIRES_OK(ctx, GroupShape(set_st.shape(), &output_shape));
260 const auto output_strides = Strides(output_shape);
261
262 TensorShape output_shape_ts;
263 OP_REQUIRES_OK(ctx,
264 TensorShapeUtils::MakeShape(output_shape, &output_shape_ts));
265 Tensor* out_t;
266 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, output_shape_ts, &out_t));
267 auto out = out_t->flat<int32>();
268 out.device(ctx->eigen_cpu_device()) = out.constant(static_cast<int32>(0.0));
269
270 // Group by all but last dimension, create a set of group values, and add set
271 // size to output.
272 VarDimArray group_ix = set_st.order().subspan(0, set_st.order().size() - 1);
273 std::set<T> group_set;
274 for (const auto& group : set_st.group(group_ix)) {
275 PopulateFromSparseGroup<T>(ctx, group, set_st.shape(), &group_set);
276
277 const auto group_key = group.group();
278 const auto output_index = std::inner_product(
279 group_key.begin(), group_key.end(), output_strides.begin(), 0LL);
280 out(output_index) = group_set.size();
281 }
282 }
283
284 #define _SET_SIZE_REGISTER_KERNEL_BUILDER(T) \
285 REGISTER_KERNEL_BUILDER( \
286 Name("SetSize").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
287 SetSizeOp<T>);
288 _SET_SIZE_REGISTER_KERNEL_BUILDER(int8);
289 _SET_SIZE_REGISTER_KERNEL_BUILDER(int16);
290 _SET_SIZE_REGISTER_KERNEL_BUILDER(int32);
291 _SET_SIZE_REGISTER_KERNEL_BUILDER(int64);
292 _SET_SIZE_REGISTER_KERNEL_BUILDER(uint8);
293 _SET_SIZE_REGISTER_KERNEL_BUILDER(uint16);
294 _SET_SIZE_REGISTER_KERNEL_BUILDER(tstring);
295 #undef _SET_SIZE_REGISTER_KERNEL_BUILDER
296
297 enum InputTypes {
298 DENSE_DENSE = 0,
299 DENSE_SPARSE = 1,
300 SPARSE_SPARSE = 2,
301 };
302
303 enum SetOperation { A_MINUS_B = 0, B_MINUS_A = 1, INTERSECTION = 2, UNION = 3 };
304
SetOperationFromContext(OpKernelConstruction * ctx)305 SetOperation SetOperationFromContext(OpKernelConstruction* ctx) {
306 string set_operation_str;
307 if (!ctx->GetAttr("set_operation", &set_operation_str).ok()) {
308 ctx->CtxFailure(errors::InvalidArgument("Missing set_operation."));
309 } else {
310 std::transform(set_operation_str.begin(), set_operation_str.end(),
311 set_operation_str.begin(), ::tolower);
312 if ("a-b" == set_operation_str) {
313 return A_MINUS_B;
314 }
315 if ("b-a" == set_operation_str) {
316 return B_MINUS_A;
317 }
318 if ("intersection" == set_operation_str) {
319 return INTERSECTION;
320 }
321 if ("union" != set_operation_str) {
322 ctx->CtxFailure(errors::InvalidArgument("Invalid set_operation ",
323 set_operation_str, "."));
324 }
325 }
326 // NOTE: This is not the default, this function fails if no 'set_operation'
327 // attribute is provided.
328 return UNION;
329 }
330
331 // Abstract base class for performing set operations across the last dimension
332 // of 2 input tensors.
333 template <typename T>
334 class SetOperationOp : public OpKernel {
335 public:
SetOperationOp(OpKernelConstruction * ctx,InputTypes input_types)336 SetOperationOp(OpKernelConstruction* ctx, InputTypes input_types)
337 : OpKernel(ctx),
338 set_operation_(SetOperationFromContext(ctx)),
339 validate_indices_(ValidateIndicesFromContext(ctx)),
340 input_types_(input_types) {}
341
342 void Compute(OpKernelContext* ctx) override;
343
344 private:
345 void ApplySetOperation(const std::set<T>& set1, const std::set<T>& set2,
346 std::set<T>* result) const;
347 void ComputeDenseToDense(OpKernelContext* ctx) const;
348 void ComputeDenseToSparse(OpKernelContext* ctx) const;
349 void ComputeSparseToSparse(OpKernelContext* ctx) const;
350 const SetOperation set_operation_;
351 const bool validate_indices_;
352 const InputTypes input_types_;
353 };
354
355 template <typename T>
ApplySetOperation(const std::set<T> & set1,const std::set<T> & set2,std::set<T> * result) const356 void SetOperationOp<T>::ApplySetOperation(const std::set<T>& set1,
357 const std::set<T>& set2,
358 std::set<T>* result) const {
359 switch (set_operation_) {
360 case A_MINUS_B:
361 std::set_difference(set1.begin(), set1.end(), set2.begin(), set2.end(),
362 std::inserter(*result, result->begin()));
363 break;
364 case B_MINUS_A:
365 std::set_difference(set2.begin(), set2.end(), set1.begin(), set1.end(),
366 std::inserter(*result, result->begin()));
367 break;
368 case INTERSECTION:
369 std::set_intersection(set1.begin(), set1.end(), set2.begin(), set2.end(),
370 std::inserter(*result, result->begin()));
371 break;
372 case UNION:
373 std::set_union(set1.begin(), set1.end(), set2.begin(), set2.end(),
374 std::inserter(*result, result->begin()));
375 break;
376 }
377 }
378
379 // Validate shapes have the same dimensions.
CheckShapesMatch(VarDimArray shape1,VarDimArray shape2)380 Status CheckShapesMatch(VarDimArray shape1, VarDimArray shape2) {
381 if (shape1 != shape2) {
382 return errors::InvalidArgument("Mismatched shapes [",
383 absl::StrJoin(shape1, ","), "] vs [",
384 absl::StrJoin(shape2, ","), "]");
385 }
386 return Status::OK();
387 }
388
389 // Validate ranks are the same, and all but last dimension are the same.
390 // Return GroupShape.
GroupShapeFromInputs(VarDimArray shape1,VarDimArray shape2,ShapeArray * group_shape)391 Status GroupShapeFromInputs(VarDimArray shape1, VarDimArray shape2,
392 ShapeArray* group_shape) {
393 ShapeArray group_shape_1;
394 TF_RETURN_IF_ERROR(GroupShape(shape1, &group_shape_1));
395 ShapeArray group_shape_2;
396 TF_RETURN_IF_ERROR(GroupShape(shape2, &group_shape_2));
397 TF_RETURN_IF_ERROR(CheckShapesMatch(group_shape_1, group_shape_2));
398 *group_shape = group_shape_1;
399 return Status::OK();
400 }
401
402 // Split `flat_group_index` into separate dimensions based on `group_shape`.
PopulateGroupIndices(const int64_t flat_group_index,VarDimArray group_shape,std::vector<int64> * group_indices)403 void PopulateGroupIndices(const int64_t flat_group_index,
404 VarDimArray group_shape,
405 std::vector<int64>* group_indices) {
406 group_indices->clear();
407 int64_t running_flat_group_index = flat_group_index;
408 for (int group_dim_index = group_shape.size() - 1; group_dim_index >= 0;
409 --group_dim_index) {
410 const auto group_dim = group_shape[group_dim_index];
411 group_indices->insert(group_indices->begin(),
412 running_flat_group_index % group_dim);
413 running_flat_group_index /= group_dim;
414 }
415 }
416
TensorShapeToArray(const TensorShape & t)417 ShapeArray TensorShapeToArray(const TensorShape& t) {
418 ShapeArray vec(t.dims());
419 for (int i = 0; i < t.dims(); ++i) vec[i] = t.dim_size(i);
420 return vec;
421 };
422
423 // `ctx` contains set1 and set2 dense tensors.
424 // Iterate over groups in set1 and set2, applying `ApplySetOperation` to each,
425 // and outputing the result `SparseTensor`. A "group" is a collection of values
426 // with the same first n-1 dimensions in set1 and set2.
427 template <typename T>
ComputeDenseToDense(OpKernelContext * ctx) const428 void SetOperationOp<T>::ComputeDenseToDense(OpKernelContext* ctx) const {
429 const Tensor& set1_t = ctx->input(0);
430 const Tensor& set2_t = ctx->input(1);
431 // The following should stay in sync with `_dense_to_dense_shape` shape
432 // assertions in python/ops/set_ops.py, and `SetShapeFn` for
433 // `DenseToDenseSetOperation` in ops/set_ops.cc.
434 ShapeArray group_shape;
435 const auto shape1 = TensorShapeToArray(set1_t.shape());
436 const auto shape2 = TensorShapeToArray(set2_t.shape());
437 OP_REQUIRES_OK(ctx, GroupShapeFromInputs(shape1, shape2, &group_shape));
438
439 const auto set1_strides = Strides(shape1);
440 const auto set2_strides = Strides(shape2);
441
442 std::map<std::vector<int64>, std::set<T>> group_sets;
443 int64_t num_result_values = 0;
444 int64_t max_set_size = 0;
445
446 std::set<T> set1_group_set;
447 std::set<T> set2_group_set;
448 std::vector<int64> group_indices;
449 int64_t num_elements;
450 OP_REQUIRES_OK(ctx,
451 TensorShapeUtils::NumElements(group_shape, &num_elements));
452 for (int64_t flat_group_index = 0; flat_group_index < num_elements;
453 ++flat_group_index) {
454 PopulateGroupIndices(flat_group_index, group_shape, &group_indices);
455 PopulateFromDenseGroup<T>(ctx, set1_t, set1_strides, group_indices,
456 &set1_group_set);
457 PopulateFromDenseGroup<T>(ctx, set2_t, set2_strides, group_indices,
458 &set2_group_set);
459
460 std::set<T> group_set;
461 ApplySetOperation(set1_group_set, set2_group_set, &group_set);
462 if (!group_set.empty()) {
463 group_sets[group_indices] = group_set;
464 const auto set_size = group_set.size();
465 if (set_size > max_set_size) {
466 max_set_size = set_size;
467 }
468 num_result_values += set_size;
469 }
470 }
471
472 TensorShape output_shape;
473 OP_REQUIRES_OK(ctx, TensorShapeUtils::MakeShape(group_shape, &output_shape));
474 output_shape.AddDim(max_set_size);
475 OutputSparseTensor<T>(ctx, output_shape, num_result_values, group_sets);
476 }
477
478 // `ctx` contains dense set1 and sparse set2 tensors.
479 // Iterate over groups in set1 and set2, applying `ApplySetOperation` to each,
480 // and outputing the result `SparseTensor`. A "group" is a collection of values
481 // with the same first n-1 dimensions in set1 and set2.
482 template <typename T>
ComputeDenseToSparse(OpKernelContext * ctx) const483 void SetOperationOp<T>::ComputeDenseToSparse(OpKernelContext* ctx) const {
484 const Tensor& set1_t = ctx->input(0);
485 sparse::SparseTensor set2_st;
486 OP_REQUIRES_OK(ctx,
487 SparseTensorFromContext(ctx, 1, validate_indices_, &set2_st));
488 OP_REQUIRES_OK(ctx, set2_st.IndicesValid());
489 // The following should stay in sync with `_dense_to_sparse_shape` shape
490 // assertions in python/ops/set_ops.py, and `SetShapeFn` for
491 // `DenseToSparseSetOperation` in ops/set_ops.cc.
492 ShapeArray group_shape;
493 OP_REQUIRES_OK(ctx, GroupShapeFromInputs(TensorShapeToArray(set1_t.shape()),
494 set2_st.shape(), &group_shape));
495
496 const ShapeArray set1_strides = Strides(TensorShapeToArray(set1_t.shape()));
497
498 std::map<std::vector<int64>, std::set<T>> group_sets;
499 int64_t num_result_values = 0;
500 int64_t max_set_size = 0;
501
502 std::set<T> set1_group_set;
503 std::set<T> set2_group_set;
504 auto set2_grouper =
505 set2_st.group(set2_st.order().subspan(0, set2_st.order().size() - 1));
506 auto set2_group_it = set2_grouper.begin();
507 std::vector<int64> group_indices;
508 int64_t num_elements;
509 OP_REQUIRES_OK(ctx,
510 TensorShapeUtils::NumElements(group_shape, &num_elements));
511 for (int64_t flat_group_index = 0; flat_group_index < num_elements;
512 ++flat_group_index) {
513 PopulateGroupIndices(flat_group_index, group_shape, &group_indices);
514
515 // Get values from set1.
516 PopulateFromDenseGroup<T>(ctx, set1_t, set1_strides, group_indices,
517 &set1_group_set);
518
519 // Get values from set2, if applicable.
520 set2_group_set.clear();
521 if (set2_group_it != set2_grouper.end()) {
522 const auto& group = *set2_group_it;
523 const auto set2_group_indices = group.group();
524 OP_REQUIRES(
525 ctx, set2_group_indices.size() == group_indices.size(),
526 errors::InvalidArgument("Invalid number of group indices ",
527 set2_group_indices.size(), ", expected ",
528 group_indices.size(), "."));
529 bool group_match = true;
530 for (int32_t i = 0; group_match && (i < set2_group_indices.size()); ++i) {
531 if (set2_group_indices[i] != group_indices[i]) {
532 group_match = false;
533 }
534 }
535 if (group_match) {
536 PopulateFromSparseGroup<T>(ctx, group, set2_st.shape(),
537 &set2_group_set);
538 ++set2_group_it;
539 }
540 }
541
542 std::set<T> group_set;
543 ApplySetOperation(set1_group_set, set2_group_set, &group_set);
544 if (!group_set.empty()) {
545 group_sets[group_indices] = group_set;
546 const auto set_size = group_set.size();
547 if (set_size > max_set_size) {
548 max_set_size = set_size;
549 }
550 num_result_values += set_size;
551 }
552 }
553
554 TensorShape output_shape;
555 OP_REQUIRES_OK(ctx, TensorShapeUtils::MakeShape(group_shape, &output_shape));
556 output_shape.AddDim(max_set_size);
557 OutputSparseTensor<T>(ctx, output_shape, num_result_values, group_sets);
558 }
559
560 // This is used to determine which group iterator is less than the other, based
561 // on row-major ordering of indices.
562 // An empty index list indicates end of iteration, which is interpreted as "max"
563 // for the purposes of comparison; i.e., non-empty < empty.
564 // Return 0 if both groups are empty, or both non-empty with the same values.
565 // Return <0 if set1 <= set2, or set2 is empty.
566 // Return >0 if set2 <= set1, or set1 is empty.
CompareGroups(OpKernelContext * ctx,const std::vector<int64> & set1_group_indices,const std::vector<int64> & set2_group_indices,int64 * result)567 void CompareGroups(OpKernelContext* ctx,
568 const std::vector<int64>& set1_group_indices,
569 const std::vector<int64>& set2_group_indices,
570 int64* result) {
571 if (set1_group_indices.empty()) {
572 *result = set2_group_indices.empty() ? 0 : 1;
573 return;
574 }
575 if (set2_group_indices.empty()) {
576 *result = set1_group_indices.empty() ? 0 : -1;
577 return;
578 }
579 OP_REQUIRES(ctx, set1_group_indices.size() == set2_group_indices.size(),
580 errors::InvalidArgument("Mismatched group dims ",
581 set1_group_indices.size(), " vs ",
582 set2_group_indices.size(), "."));
583 for (int32_t i = 0; i < set1_group_indices.size(); ++i) {
584 *result = set1_group_indices[i] - set2_group_indices[i];
585 if (*result != 0) {
586 return;
587 }
588 }
589 }
590
591 // Empty indices vector represents iteration end in `CompareGroups`.
592 const std::vector<int64> GROUP_ITER_END;
593
594 // `ctx` contains set1 and set2 sparse tensors.
595 // Iterate over groups in set1 and set2, applying `ApplySetOperation` to each,
596 // and outputing the result `SparseTensor`. A "group" is a collection of values
597 // with the same first n-1 dimensions in set1 and set2.
598 template <typename T>
ComputeSparseToSparse(OpKernelContext * ctx) const599 void SetOperationOp<T>::ComputeSparseToSparse(OpKernelContext* ctx) const {
600 sparse::SparseTensor set1_st;
601 OP_REQUIRES_OK(ctx,
602 SparseTensorFromContext(ctx, 0, validate_indices_, &set1_st));
603 OP_REQUIRES_OK(ctx, set1_st.IndicesValid());
604
605 sparse::SparseTensor set2_st;
606 OP_REQUIRES_OK(ctx,
607 SparseTensorFromContext(ctx, 3, validate_indices_, &set2_st));
608
609 // The following should stay in sync with `_sparse_to_sparse_shape` shape
610 // assertions in python/ops/set_ops.py, and `SetShapeFn` for
611 // `SparseToSparseSetOperation` in ops/set_ops.cc.
612 ShapeArray group_shape;
613 OP_REQUIRES_OK(ctx, GroupShapeFromInputs(set1_st.shape(), set2_st.shape(),
614 &group_shape));
615
616 const ShapeArray set1_strides = Strides(set1_st.shape());
617 const ShapeArray set2_strides = Strides(set2_st.shape());
618
619 std::map<std::vector<int64>, std::set<T>> group_sets;
620 int64_t num_result_values = 0;
621 int64_t max_set_size = 0;
622
623 std::set<T> set1_group_set;
624 std::set<T> set2_group_set;
625 auto set1_grouper =
626 set1_st.group(set1_st.order().subspan(0, set1_st.order().size() - 1));
627 auto set1_group_it = set1_grouper.begin();
628 auto set2_grouper =
629 set2_st.group(set2_st.order().subspan(0, set2_st.order().size() - 1));
630 auto set2_group_it = set2_grouper.begin();
631
632 // Group by rows, and iterate over rows of both sets in parallel, creating a
633 // set for each row.
634 while ((set1_group_it != set1_grouper.end()) ||
635 (set2_group_it != set2_grouper.end())) {
636 const std::vector<int64>& set1_group_indices =
637 (set1_group_it == set1_grouper.end()) ? GROUP_ITER_END
638 : (*set1_group_it).group();
639 const std::vector<int64>& set2_group_indices =
640 (set2_group_it == set2_grouper.end()) ? GROUP_ITER_END
641 : (*set2_group_it).group();
642
643 int64_t compare_groups;
644 CompareGroups(ctx, set1_group_indices, set2_group_indices, &compare_groups);
645 const std::vector<int64>* group_indices = nullptr;
646
647 // Get values from set1, if applicable.
648 set1_group_set.clear();
649 if (compare_groups <= 0) {
650 PopulateFromSparseGroup<T>(ctx, *set1_group_it, set1_st.shape(),
651 &set1_group_set);
652 ++set1_group_it;
653 group_indices = &set1_group_indices;
654 }
655
656 // Get values from set2, if applicable.
657 set2_group_set.clear();
658 if (compare_groups >= 0) {
659 PopulateFromSparseGroup<T>(ctx, *set2_group_it, set2_st.shape(),
660 &set2_group_set);
661 ++set2_group_it;
662 group_indices = &set2_group_indices;
663 }
664
665 std::set<T> group_set;
666 ApplySetOperation(set1_group_set, set2_group_set, &group_set);
667 if (!group_set.empty()) {
668 group_sets[*group_indices] = group_set;
669 const auto set_size = group_set.size();
670 if (set_size > max_set_size) {
671 max_set_size = set_size;
672 }
673 num_result_values += set_size;
674 }
675 }
676
677 TensorShape output_shape;
678 OP_REQUIRES_OK(ctx, TensorShapeUtils::MakeShape(group_shape, &output_shape));
679 output_shape.AddDim(max_set_size);
680 OutputSparseTensor<T>(ctx, output_shape, num_result_values, group_sets);
681 }
682
683 // Given set1 of shape [b, n1] and data_2 of shape [b, n2], populate result
684 // sparse tensor with [b, n3] values, where each row `i` contains the result of
685 // the set operation on elements from set1[i] and set2[i]. `n3` is the number
686 // of elements in that result row.
687 template <typename T>
Compute(OpKernelContext * ctx)688 void SetOperationOp<T>::Compute(OpKernelContext* ctx) {
689 switch (input_types_) {
690 case DENSE_DENSE:
691 ComputeDenseToDense(ctx);
692 break;
693 case DENSE_SPARSE:
694 ComputeDenseToSparse(ctx);
695 break;
696 case SPARSE_SPARSE:
697 ComputeSparseToSparse(ctx);
698 break;
699 }
700 }
701
702 template <typename T>
703 class DenseToDenseSetOperationOp : public SetOperationOp<T> {
704 public:
DenseToDenseSetOperationOp(OpKernelConstruction * ctx)705 explicit DenseToDenseSetOperationOp(OpKernelConstruction* ctx)
706 : SetOperationOp<T>(ctx, DENSE_DENSE) {}
707 };
708
709 #define _DENSE_TO_DENSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(T) \
710 REGISTER_KERNEL_BUILDER(Name("DenseToDenseSetOperation") \
711 .Device(DEVICE_CPU) \
712 .TypeConstraint<T>("T"), \
713 DenseToDenseSetOperationOp<T>);
714 _DENSE_TO_DENSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(int8);
715 _DENSE_TO_DENSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(int16);
716 _DENSE_TO_DENSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(int32);
717 _DENSE_TO_DENSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(int64);
718 _DENSE_TO_DENSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(uint8);
719 _DENSE_TO_DENSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(uint16);
720 _DENSE_TO_DENSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(tstring);
721 #undef _DENSE_TO_DENSE_SET_OPERATION_REGISTER_KERNEL_BUILDER
722
723 template <typename T>
724 class DenseToSparseSetOperationOp : public SetOperationOp<T> {
725 public:
DenseToSparseSetOperationOp(OpKernelConstruction * ctx)726 explicit DenseToSparseSetOperationOp(OpKernelConstruction* ctx)
727 : SetOperationOp<T>(ctx, DENSE_SPARSE) {}
728 };
729
730 #define _DENSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(T) \
731 REGISTER_KERNEL_BUILDER(Name("DenseToSparseSetOperation") \
732 .Device(DEVICE_CPU) \
733 .TypeConstraint<T>("T"), \
734 DenseToSparseSetOperationOp<T>);
735 _DENSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(int8);
736 _DENSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(int16);
737 _DENSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(int32);
738 _DENSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(int64);
739 _DENSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(uint8);
740 _DENSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(uint16);
741 _DENSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(tstring);
742 #undef _DENSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER
743
744 template <typename T>
745 class SparseToSparseSetOperationOp : public SetOperationOp<T> {
746 public:
SparseToSparseSetOperationOp(OpKernelConstruction * ctx)747 explicit SparseToSparseSetOperationOp(OpKernelConstruction* ctx)
748 : SetOperationOp<T>(ctx, SPARSE_SPARSE) {}
749 };
750
751 #define _SPARSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(T) \
752 REGISTER_KERNEL_BUILDER(Name("SparseToSparseSetOperation") \
753 .Device(DEVICE_CPU) \
754 .TypeConstraint<T>("T"), \
755 SparseToSparseSetOperationOp<T>);
756 _SPARSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(int8);
757 _SPARSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(int16);
758 _SPARSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(int32);
759 _SPARSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(int64);
760 _SPARSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(uint8);
761 _SPARSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(uint16);
762 _SPARSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER(tstring);
763 #undef _SPARSE_TO_SPARSE_SET_OPERATION_REGISTER_KERNEL_BUILDER
764
765 } // namespace tensorflow
766