1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/framework/common_shape_fns.h"
17 #include "tensorflow/core/framework/op.h"
18 #include "tensorflow/core/framework/shape_inference.h"
19
20 namespace tensorflow {
21
22 using shape_inference::DimensionHandle;
23 using shape_inference::InferenceContext;
24 using shape_inference::ShapeHandle;
25
26 namespace {
27
28 // Return in <out> the result of making the end of <s> a square matrix.
MakeBatchSquareMatrix(InferenceContext * c,ShapeHandle input,ShapeHandle * out)29 Status MakeBatchSquareMatrix(InferenceContext* c, ShapeHandle input,
30 ShapeHandle* out) {
31 ShapeHandle s;
32 TF_RETURN_IF_ERROR(c->WithRankAtLeast(input, 2, &s));
33
34 DimensionHandle d;
35 TF_RETURN_IF_ERROR(c->Merge(c->Dim(s, -2), c->Dim(s, -1), &d));
36
37 ShapeHandle batch_shape;
38 TF_RETURN_IF_ERROR(c->Subshape(s, 0, -2, &batch_shape));
39 TF_RETURN_IF_ERROR(c->Concatenate(batch_shape, c->Matrix(d, d), out));
40 return Status::OK();
41 }
42
BatchUnchangedSquareShapeFn(InferenceContext * c)43 Status BatchUnchangedSquareShapeFn(InferenceContext* c) {
44 ShapeHandle out;
45 TF_RETURN_IF_ERROR(MakeBatchSquareMatrix(c, c->input(0), &out));
46 c->set_output(0, out);
47 return Status::OK();
48 }
49
50 // The first input is [...,M,N] and second input is either [...,M,K] or [...,M].
51 // Output is [...,N,K] or [...,N]. If <square>, then input is [...,M,M].
MatrixSolveShapeFn(InferenceContext * c,bool square)52 Status MatrixSolveShapeFn(InferenceContext* c, bool square) {
53 ShapeHandle lhs;
54 ShapeHandle rhs;
55 if (square) {
56 TF_RETURN_IF_ERROR(MakeBatchSquareMatrix(c, c->input(0), &lhs));
57 } else {
58 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &lhs));
59 }
60 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 2, &rhs));
61
62 ShapeHandle lhs_batch_shape;
63 ShapeHandle rhs_batch_shape;
64 // Make the common batch subshape.
65 TF_RETURN_IF_ERROR(c->Subshape(lhs, 0, -2, &lhs_batch_shape));
66 TF_RETURN_IF_ERROR(c->Subshape(rhs, 0, -2, &rhs_batch_shape));
67 // Make sure the batch dimensions match between lhs and rhs.
68 TF_RETURN_IF_ERROR(
69 c->Merge(lhs_batch_shape, rhs_batch_shape, &lhs_batch_shape));
70
71 DimensionHandle m;
72 // lhs and rhs have the same value for m to be compatible.
73 TF_RETURN_IF_ERROR(c->Merge(c->Dim(lhs, -2), c->Dim(rhs, -2), &m));
74 DimensionHandle n = c->Dim(lhs, -1);
75 if (square) {
76 TF_RETURN_IF_ERROR(c->Merge(m, n, &n));
77 }
78
79 ShapeHandle out;
80 // Build final shape (batch_shape + n + k) in <out>.
81 TF_RETURN_IF_ERROR(c->Concatenate(lhs_batch_shape, c->Vector(n), &out));
82 TF_RETURN_IF_ERROR(c->Concatenate(out, c->Vector(c->Dim(rhs, -1)), &out));
83 c->set_output(0, out);
84 return Status::OK();
85 }
86
87 // The first input is [...,M,M] and second input is [...,M,N].
88 // Output is [...,M,N].
MatrixTriangularSolveShapeFn(InferenceContext * c)89 Status MatrixTriangularSolveShapeFn(InferenceContext* c) {
90 ShapeHandle lhs;
91 ShapeHandle rhs;
92 TF_RETURN_IF_ERROR(MakeBatchSquareMatrix(c, c->input(0), &lhs));
93 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 2, &rhs));
94
95 ShapeHandle lhs_batch_shape;
96 ShapeHandle rhs_batch_shape;
97 ShapeHandle output_batch_shape;
98 // Make the common batch subshape.
99 TF_RETURN_IF_ERROR(c->Subshape(lhs, 0, -2, &lhs_batch_shape));
100 TF_RETURN_IF_ERROR(c->Subshape(rhs, 0, -2, &rhs_batch_shape));
101 TF_RETURN_IF_ERROR(BroadcastBinaryOpOutputShapeFnHelper(
102 c, lhs_batch_shape, rhs_batch_shape, true, &output_batch_shape));
103 DimensionHandle m;
104 // lhs and rhs have the same value for m to be compatible.
105 TF_RETURN_IF_ERROR(c->Merge(c->Dim(lhs, -1), c->Dim(rhs, -2), &m));
106
107 ShapeHandle out;
108 // Build final shape (batch_shape + m + n) in <out>.
109 TF_RETURN_IF_ERROR(
110 c->Concatenate(output_batch_shape, c->Matrix(m, c->Dim(rhs, -1)), &out));
111 c->set_output(0, out);
112 return Status::OK();
113 }
114
115 // Input is [...,N,N]. Outputs are:
116 // [...,N];[0], if compute_v is false,
117 // [...,N];[...,N,N], if compute_v is true.
SelfAdjointEigV2ShapeFn(InferenceContext * c)118 Status SelfAdjointEigV2ShapeFn(InferenceContext* c) {
119 ShapeHandle input;
120 TF_RETURN_IF_ERROR(MakeBatchSquareMatrix(c, c->input(0), &input));
121 DimensionHandle n;
122 TF_RETURN_IF_ERROR(c->Merge(c->Dim(input, -2), c->Dim(input, -1), &n));
123 ShapeHandle batch_shape;
124 TF_RETURN_IF_ERROR(c->Subshape(input, 0, -2, &batch_shape));
125 ShapeHandle e_shape;
126 TF_RETURN_IF_ERROR(c->Concatenate(batch_shape, c->Vector(n), &e_shape));
127 c->set_output(0, e_shape);
128 bool compute_v;
129 TF_RETURN_IF_ERROR(c->GetAttr("compute_v", &compute_v));
130 if (compute_v) {
131 ShapeHandle v_shape;
132 TF_RETURN_IF_ERROR(c->Concatenate(batch_shape, c->Matrix(n, n), &v_shape));
133 c->set_output(1, v_shape);
134 } else {
135 c->set_output(1, c->Vector(0ll));
136 }
137 return Status::OK();
138 }
139
140 // Input is [...,N,N].
141 // First and second outputs are:
142 // [...,N,N]; [...,N].
LuShapeFn(InferenceContext * c)143 Status LuShapeFn(InferenceContext* c) {
144 ShapeHandle input;
145 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input));
146
147 DimensionHandle n;
148 TF_RETURN_IF_ERROR(c->Merge(c->Dim(input, -2), c->Dim(input, -1), &n));
149
150 ShapeHandle batch_shape;
151 TF_RETURN_IF_ERROR(c->Subshape(input, 0, -2, &batch_shape));
152
153 ShapeHandle lu_shape;
154 ShapeHandle p_shape;
155
156 TF_RETURN_IF_ERROR(c->Concatenate(batch_shape, c->Matrix(n, n), &lu_shape));
157 TF_RETURN_IF_ERROR(c->Concatenate(batch_shape, c->Vector(n), &p_shape));
158
159 c->set_output(0, lu_shape);
160 c->set_output(1, p_shape);
161 return Status::OK();
162 }
163
164 // Input is [...,M,N].
165 // First and second outputs are:
166 // [...,M,M]; [...,M,N], if full_matrices is true,
167 // [...,M,P]; [...,P,N], if full_matrices is false,
168 // where P = min(M,N).
QrShapeFn(InferenceContext * c)169 Status QrShapeFn(InferenceContext* c) {
170 ShapeHandle input;
171 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input));
172 DimensionHandle m = c->Dim(input, -2);
173 DimensionHandle n = c->Dim(input, -1);
174 DimensionHandle p;
175 TF_RETURN_IF_ERROR(c->Min(m, n, &p));
176 ShapeHandle batch_shape;
177 TF_RETURN_IF_ERROR(c->Subshape(input, 0, -2, &batch_shape));
178 ShapeHandle q_shape;
179 ShapeHandle r_shape;
180 bool full_matrices;
181 TF_RETURN_IF_ERROR(c->GetAttr("full_matrices", &full_matrices));
182 if (full_matrices) {
183 TF_RETURN_IF_ERROR(c->Concatenate(batch_shape, c->Matrix(m, m), &q_shape));
184 TF_RETURN_IF_ERROR(c->Concatenate(batch_shape, c->Matrix(m, n), &r_shape));
185 } else {
186 TF_RETURN_IF_ERROR(c->Concatenate(batch_shape, c->Matrix(m, p), &q_shape));
187 TF_RETURN_IF_ERROR(c->Concatenate(batch_shape, c->Matrix(p, n), &r_shape));
188 }
189 c->set_output(0, q_shape);
190 c->set_output(1, r_shape);
191 return Status::OK();
192 }
193
194 // Input is [...,M,N]. First output is [...,min(M,N)].
195 // Second and third outputs are:
196 // [0]; [0], if compute_uv is false.
197 // [...,M,M]; [...,N,N], if compute_uv is true and full_matrices is true,
198 // [...,M,P]; [...,N,P], if compute_uv is true and full_matrices is false,
199 // where P = min(M,N).
SvdShapeFn(InferenceContext * c)200 Status SvdShapeFn(InferenceContext* c) {
201 ShapeHandle input;
202 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input));
203 DimensionHandle m = c->Dim(input, -2);
204 DimensionHandle n = c->Dim(input, -1);
205 DimensionHandle p;
206 TF_RETURN_IF_ERROR(c->Min(m, n, &p));
207 ShapeHandle batch_shape;
208 TF_RETURN_IF_ERROR(c->Subshape(input, 0, -2, &batch_shape));
209 ShapeHandle e_shape;
210 TF_RETURN_IF_ERROR(c->Concatenate(batch_shape, c->Vector(p), &e_shape));
211 c->set_output(0, e_shape);
212 bool compute_uv;
213 TF_RETURN_IF_ERROR(c->GetAttr("compute_uv", &compute_uv));
214 if (compute_uv) {
215 ShapeHandle u_shape;
216 ShapeHandle v_shape;
217 bool full_matrices;
218 TF_RETURN_IF_ERROR(c->GetAttr("full_matrices", &full_matrices));
219 if (full_matrices) {
220 TF_RETURN_IF_ERROR(
221 c->Concatenate(batch_shape, c->Matrix(m, m), &u_shape));
222 TF_RETURN_IF_ERROR(
223 c->Concatenate(batch_shape, c->Matrix(n, n), &v_shape));
224 } else {
225 TF_RETURN_IF_ERROR(
226 c->Concatenate(batch_shape, c->Matrix(m, p), &u_shape));
227 TF_RETURN_IF_ERROR(
228 c->Concatenate(batch_shape, c->Matrix(n, p), &v_shape));
229 }
230 c->set_output(1, u_shape);
231 c->set_output(2, v_shape);
232 } else {
233 c->set_output(1, c->Vector(0ll));
234 c->set_output(2, c->Vector(0ll));
235 }
236 return Status::OK();
237 }
238
239 // Inputs: [...,1,M], [...,1,M], [...,1,M],[...,M,N].
240 // Output is [...,M,N].
TridiagonalMatMulShapeFn(InferenceContext * c)241 Status TridiagonalMatMulShapeFn(InferenceContext* c) {
242 ShapeHandle superdiag;
243 ShapeHandle maindiag;
244 ShapeHandle subdiag;
245 ShapeHandle rhs;
246
247 // Check that rank is at least 2.
248 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &superdiag));
249 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 2, &maindiag));
250 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(2), 2, &subdiag));
251 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(3), 2, &rhs));
252
253 // Extract batch dimensions and check they are the same.
254 ShapeHandle superdiag_batch_shape;
255 ShapeHandle maindiag_batch_shape;
256 ShapeHandle subdiag_batch_shape;
257 ShapeHandle rhs_batch_shape;
258 TF_RETURN_IF_ERROR(c->Subshape(superdiag, 0, -2, &superdiag_batch_shape));
259 TF_RETURN_IF_ERROR(c->Subshape(maindiag, 0, -2, &maindiag_batch_shape));
260 TF_RETURN_IF_ERROR(c->Subshape(subdiag, 0, -2, &subdiag_batch_shape));
261 TF_RETURN_IF_ERROR(c->Subshape(rhs, 0, -2, &rhs_batch_shape));
262 TF_RETURN_IF_ERROR(c->Merge(superdiag, maindiag, &superdiag));
263 TF_RETURN_IF_ERROR(
264 c->Merge(maindiag_batch_shape, rhs_batch_shape, &rhs_batch_shape));
265 TF_RETURN_IF_ERROR(
266 c->Merge(subdiag_batch_shape, rhs_batch_shape, &rhs_batch_shape));
267
268 // Check that diagonals have the same shape.
269 TF_RETURN_IF_ERROR(c->Merge(superdiag, maindiag, &maindiag));
270 TF_RETURN_IF_ERROR(c->Merge(subdiag, maindiag, &maindiag));
271
272 // Check that size of tri-diagonal matrix is the same as height of matrix on
273 // the right.
274 DimensionHandle m_lhs = c->Dim(maindiag, -1);
275 DimensionHandle m_rhs = c->Dim(rhs, -2);
276 TF_RETURN_IF_ERROR(c->Merge(m_lhs, m_rhs, &m_lhs));
277
278 // Check that next-to-last dimension of diagonals is 1.
279 DimensionHandle unused;
280 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(maindiag, -2), 1, &unused));
281
282 // The output shape is the same as rhs shape.
283 c->set_output(0, rhs);
284 return Status::OK();
285 }
286
287 // The first input is [...,3,M] and second input is [...,M,K].
288 // Output is [...,M,K].
TridiagonalSolveShapeFn(InferenceContext * c)289 Status TridiagonalSolveShapeFn(InferenceContext* c) {
290 ShapeHandle lhs;
291 ShapeHandle rhs;
292 // Check that rank is at least 2.
293 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &lhs));
294 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 2, &rhs));
295
296 // Extract batch dimensions and check they are the same.
297 ShapeHandle lhs_batch_shape;
298 ShapeHandle rhs_batch_shape;
299 TF_RETURN_IF_ERROR(c->Subshape(lhs, 0, -2, &lhs_batch_shape));
300 TF_RETURN_IF_ERROR(c->Subshape(rhs, 0, -2, &rhs_batch_shape));
301 TF_RETURN_IF_ERROR(
302 c->Merge(lhs_batch_shape, rhs_batch_shape, &lhs_batch_shape));
303
304 // Check that "M" is the same in both inputs.
305 DimensionHandle m_lhs = c->Dim(lhs, -1);
306 DimensionHandle m_rhs = c->Dim(rhs, -2);
307 TF_RETURN_IF_ERROR(c->Merge(m_lhs, m_rhs, &m_lhs));
308
309 // Check that next-to-last dimension of the first input is 3.
310 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(lhs, -2), 3, &m_lhs));
311
312 // The output shape is the same as rhs shape.
313 c->set_output(0, rhs);
314 return Status::OK();
315 }
316
317 } // namespace
318
319 REGISTER_OP("MatrixDeterminant")
320 .Input("input: T")
321 .Output("output: T")
322 .Attr("T: {half, float, double, complex64, complex128}")
__anon016102550202(InferenceContext* c) 323 .SetShapeFn([](InferenceContext* c) {
324 ShapeHandle input;
325 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input));
326
327 DimensionHandle unused;
328 TF_RETURN_IF_ERROR(
329 c->Merge(c->Dim(input, -1), c->Dim(input, -2), &unused));
330
331 ShapeHandle out;
332 TF_RETURN_IF_ERROR(c->Subshape(input, 0, -2, &out));
333 c->set_output(0, out);
334 return Status::OK();
335 });
336
337 REGISTER_OP("LogMatrixDeterminant")
338 .Input("input: T")
339 .Output("sign: T")
340 .Output("log_abs_determinant: T")
341 .Attr("T: {half, float, double, complex64, complex128}")
__anon016102550302(InferenceContext* c) 342 .SetShapeFn([](InferenceContext* c) {
343 ShapeHandle input;
344 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input));
345
346 DimensionHandle unused;
347 TF_RETURN_IF_ERROR(
348 c->Merge(c->Dim(input, -1), c->Dim(input, -2), &unused));
349
350 ShapeHandle s;
351 TF_RETURN_IF_ERROR(c->Subshape(input, 0, -2, &s));
352 c->set_output(0, s);
353
354 ShapeHandle out;
355 TF_RETURN_IF_ERROR(c->Subshape(input, 0, -2, &out));
356 c->set_output(1, out);
357 return Status::OK();
358 });
359
360 REGISTER_OP("MatrixInverse")
361 .Input("input: T")
362 .Output("output: T")
363 .Attr("adjoint: bool = False")
364 .Attr("T: {double, float, half, complex64, complex128}")
365 .SetShapeFn(BatchUnchangedSquareShapeFn);
366
367 REGISTER_OP("MatrixExponential")
368 .Deprecated(
369 27, "Use Python implementation tf.linalg.matrix_exponential instead.")
370 .Input("input: T")
371 .Output("output: T")
372 .Attr("T: {double, float, half, complex64, complex128}")
373 .SetShapeFn(BatchUnchangedSquareShapeFn);
374
375 REGISTER_OP("MatrixLogarithm")
376 .Input("input: T")
377 .Output("output: T")
378 .Attr("T: {complex64, complex128}")
379 .SetShapeFn(BatchUnchangedSquareShapeFn);
380
381 REGISTER_OP("Cholesky")
382 .Input("input: T")
383 .Output("output: T")
384 .Attr("T: {double, float, half, complex64, complex128}")
385 .SetShapeFn(BatchUnchangedSquareShapeFn);
386
387 REGISTER_OP("CholeskyGrad")
388 .Input("l: T")
389 .Input("grad: T")
390 .Output("output: T")
391 .Attr("T: {half, float, double}")
392 .SetShapeFn(BatchUnchangedSquareShapeFn);
393
394 REGISTER_OP("SelfAdjointEig")
395 .Input("input: T")
396 .Output("output: T")
397 .Attr("T: {double, float, half}")
398 .Deprecated(11, "Use SelfAdjointEigV2 instead.")
__anon016102550402(InferenceContext* c) 399 .SetShapeFn([](InferenceContext* c) {
400 ShapeHandle input;
401 TF_RETURN_IF_ERROR(MakeBatchSquareMatrix(c, c->input(0), &input));
402
403 DimensionHandle d = c->Dim(input, -1);
404 DimensionHandle d_plus_1;
405 TF_RETURN_IF_ERROR(c->Add(d, 1, &d_plus_1));
406
407 ShapeHandle s;
408 TF_RETURN_IF_ERROR(c->Subshape(input, 0, -2, &s));
409 TF_RETURN_IF_ERROR(c->Concatenate(s, c->Matrix(d_plus_1, d), &s));
410 c->set_output(0, s);
411 return Status::OK();
412 });
413
414 REGISTER_OP("Eig")
415 .Input("input: T")
416 .Output("e: Tout")
417 .Output("v: Tout")
418 .Attr("compute_v: bool = True")
419 .Attr("T: {float, double, complex64, complex128}")
420 .Attr("Tout: {complex64, complex128}")
421 .SetShapeFn(SelfAdjointEigV2ShapeFn);
422
423 REGISTER_OP("SelfAdjointEigV2")
424 .Input("input: T")
425 .Output("e: T")
426 .Output("v: T")
427 .Attr("compute_v: bool = True")
428 .Attr("T: {double, float, half, complex64, complex128}")
429 .SetShapeFn(SelfAdjointEigV2ShapeFn);
430
431 REGISTER_OP("Lu")
432 .Input("input: T")
433 .Output("lu: T")
434 .Output("p: output_idx_type")
435 .Attr("T: {double, float, half, complex64, complex128}")
436 .Attr("output_idx_type: {int32, int64} = DT_INT32")
437 .SetShapeFn(LuShapeFn);
438
439 REGISTER_OP("MatrixSolve")
440 .Input("matrix: T")
441 .Input("rhs: T")
442 .Output("output: T")
443 .Attr("adjoint: bool = False")
444 .Attr("T: {double, float, half, complex64, complex128}")
__anon016102550502(InferenceContext* c) 445 .SetShapeFn([](InferenceContext* c) {
446 return MatrixSolveShapeFn(c, true /* square (*/);
447 });
448
449 REGISTER_OP("MatrixTriangularSolve")
450 .Input("matrix: T")
451 .Input("rhs: T")
452 .Output("output: T")
453 .Attr("lower: bool = True")
454 .Attr("adjoint: bool = False")
455 .Attr("T: {double, float, half, complex64, complex128}")
__anon016102550602(InferenceContext* c) 456 .SetShapeFn([](InferenceContext* c) {
457 return MatrixTriangularSolveShapeFn(c);
458 });
459
460 REGISTER_OP("MatrixSolveLs")
461 .Input("matrix: T")
462 .Input("rhs: T")
463 .Input("l2_regularizer: double")
464 .Output("output: T")
465 .Attr("T: {double, float, half, complex64, complex128}")
466 .Attr("fast: bool = True")
__anon016102550702(InferenceContext* c) 467 .SetShapeFn([](InferenceContext* c) {
468 ShapeHandle l2_regularizer;
469 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &l2_regularizer));
470 return MatrixSolveShapeFn(c, false /* square */);
471 });
472
473 REGISTER_OP("MatrixSquareRoot")
474 .Input("input: T")
475 .Output("output: T")
476 .Attr("T: {double, float, half, complex64, complex128}")
477 .SetShapeFn(BatchUnchangedSquareShapeFn);
478
479 REGISTER_OP("Qr")
480 .Input("input: T")
481 .Output("q: T")
482 .Output("r: T")
483 .Attr("full_matrices: bool = False")
484 .Attr("T: {double, float, half, complex64, complex128}")
485 .SetShapeFn(QrShapeFn);
486
487 REGISTER_OP("Svd")
488 .Input("input: T")
489 .Output("s: T")
490 .Output("u: T")
491 .Output("v: T")
492 .Attr("compute_uv: bool = True")
493 .Attr("full_matrices: bool = False")
494 .Attr("T: {double, float, half, complex64, complex128}")
495 .SetShapeFn(SvdShapeFn);
496
497 REGISTER_OP("TridiagonalMatMul")
498 .Input("superdiag: T")
499 .Input("maindiag: T")
500 .Input("subdiag: T")
501 .Input("rhs: T")
502 .Output("output: T")
503 .Attr("T: {double, float, complex64, complex128}")
504 .SetShapeFn(TridiagonalMatMulShapeFn);
505
506 REGISTER_OP("TridiagonalSolve")
507 .Input("diagonals: T")
508 .Input("rhs: T")
509 .Output("output: T")
510 .Attr("partial_pivoting: bool = True")
511 .Attr("T: {double, float, complex64, complex128}")
512 .SetShapeFn(TridiagonalSolveShapeFn);
513
514 REGISTER_OP("Einsum")
515 .Input("inputs: N * T")
516 .Output("output: T")
517 .Attr("equation: string")
518 .Attr("N: int >= 1")
519 .Attr("T: type")
520 .SetShapeFn(shape_inference::EinsumShape);
521
522 // Deprecated op registrations:
523
524 // Can be deleted after 3feb2017.
525 REGISTER_OP("BatchSelfAdjointEig")
526 .Input("input: T")
527 .Output("output: T")
528 .Attr("T: {double, float}")
529 .Deprecated(11, "Use SelfAdjointEigV2 instead.")
530 .SetShapeFn(shape_inference::UnknownShape);
531
532 // Can all be deleted after 9mar2017.
533 REGISTER_OP("BatchMatrixDeterminant")
534 .Input("input: T")
535 .Output("output: T")
536 .Attr("T: {float, double, complex64, complex128}")
537 .Deprecated(13, "Use MatrixDeterminant instead.")
538 .SetShapeFn(shape_inference::UnknownShape);
539
540 REGISTER_OP("BatchMatrixInverse")
541 .Input("input: T")
542 .Output("output: T")
543 .Attr("adjoint: bool = False")
544 .Attr("T: {double, float}")
545 .Deprecated(13, "Use MatrixInverse instead.")
546 .SetShapeFn(shape_inference::UnknownShape);
547
548 REGISTER_OP("BatchCholesky")
549 .Input("input: T")
550 .Output("output: T")
551 .Attr("T: {double, float}")
552 .Deprecated(13, "Use Cholesky instead.")
553 .SetShapeFn(shape_inference::UnknownShape);
554
555 REGISTER_OP("BatchCholeskyGrad")
556 .Input("l: T")
557 .Input("grad: T")
558 .Output("output: T")
559 .Attr("T: {float, double}")
560 .Deprecated(13, "Use CholeskyGrad instead.")
561 .SetShapeFn(shape_inference::UnknownShape);
562
563 REGISTER_OP("BatchSelfAdjointEigV2")
564 .Input("input: T")
565 .Output("e: T")
566 .Output("v: T")
567 .Attr("compute_v: bool = True")
568 .Attr("T: {double, float}")
569 .Deprecated(13, "Use SelfAdjointEigV2 instead.")
570 .SetShapeFn(shape_inference::UnknownShape);
571
572 REGISTER_OP("BatchMatrixSolve")
573 .Input("matrix: T")
574 .Input("rhs: T")
575 .Output("output: T")
576 .Attr("adjoint: bool = False")
577 .Attr("T: {double, float}")
578 .Deprecated(13, "Use MatrixSolve instead.")
579 .SetShapeFn(shape_inference::UnknownShape);
580
581 REGISTER_OP("BatchMatrixTriangularSolve")
582 .Input("matrix: T")
583 .Input("rhs: T")
584 .Output("output: T")
585 .Attr("lower: bool = True")
586 .Attr("adjoint: bool = False")
587 .Attr("T: {double, float}")
588 .Deprecated(13, "Use MatrixTriangularSolve instead.")
589 .SetShapeFn(shape_inference::UnknownShape);
590
591 REGISTER_OP("BatchMatrixSolveLs")
592 .Input("matrix: T")
593 .Input("rhs: T")
594 .Input("l2_regularizer: double")
595 .Output("output: T")
596 .Attr("T: {double, float}")
597 .Attr("fast: bool = True")
598 .Deprecated(13, "Use MatrixSolveLs instead.")
599 .SetShapeFn(shape_inference::UnknownShape);
600
601 REGISTER_OP("BatchSvd")
602 .Input("input: T")
603 .Output("s: T")
604 .Output("u: T")
605 .Output("v: T")
606 .Attr("compute_uv: bool = True")
607 .Attr("full_matrices: bool = False")
608 .Attr("T: {double, float, complex64, complex128}")
609 .Deprecated(13, "Use Svd instead.")
610 .SetShapeFn(shape_inference::UnknownShape);
611
612 } // namespace tensorflow
613