• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/common_shape_fns.h"
17 #include "tensorflow/core/framework/numeric_op.h"
18 #include "tensorflow/core/framework/op.h"
19 #include "tensorflow/core/framework/shape_inference.h"
20 #include "tensorflow/core/lib/core/errors.h"
21 
22 namespace tensorflow {
23 
24 using shape_inference::DimensionHandle;
25 using shape_inference::InferenceContext;
26 using shape_inference::ShapeAndType;
27 using shape_inference::ShapeHandle;
28 
GetVariantInput(InferenceContext * c,int index,ShapeAndType * shape_and_type)29 Status GetVariantInput(InferenceContext* c, int index,
30                        ShapeAndType* shape_and_type) {
31   ShapeHandle variant;
32   TF_RETURN_IF_ERROR(c->WithRank(c->input(index), 0, &variant));
33   auto* shapes_and_types = c->input_handle_shapes_and_types(index);
34   if (shapes_and_types == nullptr || shapes_and_types->size() != 1) {
35     return errors::InvalidArgument(
36         "Unable to access shape and type info from variant input ", index);
37   }
38   *shape_and_type = shapes_and_types->at(0);
39   return OkStatus();
40 }
41 
42 // Validates that a shape represents a (rank-2) square matrix or a (rank-3)
43 // batch of square matrices.
ValidateSquareMatrixShape(InferenceContext * c,const ShapeHandle & matrix_shape,DimensionHandle * matrix_dimension)44 Status ValidateSquareMatrixShape(InferenceContext* c,
45                                  const ShapeHandle& matrix_shape,
46                                  DimensionHandle* matrix_dimension) {
47   ShapeHandle out;
48   TF_RETURN_IF_ERROR(c->WithRankAtLeast(matrix_shape, 2, &out));
49   TF_RETURN_IF_ERROR(c->WithRankAtMost(matrix_shape, 3, &out));
50   if (!c->RankKnown(matrix_shape)) {
51     return errors::Internal("Sparse matrix has an unknown rank.");
52   }
53 
54   TF_RETURN_IF_ERROR(c->Merge(c->Dim(matrix_shape, -2),
55                               c->Dim(matrix_shape, -1), matrix_dimension));
56   return OkStatus();
57 }
58 
59 REGISTER_OP("SparseTensorToCSRSparseMatrix")
60     .Input("indices: int64")
61     .Input("values: T")
62     .Input("dense_shape: int64")
63     .Attr("T: {float, double, complex64, complex128}")
64     .Output("sparse_matrix: variant")
__anonc9a48f920102(InferenceContext* c) 65     .SetShapeFn([](InferenceContext* c) {
66       TF_RETURN_IF_ERROR(shape_inference::ValidateSparseTensor(
67           c, c->input(0), c->input(1), c->input(2)));
68       auto rank = c->Value(c->Dim(c->input(0), 1));
69       ShapeHandle dense_shape;
70       TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(2, &dense_shape));
71       TF_RETURN_IF_ERROR(c->WithRank(dense_shape, rank, &dense_shape));
72       if (!c->RankKnown(dense_shape) || c->Rank(dense_shape) < 2 ||
73           c->Rank(dense_shape) > 3) {
74         return errors::InvalidArgument(
75             "Invalid rank: ", c->Rank(dense_shape),
76             ".  Expected a known rank of either 2 or 3.");
77       }
78 
79       DataType dtype;
80       TF_RETURN_IF_ERROR(c->GetAttr("T", &dtype));
81       c->set_output(0, c->Scalar());
82       c->set_output_handle_shapes_and_types(0,
83                                             {ShapeAndType{dense_shape, dtype}});
84       return OkStatus();
85     });
86 
87 REGISTER_OP("CSRSparseMatrixToSparseTensor")
88     .Input("sparse_matrix: variant")
89     .Output("indices: int64")
90     .Output("values: type")
91     .Output("dense_shape: int64")
92     .Attr("type: {float, double, complex64, complex128}")
__anonc9a48f920202(InferenceContext* c) 93     .SetShapeFn([](InferenceContext* c) {
94       ShapeAndType sparse_matrix_shape_and_type;
95       TF_RETURN_IF_ERROR(GetVariantInput(c, 0, &sparse_matrix_shape_and_type));
96       ShapeHandle sparse_matrix = sparse_matrix_shape_and_type.shape;
97       TF_RETURN_IF_ERROR(c->WithRankAtMost(sparse_matrix, 3, &sparse_matrix));
98       if (!c->RankKnown(sparse_matrix)) {
99         return errors::InvalidArgument("sparse_matrix has an unknown rank.");
100       }
101       int rank = c->Rank(sparse_matrix);
102       ShapeHandle indices = c->Matrix(c->UnknownDim(), rank);
103       ShapeHandle values = c->Vector(c->UnknownDim());
104       ShapeHandle dense_shape = c->Vector(rank);
105       c->set_output(0, indices);
106       c->set_output(1, values);
107       c->set_output(2, dense_shape);
108       return OkStatus();
109     });
110 
111 REGISTER_OP("DenseToCSRSparseMatrix")
112     .Input("dense_input: T")
113     .Input("indices: int64")
114     .Attr("T: {float, double, complex64, complex128}")
115     .Output("sparse_output: variant")
__anonc9a48f920302(InferenceContext* c) 116     .SetShapeFn([](InferenceContext* c) {
117       ShapeHandle dense_shape = c->input(0);
118       if (!c->RankKnown(dense_shape) || c->Rank(dense_shape) < 2 ||
119           c->Rank(dense_shape) > 3) {
120         return errors::InvalidArgument(
121             "Invalid rank of dense: ", c->Rank(dense_shape),
122             ".  Expected a known rank of either 2 or 3.");
123       }
124       auto rank = c->Rank(dense_shape);
125 
126       ShapeHandle indices = c->input(1);
127       if (!c->RankKnown(indices) || c->Rank(indices) != 2) {
128         return errors::InvalidArgument(
129             "indices must be a matrix; but its rank is not 2: ",
130             c->Rank(indices));
131       }
132       auto indices_col = c->Dim(indices, 1);
133       if (!c->ValueKnown(indices_col) || c->Value(indices_col) != rank) {
134         return errors::InvalidArgument(
135             "indices.shape[1] must match rank of dense; saw: ",
136             c->Value(indices_col), " vs. ", rank);
137       }
138       ShapeHandle fake_values_vec = c->Vector(c->Dim(indices, 0));
139       ShapeHandle fake_shape_shape = c->Vector(rank);
140       TF_RETURN_IF_ERROR(shape_inference::ValidateSparseTensor(
141           c, indices /*indices_shape*/, fake_values_vec /*values_shape*/,
142           fake_shape_shape /*shape_shape*/));
143       DataType dtype;
144       TF_RETURN_IF_ERROR(c->GetAttr("T", &dtype));
145       c->set_output_handle_shapes_and_types(0,
146                                             {ShapeAndType{dense_shape, dtype}});
147       c->set_output(0, c->Scalar());
148       return OkStatus();
149     });
150 
151 REGISTER_OP("CSRSparseMatrixToDense")
152     .Input("sparse_input: variant")
153     .Output("dense_output: type")
154     .Attr("type: {float, double, complex64, complex128}")
__anonc9a48f920402(InferenceContext* c) 155     .SetShapeFn([](InferenceContext* c) {
156       ShapeAndType sparse_matrix_shape_and_type;
157       TF_RETURN_IF_ERROR(GetVariantInput(c, 0, &sparse_matrix_shape_and_type));
158       ShapeHandle sparse_matrix = sparse_matrix_shape_and_type.shape;
159       TF_RETURN_IF_ERROR(c->WithRankAtMost(sparse_matrix, 3, &sparse_matrix));
160       if (!c->RankKnown(sparse_matrix)) {
161         return errors::InvalidArgument("sparse_matrix has an unknown rank.");
162       }
163       c->set_output(0, sparse_matrix);
164       return OkStatus();
165     });
166 
167 REGISTER_OP("CSRSparseMatrixComponents")
168     .Input("csr_sparse_matrix: variant")
169     .Input("index: int32")
170     .Output("row_ptrs: int32")
171     .Output("col_inds: int32")
172     .Output("values: type")
173     .Attr("type: {float, double, complex64, complex128}")
__anonc9a48f920502(InferenceContext* c) 174     .SetShapeFn([](InferenceContext* c) {
175       ShapeAndType sparse_matrix_shape_and_type;
176       TF_RETURN_IF_ERROR(GetVariantInput(c, 0, &sparse_matrix_shape_and_type));
177       ShapeHandle csr_sparse_matrix = sparse_matrix_shape_and_type.shape;
178       TF_RETURN_IF_ERROR(
179           c->WithRankAtLeast(csr_sparse_matrix, 2, &csr_sparse_matrix));
180       TF_RETURN_IF_ERROR(
181           c->WithRankAtMost(csr_sparse_matrix, 3, &csr_sparse_matrix));
182       ShapeHandle index;
183       if (c->Rank(c->input(1)) != 0) {
184         return errors::InvalidArgument("index must be a scalar.");
185       }
186       if (!c->RankKnown(csr_sparse_matrix)) {
187         return errors::InvalidArgument(
188             "csr_sparse_matrix has an unknown rank.");
189       }
190       auto row_ptrs_dh = c->Dim(csr_sparse_matrix, -2);
191       TF_RETURN_IF_ERROR(c->Add(row_ptrs_dh, 1, &row_ptrs_dh));
192       ShapeHandle row_ptrs = c->Vector(row_ptrs_dh);
193       c->set_output(0, row_ptrs);
194       c->set_output(1, c->Vector(c->UnknownDim()));
195       c->set_output(2, c->Vector(c->UnknownDim()));
196       return OkStatus();
197     });
198 
199 REGISTER_OP("SparseMatrixNNZ")
200     .Input("sparse_matrix: variant")
201     .Output("nnz: int32")
__anonc9a48f920602(InferenceContext* c) 202     .SetShapeFn([](InferenceContext* c) {
203       ShapeAndType sparse_matrix_shape_and_type;
204       TF_RETURN_IF_ERROR(GetVariantInput(c, 0, &sparse_matrix_shape_and_type));
205       ShapeHandle sparse_matrix = sparse_matrix_shape_and_type.shape;
206       TF_RETURN_IF_ERROR(c->WithRankAtLeast(sparse_matrix, 2, &sparse_matrix));
207       TF_RETURN_IF_ERROR(c->WithRankAtMost(sparse_matrix, 3, &sparse_matrix));
208       if (!c->RankKnown(sparse_matrix)) {
209         return errors::InvalidArgument("sparse_matrix has an unknown rank.");
210       }
211       ShapeHandle out;
212       if (c->Rank(sparse_matrix) == 3) {
213         out = c->Vector(c->Dim(sparse_matrix, 0));
214       } else {
215         out = c->Scalar();
216       }
217       c->set_output(0, out);
218       return OkStatus();
219     });
220 
221 REGISTER_OP("SparseMatrixMatMul")
222     .Input("a: variant")
223     .Input("b: T")
224     .Attr("T: type")
225     .Attr("transpose_a: bool = false")
226     .Attr("transpose_b: bool = false")
227     .Attr("adjoint_a: bool = false")
228     .Attr("adjoint_b: bool = false")
229     .Attr("transpose_output: bool = false")
230     .Attr("conjugate_output: bool = false")
231     .Output("output: T")
__anonc9a48f920702(InferenceContext* c) 232     .SetShapeFn([](InferenceContext* c) {
233       ShapeAndType sparse_matrix_shape_and_type;
234       TF_RETURN_IF_ERROR(GetVariantInput(c, 0, &sparse_matrix_shape_and_type));
235       ShapeHandle a_shape = sparse_matrix_shape_and_type.shape;
236       TF_RETURN_IF_ERROR(c->WithRankAtLeast(a_shape, 2, &a_shape));
237       TF_RETURN_IF_ERROR(c->WithRankAtMost(a_shape, 3, &a_shape));
238       if (!c->RankKnown(a_shape)) {
239         return errors::Internal("a has an unknown rank.");
240       }
241       ShapeHandle b_shape;
242       TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 2, &b_shape));
243       TF_RETURN_IF_ERROR(c->WithRankAtMost(b_shape, 3, &b_shape));
244 
245       bool transpose_a = false;
246       bool transpose_b = false;
247       bool transpose_output = false;
248 
249       // TODO(ebrevdo): Add transpose support.
250       TF_RETURN_IF_ERROR(c->GetAttr("transpose_a", &transpose_a));
251       TF_RETURN_IF_ERROR(c->GetAttr("transpose_b", &transpose_b));
252       TF_RETURN_IF_ERROR(c->GetAttr("transpose_output", &transpose_output));
253 
254       bool adjoint_a = false;
255       bool adjoint_b = false;
256       TF_RETURN_IF_ERROR(c->GetAttr("adjoint_a", &adjoint_a));
257       TF_RETURN_IF_ERROR(c->GetAttr("adjoint_b", &adjoint_b));
258       if (adjoint_a && transpose_a) {
259         return errors::InvalidArgument(
260             "Only one of adjoint_a and transpose_a may be true.");
261       }
262       if (adjoint_b && transpose_b) {
263         return errors::InvalidArgument(
264             "Only one of adjoint_b and transpose_b may be true.");
265       }
266       transpose_a = transpose_a || adjoint_a;
267       transpose_b = transpose_b || adjoint_b;
268 
269       auto output_rows = c->Dim(a_shape, transpose_a ? -1 : -2);
270       auto output_cols = c->Dim(b_shape, transpose_b ? -2 : -1);
271       if (transpose_output) {
272         std::tie(output_rows, output_cols) =
273             std::make_tuple(output_cols, output_rows);
274       }
275 
276       // Batch dims match between inputs.
277       ShapeHandle a_batch_dims;
278       ShapeHandle b_batch_dims;
279       ShapeHandle batch_dims;
280       TF_RETURN_IF_ERROR(c->Subshape(a_shape, 0, -2, &a_batch_dims));
281       TF_RETURN_IF_ERROR(c->Subshape(b_shape, 0, -2, &b_batch_dims));
282       TF_RETURN_IF_ERROR(c->Merge(a_batch_dims, b_batch_dims, &batch_dims));
283 
284       // Assert inner dims match.
285       shape_inference::DimensionHandle unused;
286       TF_RETURN_IF_ERROR(c->Merge(c->Dim(a_shape, transpose_a ? -2 : -1),
287                                   c->Dim(b_shape, transpose_b ? -1 : -2),
288                                   &unused));
289 
290       ShapeHandle out;
291       TF_RETURN_IF_ERROR(c->Concatenate(
292           batch_dims, c->Matrix(output_rows, output_cols), &out));
293 
294       c->set_output(0, out);
295       return OkStatus();
296     });
297 
298 REGISTER_OP("SparseMatrixMul")
299     .Input("a: variant")
300     .Input("b: T")
301     .Attr("T: type")
302     .Output("output: variant")
__anonc9a48f920802(InferenceContext* c) 303     .SetShapeFn([](InferenceContext* c) {
304       ShapeAndType sparse_matrix_shape_and_type;
305       TF_RETURN_IF_ERROR(GetVariantInput(c, 0, &sparse_matrix_shape_and_type));
306       ShapeHandle a_shape = sparse_matrix_shape_and_type.shape;
307       TF_RETURN_IF_ERROR(c->WithRankAtMost(a_shape, 3, &a_shape));
308       if (!c->RankKnown(a_shape)) {
309         return errors::Internal("a has an unknown rank.");
310       }
311       ShapeHandle b_shape;
312       TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(1), 3, &b_shape));
313       if (!c->RankKnown(b_shape)) {
314         return errors::Internal("b has an unknown rank.");
315       }
316       ShapeHandle out;
317       if (c->Rank(b_shape) == 0) {
318         out = a_shape;
319       } else if (c->Rank(b_shape) == 3) {
320         if (c->Rank(a_shape) != 3) {
321           return errors::Unimplemented("rank of b is 3 but rank of a is not.");
322         }
323         if (!(c->Value(c->Dim(b_shape, 1)) == 1 &&
324               c->Value(c->Dim(b_shape, 2)) == 1)) {
325           return errors::Unimplemented(
326               "b must be a scalar or shaped [batch_size, 1, 1]");
327         }
328         DimensionHandle batch_size = c->Dim(a_shape, 0);
329         TF_RETURN_IF_ERROR(
330             c->Merge(batch_size, c->Dim(b_shape, 0), &batch_size));
331         TF_RETURN_IF_ERROR(c->ReplaceDim(b_shape, 0, batch_size, &b_shape));
332         TF_RETURN_IF_ERROR(c->ReplaceDim(a_shape, 0, batch_size, &a_shape));
333         out = a_shape;
334       } else {
335         return errors::Unimplemented(
336             "b must be a scalar or shaped [batch_size, 1, 1]");
337       }
338       c->set_output_handle_shapes_and_types(
339           0, {ShapeAndType{out, sparse_matrix_shape_and_type.dtype}});
340       c->set_output(0, c->Scalar());
341       return OkStatus();
342     });
343 
344 REGISTER_OP("SparseMatrixAdd")
345     .Input("a: variant")
346     .Input("b: variant")
347     .Input("alpha: T")
348     .Input("beta: T")
349     .Attr("T: {float, double, complex64, complex128}")
350     .Output("c: variant")
__anonc9a48f920902(InferenceContext* c) 351     .SetShapeFn([](InferenceContext* c) {
352       // alpha and beta are scalars.
353       ShapeHandle unused_scalar_shape;
354       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused_scalar_shape));
355       TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused_scalar_shape));
356 
357       ShapeAndType sparse_matrix_shape_and_type;
358       TF_RETURN_IF_ERROR(GetVariantInput(c, 0, &sparse_matrix_shape_and_type));
359       ShapeHandle a_shape = sparse_matrix_shape_and_type.shape;
360       TF_RETURN_IF_ERROR(c->WithRankAtLeast(a_shape, 2, &a_shape));
361       TF_RETURN_IF_ERROR(c->WithRankAtMost(a_shape, 3, &a_shape));
362       if (!c->RankKnown(a_shape)) {
363         return errors::InvalidArgument("a has an unknown rank.");
364       }
365 
366       TF_RETURN_IF_ERROR(GetVariantInput(c, 1, &sparse_matrix_shape_and_type));
367       ShapeHandle b_shape = sparse_matrix_shape_and_type.shape;
368       TF_RETURN_IF_ERROR(c->WithRankAtLeast(b_shape, 2, &b_shape));
369       TF_RETURN_IF_ERROR(c->WithRankAtMost(b_shape, 3, &b_shape));
370       if (!c->RankKnown(b_shape)) {
371         return errors::InvalidArgument("b has an unknown rank.");
372       }
373       ShapeHandle out;
374       TF_RETURN_IF_ERROR(c->Merge(a_shape, b_shape, &out));
375       c->set_output_handle_shapes_and_types(
376           0, {ShapeAndType{out, sparse_matrix_shape_and_type.dtype}});
377       c->set_output(0, c->Scalar());
378       return OkStatus();
379     });
380 
381 REGISTER_OP("SparseMatrixSparseMatMul")
382     .Input("a: variant")
383     .Input("b: variant")
384     .Attr("type: {float, double, complex64, complex128}")
385     .Attr("transpose_a: bool = false")
386     .Attr("transpose_b: bool = false")
387     .Attr("adjoint_a: bool = false")
388     .Attr("adjoint_b: bool = false")
389     .Output("c: variant")
__anonc9a48f920a02(InferenceContext* c) 390     .SetShapeFn([](InferenceContext* c) {
391       ShapeAndType sparse_matrix_shape_and_type;
392       TF_RETURN_IF_ERROR(GetVariantInput(c, 0, &sparse_matrix_shape_and_type));
393       ShapeHandle a_shape = sparse_matrix_shape_and_type.shape;
394       TF_RETURN_IF_ERROR(c->WithRankAtLeast(a_shape, 2, &a_shape));
395       TF_RETURN_IF_ERROR(c->WithRankAtMost(a_shape, 3, &a_shape));
396       if (!c->RankKnown(a_shape)) {
397         return errors::Internal("a has an unknown rank.");
398       }
399 
400       TF_RETURN_IF_ERROR(GetVariantInput(c, 1, &sparse_matrix_shape_and_type));
401       ShapeHandle b_shape = sparse_matrix_shape_and_type.shape;
402       TF_RETURN_IF_ERROR(c->WithRankAtLeast(b_shape, 2, &b_shape));
403       TF_RETURN_IF_ERROR(c->WithRankAtMost(b_shape, 3, &b_shape));
404       if (!c->RankKnown(b_shape)) {
405         return errors::Internal("b has an unknown rank.");
406       }
407 
408       bool transpose_a = false;
409       bool transpose_b = false;
410       TF_RETURN_IF_ERROR(c->GetAttr("transpose_a", &transpose_a));
411       TF_RETURN_IF_ERROR(c->GetAttr("transpose_b", &transpose_b));
412       bool adjoint_a = false;
413       bool adjoint_b = false;
414       TF_RETURN_IF_ERROR(c->GetAttr("adjoint_a", &adjoint_a));
415       TF_RETURN_IF_ERROR(c->GetAttr("adjoint_b", &adjoint_b));
416       if (adjoint_a && transpose_a) {
417         return errors::InvalidArgument(
418             "Only one of adjoint_a and transpose_a may be true.");
419       } else if (adjoint_b && transpose_b) {
420         return errors::InvalidArgument(
421             "Only one of adjoint_b and transpose_b may be true.");
422       }
423       transpose_a = transpose_a || adjoint_a;
424       transpose_b = transpose_b || adjoint_b;
425 
426       auto output_rows = c->Dim(a_shape, transpose_a ? -1 : -2);
427       auto output_cols = c->Dim(b_shape, transpose_b ? -2 : -1);
428 
429       // Batch dims match between inputs.
430       ShapeHandle a_batch_dims;
431       ShapeHandle b_batch_dims;
432       ShapeHandle batch_dims;
433       TF_RETURN_IF_ERROR(c->Subshape(a_shape, 0, -2, &a_batch_dims));
434       TF_RETURN_IF_ERROR(c->Subshape(b_shape, 0, -2, &b_batch_dims));
435       TF_RETURN_IF_ERROR(c->Merge(a_batch_dims, b_batch_dims, &batch_dims));
436 
437       // Assert inner dims match.
438       shape_inference::DimensionHandle unused;
439       TF_RETURN_IF_ERROR(c->Merge(c->Dim(a_shape, transpose_a ? -2 : -1),
440                                   c->Dim(b_shape, transpose_b ? -1 : -2),
441                                   &unused));
442 
443       ShapeHandle out;
444       TF_RETURN_IF_ERROR(c->Concatenate(
445           batch_dims, c->Matrix(output_rows, output_cols), &out));
446 
447       c->set_output_handle_shapes_and_types(
448           0, {ShapeAndType{out, sparse_matrix_shape_and_type.dtype}});
449       c->set_output(0, c->Scalar());
450       return OkStatus();
451     });
452 
453 REGISTER_OP("SparseMatrixZeros")
454     .Input("dense_shape: int64")
455     .Attr("type: {float, double, complex64, complex128}")
456     .Output("sparse_matrix: variant")
__anonc9a48f920b02(InferenceContext* c) 457     .SetShapeFn([](InferenceContext* c) {
458       auto rank = c->NumElements(c->input(0));
459       ShapeHandle dense_shape;
460       TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &dense_shape));
461       TF_RETURN_IF_ERROR(
462           c->WithRank(dense_shape, c->Value(rank), &dense_shape));
463       if (!c->RankKnown(dense_shape) || c->Rank(dense_shape) < 2 ||
464           c->Rank(dense_shape) > 3) {
465         return errors::InvalidArgument(
466             "Invalid rank: ", c->Rank(dense_shape),
467             ".  Expected a known rank of either 2 or 3.");
468       }
469       DataType dtype;
470       TF_RETURN_IF_ERROR(c->GetAttr("type", &dtype));
471       c->set_output_handle_shapes_and_types(0,
472                                             {ShapeAndType{dense_shape, dtype}});
473       c->set_output(0, c->Scalar());
474       return OkStatus();
475     });
476 
477 REGISTER_OP("SparseMatrixTranspose")
478     .Input("input: variant")
479     .Attr("conjugate: bool = false")
480     .Attr("type: {float, double, complex64, complex128}")
481     .Output("output: variant")
__anonc9a48f920c02(InferenceContext* c) 482     .SetShapeFn([](InferenceContext* c) {
483       ShapeAndType sparse_matrix_shape_and_type;
484       TF_RETURN_IF_ERROR(GetVariantInput(c, 0, &sparse_matrix_shape_and_type));
485       ShapeHandle input = sparse_matrix_shape_and_type.shape;
486       TF_RETURN_IF_ERROR(c->WithRankAtLeast(input, 2, &input));
487       TF_RETURN_IF_ERROR(c->WithRankAtMost(input, 3, &input));
488       if (!c->RankKnown(input)) {
489         return errors::InvalidArgument("input has an unknown rank.");
490       }
491       ShapeHandle output;
492       if (c->Rank(input) == 2) {
493         output = c->Matrix(c->Dim(input, 1), c->Dim(input, 0));
494       } else {
495         output = c->MakeShape(
496             {c->Dim(input, 0), c->Dim(input, 2), c->Dim(input, 1)});
497       }
498       c->set_output_handle_shapes_and_types(
499           0, {ShapeAndType{output, sparse_matrix_shape_and_type.dtype}});
500       c->set_output(0, c->Scalar());
501 
502       return OkStatus();
503     });
504 
505 REGISTER_OP("SparseMatrixSoftmax")
506     .Input("logits: variant")
507     .Attr("type: {float, double}")
508     .Output("softmax: variant")
__anonc9a48f920d02(InferenceContext* c) 509     .SetShapeFn([](InferenceContext* c) {
510       ShapeAndType sparse_matrix_shape_and_type;
511       TF_RETURN_IF_ERROR(GetVariantInput(c, 0, &sparse_matrix_shape_and_type));
512       ShapeHandle logits = sparse_matrix_shape_and_type.shape;
513       TF_RETURN_IF_ERROR(c->WithRankAtLeast(logits, 2, &logits));
514       TF_RETURN_IF_ERROR(c->WithRankAtMost(logits, 3, &logits));
515       if (!c->RankKnown(logits)) {
516         return errors::InvalidArgument("logits has an unknown rank.");
517       }
518       c->set_output_handle_shapes_and_types(
519           0, {ShapeAndType{logits, sparse_matrix_shape_and_type.dtype}});
520       c->set_output(0, c->Scalar());
521       return OkStatus();
522     });
523 
524 REGISTER_OP("SparseMatrixSoftmaxGrad")
525     .Input("softmax: variant")
526     .Input("grad_softmax: variant")
527     .Attr("type: {float, double}")
528     .Output("gradient: variant")
__anonc9a48f920e02(InferenceContext* c) 529     .SetShapeFn([](InferenceContext* c) {
530       ShapeAndType sparse_matrix_shape_and_type;
531       TF_RETURN_IF_ERROR(GetVariantInput(c, 0, &sparse_matrix_shape_and_type));
532       ShapeHandle softmax = sparse_matrix_shape_and_type.shape;
533       TF_RETURN_IF_ERROR(c->WithRankAtLeast(softmax, 2, &softmax));
534       TF_RETURN_IF_ERROR(c->WithRankAtMost(softmax, 3, &softmax));
535       if (!c->RankKnown(softmax)) {
536         return errors::InvalidArgument("softmax has an unknown rank.");
537       }
538       TF_RETURN_IF_ERROR(GetVariantInput(c, 1, &sparse_matrix_shape_and_type));
539       ShapeHandle grad_softmax = sparse_matrix_shape_and_type.shape;
540       TF_RETURN_IF_ERROR(c->WithRankAtLeast(grad_softmax, 2, &grad_softmax));
541       TF_RETURN_IF_ERROR(c->WithRankAtMost(grad_softmax, 3, &grad_softmax));
542       if (!c->RankKnown(grad_softmax)) {
543         return errors::InvalidArgument("grad_softmax has an unknown rank.");
544       }
545       TF_RETURN_IF_ERROR(c->Merge(softmax, grad_softmax, &softmax));
546       c->set_output_handle_shapes_and_types(
547           0, {ShapeAndType{softmax, sparse_matrix_shape_and_type.dtype}});
548       c->set_output(0, c->Scalar());
549       return OkStatus();
550     });
551 
552 REGISTER_OP("SparseMatrixOrderingAMD")
553     .Input("input: variant")
554     .Output("output: int32")
__anonc9a48f920f02(InferenceContext* c) 555     .SetShapeFn([](InferenceContext* c) {
556       ShapeAndType sparse_matrix_shape_and_type;
557       TF_RETURN_IF_ERROR(GetVariantInput(c, 0, &sparse_matrix_shape_and_type));
558       ShapeHandle matrix_shape = sparse_matrix_shape_and_type.shape;
559       DimensionHandle n;
560       TF_RETURN_IF_ERROR(ValidateSquareMatrixShape(c, matrix_shape, &n));
561 
562       ShapeHandle output;
563       if (c->Rank(matrix_shape) == 2) {
564         output = c->Vector(c->Dim(matrix_shape, 0));
565       } else {
566         output = c->Matrix(c->Dim(matrix_shape, 0), c->Dim(matrix_shape, 1));
567       }
568       c->set_output(0, output);
569       return OkStatus();
570     });
571 
572 REGISTER_OP("SparseMatrixSparseCholesky")
573     .Input("input: variant")
574     .Input("permutation: int32")
575     .Attr("type: {float, double, complex64, complex128}")
576     .Output("output: variant")
__anonc9a48f921002(InferenceContext* c) 577     .SetShapeFn([](InferenceContext* c) {
578       ShapeAndType sparse_matrix_shape_and_type;
579       TF_RETURN_IF_ERROR(GetVariantInput(c, 0, &sparse_matrix_shape_and_type));
580       ShapeHandle matrix_shape = sparse_matrix_shape_and_type.shape;
581       DimensionHandle n;
582       TF_RETURN_IF_ERROR(ValidateSquareMatrixShape(c, matrix_shape, &n));
583 
584       ShapeHandle perm_shape;
585       TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &perm_shape));
586       TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(1), 2, &perm_shape));
587       if (!c->RankKnown(perm_shape)) {
588         return errors::Internal("permutation has an unknown rank.");
589       }
590 
591       // Each batch component of permutation must have the same number of
592       // elements as number of rows of sparse_matrix.
593       TF_RETURN_IF_ERROR(c->Merge(n, c->Dim(perm_shape, -1), &n));
594       ShapeHandle matrix_batch_shape;
595       ShapeHandle perm_batch_shape;
596 
597       // Make the common batch subshape.
598       TF_RETURN_IF_ERROR(c->Subshape(matrix_shape, 0, -2, &matrix_batch_shape));
599       TF_RETURN_IF_ERROR(c->Subshape(perm_shape, 0, -1, &perm_shape));
600       // Make sure the batch dimensions match between sparse_matrix and
601       // permutation.
602       TF_RETURN_IF_ERROR(
603           c->Merge(matrix_batch_shape, perm_batch_shape, &matrix_batch_shape));
604 
605       ShapeHandle out = matrix_shape;
606       c->set_output_handle_shapes_and_types(
607           0, {ShapeAndType{out, sparse_matrix_shape_and_type.dtype}});
608       c->set_output(0, c->Scalar());
609 
610       return OkStatus();
611     });
612 
613 }  // namespace tensorflow
614