• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/common_shape_fns.h"
17 #include "tensorflow/core/framework/op.h"
18 #include "tensorflow/core/framework/shape_inference.h"
19 
20 namespace tensorflow {
21 
22 using shape_inference::DimensionHandle;
23 using shape_inference::InferenceContext;
24 using shape_inference::ShapeHandle;
25 
26 namespace {
27 
28 // Return in <out> the result of making the end of <s> a square matrix.
MakeBatchSquareMatrix(InferenceContext * c,ShapeHandle input,ShapeHandle * out)29 Status MakeBatchSquareMatrix(InferenceContext* c, ShapeHandle input,
30                              ShapeHandle* out) {
31   ShapeHandle s;
32   TF_RETURN_IF_ERROR(c->WithRankAtLeast(input, 2, &s));
33 
34   DimensionHandle d;
35   TF_RETURN_IF_ERROR(c->Merge(c->Dim(s, -2), c->Dim(s, -1), &d));
36 
37   ShapeHandle batch_shape;
38   TF_RETURN_IF_ERROR(c->Subshape(s, 0, -2, &batch_shape));
39   TF_RETURN_IF_ERROR(c->Concatenate(batch_shape, c->Matrix(d, d), out));
40   return Status::OK();
41 }
42 
BatchUnchangedSquareShapeFn(InferenceContext * c)43 Status BatchUnchangedSquareShapeFn(InferenceContext* c) {
44   ShapeHandle out;
45   TF_RETURN_IF_ERROR(MakeBatchSquareMatrix(c, c->input(0), &out));
46   c->set_output(0, out);
47   return Status::OK();
48 }
49 
50 // The first input is [...,M,N] and second input is either [...,M,K] or [...,M].
51 // Output is [...,N,K] or [...,N]. If <square>, then input is [...,M,M].
MatrixSolveShapeFn(InferenceContext * c,bool square)52 Status MatrixSolveShapeFn(InferenceContext* c, bool square) {
53   ShapeHandle lhs;
54   ShapeHandle rhs;
55   if (square) {
56     TF_RETURN_IF_ERROR(MakeBatchSquareMatrix(c, c->input(0), &lhs));
57   } else {
58     TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &lhs));
59   }
60   TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 2, &rhs));
61 
62   ShapeHandle lhs_batch_shape;
63   ShapeHandle rhs_batch_shape;
64   // Make the common batch subshape.
65   TF_RETURN_IF_ERROR(c->Subshape(lhs, 0, -2, &lhs_batch_shape));
66   TF_RETURN_IF_ERROR(c->Subshape(rhs, 0, -2, &rhs_batch_shape));
67   // Make sure the batch dimensions match between lhs and rhs.
68   TF_RETURN_IF_ERROR(
69       c->Merge(lhs_batch_shape, rhs_batch_shape, &lhs_batch_shape));
70 
71   DimensionHandle m;
72   // lhs and rhs have the same value for m to be compatible.
73   TF_RETURN_IF_ERROR(c->Merge(c->Dim(lhs, -2), c->Dim(rhs, -2), &m));
74   DimensionHandle n = c->Dim(lhs, -1);
75   if (square) {
76     TF_RETURN_IF_ERROR(c->Merge(m, n, &n));
77   }
78 
79   ShapeHandle out;
80   // Build final shape (batch_shape + n + k) in <out>.
81   TF_RETURN_IF_ERROR(c->Concatenate(lhs_batch_shape, c->Vector(n), &out));
82   TF_RETURN_IF_ERROR(c->Concatenate(out, c->Vector(c->Dim(rhs, -1)), &out));
83   c->set_output(0, out);
84   return Status::OK();
85 }
86 
87 // The first input is [...,M,M] and second input is [...,M,N].
88 // Output is [...,M,N].
MatrixTriangularSolveShapeFn(InferenceContext * c)89 Status MatrixTriangularSolveShapeFn(InferenceContext* c) {
90   ShapeHandle lhs;
91   ShapeHandle rhs;
92   TF_RETURN_IF_ERROR(MakeBatchSquareMatrix(c, c->input(0), &lhs));
93   TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 2, &rhs));
94 
95   ShapeHandle lhs_batch_shape;
96   ShapeHandle rhs_batch_shape;
97   ShapeHandle output_batch_shape;
98   // Make the common batch subshape.
99   TF_RETURN_IF_ERROR(c->Subshape(lhs, 0, -2, &lhs_batch_shape));
100   TF_RETURN_IF_ERROR(c->Subshape(rhs, 0, -2, &rhs_batch_shape));
101   TF_RETURN_IF_ERROR(BroadcastBinaryOpOutputShapeFnHelper(
102       c, lhs_batch_shape, rhs_batch_shape, true, &output_batch_shape));
103   DimensionHandle m;
104   // lhs and rhs have the same value for m to be compatible.
105   TF_RETURN_IF_ERROR(c->Merge(c->Dim(lhs, -1), c->Dim(rhs, -2), &m));
106 
107   ShapeHandle out;
108   // Build final shape (batch_shape + m + n) in <out>.
109   TF_RETURN_IF_ERROR(
110       c->Concatenate(output_batch_shape, c->Matrix(m, c->Dim(rhs, -1)), &out));
111   c->set_output(0, out);
112   return Status::OK();
113 }
114 
115 // Input is [...,N,N]. Outputs are:
116 //   [...,N];[0], if compute_v is false,
117 //   [...,N];[...,N,N], if compute_v is true.
SelfAdjointEigV2ShapeFn(InferenceContext * c)118 Status SelfAdjointEigV2ShapeFn(InferenceContext* c) {
119   ShapeHandle input;
120   TF_RETURN_IF_ERROR(MakeBatchSquareMatrix(c, c->input(0), &input));
121   DimensionHandle n;
122   TF_RETURN_IF_ERROR(c->Merge(c->Dim(input, -2), c->Dim(input, -1), &n));
123   ShapeHandle batch_shape;
124   TF_RETURN_IF_ERROR(c->Subshape(input, 0, -2, &batch_shape));
125   ShapeHandle e_shape;
126   TF_RETURN_IF_ERROR(c->Concatenate(batch_shape, c->Vector(n), &e_shape));
127   c->set_output(0, e_shape);
128   bool compute_v;
129   TF_RETURN_IF_ERROR(c->GetAttr("compute_v", &compute_v));
130   if (compute_v) {
131     ShapeHandle v_shape;
132     TF_RETURN_IF_ERROR(c->Concatenate(batch_shape, c->Matrix(n, n), &v_shape));
133     c->set_output(1, v_shape);
134   } else {
135     c->set_output(1, c->Vector(0ll));
136   }
137   return Status::OK();
138 }
139 
140 // Input is [...,N,N].
141 // First and second outputs are:
142 //   [...,N,N]; [...,N].
LuShapeFn(InferenceContext * c)143 Status LuShapeFn(InferenceContext* c) {
144   ShapeHandle input;
145   TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input));
146 
147   DimensionHandle n;
148   TF_RETURN_IF_ERROR(c->Merge(c->Dim(input, -2), c->Dim(input, -1), &n));
149 
150   ShapeHandle batch_shape;
151   TF_RETURN_IF_ERROR(c->Subshape(input, 0, -2, &batch_shape));
152 
153   ShapeHandle lu_shape;
154   ShapeHandle p_shape;
155 
156   TF_RETURN_IF_ERROR(c->Concatenate(batch_shape, c->Matrix(n, n), &lu_shape));
157   TF_RETURN_IF_ERROR(c->Concatenate(batch_shape, c->Vector(n), &p_shape));
158 
159   c->set_output(0, lu_shape);
160   c->set_output(1, p_shape);
161   return Status::OK();
162 }
163 
164 // Input is [...,M,N].
165 // First and second outputs are:
166 //   [...,M,M]; [...,M,N], if full_matrices is true,
167 //   [...,M,P]; [...,P,N], if full_matrices is false,
168 // where P = min(M,N).
QrShapeFn(InferenceContext * c)169 Status QrShapeFn(InferenceContext* c) {
170   ShapeHandle input;
171   TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input));
172   DimensionHandle m = c->Dim(input, -2);
173   DimensionHandle n = c->Dim(input, -1);
174   DimensionHandle p;
175   TF_RETURN_IF_ERROR(c->Min(m, n, &p));
176   ShapeHandle batch_shape;
177   TF_RETURN_IF_ERROR(c->Subshape(input, 0, -2, &batch_shape));
178   ShapeHandle q_shape;
179   ShapeHandle r_shape;
180   bool full_matrices;
181   TF_RETURN_IF_ERROR(c->GetAttr("full_matrices", &full_matrices));
182   if (full_matrices) {
183     TF_RETURN_IF_ERROR(c->Concatenate(batch_shape, c->Matrix(m, m), &q_shape));
184     TF_RETURN_IF_ERROR(c->Concatenate(batch_shape, c->Matrix(m, n), &r_shape));
185   } else {
186     TF_RETURN_IF_ERROR(c->Concatenate(batch_shape, c->Matrix(m, p), &q_shape));
187     TF_RETURN_IF_ERROR(c->Concatenate(batch_shape, c->Matrix(p, n), &r_shape));
188   }
189   c->set_output(0, q_shape);
190   c->set_output(1, r_shape);
191   return Status::OK();
192 }
193 
194 // Input is [...,M,N].  First output is [...,min(M,N)].
195 // Second and third outputs are:
196 //   [0]; [0], if compute_uv is false.
197 //   [...,M,M]; [...,N,N], if compute_uv is true and full_matrices is true,
198 //   [...,M,P]; [...,N,P], if compute_uv is true and full_matrices is false,
199 // where P = min(M,N).
SvdShapeFn(InferenceContext * c)200 Status SvdShapeFn(InferenceContext* c) {
201   ShapeHandle input;
202   TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input));
203   DimensionHandle m = c->Dim(input, -2);
204   DimensionHandle n = c->Dim(input, -1);
205   DimensionHandle p;
206   TF_RETURN_IF_ERROR(c->Min(m, n, &p));
207   ShapeHandle batch_shape;
208   TF_RETURN_IF_ERROR(c->Subshape(input, 0, -2, &batch_shape));
209   ShapeHandle e_shape;
210   TF_RETURN_IF_ERROR(c->Concatenate(batch_shape, c->Vector(p), &e_shape));
211   c->set_output(0, e_shape);
212   bool compute_uv;
213   TF_RETURN_IF_ERROR(c->GetAttr("compute_uv", &compute_uv));
214   if (compute_uv) {
215     ShapeHandle u_shape;
216     ShapeHandle v_shape;
217     bool full_matrices;
218     TF_RETURN_IF_ERROR(c->GetAttr("full_matrices", &full_matrices));
219     if (full_matrices) {
220       TF_RETURN_IF_ERROR(
221           c->Concatenate(batch_shape, c->Matrix(m, m), &u_shape));
222       TF_RETURN_IF_ERROR(
223           c->Concatenate(batch_shape, c->Matrix(n, n), &v_shape));
224     } else {
225       TF_RETURN_IF_ERROR(
226           c->Concatenate(batch_shape, c->Matrix(m, p), &u_shape));
227       TF_RETURN_IF_ERROR(
228           c->Concatenate(batch_shape, c->Matrix(n, p), &v_shape));
229     }
230     c->set_output(1, u_shape);
231     c->set_output(2, v_shape);
232   } else {
233     c->set_output(1, c->Vector(0ll));
234     c->set_output(2, c->Vector(0ll));
235   }
236   return Status::OK();
237 }
238 
239 // Inputs: [...,1,M], [...,1,M], [...,1,M],[...,M,N].
240 // Output is [...,M,N].
TridiagonalMatMulShapeFn(InferenceContext * c)241 Status TridiagonalMatMulShapeFn(InferenceContext* c) {
242   ShapeHandle superdiag;
243   ShapeHandle maindiag;
244   ShapeHandle subdiag;
245   ShapeHandle rhs;
246 
247   // Check that rank is at least 2.
248   TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &superdiag));
249   TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 2, &maindiag));
250   TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(2), 2, &subdiag));
251   TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(3), 2, &rhs));
252 
253   // Extract batch dimensions and check they are the same.
254   ShapeHandle superdiag_batch_shape;
255   ShapeHandle maindiag_batch_shape;
256   ShapeHandle subdiag_batch_shape;
257   ShapeHandle rhs_batch_shape;
258   TF_RETURN_IF_ERROR(c->Subshape(superdiag, 0, -2, &superdiag_batch_shape));
259   TF_RETURN_IF_ERROR(c->Subshape(maindiag, 0, -2, &maindiag_batch_shape));
260   TF_RETURN_IF_ERROR(c->Subshape(subdiag, 0, -2, &subdiag_batch_shape));
261   TF_RETURN_IF_ERROR(c->Subshape(rhs, 0, -2, &rhs_batch_shape));
262   TF_RETURN_IF_ERROR(c->Merge(superdiag, maindiag, &superdiag));
263   TF_RETURN_IF_ERROR(
264       c->Merge(maindiag_batch_shape, rhs_batch_shape, &rhs_batch_shape));
265   TF_RETURN_IF_ERROR(
266       c->Merge(subdiag_batch_shape, rhs_batch_shape, &rhs_batch_shape));
267 
268   // Check that diagonals have the same shape.
269   TF_RETURN_IF_ERROR(c->Merge(superdiag, maindiag, &maindiag));
270   TF_RETURN_IF_ERROR(c->Merge(subdiag, maindiag, &maindiag));
271 
272   // Check that size of tri-diagonal matrix is the same as height of matrix on
273   // the right.
274   DimensionHandle m_lhs = c->Dim(maindiag, -1);
275   DimensionHandle m_rhs = c->Dim(rhs, -2);
276   TF_RETURN_IF_ERROR(c->Merge(m_lhs, m_rhs, &m_lhs));
277 
278   // Check that next-to-last dimension of diagonals is 1.
279   DimensionHandle unused;
280   TF_RETURN_IF_ERROR(c->WithValue(c->Dim(maindiag, -2), 1, &unused));
281 
282   // The output shape is the same as rhs shape.
283   c->set_output(0, rhs);
284   return Status::OK();
285 }
286 
287 // The first input is [...,3,M] and second input is [...,M,K].
288 // Output is [...,M,K].
TridiagonalSolveShapeFn(InferenceContext * c)289 Status TridiagonalSolveShapeFn(InferenceContext* c) {
290   ShapeHandle lhs;
291   ShapeHandle rhs;
292   // Check that rank is at least 2.
293   TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &lhs));
294   TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 2, &rhs));
295 
296   // Extract batch dimensions and check they are the same.
297   ShapeHandle lhs_batch_shape;
298   ShapeHandle rhs_batch_shape;
299   TF_RETURN_IF_ERROR(c->Subshape(lhs, 0, -2, &lhs_batch_shape));
300   TF_RETURN_IF_ERROR(c->Subshape(rhs, 0, -2, &rhs_batch_shape));
301   TF_RETURN_IF_ERROR(
302       c->Merge(lhs_batch_shape, rhs_batch_shape, &lhs_batch_shape));
303 
304   // Check that "M" is the same in both inputs.
305   DimensionHandle m_lhs = c->Dim(lhs, -1);
306   DimensionHandle m_rhs = c->Dim(rhs, -2);
307   TF_RETURN_IF_ERROR(c->Merge(m_lhs, m_rhs, &m_lhs));
308 
309   // Check that next-to-last dimension of the first input is 3.
310   TF_RETURN_IF_ERROR(c->WithValue(c->Dim(lhs, -2), 3, &m_lhs));
311 
312   // The output shape is the same as rhs shape.
313   c->set_output(0, rhs);
314   return Status::OK();
315 }
316 
317 }  // namespace
318 
319 REGISTER_OP("MatrixDeterminant")
320     .Input("input: T")
321     .Output("output: T")
322     .Attr("T: {half, float, double, complex64, complex128}")
__anon016102550202(InferenceContext* c) 323     .SetShapeFn([](InferenceContext* c) {
324       ShapeHandle input;
325       TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input));
326 
327       DimensionHandle unused;
328       TF_RETURN_IF_ERROR(
329           c->Merge(c->Dim(input, -1), c->Dim(input, -2), &unused));
330 
331       ShapeHandle out;
332       TF_RETURN_IF_ERROR(c->Subshape(input, 0, -2, &out));
333       c->set_output(0, out);
334       return Status::OK();
335     });
336 
337 REGISTER_OP("LogMatrixDeterminant")
338     .Input("input: T")
339     .Output("sign: T")
340     .Output("log_abs_determinant: T")
341     .Attr("T: {half, float, double, complex64, complex128}")
__anon016102550302(InferenceContext* c) 342     .SetShapeFn([](InferenceContext* c) {
343       ShapeHandle input;
344       TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input));
345 
346       DimensionHandle unused;
347       TF_RETURN_IF_ERROR(
348           c->Merge(c->Dim(input, -1), c->Dim(input, -2), &unused));
349 
350       ShapeHandle s;
351       TF_RETURN_IF_ERROR(c->Subshape(input, 0, -2, &s));
352       c->set_output(0, s);
353 
354       ShapeHandle out;
355       TF_RETURN_IF_ERROR(c->Subshape(input, 0, -2, &out));
356       c->set_output(1, out);
357       return Status::OK();
358     });
359 
360 REGISTER_OP("MatrixInverse")
361     .Input("input: T")
362     .Output("output: T")
363     .Attr("adjoint: bool = False")
364     .Attr("T: {double, float, half, complex64, complex128}")
365     .SetShapeFn(BatchUnchangedSquareShapeFn);
366 
367 REGISTER_OP("MatrixExponential")
368     .Deprecated(
369         27, "Use Python implementation tf.linalg.matrix_exponential instead.")
370     .Input("input: T")
371     .Output("output: T")
372     .Attr("T: {double, float, half, complex64, complex128}")
373     .SetShapeFn(BatchUnchangedSquareShapeFn);
374 
375 REGISTER_OP("MatrixLogarithm")
376     .Input("input: T")
377     .Output("output: T")
378     .Attr("T: {complex64, complex128}")
379     .SetShapeFn(BatchUnchangedSquareShapeFn);
380 
381 REGISTER_OP("Cholesky")
382     .Input("input: T")
383     .Output("output: T")
384     .Attr("T: {double, float, half, complex64, complex128}")
385     .SetShapeFn(BatchUnchangedSquareShapeFn);
386 
387 REGISTER_OP("CholeskyGrad")
388     .Input("l: T")
389     .Input("grad: T")
390     .Output("output: T")
391     .Attr("T: {half, float, double}")
392     .SetShapeFn(BatchUnchangedSquareShapeFn);
393 
394 REGISTER_OP("SelfAdjointEig")
395     .Input("input: T")
396     .Output("output: T")
397     .Attr("T: {double, float, half}")
398     .Deprecated(11, "Use SelfAdjointEigV2 instead.")
__anon016102550402(InferenceContext* c) 399     .SetShapeFn([](InferenceContext* c) {
400       ShapeHandle input;
401       TF_RETURN_IF_ERROR(MakeBatchSquareMatrix(c, c->input(0), &input));
402 
403       DimensionHandle d = c->Dim(input, -1);
404       DimensionHandle d_plus_1;
405       TF_RETURN_IF_ERROR(c->Add(d, 1, &d_plus_1));
406 
407       ShapeHandle s;
408       TF_RETURN_IF_ERROR(c->Subshape(input, 0, -2, &s));
409       TF_RETURN_IF_ERROR(c->Concatenate(s, c->Matrix(d_plus_1, d), &s));
410       c->set_output(0, s);
411       return Status::OK();
412     });
413 
414 REGISTER_OP("Eig")
415     .Input("input: T")
416     .Output("e: Tout")
417     .Output("v: Tout")
418     .Attr("compute_v: bool = True")
419     .Attr("T: {float, double, complex64, complex128}")
420     .Attr("Tout: {complex64, complex128}")
421     .SetShapeFn(SelfAdjointEigV2ShapeFn);
422 
423 REGISTER_OP("SelfAdjointEigV2")
424     .Input("input: T")
425     .Output("e: T")
426     .Output("v: T")
427     .Attr("compute_v: bool = True")
428     .Attr("T: {double, float, half, complex64, complex128}")
429     .SetShapeFn(SelfAdjointEigV2ShapeFn);
430 
431 REGISTER_OP("Lu")
432     .Input("input: T")
433     .Output("lu: T")
434     .Output("p: output_idx_type")
435     .Attr("T: {double, float, half, complex64, complex128}")
436     .Attr("output_idx_type: {int32, int64} = DT_INT32")
437     .SetShapeFn(LuShapeFn);
438 
439 REGISTER_OP("MatrixSolve")
440     .Input("matrix: T")
441     .Input("rhs: T")
442     .Output("output: T")
443     .Attr("adjoint: bool = False")
444     .Attr("T: {double, float, half, complex64, complex128}")
__anon016102550502(InferenceContext* c) 445     .SetShapeFn([](InferenceContext* c) {
446       return MatrixSolveShapeFn(c, true /* square (*/);
447     });
448 
449 REGISTER_OP("MatrixTriangularSolve")
450     .Input("matrix: T")
451     .Input("rhs: T")
452     .Output("output: T")
453     .Attr("lower: bool = True")
454     .Attr("adjoint: bool = False")
455     .Attr("T: {double, float, half, complex64, complex128}")
__anon016102550602(InferenceContext* c) 456     .SetShapeFn([](InferenceContext* c) {
457       return MatrixTriangularSolveShapeFn(c);
458     });
459 
460 REGISTER_OP("MatrixSolveLs")
461     .Input("matrix: T")
462     .Input("rhs: T")
463     .Input("l2_regularizer: double")
464     .Output("output: T")
465     .Attr("T: {double, float, half, complex64, complex128}")
466     .Attr("fast: bool = True")
__anon016102550702(InferenceContext* c) 467     .SetShapeFn([](InferenceContext* c) {
468       ShapeHandle l2_regularizer;
469       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &l2_regularizer));
470       return MatrixSolveShapeFn(c, false /* square */);
471     });
472 
473 REGISTER_OP("MatrixSquareRoot")
474     .Input("input: T")
475     .Output("output: T")
476     .Attr("T: {double, float, half, complex64, complex128}")
477     .SetShapeFn(BatchUnchangedSquareShapeFn);
478 
479 REGISTER_OP("Qr")
480     .Input("input: T")
481     .Output("q: T")
482     .Output("r: T")
483     .Attr("full_matrices: bool = False")
484     .Attr("T: {double, float, half, complex64, complex128}")
485     .SetShapeFn(QrShapeFn);
486 
487 REGISTER_OP("Svd")
488     .Input("input: T")
489     .Output("s: T")
490     .Output("u: T")
491     .Output("v: T")
492     .Attr("compute_uv: bool = True")
493     .Attr("full_matrices: bool = False")
494     .Attr("T: {double, float, half, complex64, complex128}")
495     .SetShapeFn(SvdShapeFn);
496 
497 REGISTER_OP("TridiagonalMatMul")
498     .Input("superdiag: T")
499     .Input("maindiag: T")
500     .Input("subdiag: T")
501     .Input("rhs: T")
502     .Output("output: T")
503     .Attr("T: {double, float, complex64, complex128}")
504     .SetShapeFn(TridiagonalMatMulShapeFn);
505 
506 REGISTER_OP("TridiagonalSolve")
507     .Input("diagonals: T")
508     .Input("rhs: T")
509     .Output("output: T")
510     .Attr("partial_pivoting: bool = True")
511     .Attr("T: {double, float, complex64, complex128}")
512     .SetShapeFn(TridiagonalSolveShapeFn);
513 
514 REGISTER_OP("Einsum")
515     .Input("inputs: N * T")
516     .Output("output: T")
517     .Attr("equation: string")
518     .Attr("N: int >= 1")
519     .Attr("T: type")
520     .SetShapeFn(shape_inference::EinsumShape);
521 
522 // Deprecated op registrations:
523 
524 // Can be deleted after 3feb2017.
525 REGISTER_OP("BatchSelfAdjointEig")
526     .Input("input: T")
527     .Output("output: T")
528     .Attr("T: {double, float}")
529     .Deprecated(11, "Use SelfAdjointEigV2 instead.")
530     .SetShapeFn(shape_inference::UnknownShape);
531 
532 // Can all be deleted after 9mar2017.
533 REGISTER_OP("BatchMatrixDeterminant")
534     .Input("input: T")
535     .Output("output: T")
536     .Attr("T: {float, double, complex64, complex128}")
537     .Deprecated(13, "Use MatrixDeterminant instead.")
538     .SetShapeFn(shape_inference::UnknownShape);
539 
540 REGISTER_OP("BatchMatrixInverse")
541     .Input("input: T")
542     .Output("output: T")
543     .Attr("adjoint: bool = False")
544     .Attr("T: {double, float}")
545     .Deprecated(13, "Use MatrixInverse instead.")
546     .SetShapeFn(shape_inference::UnknownShape);
547 
548 REGISTER_OP("BatchCholesky")
549     .Input("input: T")
550     .Output("output: T")
551     .Attr("T: {double, float}")
552     .Deprecated(13, "Use Cholesky instead.")
553     .SetShapeFn(shape_inference::UnknownShape);
554 
555 REGISTER_OP("BatchCholeskyGrad")
556     .Input("l: T")
557     .Input("grad: T")
558     .Output("output: T")
559     .Attr("T: {float, double}")
560     .Deprecated(13, "Use CholeskyGrad instead.")
561     .SetShapeFn(shape_inference::UnknownShape);
562 
563 REGISTER_OP("BatchSelfAdjointEigV2")
564     .Input("input: T")
565     .Output("e: T")
566     .Output("v: T")
567     .Attr("compute_v: bool = True")
568     .Attr("T: {double, float}")
569     .Deprecated(13, "Use SelfAdjointEigV2 instead.")
570     .SetShapeFn(shape_inference::UnknownShape);
571 
572 REGISTER_OP("BatchMatrixSolve")
573     .Input("matrix: T")
574     .Input("rhs: T")
575     .Output("output: T")
576     .Attr("adjoint: bool = False")
577     .Attr("T: {double, float}")
578     .Deprecated(13, "Use MatrixSolve instead.")
579     .SetShapeFn(shape_inference::UnknownShape);
580 
581 REGISTER_OP("BatchMatrixTriangularSolve")
582     .Input("matrix: T")
583     .Input("rhs: T")
584     .Output("output: T")
585     .Attr("lower: bool = True")
586     .Attr("adjoint: bool = False")
587     .Attr("T: {double, float}")
588     .Deprecated(13, "Use MatrixTriangularSolve instead.")
589     .SetShapeFn(shape_inference::UnknownShape);
590 
591 REGISTER_OP("BatchMatrixSolveLs")
592     .Input("matrix: T")
593     .Input("rhs: T")
594     .Input("l2_regularizer: double")
595     .Output("output: T")
596     .Attr("T: {double, float}")
597     .Attr("fast: bool = True")
598     .Deprecated(13, "Use MatrixSolveLs instead.")
599     .SetShapeFn(shape_inference::UnknownShape);
600 
601 REGISTER_OP("BatchSvd")
602     .Input("input: T")
603     .Output("s: T")
604     .Output("u: T")
605     .Output("v: T")
606     .Attr("compute_uv: bool = True")
607     .Attr("full_matrices: bool = False")
608     .Attr("T: {double, float, complex64, complex128}")
609     .Deprecated(13, "Use Svd instead.")
610     .SetShapeFn(shape_inference::UnknownShape);
611 
612 }  // namespace tensorflow
613