1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/framework/common_shape_fns.h"
17 #include "tensorflow/core/framework/op.h"
18 #include "tensorflow/core/framework/shape_inference.h"
19
20 namespace tensorflow {
21
22 using shape_inference::DimensionHandle;
23 using shape_inference::InferenceContext;
24 using shape_inference::ShapeHandle;
25
26 namespace {
27
28 // Return in <out> the result of making the end of <s> a square matrix.
MakeBatchSquareMatrix(InferenceContext * c,ShapeHandle input,ShapeHandle * out)29 Status MakeBatchSquareMatrix(InferenceContext* c, ShapeHandle input,
30 ShapeHandle* out) {
31 ShapeHandle s;
32 TF_RETURN_IF_ERROR(c->WithRankAtLeast(input, 2, &s));
33
34 DimensionHandle d;
35 TF_RETURN_IF_ERROR(c->Merge(c->Dim(s, -2), c->Dim(s, -1), &d));
36
37 ShapeHandle batch_shape;
38 TF_RETURN_IF_ERROR(c->Subshape(s, 0, -2, &batch_shape));
39 TF_RETURN_IF_ERROR(c->Concatenate(batch_shape, c->Matrix(d, d), out));
40 return Status::OK();
41 }
42
BatchUnchangedSquareShapeFn(InferenceContext * c)43 Status BatchUnchangedSquareShapeFn(InferenceContext* c) {
44 ShapeHandle out;
45 TF_RETURN_IF_ERROR(MakeBatchSquareMatrix(c, c->input(0), &out));
46 c->set_output(0, out);
47 return Status::OK();
48 }
49
50 // The first input is [...,K,M] and second input is [...,M,N].
BandedTriangularSolveShapeFn(InferenceContext * c)51 Status BandedTriangularSolveShapeFn(InferenceContext* c) {
52 ShapeHandle lhs;
53 ShapeHandle rhs;
54
55 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &lhs));
56 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 2, &rhs));
57
58 // Check K > 0.
59 DimensionHandle num_bands = c->Dim(lhs, -2);
60 DimensionHandle m = c->Dim(lhs, -1);
61 if (c->ValueKnown(num_bands) && c->Value(num_bands) <= 0) {
62 return errors::InvalidArgument("Number of bands must be positive, but is ",
63 c->Value(num_bands));
64 }
65 if (c->ValueKnown(num_bands) && c->ValueKnown(m) &&
66 c->Value(num_bands) > c->Value(m)) {
67 return errors::InvalidArgument("Number of bands ", c->Value(num_bands),
68 " cannot exceed the size of the matrix ",
69 c->Value(m));
70 }
71
72 ShapeHandle lhs_batch_shape;
73 ShapeHandle rhs_batch_shape;
74 ShapeHandle output_batch_shape;
75 // Make the common batch subshape.
76 TF_RETURN_IF_ERROR(c->Subshape(lhs, 0, -2, &lhs_batch_shape));
77 TF_RETURN_IF_ERROR(c->Subshape(rhs, 0, -2, &rhs_batch_shape));
78 TF_RETURN_IF_ERROR(BroadcastBinaryOpOutputShapeFnHelper(
79 c, lhs_batch_shape, rhs_batch_shape, true, &output_batch_shape));
80
81 // lhs and rhs have the same value for M to be compatible.
82 TF_RETURN_IF_ERROR(c->Merge(m, c->Dim(rhs, -2), &m));
83
84 // Build final shape (batch_shape + m + n) in <out>.
85 ShapeHandle out;
86 TF_RETURN_IF_ERROR(
87 c->Concatenate(output_batch_shape, c->Matrix(m, c->Dim(rhs, -1)), &out));
88
89 c->set_output(0, out);
90 return Status::OK();
91 }
92
93 // The first input is [...,M,N] and second input is either [...,M,K] or [...,M].
94 // Output is [...,N,K] or [...,N]. If <square>, then input is [...,M,M].
MatrixSolveShapeFn(InferenceContext * c,bool square)95 Status MatrixSolveShapeFn(InferenceContext* c, bool square) {
96 ShapeHandle lhs;
97 ShapeHandle rhs;
98 if (square) {
99 TF_RETURN_IF_ERROR(MakeBatchSquareMatrix(c, c->input(0), &lhs));
100 } else {
101 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &lhs));
102 }
103 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 2, &rhs));
104
105 ShapeHandle lhs_batch_shape;
106 ShapeHandle rhs_batch_shape;
107 // Make the common batch subshape.
108 TF_RETURN_IF_ERROR(c->Subshape(lhs, 0, -2, &lhs_batch_shape));
109 TF_RETURN_IF_ERROR(c->Subshape(rhs, 0, -2, &rhs_batch_shape));
110 // Make sure the batch dimensions match between lhs and rhs.
111 TF_RETURN_IF_ERROR(
112 c->Merge(lhs_batch_shape, rhs_batch_shape, &lhs_batch_shape));
113
114 DimensionHandle m;
115 // lhs and rhs have the same value for m to be compatible.
116 TF_RETURN_IF_ERROR(c->Merge(c->Dim(lhs, -2), c->Dim(rhs, -2), &m));
117 DimensionHandle n = c->Dim(lhs, -1);
118 if (square) {
119 TF_RETURN_IF_ERROR(c->Merge(m, n, &n));
120 }
121
122 ShapeHandle out;
123 // Build final shape (batch_shape + n + k) in <out>.
124 TF_RETURN_IF_ERROR(c->Concatenate(lhs_batch_shape, c->Vector(n), &out));
125 TF_RETURN_IF_ERROR(c->Concatenate(out, c->Vector(c->Dim(rhs, -1)), &out));
126 c->set_output(0, out);
127 return Status::OK();
128 }
129
130 // The first input is [...,M,M] and second input is [...,M,N].
131 // Output is [...,M,N].
MatrixTriangularSolveShapeFn(InferenceContext * c)132 Status MatrixTriangularSolveShapeFn(InferenceContext* c) {
133 ShapeHandle lhs;
134 ShapeHandle rhs;
135 TF_RETURN_IF_ERROR(MakeBatchSquareMatrix(c, c->input(0), &lhs));
136 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 2, &rhs));
137
138 ShapeHandle lhs_batch_shape;
139 ShapeHandle rhs_batch_shape;
140 ShapeHandle output_batch_shape;
141 // Make the common batch subshape.
142 TF_RETURN_IF_ERROR(c->Subshape(lhs, 0, -2, &lhs_batch_shape));
143 TF_RETURN_IF_ERROR(c->Subshape(rhs, 0, -2, &rhs_batch_shape));
144 TF_RETURN_IF_ERROR(BroadcastBinaryOpOutputShapeFnHelper(
145 c, lhs_batch_shape, rhs_batch_shape, true, &output_batch_shape));
146 DimensionHandle m;
147 // lhs and rhs have the same value for m to be compatible.
148 TF_RETURN_IF_ERROR(c->Merge(c->Dim(lhs, -1), c->Dim(rhs, -2), &m));
149
150 ShapeHandle out;
151 // Build final shape (batch_shape + m + n) in <out>.
152 TF_RETURN_IF_ERROR(
153 c->Concatenate(output_batch_shape, c->Matrix(m, c->Dim(rhs, -1)), &out));
154 c->set_output(0, out);
155 return Status::OK();
156 }
157
158 // Input is [...,N,N]. Outputs are:
159 // [...,N];[0], if compute_v is false,
160 // [...,N];[...,N,N], if compute_v is true.
SelfAdjointEigV2ShapeFn(InferenceContext * c)161 Status SelfAdjointEigV2ShapeFn(InferenceContext* c) {
162 ShapeHandle input;
163 TF_RETURN_IF_ERROR(MakeBatchSquareMatrix(c, c->input(0), &input));
164 DimensionHandle n;
165 TF_RETURN_IF_ERROR(c->Merge(c->Dim(input, -2), c->Dim(input, -1), &n));
166 ShapeHandle batch_shape;
167 TF_RETURN_IF_ERROR(c->Subshape(input, 0, -2, &batch_shape));
168 ShapeHandle e_shape;
169 TF_RETURN_IF_ERROR(c->Concatenate(batch_shape, c->Vector(n), &e_shape));
170 c->set_output(0, e_shape);
171 bool compute_v;
172 TF_RETURN_IF_ERROR(c->GetAttr("compute_v", &compute_v));
173 if (compute_v) {
174 ShapeHandle v_shape;
175 TF_RETURN_IF_ERROR(c->Concatenate(batch_shape, c->Matrix(n, n), &v_shape));
176 c->set_output(1, v_shape);
177 } else {
178 c->set_output(1, c->Vector(0ll));
179 }
180 return Status::OK();
181 }
182
183 // Input is [...,N,N].
184 // First and second outputs are:
185 // [...,N,N]; [...,N].
LuShapeFn(InferenceContext * c)186 Status LuShapeFn(InferenceContext* c) {
187 ShapeHandle input;
188 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input));
189
190 DimensionHandle n;
191 TF_RETURN_IF_ERROR(c->Merge(c->Dim(input, -2), c->Dim(input, -1), &n));
192
193 ShapeHandle batch_shape;
194 TF_RETURN_IF_ERROR(c->Subshape(input, 0, -2, &batch_shape));
195
196 ShapeHandle lu_shape;
197 ShapeHandle p_shape;
198
199 TF_RETURN_IF_ERROR(c->Concatenate(batch_shape, c->Matrix(n, n), &lu_shape));
200 TF_RETURN_IF_ERROR(c->Concatenate(batch_shape, c->Vector(n), &p_shape));
201
202 c->set_output(0, lu_shape);
203 c->set_output(1, p_shape);
204 return Status::OK();
205 }
206
207 // Input is [...,M,N].
208 // First and second outputs are:
209 // [...,M,M]; [...,M,N], if full_matrices is true,
210 // [...,M,P]; [...,P,N], if full_matrices is false,
211 // where P = min(M,N).
QrShapeFn(InferenceContext * c)212 Status QrShapeFn(InferenceContext* c) {
213 ShapeHandle input;
214 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input));
215 DimensionHandle m = c->Dim(input, -2);
216 DimensionHandle n = c->Dim(input, -1);
217 DimensionHandle p;
218 TF_RETURN_IF_ERROR(c->Min(m, n, &p));
219 ShapeHandle batch_shape;
220 TF_RETURN_IF_ERROR(c->Subshape(input, 0, -2, &batch_shape));
221 ShapeHandle q_shape;
222 ShapeHandle r_shape;
223 bool full_matrices;
224 TF_RETURN_IF_ERROR(c->GetAttr("full_matrices", &full_matrices));
225 if (full_matrices) {
226 TF_RETURN_IF_ERROR(c->Concatenate(batch_shape, c->Matrix(m, m), &q_shape));
227 TF_RETURN_IF_ERROR(c->Concatenate(batch_shape, c->Matrix(m, n), &r_shape));
228 } else {
229 TF_RETURN_IF_ERROR(c->Concatenate(batch_shape, c->Matrix(m, p), &q_shape));
230 TF_RETURN_IF_ERROR(c->Concatenate(batch_shape, c->Matrix(p, n), &r_shape));
231 }
232 c->set_output(0, q_shape);
233 c->set_output(1, r_shape);
234 return Status::OK();
235 }
236
237 // Input is [...,M,N]. First output is [...,min(M,N)].
238 // Second and third outputs are:
239 // [0]; [0], if compute_uv is false.
240 // [...,M,M]; [...,N,N], if compute_uv is true and full_matrices is true,
241 // [...,M,P]; [...,N,P], if compute_uv is true and full_matrices is false,
242 // where P = min(M,N).
SvdShapeFn(InferenceContext * c)243 Status SvdShapeFn(InferenceContext* c) {
244 ShapeHandle input;
245 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input));
246 DimensionHandle m = c->Dim(input, -2);
247 DimensionHandle n = c->Dim(input, -1);
248 DimensionHandle p;
249 TF_RETURN_IF_ERROR(c->Min(m, n, &p));
250 ShapeHandle batch_shape;
251 TF_RETURN_IF_ERROR(c->Subshape(input, 0, -2, &batch_shape));
252 ShapeHandle e_shape;
253 TF_RETURN_IF_ERROR(c->Concatenate(batch_shape, c->Vector(p), &e_shape));
254 c->set_output(0, e_shape);
255 bool compute_uv;
256 TF_RETURN_IF_ERROR(c->GetAttr("compute_uv", &compute_uv));
257 if (compute_uv) {
258 ShapeHandle u_shape;
259 ShapeHandle v_shape;
260 bool full_matrices;
261 TF_RETURN_IF_ERROR(c->GetAttr("full_matrices", &full_matrices));
262 if (full_matrices) {
263 TF_RETURN_IF_ERROR(
264 c->Concatenate(batch_shape, c->Matrix(m, m), &u_shape));
265 TF_RETURN_IF_ERROR(
266 c->Concatenate(batch_shape, c->Matrix(n, n), &v_shape));
267 } else {
268 TF_RETURN_IF_ERROR(
269 c->Concatenate(batch_shape, c->Matrix(m, p), &u_shape));
270 TF_RETURN_IF_ERROR(
271 c->Concatenate(batch_shape, c->Matrix(n, p), &v_shape));
272 }
273 c->set_output(1, u_shape);
274 c->set_output(2, v_shape);
275 } else {
276 c->set_output(1, c->Vector(0ll));
277 c->set_output(2, c->Vector(0ll));
278 }
279 return Status::OK();
280 }
281
282 // Inputs: [...,1,M], [...,1,M], [...,1,M],[...,M,N].
283 // Output is [...,M,N].
TridiagonalMatMulShapeFn(InferenceContext * c)284 Status TridiagonalMatMulShapeFn(InferenceContext* c) {
285 ShapeHandle superdiag;
286 ShapeHandle maindiag;
287 ShapeHandle subdiag;
288 ShapeHandle rhs;
289
290 // Check that rank is at least 2.
291 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &superdiag));
292 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 2, &maindiag));
293 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(2), 2, &subdiag));
294 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(3), 2, &rhs));
295
296 // Extract batch dimensions and check they are the same.
297 ShapeHandle superdiag_batch_shape;
298 ShapeHandle maindiag_batch_shape;
299 ShapeHandle subdiag_batch_shape;
300 ShapeHandle rhs_batch_shape;
301 TF_RETURN_IF_ERROR(c->Subshape(superdiag, 0, -2, &superdiag_batch_shape));
302 TF_RETURN_IF_ERROR(c->Subshape(maindiag, 0, -2, &maindiag_batch_shape));
303 TF_RETURN_IF_ERROR(c->Subshape(subdiag, 0, -2, &subdiag_batch_shape));
304 TF_RETURN_IF_ERROR(c->Subshape(rhs, 0, -2, &rhs_batch_shape));
305 TF_RETURN_IF_ERROR(c->Merge(superdiag, maindiag, &superdiag));
306 TF_RETURN_IF_ERROR(
307 c->Merge(maindiag_batch_shape, rhs_batch_shape, &rhs_batch_shape));
308 TF_RETURN_IF_ERROR(
309 c->Merge(subdiag_batch_shape, rhs_batch_shape, &rhs_batch_shape));
310
311 // Check that diagonals have the same shape.
312 TF_RETURN_IF_ERROR(c->Merge(superdiag, maindiag, &maindiag));
313 TF_RETURN_IF_ERROR(c->Merge(subdiag, maindiag, &maindiag));
314
315 // Check that size of tri-diagonal matrix is the same as height of matrix on
316 // the right.
317 DimensionHandle m_lhs = c->Dim(maindiag, -1);
318 DimensionHandle m_rhs = c->Dim(rhs, -2);
319 TF_RETURN_IF_ERROR(c->Merge(m_lhs, m_rhs, &m_lhs));
320
321 // Check that next-to-last dimension of diagonals is 1.
322 DimensionHandle unused;
323 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(maindiag, -2), 1, &unused));
324
325 // The output shape is the same as rhs shape.
326 c->set_output(0, rhs);
327 return Status::OK();
328 }
329
330 // The first input is [...,3,M] and second input is [...,M,K].
331 // Output is [...,M,K].
TridiagonalSolveShapeFn(InferenceContext * c)332 Status TridiagonalSolveShapeFn(InferenceContext* c) {
333 ShapeHandle lhs;
334 ShapeHandle rhs;
335 // Check that rank is at least 2.
336 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &lhs));
337 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 2, &rhs));
338
339 // Extract batch dimensions and check they are the same.
340 ShapeHandle lhs_batch_shape;
341 ShapeHandle rhs_batch_shape;
342 TF_RETURN_IF_ERROR(c->Subshape(lhs, 0, -2, &lhs_batch_shape));
343 TF_RETURN_IF_ERROR(c->Subshape(rhs, 0, -2, &rhs_batch_shape));
344 TF_RETURN_IF_ERROR(
345 c->Merge(lhs_batch_shape, rhs_batch_shape, &lhs_batch_shape));
346
347 // Check that "M" is the same in both inputs.
348 DimensionHandle m_lhs = c->Dim(lhs, -1);
349 DimensionHandle m_rhs = c->Dim(rhs, -2);
350 TF_RETURN_IF_ERROR(c->Merge(m_lhs, m_rhs, &m_lhs));
351
352 // Check that next-to-last dimension of the first input is 3.
353 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(lhs, -2), 3, &m_lhs));
354
355 // The output shape is the same as rhs shape.
356 c->set_output(0, rhs);
357 return Status::OK();
358 }
359
360 } // namespace
361
362 REGISTER_OP("MatrixDeterminant")
363 .Input("input: T")
364 .Output("output: T")
365 .Attr("T: {half, float, double, complex64, complex128}")
__anonde8650310202(InferenceContext* c) 366 .SetShapeFn([](InferenceContext* c) {
367 ShapeHandle input;
368 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input));
369
370 DimensionHandle unused;
371 TF_RETURN_IF_ERROR(
372 c->Merge(c->Dim(input, -1), c->Dim(input, -2), &unused));
373
374 ShapeHandle out;
375 TF_RETURN_IF_ERROR(c->Subshape(input, 0, -2, &out));
376 c->set_output(0, out);
377 return Status::OK();
378 });
379
380 REGISTER_OP("LogMatrixDeterminant")
381 .Input("input: T")
382 .Output("sign: T")
383 .Output("log_abs_determinant: T")
384 .Attr("T: {half, float, double, complex64, complex128}")
__anonde8650310302(InferenceContext* c) 385 .SetShapeFn([](InferenceContext* c) {
386 ShapeHandle input;
387 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input));
388
389 DimensionHandle unused;
390 TF_RETURN_IF_ERROR(
391 c->Merge(c->Dim(input, -1), c->Dim(input, -2), &unused));
392
393 ShapeHandle s;
394 TF_RETURN_IF_ERROR(c->Subshape(input, 0, -2, &s));
395 c->set_output(0, s);
396
397 ShapeHandle out;
398 TF_RETURN_IF_ERROR(c->Subshape(input, 0, -2, &out));
399 c->set_output(1, out);
400 return Status::OK();
401 });
402
403 REGISTER_OP("MatrixInverse")
404 .Input("input: T")
405 .Output("output: T")
406 .Attr("adjoint: bool = False")
407 .Attr("T: {double, float, half, complex64, complex128}")
408 .SetShapeFn(BatchUnchangedSquareShapeFn);
409
410 REGISTER_OP("MatrixExponential")
411 .Deprecated(
412 27, "Use Python implementation tf.linalg.matrix_exponential instead.")
413 .Input("input: T")
414 .Output("output: T")
415 .Attr("T: {double, float, half, complex64, complex128}")
416 .SetShapeFn(BatchUnchangedSquareShapeFn);
417
418 REGISTER_OP("MatrixLogarithm")
419 .Input("input: T")
420 .Output("output: T")
421 .Attr("T: {complex64, complex128}")
422 .SetShapeFn(BatchUnchangedSquareShapeFn);
423
424 REGISTER_OP("Cholesky")
425 .Input("input: T")
426 .Output("output: T")
427 .Attr("T: {double, float, half, complex64, complex128}")
428 .SetShapeFn(BatchUnchangedSquareShapeFn);
429
430 REGISTER_OP("CholeskyGrad")
431 .Input("l: T")
432 .Input("grad: T")
433 .Output("output: T")
434 .Attr("T: {half, float, double}")
435 .SetShapeFn(BatchUnchangedSquareShapeFn);
436
437 REGISTER_OP("SelfAdjointEig")
438 .Input("input: T")
439 .Output("output: T")
440 .Attr("T: {double, float, half}")
441 .Deprecated(11, "Use SelfAdjointEigV2 instead.")
__anonde8650310402(InferenceContext* c) 442 .SetShapeFn([](InferenceContext* c) {
443 ShapeHandle input;
444 TF_RETURN_IF_ERROR(MakeBatchSquareMatrix(c, c->input(0), &input));
445
446 DimensionHandle d = c->Dim(input, -1);
447 DimensionHandle d_plus_1;
448 TF_RETURN_IF_ERROR(c->Add(d, 1, &d_plus_1));
449
450 ShapeHandle s;
451 TF_RETURN_IF_ERROR(c->Subshape(input, 0, -2, &s));
452 TF_RETURN_IF_ERROR(c->Concatenate(s, c->Matrix(d_plus_1, d), &s));
453 c->set_output(0, s);
454 return Status::OK();
455 });
456
457 REGISTER_OP("Eig")
458 .Input("input: T")
459 .Output("e: Tout")
460 .Output("v: Tout")
461 .Attr("compute_v: bool = True")
462 .Attr("T: {float, double, complex64, complex128}")
463 .Attr("Tout: {complex64, complex128}")
464 .SetShapeFn(SelfAdjointEigV2ShapeFn);
465
466 REGISTER_OP("SelfAdjointEigV2")
467 .Input("input: T")
468 .Output("e: T")
469 .Output("v: T")
470 .Attr("compute_v: bool = True")
471 .Attr("T: {double, float, half, complex64, complex128}")
472 .SetShapeFn(SelfAdjointEigV2ShapeFn);
473
474 REGISTER_OP("Lu")
475 .Input("input: T")
476 .Output("lu: T")
477 .Output("p: output_idx_type")
478 .Attr("T: {double, float, half, complex64, complex128}")
479 .Attr("output_idx_type: {int32, int64} = DT_INT32")
480 .SetShapeFn(LuShapeFn);
481
482 REGISTER_OP("MatrixSolve")
483 .Input("matrix: T")
484 .Input("rhs: T")
485 .Output("output: T")
486 .Attr("adjoint: bool = False")
487 .Attr("T: {double, float, half, complex64, complex128}")
__anonde8650310502(InferenceContext* c) 488 .SetShapeFn([](InferenceContext* c) {
489 return MatrixSolveShapeFn(c, true /* square (*/);
490 });
491
492 REGISTER_OP("BandedTriangularSolve")
493 .Input("matrix: T")
494 .Input("rhs: T")
495 .Output("output: T")
496 .Attr("lower: bool = True")
497 .Attr("adjoint: bool = False")
498 .Attr("T: {double, float, half, complex64, complex128}")
__anonde8650310602(InferenceContext* c) 499 .SetShapeFn([](InferenceContext* c) {
500 return BandedTriangularSolveShapeFn(c);
501 });
502
503 REGISTER_OP("MatrixTriangularSolve")
504 .Input("matrix: T")
505 .Input("rhs: T")
506 .Output("output: T")
507 .Attr("lower: bool = True")
508 .Attr("adjoint: bool = False")
509 .Attr("T: {double, float, half, complex64, complex128}")
__anonde8650310702(InferenceContext* c) 510 .SetShapeFn([](InferenceContext* c) {
511 return MatrixTriangularSolveShapeFn(c);
512 });
513
514 REGISTER_OP("MatrixSolveLs")
515 .Input("matrix: T")
516 .Input("rhs: T")
517 .Input("l2_regularizer: double")
518 .Output("output: T")
519 .Attr("T: {double, float, half, complex64, complex128}")
520 .Attr("fast: bool = True")
__anonde8650310802(InferenceContext* c) 521 .SetShapeFn([](InferenceContext* c) {
522 ShapeHandle l2_regularizer;
523 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &l2_regularizer));
524 return MatrixSolveShapeFn(c, false /* square */);
525 });
526
527 REGISTER_OP("MatrixSquareRoot")
528 .Input("input: T")
529 .Output("output: T")
530 .Attr("T: {double, float, half, complex64, complex128}")
531 .SetShapeFn(BatchUnchangedSquareShapeFn);
532
533 REGISTER_OP("Qr")
534 .Input("input: T")
535 .Output("q: T")
536 .Output("r: T")
537 .Attr("full_matrices: bool = False")
538 .Attr("T: {double, float, half, complex64, complex128}")
539 .SetShapeFn(QrShapeFn);
540
541 REGISTER_OP("Svd")
542 .Input("input: T")
543 .Output("s: T")
544 .Output("u: T")
545 .Output("v: T")
546 .Attr("compute_uv: bool = True")
547 .Attr("full_matrices: bool = False")
548 .Attr("T: {double, float, half, complex64, complex128}")
549 .SetShapeFn(SvdShapeFn);
550
551 REGISTER_OP("TridiagonalMatMul")
552 .Input("superdiag: T")
553 .Input("maindiag: T")
554 .Input("subdiag: T")
555 .Input("rhs: T")
556 .Output("output: T")
557 .Attr("T: {double, float, complex64, complex128}")
558 .SetShapeFn(TridiagonalMatMulShapeFn);
559
560 REGISTER_OP("TridiagonalSolve")
561 .Input("diagonals: T")
562 .Input("rhs: T")
563 .Output("output: T")
564 .Attr("partial_pivoting: bool = True")
565 .Attr("T: {double, float, complex64, complex128}")
566 .SetShapeFn(TridiagonalSolveShapeFn);
567
568 REGISTER_OP("Einsum")
569 .Input("inputs: N * T")
570 .Output("output: T")
571 .Attr("equation: string")
572 .Attr("N: int >= 1")
573 .Attr("T: type")
574 .SetShapeFn(shape_inference::EinsumShape);
575
576 // Deprecated op registrations:
577
578 // Can be deleted after 3feb2017.
579 REGISTER_OP("BatchSelfAdjointEig")
580 .Input("input: T")
581 .Output("output: T")
582 .Attr("T: {double, float}")
583 .Deprecated(11, "Use SelfAdjointEigV2 instead.")
584 .SetShapeFn(shape_inference::UnknownShape);
585
586 // Can all be deleted after 9mar2017.
587 REGISTER_OP("BatchMatrixDeterminant")
588 .Input("input: T")
589 .Output("output: T")
590 .Attr("T: {float, double, complex64, complex128}")
591 .Deprecated(13, "Use MatrixDeterminant instead.")
592 .SetShapeFn(shape_inference::UnknownShape);
593
594 REGISTER_OP("BatchMatrixInverse")
595 .Input("input: T")
596 .Output("output: T")
597 .Attr("adjoint: bool = False")
598 .Attr("T: {double, float}")
599 .Deprecated(13, "Use MatrixInverse instead.")
600 .SetShapeFn(shape_inference::UnknownShape);
601
602 REGISTER_OP("BatchCholesky")
603 .Input("input: T")
604 .Output("output: T")
605 .Attr("T: {double, float}")
606 .Deprecated(13, "Use Cholesky instead.")
607 .SetShapeFn(shape_inference::UnknownShape);
608
609 REGISTER_OP("BatchCholeskyGrad")
610 .Input("l: T")
611 .Input("grad: T")
612 .Output("output: T")
613 .Attr("T: {float, double}")
614 .Deprecated(13, "Use CholeskyGrad instead.")
615 .SetShapeFn(shape_inference::UnknownShape);
616
617 REGISTER_OP("BatchSelfAdjointEigV2")
618 .Input("input: T")
619 .Output("e: T")
620 .Output("v: T")
621 .Attr("compute_v: bool = True")
622 .Attr("T: {double, float}")
623 .Deprecated(13, "Use SelfAdjointEigV2 instead.")
624 .SetShapeFn(shape_inference::UnknownShape);
625
626 REGISTER_OP("BatchMatrixSolve")
627 .Input("matrix: T")
628 .Input("rhs: T")
629 .Output("output: T")
630 .Attr("adjoint: bool = False")
631 .Attr("T: {double, float}")
632 .Deprecated(13, "Use MatrixSolve instead.")
633 .SetShapeFn(shape_inference::UnknownShape);
634
635 REGISTER_OP("BatchMatrixTriangularSolve")
636 .Input("matrix: T")
637 .Input("rhs: T")
638 .Output("output: T")
639 .Attr("lower: bool = True")
640 .Attr("adjoint: bool = False")
641 .Attr("T: {double, float}")
642 .Deprecated(13, "Use MatrixTriangularSolve instead.")
643 .SetShapeFn(shape_inference::UnknownShape);
644
645 REGISTER_OP("BatchMatrixSolveLs")
646 .Input("matrix: T")
647 .Input("rhs: T")
648 .Input("l2_regularizer: double")
649 .Output("output: T")
650 .Attr("T: {double, float}")
651 .Attr("fast: bool = True")
652 .Deprecated(13, "Use MatrixSolveLs instead.")
653 .SetShapeFn(shape_inference::UnknownShape);
654
655 REGISTER_OP("BatchSvd")
656 .Input("input: T")
657 .Output("s: T")
658 .Output("u: T")
659 .Output("v: T")
660 .Attr("compute_uv: bool = True")
661 .Attr("full_matrices: bool = False")
662 .Attr("T: {double, float, complex64, complex128}")
663 .Deprecated(13, "Use Svd instead.")
664 .SetShapeFn(shape_inference::UnknownShape);
665
666 } // namespace tensorflow
667