1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/framework/common_shape_fns.h"
17 #include "tensorflow/core/framework/op.h"
18 #include "tensorflow/core/framework/shape_inference.h"
19 #include "tensorflow/core/platform/errors.h"
20
21 namespace tensorflow {
22
23 using shape_inference::DimensionHandle;
24 using shape_inference::InferenceContext;
25 using shape_inference::ShapeHandle;
26
27 namespace {
28
SparseSparseMinOrMaxShapeFn(InferenceContext * c)29 Status SparseSparseMinOrMaxShapeFn(InferenceContext* c) {
30 ShapeHandle unused;
31 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused)); // a_indices
32 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused)); // a_values
33 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused)); // a_shape
34 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 2, &unused)); // b_indices
35 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 1, &unused)); // b_values
36 TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 1, &unused)); // b_shape
37 c->set_output(0, c->Matrix(InferenceContext::kUnknownDim,
38 InferenceContext::kUnknownDim));
39 c->set_output(1, c->Vector(InferenceContext::kUnknownDim));
40 return Status::OK();
41 }
42
43 } // namespace
44
45 REGISTER_OP("SparseAddGrad")
46 .Input("backprop_val_grad: T")
47 .Input("a_indices: int64")
48 .Input("b_indices: int64")
49 .Input("sum_indices: int64")
50 .Output("a_val_grad: T")
51 .Output("b_val_grad: T")
52 .Attr("T: numbertype")
__anon9213d3ed0202(InferenceContext* c) 53 .SetShapeFn([](InferenceContext* c) {
54 ShapeHandle a_indices;
55 ShapeHandle b_indices;
56 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &a_indices));
57 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 2, &b_indices));
58 c->set_output(0, c->Vector(c->Dim(a_indices, 0)));
59 c->set_output(1, c->Vector(c->Dim(b_indices, 0)));
60 return Status::OK();
61 });
62
63 REGISTER_OP("SparseAdd")
64 .Input("a_indices: int64")
65 .Input("a_values: T")
66 .Input("a_shape: int64")
67 .Input("b_indices: int64")
68 .Input("b_values: T")
69 .Input("b_shape: int64")
70 .Input("thresh: Treal")
71 .Output("sum_indices: int64")
72 .Output("sum_values: T")
73 .Output("sum_shape: int64")
74 .Attr("T: numbertype")
75 .Attr("Treal: realnumbertype")
__anon9213d3ed0302(InferenceContext* c) 76 .SetShapeFn([](InferenceContext* c) {
77 ShapeHandle a_shape;
78 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &a_shape));
79 c->set_output(
80 0, c->Matrix(InferenceContext::kUnknownDim, c->Dim(a_shape, 0)));
81 c->set_output(1, c->Vector(InferenceContext::kUnknownDim));
82 c->set_output(2, a_shape);
83 return Status::OK();
84 });
85
86 REGISTER_OP("SparseTensorDenseMatMul")
87 .Input("a_indices: Tindices")
88 .Input("a_values: T")
89 .Input("a_shape: int64")
90 .Input("b: T")
91 .Output("product: T")
92 .Attr("T: type")
93 .Attr("Tindices: {int32,int64} = DT_INT64")
94 .Attr("adjoint_a: bool = false")
95 .Attr("adjoint_b: bool = false")
__anon9213d3ed0402(InferenceContext* c) 96 .SetShapeFn([](InferenceContext* c) {
97 DimensionHandle unused_dim;
98 ShapeHandle unused;
99 ShapeHandle b;
100 ShapeHandle a_shape;
101 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused)); // a_indices
102 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused)); // a_values
103 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(2, &a_shape));
104 TF_RETURN_IF_ERROR(c->WithRank(a_shape, 2, &a_shape));
105 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 2, &b));
106
107 bool adjoint_a;
108 bool adjoint_b;
109 TF_RETURN_IF_ERROR(c->GetAttr("adjoint_a", &adjoint_a));
110 TF_RETURN_IF_ERROR(c->GetAttr("adjoint_b", &adjoint_b));
111
112 DimensionHandle output_right = c->Dim(b, adjoint_b ? 0 : 1);
113 DimensionHandle output_left = c->Dim(a_shape, adjoint_a ? 1 : 0);
114 DimensionHandle inner_left = c->Dim(a_shape, adjoint_a ? 0 : 1);
115 DimensionHandle inner_right = c->Dim(b, adjoint_b ? 1 : 0);
116 TF_RETURN_IF_ERROR(c->Merge(inner_left, inner_right, &unused_dim));
117 c->set_output(0, c->Matrix(output_left, output_right));
118 return Status::OK();
119 });
120
121 REGISTER_OP("SerializeSparse")
122 .Input("sparse_indices: int64")
123 .Input("sparse_values: T")
124 .Input("sparse_shape: int64")
125 .Attr("T: type")
126 .Output("serialized_sparse: out_type")
127 .Attr("out_type: {string, variant} = DT_STRING")
__anon9213d3ed0502(InferenceContext* c) 128 .SetShapeFn([](InferenceContext* c) {
129 ShapeHandle unused;
130 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused));
131 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
132 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
133 c->set_output(0, c->Vector(3));
134 return Status::OK();
135 });
136
137 REGISTER_OP("SerializeManySparse")
138 .Input("sparse_indices: int64")
139 .Input("sparse_values: T")
140 .Input("sparse_shape: int64")
141 .Attr("T: type")
142 .Output("serialized_sparse: out_type")
143 .Attr("out_type: {string, variant} = DT_STRING")
__anon9213d3ed0602(InferenceContext* c) 144 .SetShapeFn([](InferenceContext* c) {
145 ShapeHandle unused;
146 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused));
147 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
148 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
149 c->set_output(0, c->Matrix(InferenceContext::kUnknownDim, 3));
150 return Status::OK();
151 });
152
153 REGISTER_OP("DeserializeSparse")
154 .Input("serialized_sparse: Tserialized")
155 .Output("sparse_indices: int64")
156 .Output("sparse_values: dtype")
157 .Output("sparse_shape: int64")
158 .Attr("dtype: type")
159 .Attr("Tserialized: {string, variant} = DT_STRING")
__anon9213d3ed0702(InferenceContext* c) 160 .SetShapeFn([](InferenceContext* c) {
161 // serialized sparse is [?, ..., ?, 3] vector.
162 DimensionHandle unused;
163 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(c->input(0), -1), 3, &unused));
164 c->set_output(0, c->Matrix(InferenceContext::kUnknownDim,
165 InferenceContext::kUnknownDim));
166 c->set_output(1, c->Vector(InferenceContext::kUnknownDim));
167 c->set_output(2, c->Vector(InferenceContext::kUnknownDim));
168 return Status::OK();
169 });
170
171 REGISTER_OP("DeserializeManySparse")
172 .Input("serialized_sparse: string")
173 .Output("sparse_indices: int64")
174 .Output("sparse_values: dtype")
175 .Output("sparse_shape: int64")
176 .Attr("dtype: type")
__anon9213d3ed0802(InferenceContext* c) 177 .SetShapeFn([](InferenceContext* c) {
178 // serialized sparse is [?,3] matrix.
179 ShapeHandle serialized_sparse;
180 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &serialized_sparse));
181 DimensionHandle unused;
182 TF_RETURN_IF_ERROR(
183 c->WithValue(c->Dim(serialized_sparse, 1), 3, &unused));
184
185 c->set_output(0, c->Matrix(InferenceContext::kUnknownDim,
186 InferenceContext::kUnknownDim));
187 c->set_output(1, c->Vector(InferenceContext::kUnknownDim));
188 c->set_output(2, c->Vector(InferenceContext::kUnknownDim));
189 return Status::OK();
190 });
191
192 REGISTER_OP("SparseToDense")
193 .Input("sparse_indices: Tindices")
194 .Input("output_shape: Tindices")
195 .Input("sparse_values: T")
196 .Input("default_value: T")
197 .Attr("validate_indices: bool = true")
198 .Attr("T: type")
199 .Output("dense: T")
200 .Attr("Tindices: {int32, int64}")
__anon9213d3ed0902(InferenceContext* c) 201 .SetShapeFn([](InferenceContext* c) {
202 ShapeHandle out;
203 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &out));
204 c->set_output(0, out);
205 return Status::OK();
206 });
207
208 REGISTER_OP("SparseConcat")
209 .Input("indices: N * int64")
210 .Input("values: N * T")
211 .Input("shapes: N * int64")
212 .Output("output_indices: int64")
213 .Output("output_values: T")
214 .Output("output_shape: int64")
215 .Attr("concat_dim: int")
216 .Attr("N: int >= 2")
217 .Attr("T: type")
__anon9213d3ed0a02(InferenceContext* c) 218 .SetShapeFn([](InferenceContext* c) {
219 // These accumulates the sum.
220 DimensionHandle output_row_count = c->MakeDim(0ll);
221
222 // These are only merged.
223 DimensionHandle output_ind_cols = c->UnknownDim();
224 ShapeHandle output_shape = c->UnknownShape();
225
226 const int n = c->num_inputs() / 3;
227 for (int i = 0; i < n; i++) {
228 ShapeHandle ind;
229 TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 2, &ind));
230 ShapeHandle val;
231 TF_RETURN_IF_ERROR(c->WithRank(c->input(i + n), 1, &val));
232 ShapeHandle shape;
233 TF_RETURN_IF_ERROR(c->WithRank(c->input(i + 2 * n), 1, &shape));
234
235 // Add to output_ind_rows.
236 DimensionHandle num_dim;
237 TF_RETURN_IF_ERROR(c->Merge(c->Dim(ind, 0), c->Dim(val, 0), &num_dim));
238 TF_RETURN_IF_ERROR(
239 c->Add(output_row_count, num_dim, &output_row_count));
240
241 // Merge into output_ind_cols and output_shape.
242 TF_RETURN_IF_ERROR(
243 c->Merge(output_ind_cols, c->Dim(ind, 1), &output_ind_cols));
244 TF_RETURN_IF_ERROR(c->Merge(output_shape, shape, &output_shape));
245 }
246
247 c->set_output(0, c->Matrix(output_row_count, output_ind_cols));
248 c->set_output(1, c->Vector(output_row_count));
249 c->set_output(2, output_shape);
250 return Status::OK();
251 });
252
253 REGISTER_OP("SparseCross")
254 .Input("indices: N * int64")
255 .Input("values: sparse_types")
256 .Input("shapes: N * int64")
257 .Input("dense_inputs: dense_types")
258 .Output("output_indices: int64")
259 .Output("output_values: out_type")
260 .Output("output_shape: int64")
261 .Attr("N: int >= 0")
262 .Attr("hashed_output: bool")
263 .Attr("num_buckets: int >= 0")
264 .Attr("hash_key: int")
265 .Attr("sparse_types: list({int64, string}) >= 0")
266 .Attr("dense_types: list({int64, string}) >= 0")
267 .Attr("out_type: {int64, string}")
268 .Attr("internal_type: {int64, string}")
__anon9213d3ed0b02(shape_inference::InferenceContext* c) 269 .SetShapeFn([](shape_inference::InferenceContext* c) {
270 c->set_output(0, c->Matrix(c->UnknownDim(), 2));
271 c->set_output(1, c->Vector(c->UnknownDim()));
272 c->set_output(2, c->Vector(2));
273 return Status::OK();
274 });
275
276 REGISTER_OP("SparseCrossV2")
277 .Input("indices: N * int64")
278 .Input("values: sparse_types")
279 .Input("shapes: N * int64")
280 .Input("dense_inputs: dense_types")
281 .Input("sep: string")
282 .Output("output_indices: int64")
283 .Output("output_values: string")
284 .Output("output_shape: int64")
285 .Attr("N: int >= 0")
286 .Attr("sparse_types: list({int64, string}) >= 0")
287 .Attr("dense_types: list({int64, string}) >= 0")
__anon9213d3ed0c02(shape_inference::InferenceContext* c) 288 .SetShapeFn([](shape_inference::InferenceContext* c) {
289 c->set_output(0, c->Matrix(c->UnknownDim(), 2));
290 c->set_output(1, c->Vector(c->UnknownDim()));
291 c->set_output(2, c->Vector(2));
292 return Status::OK();
293 });
294
295 REGISTER_OP("SparseCrossHashed")
296 .Input("indices: N * int64")
297 .Input("values: sparse_types")
298 .Input("shapes: N * int64")
299 .Input("dense_inputs: dense_types")
300 .Input("num_buckets: int64")
301 .Input("strong_hash: bool")
302 .Input("salt: int64")
303 .Output("output_indices: int64")
304 .Output("output_values: int64")
305 .Output("output_shape: int64")
306 .Attr("N: int >= 0")
307 .Attr("sparse_types: list({int64, string}) >= 0")
308 .Attr("dense_types: list({int64, string}) >= 0")
__anon9213d3ed0d02(shape_inference::InferenceContext* c) 309 .SetShapeFn([](shape_inference::InferenceContext* c) {
310 c->set_output(0, c->Matrix(c->UnknownDim(), 2));
311 c->set_output(1, c->Vector(c->UnknownDim()));
312 c->set_output(2, c->Vector(2));
313 return Status::OK();
314 });
315
316 REGISTER_OP("SparseSplit")
317 .Input("split_dim: int64")
318 .Input("indices: int64")
319 .Input("values: T")
320 .Input("shape: int64")
321 .Output("output_indices: num_split * int64")
322 .Output("output_values: num_split * T")
323 .Output("output_shape: num_split * int64")
324 .Attr("num_split: int >= 1")
325 .Attr("T: type")
__anon9213d3ed0e02(InferenceContext* c) 326 .SetShapeFn([](InferenceContext* c) {
327 ShapeHandle input_shape = c->input(3);
328 ShapeHandle output_indices =
329 c->Matrix(InferenceContext::kUnknownDim, c->NumElements(input_shape));
330 ShapeHandle output_values = c->Vector(InferenceContext::kUnknownDim);
331 ShapeHandle output_shape = input_shape;
332
333 // Copy the outputs into the output ranges.
334 int num_splits = c->num_outputs() / 3;
335 int out_idx = 0;
336 for (int i = 0; i < num_splits; ++i)
337 c->set_output(out_idx++, output_indices);
338 for (int i = 0; i < num_splits; ++i)
339 c->set_output(out_idx++, output_values);
340 for (int i = 0; i < num_splits; ++i)
341 c->set_output(out_idx++, output_shape);
342 return Status::OK();
343 });
344
345 REGISTER_OP("SparseSliceGrad")
346 .Input("backprop_val_grad: T")
347 .Input("input_indices: int64")
348 .Input("input_start: int64")
349 .Input("output_indices: int64")
350 .Output("val_grad: T")
351 .Attr("T: numbertype")
__anon9213d3ed0f02(InferenceContext* c) 352 .SetShapeFn([](InferenceContext* c) {
353 ShapeHandle indices;
354 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &indices));
355 c->set_output(0, c->Vector(c->Dim(indices, 0)));
356 return Status::OK();
357 });
358
359 REGISTER_OP("SparseSlice")
360 .Input("indices: int64")
361 .Input("values: T")
362 .Input("shape: int64")
363 .Input("start: int64")
364 .Input("size: int64")
365 .Output("output_indices: int64")
366 .Output("output_values: T")
367 .Output("output_shape: int64")
368 .Attr("T: type")
__anon9213d3ed1002(InferenceContext* c) 369 .SetShapeFn([](InferenceContext* c) {
370 ShapeHandle input_shape = c->input(2);
371 ShapeHandle output_indices =
372 c->Matrix(InferenceContext::kUnknownDim, c->NumElements(input_shape));
373 ShapeHandle output_values = c->Vector(InferenceContext::kUnknownDim);
374 ShapeHandle output_shape = input_shape;
375
376 c->set_output(0, output_indices);
377 c->set_output(1, output_values);
378 c->set_output(2, output_shape);
379 return Status::OK();
380 });
381
382 REGISTER_OP("SparseReorder")
383 .Input("input_indices: int64")
384 .Input("input_values: T")
385 .Input("input_shape: int64")
386 .Output("output_indices: int64")
387 .Output("output_values: T")
388 .Attr("T: type")
__anon9213d3ed1102(InferenceContext* c) 389 .SetShapeFn([](InferenceContext* c) {
390 ShapeHandle indices;
391 ShapeHandle values;
392 ShapeHandle unused;
393
394 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &indices));
395 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &values));
396 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
397
398 c->set_output(0, indices);
399 c->set_output(1, values);
400 return Status::OK();
401 });
402
403 REGISTER_OP("SparseReshape")
404 .Input("input_indices: int64")
405 .Input("input_shape: int64")
406 .Input("new_shape: int64")
407 .Output("output_indices: int64")
408 .Output("output_shape: int64")
__anon9213d3ed1202(InferenceContext* c) 409 .SetShapeFn([](InferenceContext* c) {
410 ShapeHandle indices;
411 ShapeHandle unused;
412 ShapeHandle new_shape;
413
414 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &indices));
415 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
416 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &new_shape));
417
418 c->set_output(0, c->Matrix(c->Dim(indices, 0), c->Dim(new_shape, 0)));
419 c->set_output(1, new_shape);
420 return Status::OK();
421 });
422
423 REGISTER_OP("SparseTensorDenseAdd")
424 .Input("a_indices: Tindices")
425 .Input("a_values: T")
426 .Input("a_shape: Tindices")
427 .Input("b: T")
428 .Output("output: T")
429 .Attr("T: numbertype")
430 .Attr("Tindices: {int32, int64}")
__anon9213d3ed1302(InferenceContext* c) 431 .SetShapeFn([](InferenceContext* c) {
432 c->set_output(0, c->input(3));
433 return Status::OK();
434 });
435
436 REGISTER_OP("SparseReduceMax")
437 .Input("input_indices: int64")
438 .Input("input_values: T")
439 .Input("input_shape: int64")
440 .Input("reduction_axes: int32")
441 .Attr("keep_dims: bool = False")
442 .Output("output: T")
443 .Attr("T: realnumbertype")
444 .SetShapeFn(shape_inference::SparseReduceShapeFn);
445
446 REGISTER_OP("SparseReduceMaxSparse")
447 .Input("input_indices: int64")
448 .Input("input_values: T")
449 .Input("input_shape: int64")
450 .Input("reduction_axes: int32")
451 .Attr("keep_dims: bool = False")
452 .Output("output_indices: int64")
453 .Output("output_values: T")
454 .Output("output_shape: int64")
455 .Attr("T: realnumbertype")
456 .SetShapeFn(shape_inference::UnknownShape);
457
458 REGISTER_OP("SparseReduceSum")
459 .Input("input_indices: int64")
460 .Input("input_values: T")
461 .Input("input_shape: int64")
462 .Input("reduction_axes: int32")
463 .Attr("keep_dims: bool = False")
464 .Output("output: T")
465 .Attr("T: numbertype")
466 .SetShapeFn(shape_inference::SparseReduceShapeFn);
467
468 REGISTER_OP("SparseReduceSumSparse")
469 .Input("input_indices: int64")
470 .Input("input_values: T")
471 .Input("input_shape: int64")
472 .Input("reduction_axes: int32")
473 .Attr("keep_dims: bool = False")
474 .Output("output_indices: int64")
475 .Output("output_values: T")
476 .Output("output_shape: int64")
477 .Attr("T: numbertype")
478 .SetShapeFn(shape_inference::UnknownShape);
479
480 #define SPARSE_DENSE_CWISE_SIGNATURE() \
481 Input("sp_indices: int64") \
482 .Input("sp_values: T") \
483 .Input("sp_shape: int64") \
484 .Input("dense: T") \
485 .Output("output: T") \
486 .Attr("T: numbertype") \
487 .SetShapeFn([](InferenceContext* c) { \
488 ShapeHandle input; \
489 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &input)); \
490 c->set_output(0, c->Vector(c->Dim(input, 0))); \
491 return Status::OK(); \
492 })
493
494 REGISTER_OP("SparseDenseCwiseMul").SPARSE_DENSE_CWISE_SIGNATURE();
495
496 REGISTER_OP("SparseDenseCwiseDiv").SPARSE_DENSE_CWISE_SIGNATURE();
497
498 REGISTER_OP("SparseDenseCwiseAdd").SPARSE_DENSE_CWISE_SIGNATURE();
499
500 #undef SPARSE_DENSE_CWISE_SIGNATURE
501
502 REGISTER_OP("SparseSoftmax")
503 .Input("sp_indices: int64")
504 .Input("sp_values: T")
505 .Input("sp_shape: int64")
506 .Output("output: T")
507 .Attr("T: {float, double}")
__anon9213d3ed1402(InferenceContext* c) 508 .SetShapeFn([](InferenceContext* c) {
509 ShapeHandle unused;
510 ShapeHandle values;
511 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused)); // sp_indices
512 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &values)); // sp_values
513 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
514 c->set_output(0, values);
515 return Status::OK();
516 });
517
518 REGISTER_OP("SparseSparseMaximum")
519 .Input("a_indices: int64")
520 .Input("a_values: T")
521 .Input("a_shape: int64")
522 .Input("b_indices: int64")
523 .Input("b_values: T")
524 .Input("b_shape: int64")
525 .Output("output_indices: int64")
526 .Output("output_values: T")
527 .Attr("T: realnumbertype")
528 .SetShapeFn(SparseSparseMinOrMaxShapeFn);
529
530 REGISTER_OP("SparseSparseMinimum")
531 .Input("a_indices: int64")
532 .Input("a_values: T")
533 .Input("a_shape: int64")
534 .Input("b_indices: int64")
535 .Input("b_values: T")
536 .Input("b_shape: int64")
537 .Output("output_indices: int64")
538 .Output("output_values: T")
539 .Attr("T: numbertype")
540 .SetShapeFn(SparseSparseMinOrMaxShapeFn);
541
542 REGISTER_OP("AddSparseToTensorsMap")
543 .Input("sparse_indices: int64")
544 .Input("sparse_values: T")
545 .Input("sparse_shape: int64")
546 .Output("sparse_handle: int64")
547 .Attr("T: type")
548 .Attr("container: string = ''")
549 .Attr("shared_name: string = ''")
550 .SetIsStateful()
__anon9213d3ed1502(InferenceContext* c) 551 .SetShapeFn([](InferenceContext* c) {
552 ShapeHandle unused;
553 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused));
554 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
555 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
556 c->set_output(0, c->Scalar());
557 return Status::OK();
558 });
559
560 REGISTER_OP("AddManySparseToTensorsMap")
561 .Input("sparse_indices: int64")
562 .Input("sparse_values: T")
563 .Input("sparse_shape: int64")
564 .Output("sparse_handles: int64")
565 .Attr("T: type")
566 .Attr("container: string = ''")
567 .Attr("shared_name: string = ''")
568 .SetIsStateful()
__anon9213d3ed1602(InferenceContext* c) 569 .SetShapeFn([](InferenceContext* c) {
570 ShapeHandle unused;
571 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused));
572 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
573 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
574 c->set_output(0, c->Vector(InferenceContext::kUnknownDim));
575 return Status::OK();
576 });
577
578 REGISTER_OP("TakeManySparseFromTensorsMap")
579 .Input("sparse_handles: int64")
580 .Output("sparse_indices: int64")
581 .Output("sparse_values: dtype")
582 .Output("sparse_shape: int64")
583 .Attr("dtype: type")
584 .Attr("container: string = ''")
585 .Attr("shared_name: string = ''")
586 .SetIsStateful()
__anon9213d3ed1702(InferenceContext* c) 587 .SetShapeFn([](InferenceContext* c) {
588 // serialized sparse is [?,1] matrix.
589 ShapeHandle sparse_handles;
590 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &sparse_handles));
591
592 c->set_output(0, c->Matrix(InferenceContext::kUnknownDim,
593 InferenceContext::kUnknownDim));
594 c->set_output(1, c->Vector(InferenceContext::kUnknownDim));
595 c->set_output(2, c->Vector(InferenceContext::kUnknownDim));
596 return Status::OK();
597 });
598
599 REGISTER_OP("SparseFillEmptyRows")
600 .Input("indices: int64")
601 .Input("values: T")
602 .Input("dense_shape: int64")
603 .Input("default_value: T")
604 .Output("output_indices: int64")
605 .Output("output_values: T")
606 .Output("empty_row_indicator: bool")
607 .Output("reverse_index_map: int64")
608 .Attr("T: type")
__anon9213d3ed1802(InferenceContext* c) 609 .SetShapeFn([](InferenceContext* c) {
610 ShapeHandle input_indices = c->input(0);
611 TF_RETURN_IF_ERROR(c->WithRank(input_indices, 2, &input_indices));
612 ShapeHandle input_values = c->input(1);
613 TF_RETURN_IF_ERROR(c->WithRank(input_values, 1, &input_values));
614 ShapeHandle input_shape = c->input(2);
615 TF_RETURN_IF_ERROR(c->WithRank(input_shape, 1, &input_shape));
616 ShapeHandle default_value = c->input(3);
617 TF_RETURN_IF_ERROR(c->WithRank(default_value, 0, &default_value));
618 DimensionHandle N = c->Dim(input_indices, 0);
619 TF_RETURN_IF_ERROR(c->Merge(N, c->Dim(input_values, 0), &N));
620 DimensionHandle unused_dim;
621 TF_RETURN_IF_ERROR(c->Merge(c->Dim(input_indices, 1),
622 c->Dim(input_shape, 0), &unused_dim));
623 if (c->Value(c->NumElements(input_shape)) == 0)
624 return errors::InvalidArgument("dense_shape must not be empty");
625 ShapeHandle output_indices =
626 c->Matrix(InferenceContext::kUnknownDim, c->NumElements(input_shape));
627 ShapeHandle output_values = c->Vector(InferenceContext::kUnknownDim);
628 ShapeHandle constant_input_shape;
629 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(2, &constant_input_shape));
630 ShapeHandle empty_row_indicator =
631 c->Vector(c->Dim(constant_input_shape, 0));
632 ShapeHandle reverse_index_map = c->Vector(N);
633 c->set_output(0, output_indices);
634 c->set_output(1, output_values);
635 c->set_output(2, empty_row_indicator);
636 c->set_output(3, reverse_index_map);
637 return Status::OK();
638 });
639
640 REGISTER_OP("SparseFillEmptyRowsGrad")
641 .Input("reverse_index_map: int64")
642 .Input("grad_values: T")
643 .Output("d_values: T")
644 .Output("d_default_value: T")
645 .Attr("T: type")
__anon9213d3ed1902(InferenceContext* c) 646 .SetShapeFn([](InferenceContext* c) {
647 ShapeHandle reverse_index_map = c->input(0);
648 TF_RETURN_IF_ERROR(c->WithRank(reverse_index_map, 1, &reverse_index_map));
649 ShapeHandle grad_values = c->input(1);
650 TF_RETURN_IF_ERROR(c->WithRank(grad_values, 1, &grad_values));
651 c->set_output(0, reverse_index_map);
652 c->set_output(1, c->Scalar());
653 return Status::OK();
654 });
655
656 } // namespace tensorflow
657