• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/common_shape_fns.h"
17 #include "tensorflow/core/framework/op.h"
18 #include "tensorflow/core/framework/shape_inference.h"
19 
20 namespace tensorflow {
21 
22 using shape_inference::DimensionHandle;
23 using shape_inference::InferenceContext;
24 using shape_inference::ShapeHandle;
25 
26 namespace {
27 
28 // Return in <out> the result of making the end of <s> a square matrix.
MakeBatchSquareMatrix(InferenceContext * c,ShapeHandle input,ShapeHandle * out)29 Status MakeBatchSquareMatrix(InferenceContext* c, ShapeHandle input,
30                              ShapeHandle* out) {
31   ShapeHandle s;
32   TF_RETURN_IF_ERROR(c->WithRankAtLeast(input, 2, &s));
33 
34   DimensionHandle d;
35   TF_RETURN_IF_ERROR(c->Merge(c->Dim(s, -2), c->Dim(s, -1), &d));
36 
37   ShapeHandle batch_shape;
38   TF_RETURN_IF_ERROR(c->Subshape(s, 0, -2, &batch_shape));
39   TF_RETURN_IF_ERROR(c->Concatenate(batch_shape, c->Matrix(d, d), out));
40   return OkStatus();
41 }
42 
BatchUnchangedSquareShapeFn(InferenceContext * c)43 Status BatchUnchangedSquareShapeFn(InferenceContext* c) {
44   ShapeHandle out;
45   TF_RETURN_IF_ERROR(MakeBatchSquareMatrix(c, c->input(0), &out));
46   c->set_output(0, out);
47   return OkStatus();
48 }
49 
50 // The first input is [...,K,M] and second input is [...,M,N].
BandedTriangularSolveShapeFn(InferenceContext * c)51 Status BandedTriangularSolveShapeFn(InferenceContext* c) {
52   ShapeHandle lhs;
53   ShapeHandle rhs;
54 
55   TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &lhs));
56   TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 2, &rhs));
57 
58   // Check K > 0.
59   DimensionHandle num_bands = c->Dim(lhs, -2);
60   DimensionHandle m = c->Dim(lhs, -1);
61   if (c->ValueKnown(num_bands) && c->Value(num_bands) <= 0) {
62     return errors::InvalidArgument("Number of bands must be positive, but is ",
63                                    c->Value(num_bands));
64   }
65   if (c->ValueKnown(num_bands) && c->ValueKnown(m) &&
66       c->Value(num_bands) > c->Value(m)) {
67     return errors::InvalidArgument("Number of bands ", c->Value(num_bands),
68                                    " cannot exceed the size of the matrix ",
69                                    c->Value(m));
70   }
71 
72   ShapeHandle lhs_batch_shape;
73   ShapeHandle rhs_batch_shape;
74   ShapeHandle output_batch_shape;
75   // Make the common batch subshape.
76   TF_RETURN_IF_ERROR(c->Subshape(lhs, 0, -2, &lhs_batch_shape));
77   TF_RETURN_IF_ERROR(c->Subshape(rhs, 0, -2, &rhs_batch_shape));
78   TF_RETURN_IF_ERROR(BroadcastBinaryOpOutputShapeFnHelper(
79       c, lhs_batch_shape, rhs_batch_shape, true, &output_batch_shape));
80 
81   // lhs and rhs have the same value for M to be compatible.
82   TF_RETURN_IF_ERROR(c->Merge(m, c->Dim(rhs, -2), &m));
83 
84   // Build final shape (batch_shape + m + n) in <out>.
85   ShapeHandle out;
86   TF_RETURN_IF_ERROR(
87       c->Concatenate(output_batch_shape, c->Matrix(m, c->Dim(rhs, -1)), &out));
88 
89   c->set_output(0, out);
90   return OkStatus();
91 }
92 
93 // The first input is [...,M,N] and second input is either [...,M,K] or [...,M].
94 // Output is [...,N,K] or [...,N]. If <square>, then input is [...,M,M].
MatrixSolveShapeFn(InferenceContext * c,bool square)95 Status MatrixSolveShapeFn(InferenceContext* c, bool square) {
96   ShapeHandle lhs;
97   ShapeHandle rhs;
98   if (square) {
99     TF_RETURN_IF_ERROR(MakeBatchSquareMatrix(c, c->input(0), &lhs));
100   } else {
101     TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &lhs));
102   }
103   TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 2, &rhs));
104 
105   ShapeHandle lhs_batch_shape;
106   ShapeHandle rhs_batch_shape;
107   // Make the common batch subshape.
108   TF_RETURN_IF_ERROR(c->Subshape(lhs, 0, -2, &lhs_batch_shape));
109   TF_RETURN_IF_ERROR(c->Subshape(rhs, 0, -2, &rhs_batch_shape));
110   // Make sure the batch dimensions match between lhs and rhs.
111   TF_RETURN_IF_ERROR(
112       c->Merge(lhs_batch_shape, rhs_batch_shape, &lhs_batch_shape));
113 
114   DimensionHandle m;
115   // lhs and rhs have the same value for m to be compatible.
116   TF_RETURN_IF_ERROR(c->Merge(c->Dim(lhs, -2), c->Dim(rhs, -2), &m));
117   DimensionHandle n = c->Dim(lhs, -1);
118   if (square) {
119     TF_RETURN_IF_ERROR(c->Merge(m, n, &n));
120   }
121 
122   ShapeHandle out;
123   // Build final shape (batch_shape + n + k) in <out>.
124   TF_RETURN_IF_ERROR(c->Concatenate(lhs_batch_shape, c->Vector(n), &out));
125   TF_RETURN_IF_ERROR(c->Concatenate(out, c->Vector(c->Dim(rhs, -1)), &out));
126   c->set_output(0, out);
127   return OkStatus();
128 }
129 
130 // The first input is [...,M,M] and second input is [...,M,N].
131 // Output is [...,M,N].
MatrixTriangularSolveShapeFn(InferenceContext * c)132 Status MatrixTriangularSolveShapeFn(InferenceContext* c) {
133   ShapeHandle lhs;
134   ShapeHandle rhs;
135   TF_RETURN_IF_ERROR(MakeBatchSquareMatrix(c, c->input(0), &lhs));
136   TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 2, &rhs));
137 
138   ShapeHandle lhs_batch_shape;
139   ShapeHandle rhs_batch_shape;
140   ShapeHandle output_batch_shape;
141   // Make the common batch subshape.
142   TF_RETURN_IF_ERROR(c->Subshape(lhs, 0, -2, &lhs_batch_shape));
143   TF_RETURN_IF_ERROR(c->Subshape(rhs, 0, -2, &rhs_batch_shape));
144   TF_RETURN_IF_ERROR(BroadcastBinaryOpOutputShapeFnHelper(
145       c, lhs_batch_shape, rhs_batch_shape, true, &output_batch_shape));
146   DimensionHandle m;
147   // lhs and rhs have the same value for m to be compatible.
148   TF_RETURN_IF_ERROR(c->Merge(c->Dim(lhs, -1), c->Dim(rhs, -2), &m));
149 
150   ShapeHandle out;
151   // Build final shape (batch_shape + m + n) in <out>.
152   TF_RETURN_IF_ERROR(
153       c->Concatenate(output_batch_shape, c->Matrix(m, c->Dim(rhs, -1)), &out));
154   c->set_output(0, out);
155   return OkStatus();
156 }
157 
158 // Input is [...,N,N]. Outputs are:
159 //   [...,N];[0], if compute_v is false,
160 //   [...,N];[...,N,N], if compute_v is true.
SelfAdjointEigV2ShapeFn(InferenceContext * c)161 Status SelfAdjointEigV2ShapeFn(InferenceContext* c) {
162   ShapeHandle input;
163   TF_RETURN_IF_ERROR(MakeBatchSquareMatrix(c, c->input(0), &input));
164   DimensionHandle n;
165   TF_RETURN_IF_ERROR(c->Merge(c->Dim(input, -2), c->Dim(input, -1), &n));
166   ShapeHandle batch_shape;
167   TF_RETURN_IF_ERROR(c->Subshape(input, 0, -2, &batch_shape));
168   ShapeHandle e_shape;
169   TF_RETURN_IF_ERROR(c->Concatenate(batch_shape, c->Vector(n), &e_shape));
170   c->set_output(0, e_shape);
171   bool compute_v;
172   TF_RETURN_IF_ERROR(c->GetAttr("compute_v", &compute_v));
173   if (compute_v) {
174     ShapeHandle v_shape;
175     TF_RETURN_IF_ERROR(c->Concatenate(batch_shape, c->Matrix(n, n), &v_shape));
176     c->set_output(1, v_shape);
177   } else {
178     c->set_output(1, c->Vector(0ll));
179   }
180   return OkStatus();
181 }
182 
183 // Input is [...,N,N].
184 // First and second outputs are:
185 //   [...,N,N]; [...,N].
LuShapeFn(InferenceContext * c)186 Status LuShapeFn(InferenceContext* c) {
187   ShapeHandle input;
188   TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input));
189 
190   DimensionHandle n;
191   TF_RETURN_IF_ERROR(c->Merge(c->Dim(input, -2), c->Dim(input, -1), &n));
192 
193   ShapeHandle batch_shape;
194   TF_RETURN_IF_ERROR(c->Subshape(input, 0, -2, &batch_shape));
195 
196   ShapeHandle lu_shape;
197   ShapeHandle p_shape;
198 
199   TF_RETURN_IF_ERROR(c->Concatenate(batch_shape, c->Matrix(n, n), &lu_shape));
200   TF_RETURN_IF_ERROR(c->Concatenate(batch_shape, c->Vector(n), &p_shape));
201 
202   c->set_output(0, lu_shape);
203   c->set_output(1, p_shape);
204   return OkStatus();
205 }
206 
207 // Input is [...,M,N].
208 // First and second outputs are:
209 //   [...,M,M]; [...,M,N], if full_matrices is true,
210 //   [...,M,P]; [...,P,N], if full_matrices is false,
211 // where P = min(M,N).
QrShapeFn(InferenceContext * c)212 Status QrShapeFn(InferenceContext* c) {
213   ShapeHandle input;
214   TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input));
215   DimensionHandle m = c->Dim(input, -2);
216   DimensionHandle n = c->Dim(input, -1);
217   DimensionHandle p;
218   TF_RETURN_IF_ERROR(c->Min(m, n, &p));
219   ShapeHandle batch_shape;
220   TF_RETURN_IF_ERROR(c->Subshape(input, 0, -2, &batch_shape));
221   ShapeHandle q_shape;
222   ShapeHandle r_shape;
223   bool full_matrices;
224   TF_RETURN_IF_ERROR(c->GetAttr("full_matrices", &full_matrices));
225   if (full_matrices) {
226     TF_RETURN_IF_ERROR(c->Concatenate(batch_shape, c->Matrix(m, m), &q_shape));
227     TF_RETURN_IF_ERROR(c->Concatenate(batch_shape, c->Matrix(m, n), &r_shape));
228   } else {
229     TF_RETURN_IF_ERROR(c->Concatenate(batch_shape, c->Matrix(m, p), &q_shape));
230     TF_RETURN_IF_ERROR(c->Concatenate(batch_shape, c->Matrix(p, n), &r_shape));
231   }
232   c->set_output(0, q_shape);
233   c->set_output(1, r_shape);
234   return OkStatus();
235 }
236 
237 // Input is [...,M,N].  First output is [...,min(M,N)].
238 // Second and third outputs are:
239 //   [0]; [0], if compute_uv is false.
240 //   [...,M,M]; [...,N,N], if compute_uv is true and full_matrices is true,
241 //   [...,M,P]; [...,N,P], if compute_uv is true and full_matrices is false,
242 // where P = min(M,N).
SvdShapeFn(InferenceContext * c)243 Status SvdShapeFn(InferenceContext* c) {
244   ShapeHandle input;
245   TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input));
246   DimensionHandle m = c->Dim(input, -2);
247   DimensionHandle n = c->Dim(input, -1);
248   DimensionHandle p;
249   TF_RETURN_IF_ERROR(c->Min(m, n, &p));
250   ShapeHandle batch_shape;
251   TF_RETURN_IF_ERROR(c->Subshape(input, 0, -2, &batch_shape));
252   ShapeHandle e_shape;
253   TF_RETURN_IF_ERROR(c->Concatenate(batch_shape, c->Vector(p), &e_shape));
254   c->set_output(0, e_shape);
255   bool compute_uv;
256   TF_RETURN_IF_ERROR(c->GetAttr("compute_uv", &compute_uv));
257   if (compute_uv) {
258     ShapeHandle u_shape;
259     ShapeHandle v_shape;
260     bool full_matrices;
261     TF_RETURN_IF_ERROR(c->GetAttr("full_matrices", &full_matrices));
262     if (full_matrices) {
263       TF_RETURN_IF_ERROR(
264           c->Concatenate(batch_shape, c->Matrix(m, m), &u_shape));
265       TF_RETURN_IF_ERROR(
266           c->Concatenate(batch_shape, c->Matrix(n, n), &v_shape));
267     } else {
268       TF_RETURN_IF_ERROR(
269           c->Concatenate(batch_shape, c->Matrix(m, p), &u_shape));
270       TF_RETURN_IF_ERROR(
271           c->Concatenate(batch_shape, c->Matrix(n, p), &v_shape));
272     }
273     c->set_output(1, u_shape);
274     c->set_output(2, v_shape);
275   } else {
276     c->set_output(1, c->Vector(0ll));
277     c->set_output(2, c->Vector(0ll));
278   }
279   return OkStatus();
280 }
281 
282 // Inputs: [...,1,M], [...,1,M], [...,1,M],[...,M,N].
283 // Output is [...,M,N].
TridiagonalMatMulShapeFn(InferenceContext * c)284 Status TridiagonalMatMulShapeFn(InferenceContext* c) {
285   ShapeHandle superdiag;
286   ShapeHandle maindiag;
287   ShapeHandle subdiag;
288   ShapeHandle rhs;
289 
290   // Check that rank is at least 2.
291   TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &superdiag));
292   TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 2, &maindiag));
293   TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(2), 2, &subdiag));
294   TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(3), 2, &rhs));
295 
296   // Extract batch dimensions and check they are the same.
297   ShapeHandle superdiag_batch_shape;
298   ShapeHandle maindiag_batch_shape;
299   ShapeHandle subdiag_batch_shape;
300   ShapeHandle rhs_batch_shape;
301   TF_RETURN_IF_ERROR(c->Subshape(superdiag, 0, -2, &superdiag_batch_shape));
302   TF_RETURN_IF_ERROR(c->Subshape(maindiag, 0, -2, &maindiag_batch_shape));
303   TF_RETURN_IF_ERROR(c->Subshape(subdiag, 0, -2, &subdiag_batch_shape));
304   TF_RETURN_IF_ERROR(c->Subshape(rhs, 0, -2, &rhs_batch_shape));
305   TF_RETURN_IF_ERROR(c->Merge(superdiag, maindiag, &superdiag));
306   TF_RETURN_IF_ERROR(
307       c->Merge(maindiag_batch_shape, rhs_batch_shape, &rhs_batch_shape));
308   TF_RETURN_IF_ERROR(
309       c->Merge(subdiag_batch_shape, rhs_batch_shape, &rhs_batch_shape));
310 
311   // Check that diagonals have the same shape.
312   TF_RETURN_IF_ERROR(c->Merge(superdiag, maindiag, &maindiag));
313   TF_RETURN_IF_ERROR(c->Merge(subdiag, maindiag, &maindiag));
314 
315   // Check that size of tri-diagonal matrix is the same as height of matrix on
316   // the right.
317   DimensionHandle m_lhs = c->Dim(maindiag, -1);
318   DimensionHandle m_rhs = c->Dim(rhs, -2);
319   TF_RETURN_IF_ERROR(c->Merge(m_lhs, m_rhs, &m_lhs));
320 
321   // Check that next-to-last dimension of diagonals is 1.
322   DimensionHandle unused;
323   TF_RETURN_IF_ERROR(c->WithValue(c->Dim(maindiag, -2), 1, &unused));
324 
325   // The output shape is the same as rhs shape.
326   c->set_output(0, rhs);
327   return OkStatus();
328 }
329 
330 // The first input is [...,3,M] and second input is [...,M,K].
331 // Output is [...,M,K].
TridiagonalSolveShapeFn(InferenceContext * c)332 Status TridiagonalSolveShapeFn(InferenceContext* c) {
333   ShapeHandle lhs;
334   ShapeHandle rhs;
335   // Check that rank is at least 2.
336   TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &lhs));
337   TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 2, &rhs));
338 
339   // Extract batch dimensions and check they are the same.
340   ShapeHandle lhs_batch_shape;
341   ShapeHandle rhs_batch_shape;
342   TF_RETURN_IF_ERROR(c->Subshape(lhs, 0, -2, &lhs_batch_shape));
343   TF_RETURN_IF_ERROR(c->Subshape(rhs, 0, -2, &rhs_batch_shape));
344   TF_RETURN_IF_ERROR(
345       c->Merge(lhs_batch_shape, rhs_batch_shape, &lhs_batch_shape));
346 
347   // Check that "M" is the same in both inputs.
348   DimensionHandle m_lhs = c->Dim(lhs, -1);
349   DimensionHandle m_rhs = c->Dim(rhs, -2);
350   TF_RETURN_IF_ERROR(c->Merge(m_lhs, m_rhs, &m_lhs));
351 
352   // Check that next-to-last dimension of the first input is 3.
353   TF_RETURN_IF_ERROR(c->WithValue(c->Dim(lhs, -2), 3, &m_lhs));
354 
355   // The output shape is the same as rhs shape.
356   c->set_output(0, rhs);
357   return OkStatus();
358 }
359 
360 }  // namespace
361 
362 REGISTER_OP("MatrixDeterminant")
363     .Input("input: T")
364     .Output("output: T")
365     .Attr("T: {half, float, double, complex64, complex128}")
__anon552764600202(InferenceContext* c) 366     .SetShapeFn([](InferenceContext* c) {
367       ShapeHandle input;
368       TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input));
369 
370       DimensionHandle unused;
371       TF_RETURN_IF_ERROR(
372           c->Merge(c->Dim(input, -1), c->Dim(input, -2), &unused));
373 
374       ShapeHandle out;
375       TF_RETURN_IF_ERROR(c->Subshape(input, 0, -2, &out));
376       c->set_output(0, out);
377       return OkStatus();
378     });
379 
380 REGISTER_OP("LogMatrixDeterminant")
381     .Input("input: T")
382     .Output("sign: T")
383     .Output("log_abs_determinant: T")
384     .Attr("T: {half, float, double, complex64, complex128}")
__anon552764600302(InferenceContext* c) 385     .SetShapeFn([](InferenceContext* c) {
386       ShapeHandle input;
387       TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input));
388 
389       DimensionHandle unused;
390       TF_RETURN_IF_ERROR(
391           c->Merge(c->Dim(input, -1), c->Dim(input, -2), &unused));
392 
393       ShapeHandle s;
394       TF_RETURN_IF_ERROR(c->Subshape(input, 0, -2, &s));
395       c->set_output(0, s);
396 
397       ShapeHandle out;
398       TF_RETURN_IF_ERROR(c->Subshape(input, 0, -2, &out));
399       c->set_output(1, out);
400       return OkStatus();
401     });
402 
403 REGISTER_OP("MatrixInverse")
404     .Input("input: T")
405     .Output("output: T")
406     .Attr("adjoint: bool = False")
407     .Attr("T: {double, float, half, complex64, complex128}")
408     .SetShapeFn(BatchUnchangedSquareShapeFn);
409 
410 REGISTER_OP("MatrixExponential")
411     .Deprecated(
412         27, "Use Python implementation tf.linalg.matrix_exponential instead.")
413     .Input("input: T")
414     .Output("output: T")
415     .Attr("T: {double, float, half, complex64, complex128}")
416     .SetShapeFn(BatchUnchangedSquareShapeFn);
417 
418 REGISTER_OP("MatrixLogarithm")
419     .Input("input: T")
420     .Output("output: T")
421     .Attr("T: {complex64, complex128}")
422     .SetShapeFn(BatchUnchangedSquareShapeFn);
423 
424 REGISTER_OP("Cholesky")
425     .Input("input: T")
426     .Output("output: T")
427     .Attr("T: {double, float, half, complex64, complex128}")
428     .SetShapeFn(BatchUnchangedSquareShapeFn);
429 
430 REGISTER_OP("CholeskyGrad")
431     .Input("l: T")
432     .Input("grad: T")
433     .Output("output: T")
434     .Attr("T: {half, float, double}")
435     .SetShapeFn(BatchUnchangedSquareShapeFn);
436 
437 REGISTER_OP("SelfAdjointEig")
438     .Input("input: T")
439     .Output("output: T")
440     .Attr("T: {double, float, half}")
441     .Deprecated(11, "Use SelfAdjointEigV2 instead.")
__anon552764600402(InferenceContext* c) 442     .SetShapeFn([](InferenceContext* c) {
443       ShapeHandle input;
444       TF_RETURN_IF_ERROR(MakeBatchSquareMatrix(c, c->input(0), &input));
445 
446       DimensionHandle d = c->Dim(input, -1);
447       DimensionHandle d_plus_1;
448       TF_RETURN_IF_ERROR(c->Add(d, 1, &d_plus_1));
449 
450       ShapeHandle s;
451       TF_RETURN_IF_ERROR(c->Subshape(input, 0, -2, &s));
452       TF_RETURN_IF_ERROR(c->Concatenate(s, c->Matrix(d_plus_1, d), &s));
453       c->set_output(0, s);
454       return OkStatus();
455     });
456 
457 REGISTER_OP("Eig")
458     .Input("input: T")
459     .Output("e: Tout")
460     .Output("v: Tout")
461     .Attr("compute_v: bool = True")
462     .Attr("T: {float, double, complex64, complex128}")
463     .Attr("Tout: {complex64, complex128}")
464     .SetShapeFn(SelfAdjointEigV2ShapeFn);
465 
466 REGISTER_OP("SelfAdjointEigV2")
467     .Input("input: T")
468     .Output("e: T")
469     .Output("v: T")
470     .Attr("compute_v: bool = True")
471     .Attr("T: {double, float, half, complex64, complex128}")
472     .SetShapeFn(SelfAdjointEigV2ShapeFn);
473 
474 REGISTER_OP("Lu")
475     .Input("input: T")
476     .Output("lu: T")
477     .Output("p: output_idx_type")
478     .Attr("T: {double, float, half, complex64, complex128}")
479     .Attr("output_idx_type: {int32, int64} = DT_INT32")
480     .SetShapeFn(LuShapeFn);
481 
482 REGISTER_OP("MatrixSolve")
483     .Input("matrix: T")
484     .Input("rhs: T")
485     .Output("output: T")
486     .Attr("adjoint: bool = False")
487     .Attr("T: {double, float, half, complex64, complex128}")
__anon552764600502(InferenceContext* c) 488     .SetShapeFn([](InferenceContext* c) {
489       return MatrixSolveShapeFn(c, true /* square (*/);
490     });
491 
492 REGISTER_OP("BandedTriangularSolve")
493     .Input("matrix: T")
494     .Input("rhs: T")
495     .Output("output: T")
496     .Attr("lower: bool = True")
497     .Attr("adjoint: bool = False")
498     .Attr("T: {double, float, half, complex64, complex128}")
__anon552764600602(InferenceContext* c) 499     .SetShapeFn([](InferenceContext* c) {
500       return BandedTriangularSolveShapeFn(c);
501     });
502 
503 REGISTER_OP("MatrixTriangularSolve")
504     .Input("matrix: T")
505     .Input("rhs: T")
506     .Output("output: T")
507     .Attr("lower: bool = True")
508     .Attr("adjoint: bool = False")
509     .Attr("T: {bfloat16, double, float, half, complex64, complex128}")
__anon552764600702(InferenceContext* c) 510     .SetShapeFn([](InferenceContext* c) {
511       return MatrixTriangularSolveShapeFn(c);
512     });
513 
514 REGISTER_OP("MatrixSolveLs")
515     .Input("matrix: T")
516     .Input("rhs: T")
517     .Input("l2_regularizer: double")
518     .Output("output: T")
519     .Attr("T: {double, float, half, complex64, complex128}")
520     .Attr("fast: bool = True")
__anon552764600802(InferenceContext* c) 521     .SetShapeFn([](InferenceContext* c) {
522       ShapeHandle l2_regularizer;
523       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &l2_regularizer));
524       return MatrixSolveShapeFn(c, false /* square */);
525     });
526 
527 REGISTER_OP("MatrixSquareRoot")
528     .Input("input: T")
529     .Output("output: T")
530     .Attr("T: {double, float, half, complex64, complex128}")
531     .SetShapeFn(BatchUnchangedSquareShapeFn);
532 
533 REGISTER_OP("Qr")
534     .Input("input: T")
535     .Output("q: T")
536     .Output("r: T")
537     .Attr("full_matrices: bool = False")
538     .Attr("T: {double, float, half, complex64, complex128}")
539     .SetShapeFn(QrShapeFn);
540 
541 REGISTER_OP("Svd")
542     .Input("input: T")
543     .Output("s: T")
544     .Output("u: T")
545     .Output("v: T")
546     .Attr("compute_uv: bool = True")
547     .Attr("full_matrices: bool = False")
548     .Attr("T: {double, float, half, complex64, complex128}")
549     .SetShapeFn(SvdShapeFn);
550 
551 REGISTER_OP("TridiagonalMatMul")
552     .Input("superdiag: T")
553     .Input("maindiag: T")
554     .Input("subdiag: T")
555     .Input("rhs: T")
556     .Output("output: T")
557     .Attr("T: {double, float, complex64, complex128}")
558     .SetShapeFn(TridiagonalMatMulShapeFn);
559 
560 REGISTER_OP("TridiagonalSolve")
561     .Input("diagonals: T")
562     .Input("rhs: T")
563     .Output("output: T")
564     .Attr("partial_pivoting: bool = True")
565     .Attr("perturb_singular: bool = False")
566     .Attr("T: {double, float, complex64, complex128}")
567     .SetShapeFn(TridiagonalSolveShapeFn);
568 
569 REGISTER_OP("Einsum")
570     .Input("inputs: N * T")
571     .Output("output: T")
572     .Attr("equation: string")
573     .Attr("N: int >= 1")
574     .Attr("T: type")
575     .SetShapeFn(shape_inference::EinsumShape);
576 
577 // Deprecated op registrations:
578 
579 // Can be deleted after 3feb2017.
580 REGISTER_OP("BatchSelfAdjointEig")
581     .Input("input: T")
582     .Output("output: T")
583     .Attr("T: {double, float}")
584     .Deprecated(11, "Use SelfAdjointEigV2 instead.")
585     .SetShapeFn(shape_inference::UnknownShape);
586 
587 // Can all be deleted after 9mar2017.
588 REGISTER_OP("BatchMatrixDeterminant")
589     .Input("input: T")
590     .Output("output: T")
591     .Attr("T: {float, double, complex64, complex128}")
592     .Deprecated(13, "Use MatrixDeterminant instead.")
593     .SetShapeFn(shape_inference::UnknownShape);
594 
595 REGISTER_OP("BatchMatrixInverse")
596     .Input("input: T")
597     .Output("output: T")
598     .Attr("adjoint: bool = False")
599     .Attr("T: {double, float}")
600     .Deprecated(13, "Use MatrixInverse instead.")
601     .SetShapeFn(shape_inference::UnknownShape);
602 
603 REGISTER_OP("BatchCholesky")
604     .Input("input: T")
605     .Output("output: T")
606     .Attr("T: {double, float}")
607     .Deprecated(13, "Use Cholesky instead.")
608     .SetShapeFn(shape_inference::UnknownShape);
609 
610 REGISTER_OP("BatchCholeskyGrad")
611     .Input("l: T")
612     .Input("grad: T")
613     .Output("output: T")
614     .Attr("T: {float, double}")
615     .Deprecated(13, "Use CholeskyGrad instead.")
616     .SetShapeFn(shape_inference::UnknownShape);
617 
618 REGISTER_OP("BatchSelfAdjointEigV2")
619     .Input("input: T")
620     .Output("e: T")
621     .Output("v: T")
622     .Attr("compute_v: bool = True")
623     .Attr("T: {double, float}")
624     .Deprecated(13, "Use SelfAdjointEigV2 instead.")
625     .SetShapeFn(shape_inference::UnknownShape);
626 
627 REGISTER_OP("BatchMatrixSolve")
628     .Input("matrix: T")
629     .Input("rhs: T")
630     .Output("output: T")
631     .Attr("adjoint: bool = False")
632     .Attr("T: {double, float}")
633     .Deprecated(13, "Use MatrixSolve instead.")
634     .SetShapeFn(shape_inference::UnknownShape);
635 
636 REGISTER_OP("BatchMatrixTriangularSolve")
637     .Input("matrix: T")
638     .Input("rhs: T")
639     .Output("output: T")
640     .Attr("lower: bool = True")
641     .Attr("adjoint: bool = False")
642     .Attr("T: {double, float}")
643     .Deprecated(13, "Use MatrixTriangularSolve instead.")
644     .SetShapeFn(shape_inference::UnknownShape);
645 
646 REGISTER_OP("BatchMatrixSolveLs")
647     .Input("matrix: T")
648     .Input("rhs: T")
649     .Input("l2_regularizer: double")
650     .Output("output: T")
651     .Attr("T: {double, float}")
652     .Attr("fast: bool = True")
653     .Deprecated(13, "Use MatrixSolveLs instead.")
654     .SetShapeFn(shape_inference::UnknownShape);
655 
656 REGISTER_OP("BatchSvd")
657     .Input("input: T")
658     .Output("s: T")
659     .Output("u: T")
660     .Output("v: T")
661     .Attr("compute_uv: bool = True")
662     .Attr("full_matrices: bool = False")
663     .Attr("T: {double, float, complex64, complex128}")
664     .Deprecated(13, "Use Svd instead.")
665     .SetShapeFn(shape_inference::UnknownShape);
666 
667 }  // namespace tensorflow
668