• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/common_shape_fns.h"
17 #include "tensorflow/core/framework/op.h"
18 #include "tensorflow/core/framework/shape_inference.h"
19 #include "tensorflow/core/platform/errors.h"
20 
21 namespace tensorflow {
22 
23 using shape_inference::DimensionHandle;
24 using shape_inference::InferenceContext;
25 using shape_inference::ShapeHandle;
26 
27 namespace {
28 
SparseSparseMinOrMaxShapeFn(InferenceContext * c)29 Status SparseSparseMinOrMaxShapeFn(InferenceContext* c) {
30   ShapeHandle unused;
31   TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused));  // a_indices
32   TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));  // a_values
33   TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));  // a_shape
34   TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 2, &unused));  // b_indices
35   TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 1, &unused));  // b_values
36   TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 1, &unused));  // b_shape
37   c->set_output(0, c->Matrix(InferenceContext::kUnknownDim,
38                              InferenceContext::kUnknownDim));
39   c->set_output(1, c->Vector(InferenceContext::kUnknownDim));
40   return Status::OK();
41 }
42 
43 }  // namespace
44 
45 REGISTER_OP("SparseAddGrad")
46     .Input("backprop_val_grad: T")
47     .Input("a_indices: int64")
48     .Input("b_indices: int64")
49     .Input("sum_indices: int64")
50     .Output("a_val_grad: T")
51     .Output("b_val_grad: T")
52     .Attr("T: numbertype")
__anon9213d3ed0202(InferenceContext* c) 53     .SetShapeFn([](InferenceContext* c) {
54       ShapeHandle a_indices;
55       ShapeHandle b_indices;
56       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &a_indices));
57       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 2, &b_indices));
58       c->set_output(0, c->Vector(c->Dim(a_indices, 0)));
59       c->set_output(1, c->Vector(c->Dim(b_indices, 0)));
60       return Status::OK();
61     });
62 
63 REGISTER_OP("SparseAdd")
64     .Input("a_indices: int64")
65     .Input("a_values: T")
66     .Input("a_shape: int64")
67     .Input("b_indices: int64")
68     .Input("b_values: T")
69     .Input("b_shape: int64")
70     .Input("thresh: Treal")
71     .Output("sum_indices: int64")
72     .Output("sum_values: T")
73     .Output("sum_shape: int64")
74     .Attr("T: numbertype")
75     .Attr("Treal: realnumbertype")
__anon9213d3ed0302(InferenceContext* c) 76     .SetShapeFn([](InferenceContext* c) {
77       ShapeHandle a_shape;
78       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &a_shape));
79       c->set_output(
80           0, c->Matrix(InferenceContext::kUnknownDim, c->Dim(a_shape, 0)));
81       c->set_output(1, c->Vector(InferenceContext::kUnknownDim));
82       c->set_output(2, a_shape);
83       return Status::OK();
84     });
85 
86 REGISTER_OP("SparseTensorDenseMatMul")
87     .Input("a_indices: Tindices")
88     .Input("a_values: T")
89     .Input("a_shape: int64")
90     .Input("b: T")
91     .Output("product: T")
92     .Attr("T: type")
93     .Attr("Tindices: {int32,int64} = DT_INT64")
94     .Attr("adjoint_a: bool = false")
95     .Attr("adjoint_b: bool = false")
__anon9213d3ed0402(InferenceContext* c) 96     .SetShapeFn([](InferenceContext* c) {
97       DimensionHandle unused_dim;
98       ShapeHandle unused;
99       ShapeHandle b;
100       ShapeHandle a_shape;
101       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused));  // a_indices
102       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));  // a_values
103       TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(2, &a_shape));
104       TF_RETURN_IF_ERROR(c->WithRank(a_shape, 2, &a_shape));
105       TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 2, &b));
106 
107       bool adjoint_a;
108       bool adjoint_b;
109       TF_RETURN_IF_ERROR(c->GetAttr("adjoint_a", &adjoint_a));
110       TF_RETURN_IF_ERROR(c->GetAttr("adjoint_b", &adjoint_b));
111 
112       DimensionHandle output_right = c->Dim(b, adjoint_b ? 0 : 1);
113       DimensionHandle output_left = c->Dim(a_shape, adjoint_a ? 1 : 0);
114       DimensionHandle inner_left = c->Dim(a_shape, adjoint_a ? 0 : 1);
115       DimensionHandle inner_right = c->Dim(b, adjoint_b ? 1 : 0);
116       TF_RETURN_IF_ERROR(c->Merge(inner_left, inner_right, &unused_dim));
117       c->set_output(0, c->Matrix(output_left, output_right));
118       return Status::OK();
119     });
120 
121 REGISTER_OP("SerializeSparse")
122     .Input("sparse_indices: int64")
123     .Input("sparse_values: T")
124     .Input("sparse_shape: int64")
125     .Attr("T: type")
126     .Output("serialized_sparse: out_type")
127     .Attr("out_type: {string, variant} = DT_STRING")
__anon9213d3ed0502(InferenceContext* c) 128     .SetShapeFn([](InferenceContext* c) {
129       ShapeHandle unused;
130       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused));
131       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
132       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
133       c->set_output(0, c->Vector(3));
134       return Status::OK();
135     });
136 
137 REGISTER_OP("SerializeManySparse")
138     .Input("sparse_indices: int64")
139     .Input("sparse_values: T")
140     .Input("sparse_shape: int64")
141     .Attr("T: type")
142     .Output("serialized_sparse: out_type")
143     .Attr("out_type: {string, variant} = DT_STRING")
__anon9213d3ed0602(InferenceContext* c) 144     .SetShapeFn([](InferenceContext* c) {
145       ShapeHandle unused;
146       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused));
147       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
148       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
149       c->set_output(0, c->Matrix(InferenceContext::kUnknownDim, 3));
150       return Status::OK();
151     });
152 
153 REGISTER_OP("DeserializeSparse")
154     .Input("serialized_sparse: Tserialized")
155     .Output("sparse_indices: int64")
156     .Output("sparse_values: dtype")
157     .Output("sparse_shape: int64")
158     .Attr("dtype: type")
159     .Attr("Tserialized: {string, variant} = DT_STRING")
__anon9213d3ed0702(InferenceContext* c) 160     .SetShapeFn([](InferenceContext* c) {
161       // serialized sparse is [?, ..., ?, 3] vector.
162       DimensionHandle unused;
163       TF_RETURN_IF_ERROR(c->WithValue(c->Dim(c->input(0), -1), 3, &unused));
164       c->set_output(0, c->Matrix(InferenceContext::kUnknownDim,
165                                  InferenceContext::kUnknownDim));
166       c->set_output(1, c->Vector(InferenceContext::kUnknownDim));
167       c->set_output(2, c->Vector(InferenceContext::kUnknownDim));
168       return Status::OK();
169     });
170 
171 REGISTER_OP("DeserializeManySparse")
172     .Input("serialized_sparse: string")
173     .Output("sparse_indices: int64")
174     .Output("sparse_values: dtype")
175     .Output("sparse_shape: int64")
176     .Attr("dtype: type")
__anon9213d3ed0802(InferenceContext* c) 177     .SetShapeFn([](InferenceContext* c) {
178       // serialized sparse is [?,3] matrix.
179       ShapeHandle serialized_sparse;
180       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &serialized_sparse));
181       DimensionHandle unused;
182       TF_RETURN_IF_ERROR(
183           c->WithValue(c->Dim(serialized_sparse, 1), 3, &unused));
184 
185       c->set_output(0, c->Matrix(InferenceContext::kUnknownDim,
186                                  InferenceContext::kUnknownDim));
187       c->set_output(1, c->Vector(InferenceContext::kUnknownDim));
188       c->set_output(2, c->Vector(InferenceContext::kUnknownDim));
189       return Status::OK();
190     });
191 
192 REGISTER_OP("SparseToDense")
193     .Input("sparse_indices: Tindices")
194     .Input("output_shape: Tindices")
195     .Input("sparse_values: T")
196     .Input("default_value: T")
197     .Attr("validate_indices: bool = true")
198     .Attr("T: type")
199     .Output("dense: T")
200     .Attr("Tindices: {int32, int64}")
__anon9213d3ed0902(InferenceContext* c) 201     .SetShapeFn([](InferenceContext* c) {
202       ShapeHandle out;
203       TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &out));
204       c->set_output(0, out);
205       return Status::OK();
206     });
207 
208 REGISTER_OP("SparseConcat")
209     .Input("indices: N * int64")
210     .Input("values: N * T")
211     .Input("shapes: N * int64")
212     .Output("output_indices: int64")
213     .Output("output_values: T")
214     .Output("output_shape: int64")
215     .Attr("concat_dim: int")
216     .Attr("N: int >= 2")
217     .Attr("T: type")
__anon9213d3ed0a02(InferenceContext* c) 218     .SetShapeFn([](InferenceContext* c) {
219       // These accumulates the sum.
220       DimensionHandle output_row_count = c->MakeDim(0ll);
221 
222       // These are only merged.
223       DimensionHandle output_ind_cols = c->UnknownDim();
224       ShapeHandle output_shape = c->UnknownShape();
225 
226       const int n = c->num_inputs() / 3;
227       for (int i = 0; i < n; i++) {
228         ShapeHandle ind;
229         TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 2, &ind));
230         ShapeHandle val;
231         TF_RETURN_IF_ERROR(c->WithRank(c->input(i + n), 1, &val));
232         ShapeHandle shape;
233         TF_RETURN_IF_ERROR(c->WithRank(c->input(i + 2 * n), 1, &shape));
234 
235         // Add to output_ind_rows.
236         DimensionHandle num_dim;
237         TF_RETURN_IF_ERROR(c->Merge(c->Dim(ind, 0), c->Dim(val, 0), &num_dim));
238         TF_RETURN_IF_ERROR(
239             c->Add(output_row_count, num_dim, &output_row_count));
240 
241         // Merge into output_ind_cols and output_shape.
242         TF_RETURN_IF_ERROR(
243             c->Merge(output_ind_cols, c->Dim(ind, 1), &output_ind_cols));
244         TF_RETURN_IF_ERROR(c->Merge(output_shape, shape, &output_shape));
245       }
246 
247       c->set_output(0, c->Matrix(output_row_count, output_ind_cols));
248       c->set_output(1, c->Vector(output_row_count));
249       c->set_output(2, output_shape);
250       return Status::OK();
251     });
252 
253 REGISTER_OP("SparseCross")
254     .Input("indices: N * int64")
255     .Input("values: sparse_types")
256     .Input("shapes: N * int64")
257     .Input("dense_inputs: dense_types")
258     .Output("output_indices: int64")
259     .Output("output_values: out_type")
260     .Output("output_shape: int64")
261     .Attr("N: int >= 0")
262     .Attr("hashed_output: bool")
263     .Attr("num_buckets: int >= 0")
264     .Attr("hash_key: int")
265     .Attr("sparse_types: list({int64, string}) >= 0")
266     .Attr("dense_types: list({int64, string}) >= 0")
267     .Attr("out_type: {int64, string}")
268     .Attr("internal_type: {int64, string}")
__anon9213d3ed0b02(shape_inference::InferenceContext* c) 269     .SetShapeFn([](shape_inference::InferenceContext* c) {
270       c->set_output(0, c->Matrix(c->UnknownDim(), 2));
271       c->set_output(1, c->Vector(c->UnknownDim()));
272       c->set_output(2, c->Vector(2));
273       return Status::OK();
274     });
275 
276 REGISTER_OP("SparseCrossV2")
277     .Input("indices: N * int64")
278     .Input("values: sparse_types")
279     .Input("shapes: N * int64")
280     .Input("dense_inputs: dense_types")
281     .Input("sep: string")
282     .Output("output_indices: int64")
283     .Output("output_values: string")
284     .Output("output_shape: int64")
285     .Attr("N: int >= 0")
286     .Attr("sparse_types: list({int64, string}) >= 0")
287     .Attr("dense_types: list({int64, string}) >= 0")
__anon9213d3ed0c02(shape_inference::InferenceContext* c) 288     .SetShapeFn([](shape_inference::InferenceContext* c) {
289       c->set_output(0, c->Matrix(c->UnknownDim(), 2));
290       c->set_output(1, c->Vector(c->UnknownDim()));
291       c->set_output(2, c->Vector(2));
292       return Status::OK();
293     });
294 
295 REGISTER_OP("SparseCrossHashed")
296     .Input("indices: N * int64")
297     .Input("values: sparse_types")
298     .Input("shapes: N * int64")
299     .Input("dense_inputs: dense_types")
300     .Input("num_buckets: int64")
301     .Input("strong_hash: bool")
302     .Input("salt: int64")
303     .Output("output_indices: int64")
304     .Output("output_values: int64")
305     .Output("output_shape: int64")
306     .Attr("N: int >= 0")
307     .Attr("sparse_types: list({int64, string}) >= 0")
308     .Attr("dense_types: list({int64, string}) >= 0")
__anon9213d3ed0d02(shape_inference::InferenceContext* c) 309     .SetShapeFn([](shape_inference::InferenceContext* c) {
310       c->set_output(0, c->Matrix(c->UnknownDim(), 2));
311       c->set_output(1, c->Vector(c->UnknownDim()));
312       c->set_output(2, c->Vector(2));
313       return Status::OK();
314     });
315 
316 REGISTER_OP("SparseSplit")
317     .Input("split_dim: int64")
318     .Input("indices: int64")
319     .Input("values: T")
320     .Input("shape: int64")
321     .Output("output_indices: num_split * int64")
322     .Output("output_values:  num_split * T")
323     .Output("output_shape:   num_split * int64")
324     .Attr("num_split: int >= 1")
325     .Attr("T: type")
__anon9213d3ed0e02(InferenceContext* c) 326     .SetShapeFn([](InferenceContext* c) {
327       ShapeHandle input_shape = c->input(3);
328       ShapeHandle output_indices =
329           c->Matrix(InferenceContext::kUnknownDim, c->NumElements(input_shape));
330       ShapeHandle output_values = c->Vector(InferenceContext::kUnknownDim);
331       ShapeHandle output_shape = input_shape;
332 
333       // Copy the outputs into the output ranges.
334       int num_splits = c->num_outputs() / 3;
335       int out_idx = 0;
336       for (int i = 0; i < num_splits; ++i)
337         c->set_output(out_idx++, output_indices);
338       for (int i = 0; i < num_splits; ++i)
339         c->set_output(out_idx++, output_values);
340       for (int i = 0; i < num_splits; ++i)
341         c->set_output(out_idx++, output_shape);
342       return Status::OK();
343     });
344 
345 REGISTER_OP("SparseSliceGrad")
346     .Input("backprop_val_grad: T")
347     .Input("input_indices: int64")
348     .Input("input_start: int64")
349     .Input("output_indices: int64")
350     .Output("val_grad: T")
351     .Attr("T: numbertype")
__anon9213d3ed0f02(InferenceContext* c) 352     .SetShapeFn([](InferenceContext* c) {
353       ShapeHandle indices;
354       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &indices));
355       c->set_output(0, c->Vector(c->Dim(indices, 0)));
356       return Status::OK();
357     });
358 
359 REGISTER_OP("SparseSlice")
360     .Input("indices: int64")
361     .Input("values: T")
362     .Input("shape: int64")
363     .Input("start: int64")
364     .Input("size: int64")
365     .Output("output_indices: int64")
366     .Output("output_values: T")
367     .Output("output_shape: int64")
368     .Attr("T: type")
__anon9213d3ed1002(InferenceContext* c) 369     .SetShapeFn([](InferenceContext* c) {
370       ShapeHandle input_shape = c->input(2);
371       ShapeHandle output_indices =
372           c->Matrix(InferenceContext::kUnknownDim, c->NumElements(input_shape));
373       ShapeHandle output_values = c->Vector(InferenceContext::kUnknownDim);
374       ShapeHandle output_shape = input_shape;
375 
376       c->set_output(0, output_indices);
377       c->set_output(1, output_values);
378       c->set_output(2, output_shape);
379       return Status::OK();
380     });
381 
382 REGISTER_OP("SparseReorder")
383     .Input("input_indices: int64")
384     .Input("input_values: T")
385     .Input("input_shape: int64")
386     .Output("output_indices: int64")
387     .Output("output_values: T")
388     .Attr("T: type")
__anon9213d3ed1102(InferenceContext* c) 389     .SetShapeFn([](InferenceContext* c) {
390       ShapeHandle indices;
391       ShapeHandle values;
392       ShapeHandle unused;
393 
394       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &indices));
395       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &values));
396       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
397 
398       c->set_output(0, indices);
399       c->set_output(1, values);
400       return Status::OK();
401     });
402 
403 REGISTER_OP("SparseReshape")
404     .Input("input_indices: int64")
405     .Input("input_shape: int64")
406     .Input("new_shape: int64")
407     .Output("output_indices: int64")
408     .Output("output_shape: int64")
__anon9213d3ed1202(InferenceContext* c) 409     .SetShapeFn([](InferenceContext* c) {
410       ShapeHandle indices;
411       ShapeHandle unused;
412       ShapeHandle new_shape;
413 
414       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &indices));
415       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
416       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &new_shape));
417 
418       c->set_output(0, c->Matrix(c->Dim(indices, 0), c->Dim(new_shape, 0)));
419       c->set_output(1, new_shape);
420       return Status::OK();
421     });
422 
423 REGISTER_OP("SparseTensorDenseAdd")
424     .Input("a_indices: Tindices")
425     .Input("a_values: T")
426     .Input("a_shape: Tindices")
427     .Input("b: T")
428     .Output("output: T")
429     .Attr("T: numbertype")
430     .Attr("Tindices: {int32, int64}")
__anon9213d3ed1302(InferenceContext* c) 431     .SetShapeFn([](InferenceContext* c) {
432       c->set_output(0, c->input(3));
433       return Status::OK();
434     });
435 
436 REGISTER_OP("SparseReduceMax")
437     .Input("input_indices: int64")
438     .Input("input_values: T")
439     .Input("input_shape: int64")
440     .Input("reduction_axes: int32")
441     .Attr("keep_dims: bool = False")
442     .Output("output: T")
443     .Attr("T: realnumbertype")
444     .SetShapeFn(shape_inference::SparseReduceShapeFn);
445 
446 REGISTER_OP("SparseReduceMaxSparse")
447     .Input("input_indices: int64")
448     .Input("input_values: T")
449     .Input("input_shape: int64")
450     .Input("reduction_axes: int32")
451     .Attr("keep_dims: bool = False")
452     .Output("output_indices: int64")
453     .Output("output_values: T")
454     .Output("output_shape: int64")
455     .Attr("T: realnumbertype")
456     .SetShapeFn(shape_inference::UnknownShape);
457 
458 REGISTER_OP("SparseReduceSum")
459     .Input("input_indices: int64")
460     .Input("input_values: T")
461     .Input("input_shape: int64")
462     .Input("reduction_axes: int32")
463     .Attr("keep_dims: bool = False")
464     .Output("output: T")
465     .Attr("T: numbertype")
466     .SetShapeFn(shape_inference::SparseReduceShapeFn);
467 
468 REGISTER_OP("SparseReduceSumSparse")
469     .Input("input_indices: int64")
470     .Input("input_values: T")
471     .Input("input_shape: int64")
472     .Input("reduction_axes: int32")
473     .Attr("keep_dims: bool = False")
474     .Output("output_indices: int64")
475     .Output("output_values: T")
476     .Output("output_shape: int64")
477     .Attr("T: numbertype")
478     .SetShapeFn(shape_inference::UnknownShape);
479 
480 #define SPARSE_DENSE_CWISE_SIGNATURE()                           \
481   Input("sp_indices: int64")                                     \
482       .Input("sp_values: T")                                     \
483       .Input("sp_shape: int64")                                  \
484       .Input("dense: T")                                         \
485       .Output("output: T")                                       \
486       .Attr("T: numbertype")                                     \
487       .SetShapeFn([](InferenceContext* c) {                      \
488         ShapeHandle input;                                       \
489         TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &input)); \
490         c->set_output(0, c->Vector(c->Dim(input, 0)));           \
491         return Status::OK();                                     \
492       })
493 
494 REGISTER_OP("SparseDenseCwiseMul").SPARSE_DENSE_CWISE_SIGNATURE();
495 
496 REGISTER_OP("SparseDenseCwiseDiv").SPARSE_DENSE_CWISE_SIGNATURE();
497 
498 REGISTER_OP("SparseDenseCwiseAdd").SPARSE_DENSE_CWISE_SIGNATURE();
499 
500 #undef SPARSE_DENSE_CWISE_SIGNATURE
501 
502 REGISTER_OP("SparseSoftmax")
503     .Input("sp_indices: int64")
504     .Input("sp_values: T")
505     .Input("sp_shape: int64")
506     .Output("output: T")
507     .Attr("T: {float, double}")
__anon9213d3ed1402(InferenceContext* c) 508     .SetShapeFn([](InferenceContext* c) {
509       ShapeHandle unused;
510       ShapeHandle values;
511       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused));  // sp_indices
512       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &values));  // sp_values
513       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
514       c->set_output(0, values);
515       return Status::OK();
516     });
517 
518 REGISTER_OP("SparseSparseMaximum")
519     .Input("a_indices: int64")
520     .Input("a_values: T")
521     .Input("a_shape: int64")
522     .Input("b_indices: int64")
523     .Input("b_values: T")
524     .Input("b_shape: int64")
525     .Output("output_indices: int64")
526     .Output("output_values: T")
527     .Attr("T: realnumbertype")
528     .SetShapeFn(SparseSparseMinOrMaxShapeFn);
529 
530 REGISTER_OP("SparseSparseMinimum")
531     .Input("a_indices: int64")
532     .Input("a_values: T")
533     .Input("a_shape: int64")
534     .Input("b_indices: int64")
535     .Input("b_values: T")
536     .Input("b_shape: int64")
537     .Output("output_indices: int64")
538     .Output("output_values: T")
539     .Attr("T: numbertype")
540     .SetShapeFn(SparseSparseMinOrMaxShapeFn);
541 
542 REGISTER_OP("AddSparseToTensorsMap")
543     .Input("sparse_indices: int64")
544     .Input("sparse_values: T")
545     .Input("sparse_shape: int64")
546     .Output("sparse_handle: int64")
547     .Attr("T: type")
548     .Attr("container: string = ''")
549     .Attr("shared_name: string = ''")
550     .SetIsStateful()
__anon9213d3ed1502(InferenceContext* c) 551     .SetShapeFn([](InferenceContext* c) {
552       ShapeHandle unused;
553       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused));
554       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
555       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
556       c->set_output(0, c->Scalar());
557       return Status::OK();
558     });
559 
560 REGISTER_OP("AddManySparseToTensorsMap")
561     .Input("sparse_indices: int64")
562     .Input("sparse_values: T")
563     .Input("sparse_shape: int64")
564     .Output("sparse_handles: int64")
565     .Attr("T: type")
566     .Attr("container: string = ''")
567     .Attr("shared_name: string = ''")
568     .SetIsStateful()
__anon9213d3ed1602(InferenceContext* c) 569     .SetShapeFn([](InferenceContext* c) {
570       ShapeHandle unused;
571       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused));
572       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
573       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
574       c->set_output(0, c->Vector(InferenceContext::kUnknownDim));
575       return Status::OK();
576     });
577 
578 REGISTER_OP("TakeManySparseFromTensorsMap")
579     .Input("sparse_handles: int64")
580     .Output("sparse_indices: int64")
581     .Output("sparse_values: dtype")
582     .Output("sparse_shape: int64")
583     .Attr("dtype: type")
584     .Attr("container: string = ''")
585     .Attr("shared_name: string = ''")
586     .SetIsStateful()
__anon9213d3ed1702(InferenceContext* c) 587     .SetShapeFn([](InferenceContext* c) {
588       // serialized sparse is [?,1] matrix.
589       ShapeHandle sparse_handles;
590       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &sparse_handles));
591 
592       c->set_output(0, c->Matrix(InferenceContext::kUnknownDim,
593                                  InferenceContext::kUnknownDim));
594       c->set_output(1, c->Vector(InferenceContext::kUnknownDim));
595       c->set_output(2, c->Vector(InferenceContext::kUnknownDim));
596       return Status::OK();
597     });
598 
599 REGISTER_OP("SparseFillEmptyRows")
600     .Input("indices: int64")
601     .Input("values: T")
602     .Input("dense_shape: int64")
603     .Input("default_value: T")
604     .Output("output_indices: int64")
605     .Output("output_values: T")
606     .Output("empty_row_indicator: bool")
607     .Output("reverse_index_map: int64")
608     .Attr("T: type")
__anon9213d3ed1802(InferenceContext* c) 609     .SetShapeFn([](InferenceContext* c) {
610       ShapeHandle input_indices = c->input(0);
611       TF_RETURN_IF_ERROR(c->WithRank(input_indices, 2, &input_indices));
612       ShapeHandle input_values = c->input(1);
613       TF_RETURN_IF_ERROR(c->WithRank(input_values, 1, &input_values));
614       ShapeHandle input_shape = c->input(2);
615       TF_RETURN_IF_ERROR(c->WithRank(input_shape, 1, &input_shape));
616       ShapeHandle default_value = c->input(3);
617       TF_RETURN_IF_ERROR(c->WithRank(default_value, 0, &default_value));
618       DimensionHandle N = c->Dim(input_indices, 0);
619       TF_RETURN_IF_ERROR(c->Merge(N, c->Dim(input_values, 0), &N));
620       DimensionHandle unused_dim;
621       TF_RETURN_IF_ERROR(c->Merge(c->Dim(input_indices, 1),
622                                   c->Dim(input_shape, 0), &unused_dim));
623       if (c->Value(c->NumElements(input_shape)) == 0)
624         return errors::InvalidArgument("dense_shape must not be empty");
625       ShapeHandle output_indices =
626           c->Matrix(InferenceContext::kUnknownDim, c->NumElements(input_shape));
627       ShapeHandle output_values = c->Vector(InferenceContext::kUnknownDim);
628       ShapeHandle constant_input_shape;
629       TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(2, &constant_input_shape));
630       ShapeHandle empty_row_indicator =
631           c->Vector(c->Dim(constant_input_shape, 0));
632       ShapeHandle reverse_index_map = c->Vector(N);
633       c->set_output(0, output_indices);
634       c->set_output(1, output_values);
635       c->set_output(2, empty_row_indicator);
636       c->set_output(3, reverse_index_map);
637       return Status::OK();
638     });
639 
640 REGISTER_OP("SparseFillEmptyRowsGrad")
641     .Input("reverse_index_map: int64")
642     .Input("grad_values: T")
643     .Output("d_values: T")
644     .Output("d_default_value: T")
645     .Attr("T: type")
__anon9213d3ed1902(InferenceContext* c) 646     .SetShapeFn([](InferenceContext* c) {
647       ShapeHandle reverse_index_map = c->input(0);
648       TF_RETURN_IF_ERROR(c->WithRank(reverse_index_map, 1, &reverse_index_map));
649       ShapeHandle grad_values = c->input(1);
650       TF_RETURN_IF_ERROR(c->WithRank(grad_values, 1, &grad_values));
651       c->set_output(0, reverse_index_map);
652       c->set_output(1, c->Scalar());
653       return Status::OK();
654     });
655 
656 }  // namespace tensorflow
657