1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/compiler/xla/client/lib/svd.h"
16
17 #include <memory>
18 #include <numeric>
19 #include <utility>
20 #include <vector>
21
22 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
23 #include "tensorflow/compiler/xla/client/lib/comparators.h"
24 #include "tensorflow/compiler/xla/client/lib/constants.h"
25 #include "tensorflow/compiler/xla/client/lib/loops.h"
26 #include "tensorflow/compiler/xla/client/lib/math.h"
27 #include "tensorflow/compiler/xla/client/lib/matrix.h"
28 #include "tensorflow/compiler/xla/client/lib/slicing.h"
29 #include "tensorflow/compiler/xla/client/xla_builder.h"
30 #include "tensorflow/compiler/xla/literal_util.h"
31 #include "tensorflow/compiler/xla/shape_util.h"
32 #include "tensorflow/compiler/xla/status_macros.h"
33 #include "tensorflow/compiler/xla/statusor.h"
34 #include "tensorflow/core/lib/core/errors.h"
35
36 namespace xla {
37
38 namespace {
39
40 // Given a matrix A, define H,
41 // H = A * (I - beta * v_T * v) if v is a row vector, or
42 // H = (I - beta * v * v_T) if v is column vector.
43 // A * H or H * A zeros out trailing part of some row or column of A.
44 //
45 // [x0, ..., x_{k-1}, xk, x_{k+1}, ..., x_{n-1}] * H
46 // = [x0, ..., x_{k-1}, xnorm, 0, ..., 0]
47 //
48 // Here xnorm = norm([x_k, x_{k+1}, ..., x_{n - 1}])
49 struct HouseHolderResult {
50 XlaOp v;
51 XlaOp beta;
52 XlaOp a;
53 };
54
55 // Jacobi rotation (also known as Givens rotation):
56 // G = [[ c, s],
57 // [-s, c]]
58 // matmul(G_T, G) = I
59 struct JacobiRotation {
60 XlaOp c; // cosine.
61 XlaOp s; // sine.
62 };
63
64 // JacobiUpdate holds the intermediate orthogonal matrix, Jacobi-rotated matrix.
65 struct JacobiUpdate {
66 XlaOp v;
67 XlaOp w;
68 };
69
70 // OneSidedJacobiRotation holds the left and right Jacobi rotations. Refer to
71 // GetOneSidedJacobiRotation for the effect of applying OneSidedJacobiRotation
72 // to a matrix.
73 struct OneSidedJacobiRotation {
74 JacobiRotation rot_l;
75 JacobiRotation rot_r;
76 };
77
78 struct FrobeniusNorms {
79 XlaOp off_diagonal_norm;
80 XlaOp total_norm;
81 };
82
83 // Householder reflection on the trailing elements of a vector.
84 //
85 // H = I - beta * [1, v]' * [1, v]
86 //
87 // H * x = [..., xnorm, 0, ..., 0]
88 // ..., j, j + 1, ..., n
89 //
90 // def house(x, j, eps):
91 // sigma = np.linalg.norm(x[(j + 1):])
92 // v = np.zeros_like(x)
93 // v[(j + 1):] = x[(j + 1):]
94 // if sigma < eps:
95 // beta = 0
96 // else:
97 // mu = sigma * np.sqrt((x[j]/sigma)**2 + 1)
98 // if x[j] <= 0:
99 // v[j] = x[j] - mu
100 // else:
101 // v[j] = -sigma / (x[j] + mu) * sigma
102 // beta = 2 / ((sigma / v[j])**2 + 1)
103 // v = v / v[j]
104 // v[j] = 1
105 // return v, beta
106 //
107 // Householder reflection on the trailing elements of a row of a matrix. After
108 // applying it on the matrix, all elements in [i, (j+1):] become zeros, i.e.,
109 //
110 // H = I - beta * [1, v]' * [1, v], then,
111 //
112 // A[i, j:] * H = [sigma, 0, 0, ..., 0]
113 //
HouseRow(XlaOp a,XlaOp i,XlaOp j,XlaOp eps,PrecisionConfig::Precision precision)114 StatusOr<HouseHolderResult> HouseRow(XlaOp a, XlaOp i, XlaOp j, XlaOp eps,
115 PrecisionConfig::Precision precision) {
116 XlaBuilder* builder = a.builder();
117 TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
118 const int64 num_dims = a_shape.rank();
119 const int64 n = ShapeUtil::GetDimension(a_shape, -1);
120 XlaOp zero = ScalarLike(i, 0);
121 XlaOp x = DynamicSliceInMinorDims(a, {i, zero}, {1, n});
122
123 const int64 num_batch_dims = num_dims - 2;
124 std::vector<int64> batch_dims(num_batch_dims);
125 for (int k = 0; k < num_batch_dims; ++k) {
126 batch_dims[k] = ShapeUtil::GetDimension(a_shape, k);
127 }
128
129 TF_ASSIGN_OR_RETURN(Shape x_shape, builder->GetShape(x));
130 auto idx = Iota(builder, ShapeUtil::MakeShape(S32, x_shape.dimensions()),
131 num_dims - 1);
132 auto zeros = ZerosLike(x);
133 auto v = Select(Gt(idx, j), x, zeros);
134
135 auto one = ScalarLike(v, 1.0);
136
137 auto sigma =
138 Sqrt(Reduce(Square(v), ScalarLike(v, 0.0),
139 CreateScalarAddComputation(x_shape.element_type(), builder),
140 {num_dims - 1}));
141
142 std::vector<int64> broadcast_dims(num_dims - 1);
143 std::iota(broadcast_dims.begin(), broadcast_dims.end(), 0);
144 auto x_0j = DynamicSliceInMinorDims(x, {zero, j}, {1, 1});
145 auto mu = Mul(sigma, Sqrt(Square(Div(x_0j, sigma, broadcast_dims)) + one),
146 broadcast_dims);
147
148 auto v_0j = Select(
149 Le(x_0j, ScalarLike(x_0j, 0.0)), Sub(x_0j, mu),
150 -Mul(sigma, Div(sigma, Add(x_0j, mu), broadcast_dims), broadcast_dims));
151
152 auto beta = Div(ScalarLike(v_0j, 2.0),
153 (Square(Div(sigma, v_0j, broadcast_dims)) + one));
154
155 v = Select(
156 BroadcastInDim(Lt(sigma, eps), x_shape.dimensions(), broadcast_dims), v,
157 v / v_0j);
158 v = Select(Eq(idx, j), zeros + one, v);
159
160 beta = Select(Lt(Add(sigma, ZerosLike(beta), broadcast_dims), eps),
161 ZerosLike(beta), beta);
162
163 HouseHolderResult result;
164 result.v = v;
165 result.beta = beta;
166 result.a =
167 Sub(a, Mul(beta, BatchDot(BatchDot(a, TransposeInMinorDims(v), precision),
168 v, precision)));
169
170 return result;
171 }
172
173 // Householder reflection on the trailing elements of a col of a matrix. After
174 // applying it on the matrix, all elements in [(i+1):, j] become zeros, i.e.,
175 //
176 // H = I - beta * [1; v] * [1; v]', then,
177 //
178 // H * A[i:, j] = [xnorm, 0, 0, ..., 0]
179 //
HouseCol(XlaOp a,XlaOp i,XlaOp j,XlaOp eps,PrecisionConfig::Precision precision)180 StatusOr<HouseHolderResult> HouseCol(XlaOp a, XlaOp i, XlaOp j, XlaOp eps,
181 PrecisionConfig::Precision precision) {
182 XlaBuilder* builder = a.builder();
183 TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
184 const int64 num_dims = a_shape.rank();
185 const int64 m = ShapeUtil::GetDimension(a_shape, -2);
186 XlaOp zero = ScalarLike(i, 0);
187 XlaOp x = DynamicSliceInMinorDims(a, {zero, j}, {m, 1});
188
189 const int64 num_batch_dims = num_dims - 2;
190 std::vector<int64> batch_dims(num_batch_dims);
191 for (int k = 0; k < num_batch_dims; ++k) {
192 batch_dims[k] = ShapeUtil::GetDimension(a_shape, k);
193 }
194
195 TF_ASSIGN_OR_RETURN(Shape x_shape, builder->GetShape(x));
196 auto idx = Iota(builder, ShapeUtil::MakeShape(S32, x_shape.dimensions()),
197 num_dims - 2);
198 auto zeros = ZerosLike(x);
199 auto v = Select(Gt(idx, i), x, zeros);
200
201 auto one = ScalarLike(v, 1.0);
202
203 auto sigma =
204 Sqrt(Reduce(Square(v), ScalarLike(v, 0.0),
205 CreateScalarAddComputation(x_shape.element_type(), builder),
206 {num_dims - 2}));
207
208 std::vector<int64> broadcast_dims(num_dims - 1);
209 std::iota(broadcast_dims.begin(), broadcast_dims.end(), 0);
210 broadcast_dims[num_dims - 2] = num_dims - 1;
211 auto x_0i = DynamicSliceInMinorDims(x, {i, zero}, {1, 1});
212 auto mu = Mul(sigma, Sqrt(Square(Div(x_0i, sigma, broadcast_dims)) + one),
213 broadcast_dims);
214
215 auto v_0i = Select(
216 Le(x_0i, ScalarLike(x_0i, 0.0)), Sub(x_0i, mu),
217 -Mul(sigma, Div(sigma, Add(x_0i, mu), broadcast_dims), broadcast_dims));
218
219 auto beta = Div(ScalarLike(v_0i, 2.0),
220 (Square(Div(sigma, v_0i, broadcast_dims)) + one));
221
222 v = Select(
223 BroadcastInDim(Lt(sigma, eps), x_shape.dimensions(), broadcast_dims), v,
224 v / v_0i);
225 v = Select(Eq(idx, i), zeros + one, v);
226
227 beta = Select(Lt(Add(sigma, ZerosLike(beta), broadcast_dims), eps),
228 ZerosLike(beta), beta);
229
230 HouseHolderResult result;
231 result.v = v;
232 result.beta = beta;
233 result.a = Sub(
234 a, Mul(beta, BatchDot(v, BatchDot(TransposeInMinorDims(v), a, precision),
235 precision)));
236
237 return result;
238 }
239
240 // Apply column and row householder reflections for bidiagonalization.
241 //
242 // def house_bidiag(A):
243 // xz, yz = A.shape
244 // LL = np.eye(xz)
245 // RR = np.eye(yz)
246 // for i in range(yz - 1):
247 // v, beta = house_col(A, i, i, 1e-8)
248 // L = np.eye(xz) - beta * np.outer(v, v)
249 // LL = np.matmul(LL, L)
250 // A = np.matmul(L, A)
251 // if i < yz - 2:
252 // v, beta = house_row(A, i, i + 1, 1e-8)
253 // R = np.eye(yz) - beta * np.outer(v, v)
254 // RR = np.matmul(RR, R)
255 // A = np.matmul(A, R)
256 // return LL, A, RR
257 //
HouseHolderBidiagonalization(XlaOp a,XlaOp eps,PrecisionConfig::Precision precision)258 StatusOr<SVDResult> HouseHolderBidiagonalization(
259 XlaOp a, XlaOp eps, PrecisionConfig::Precision precision) {
260 XlaBuilder* builder = a.builder();
261 TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
262 const int64 num_dims = a_shape.rank();
263 const int64 num_batch_dims = num_dims - 2;
264 std::vector<int64> batch_dims(num_batch_dims);
265 for (int i = 0; i < num_batch_dims; ++i) {
266 batch_dims[i] = ShapeUtil::GetDimension(a_shape, i);
267 }
268 const int64 m = ShapeUtil::GetDimension(a_shape, -2);
269 const int64 n = ShapeUtil::GetDimension(a_shape, -1);
270 XlaOp u_init = Broadcast(
271 IdentityMatrix(builder, a_shape.element_type(), m, m), batch_dims);
272 XlaOp v_init = Broadcast(
273 IdentityMatrix(builder, a_shape.element_type(), n, n), batch_dims);
274
275 auto while_cond_fn = [&](absl::Span<const XlaOp> values,
276 XlaBuilder* cond_builder) -> StatusOr<XlaOp> {
277 auto i = values[0];
278 return Lt(i, ScalarLike(i, n - 2));
279 };
280 auto while_body_fn =
281 [&](absl::Span<const XlaOp> values,
282 XlaBuilder* body_builder) -> StatusOr<std::vector<XlaOp>> {
283 auto i = values[0];
284 auto one = ScalarLike(i, 1);
285
286 auto u = values[1];
287 auto v = values[2];
288 auto a = values[3];
289 auto eps = values[4];
290
291 TF_ASSIGN_OR_RETURN(HouseHolderResult house_col,
292 HouseCol(a, i, i, eps, precision));
293 u = Sub(u, Mul(house_col.beta,
294 BatchDot(BatchDot(u, house_col.v, precision),
295 TransposeInMinorDims(house_col.v), precision)));
296 a = house_col.a;
297
298 TF_ASSIGN_OR_RETURN(HouseHolderResult house_row,
299 HouseRow(a, i, i + one, eps, precision));
300 v = Sub(
301 v,
302 Mul(house_row.beta,
303 BatchDot(BatchDot(v, TransposeInMinorDims(house_row.v), precision),
304 house_row.v, precision)));
305 a = house_row.a;
306
307 std::vector<XlaOp> updated_values;
308 updated_values.reserve(values.size());
309
310 updated_values.push_back(i + one);
311 updated_values.push_back(u);
312 updated_values.push_back(v);
313 updated_values.push_back(a);
314 updated_values.push_back(eps);
315 return updated_values;
316 };
317
318 std::vector<XlaOp> values(5);
319 values[0] = Zero(builder, S32);
320 values[1] = u_init;
321 values[2] = v_init;
322 values[3] = a;
323 values[4] = eps;
324
325 TF_ASSIGN_OR_RETURN(values,
326 WhileLoopHelper(while_cond_fn, while_body_fn, values,
327 "HouseHolderBidiagonalization", builder));
328
329 for (int k = 2; k > 0; --k) {
330 if (n - k >= 0) {
331 XlaOp index = ScalarLike(values[0], n - k);
332 TF_ASSIGN_OR_RETURN(HouseHolderResult house_col,
333 HouseCol(values[3], index, index, eps, precision));
334 values[1] =
335 Sub(values[1],
336 Mul(house_col.beta,
337 BatchDot(BatchDot(values[1], house_col.v, precision),
338 TransposeInMinorDims(house_col.v), precision)));
339 values[3] = house_col.a;
340 }
341 }
342
343 SVDResult result;
344 result.u = values[1];
345 result.v = values[2];
346 result.d = values[3];
347 return result;
348 }
349
350 // MakeJacobi computes a rotation matrix G = [[c, s], [-s, c]], such that
351 // G_T * [[ps, pqs], [pqs, qs]] * G
352 // is diagonalized.
353 //
354 // def make_jacobi(ps, qs, pqs, eps):
355 // if np.abs(a_pq) > eps:
356 // tau = (a_qq - a_pp) / (2 * a_pq)
357 // if tau >= 0:
358 // t = 1.0 / (tau + np.sqrt(1 + tau ** 2))
359 // else:
360 // t = -1.0 / (-tau + np.sqrt(1 + tau ** 2))
361 // c = 1.0 / np.sqrt(1.0 + t ** 2)
362 // s = t * c
363 // else:
364 // c = 1.0
365 // s = 0.0
366 // return c, s
367 //
MakeJacobi(XlaOp ps,XlaOp qs,XlaOp pqs,XlaOp eps)368 StatusOr<JacobiRotation> MakeJacobi(XlaOp ps, XlaOp qs, XlaOp pqs, XlaOp eps) {
369 auto zero = ScalarLike(ps, 0.0);
370 auto one = ScalarLike(ps, 1.0);
371 auto two = ScalarLike(ps, 2.0);
372
373 auto tau = (qs - ps) / (pqs * two);
374 auto t_pos = one / (tau + Sqrt(one + Square(tau)));
375 auto t_neg = -one / (-tau + Sqrt(one + Square(tau)));
376 auto t = Select(Ge(tau, zero), t_pos, t_neg);
377
378 auto c_temp = Rsqrt(one + Square(t));
379 auto s_temp = t * c_temp;
380
381 auto c = Select(Ge(Abs(pqs), eps), c_temp, ZerosLike(c_temp) + one);
382 auto s = Select(Ge(Abs(pqs), eps), s_temp, ZerosLike(s_temp));
383 // Renormalize c and s to compensate for low precision arithmetic, this step
384 // is redundant if high precision float is used, like float64.
385 auto rnorm = Rsqrt(Square(c) + Square(s));
386
387 JacobiRotation rot;
388
389 rot.c = c * rnorm;
390 rot.s = s * rnorm;
391
392 return rot;
393 }
394
395 // One sided Jacobi rotations. For a matrix,
396 // [a_pp, a_pq]
397 // [a_qp, a_qq]
398 // After applying Jacobi rotations on both sides, the matrix is diagonalized.
399 // [b_pp, 0]
400 // [0, b_qq]
401 //
402 // def jacobi_rot(a, p, q, eps):
403 // t = a[p, p] + a[q, q]
404 // d = a[q, p] - a[p, q]
405 //
406 // if np.abs(d) < eps:
407 // s = 0.0
408 // c = 1.0
409 // else:
410 // u = t / d
411 // tmp = np.sqrt(1.0 + u**2)
412 // s = -1.0 / tmp
413 // c = u / tmp
414 //
415 // rot = np.array([[c, s], [-s, c]])
416 // m_tmp = rot.T @ a[[p, q], [p, q]]
417 // c_r, s_r = make_jacobi(m_tmp[0, 0], m_tmp[1, 1], m_tmp[0, 1])
418 // rot_r = np.array([[c_r, s_r], [-s_r, c_r]])
419 // rot_l = rot @ rot_r
420 // return rot_l, rot_r
421 //
GetOneSidedJacobiRotation(XlaOp a,XlaOp p,XlaOp q,XlaOp eps)422 StatusOr<OneSidedJacobiRotation> GetOneSidedJacobiRotation(XlaOp a, XlaOp p,
423 XlaOp q, XlaOp eps) {
424 XlaOp a_pp = DynamicSliceInMinorDims(a, {p, p}, {1, 1});
425 XlaOp a_pq = DynamicSliceInMinorDims(a, {p, q}, {1, 1});
426 XlaOp a_qp = DynamicSliceInMinorDims(a, {q, p}, {1, 1});
427 XlaOp a_qq = DynamicSliceInMinorDims(a, {q, q}, {1, 1});
428
429 XlaOp one = ScalarLike(a, 1.0);
430
431 XlaOp t = a_pp + a_qq;
432 XlaOp d = a_qp - a_pq;
433
434 XlaOp u = Div(t, d);
435 XlaOp tmp = Rsqrt(one + Square(u));
436
437 JacobiRotation rot;
438
439 XlaOp zeros = ZerosLike(tmp);
440 XlaOp ones = zeros + one;
441
442 rot.s = Select(Lt(Abs(d), eps), zeros, -tmp);
443 rot.c = Select(Lt(Abs(d), eps), ones, Mul(u, tmp));
444
445 XlaOp a_pp_new = rot.c * a_pp - rot.s * a_qp;
446 XlaOp a_pq_new = rot.c * a_pq - rot.s * a_qq;
447 XlaOp a_qq_new = rot.s * a_pq + rot.c * a_qq;
448
449 OneSidedJacobiRotation rots;
450 TF_ASSIGN_OR_RETURN(rots.rot_r,
451 MakeJacobi(a_pp_new, a_qq_new, a_pq_new, eps));
452
453 rots.rot_l.c = rot.c * rots.rot_r.c - rot.s * rots.rot_r.s;
454 rots.rot_l.s = rot.s * rots.rot_r.c + rot.c * rots.rot_r.s;
455
456 return rots;
457 }
458
459 // Apply one-sided Jacobi on elements at indices pp, pq, qp, qq.
OneSidedJacobiUpdate(SVDResult svd_result,XlaOp p,XlaOp q,XlaOp eps)460 StatusOr<SVDResult> OneSidedJacobiUpdate(SVDResult svd_result, XlaOp p, XlaOp q,
461 XlaOp eps) {
462 XlaOp u = svd_result.u;
463 XlaOp v = svd_result.v;
464 XlaOp d = svd_result.d;
465 XlaBuilder* builder = d.builder();
466 TF_ASSIGN_OR_RETURN(Shape d_shape, builder->GetShape(d));
467 const int64 num_dims = d_shape.rank();
468 const int64 num_batch_dims = num_dims - 2;
469 std::vector<int64> batch_dims(num_batch_dims);
470 for (int i = 0; i < num_batch_dims; ++i) {
471 batch_dims[i] = ShapeUtil::GetDimension(d_shape, i);
472 }
473 const int64 m = ShapeUtil::GetDimension(d_shape, -2);
474 const int64 n = ShapeUtil::GetDimension(d_shape, -1);
475
476 TF_ASSIGN_OR_RETURN(OneSidedJacobiRotation onesided_jacobi,
477 GetOneSidedJacobiRotation(d, p, q, eps));
478
479 auto zero = ScalarLike(p, 0);
480
481 // Zero out a_{pq} explicitly.
482 std::vector<int64> pq_dims(batch_dims.begin(), batch_dims.end());
483 pq_dims.push_back(1);
484 pq_dims.push_back(1);
485 auto pq_zero = ScalarLike(d, 0.0);
486 auto pq_zeros = Broadcast(pq_zero, pq_dims);
487
488 std::vector<int64> broadcast_dims(batch_dims.size());
489 std::iota(broadcast_dims.begin(), broadcast_dims.end(), 0);
490 broadcast_dims.push_back(num_dims - 1);
491
492 // Apply Jacobi Rotation on the left.
493 auto slice_p = DynamicSliceInMinorDims(d, {p, zero}, {1, n});
494 auto slice_q = DynamicSliceInMinorDims(d, {q, zero}, {1, n});
495 auto slice_p_new =
496 onesided_jacobi.rot_l.c * slice_p - onesided_jacobi.rot_l.s * slice_q;
497 auto slice_q_new =
498 onesided_jacobi.rot_l.s * slice_p + onesided_jacobi.rot_l.c * slice_q;
499 d = DynamicUpdateSliceInMinorDims(d, slice_p_new, {p, zero});
500 d = DynamicUpdateSliceInMinorDims(d, slice_q_new, {q, zero});
501
502 // Apply Jacobi Rotation on the right.
503 slice_p = DynamicSliceInMinorDims(d, {zero, p}, {m, 1});
504 slice_q = DynamicSliceInMinorDims(d, {zero, q}, {m, 1});
505 slice_p_new =
506 onesided_jacobi.rot_r.c * slice_p - onesided_jacobi.rot_r.s * slice_q;
507 slice_q_new =
508 onesided_jacobi.rot_r.s * slice_p + onesided_jacobi.rot_r.c * slice_q;
509 d = DynamicUpdateSliceInMinorDims(d, slice_p_new, {zero, p});
510 d = DynamicUpdateSliceInMinorDims(d, slice_q_new, {zero, q});
511
512 d = DynamicUpdateSliceInMinorDims(d, pq_zeros, {p, q});
513 d = DynamicUpdateSliceInMinorDims(d, pq_zeros, {q, p});
514
515 // Apply left Jacobi Rotation on U.
516 slice_p = DynamicSliceInMinorDims(u, {zero, p}, {m, 1});
517 slice_q = DynamicSliceInMinorDims(u, {zero, q}, {m, 1});
518 slice_p_new =
519 onesided_jacobi.rot_l.c * slice_p - onesided_jacobi.rot_l.s * slice_q;
520
521 slice_p_new = Mul(
522 slice_p_new,
523 Rsqrt(Reduce(Square(slice_p_new), pq_zero,
524 CreateScalarAddComputation(d_shape.element_type(), builder),
525 {num_dims - 2})),
526 broadcast_dims);
527
528 slice_q_new =
529 onesided_jacobi.rot_l.s * slice_p + onesided_jacobi.rot_l.c * slice_q;
530
531 slice_q_new = Mul(
532 slice_q_new,
533 Rsqrt(Reduce(Square(slice_q_new), pq_zero,
534 CreateScalarAddComputation(d_shape.element_type(), builder),
535 {num_dims - 2})),
536 broadcast_dims);
537
538 u = DynamicUpdateSliceInMinorDims(u, slice_p_new, {zero, p});
539 u = DynamicUpdateSliceInMinorDims(u, slice_q_new, {zero, q});
540
541 // Apply right Jacobi Rotation on V.
542 slice_p = DynamicSliceInMinorDims(v, {zero, p}, {n, 1});
543 slice_q = DynamicSliceInMinorDims(v, {zero, q}, {n, 1});
544 slice_p_new =
545 onesided_jacobi.rot_r.c * slice_p - onesided_jacobi.rot_r.s * slice_q;
546
547 slice_p_new = Mul(
548 slice_p_new,
549 Rsqrt(Reduce(Square(slice_p_new), pq_zero,
550 CreateScalarAddComputation(d_shape.element_type(), builder),
551 {num_dims - 2})),
552 broadcast_dims);
553
554 slice_q_new =
555 onesided_jacobi.rot_r.s * slice_p + onesided_jacobi.rot_r.c * slice_q;
556
557 slice_q_new = Mul(
558 slice_q_new,
559 Rsqrt(Reduce(Square(slice_q_new), pq_zero,
560 CreateScalarAddComputation(d_shape.element_type(), builder),
561 {num_dims - 2})),
562 broadcast_dims);
563
564 v = DynamicUpdateSliceInMinorDims(v, slice_p_new, {zero, p});
565 v = DynamicUpdateSliceInMinorDims(v, slice_q_new, {zero, q});
566
567 svd_result.d = d;
568 svd_result.u = u;
569 svd_result.v = v;
570
571 return svd_result;
572 }
573
ComputeFrobeniusNorms(XlaOp w)574 StatusOr<FrobeniusNorms> ComputeFrobeniusNorms(XlaOp w) {
575 XlaBuilder* builder = w.builder();
576 TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(w));
577 const int64 num_dims = shape.rank();
578 auto frobenius_norm =
579 Sqrt(Reduce(Square(w), ScalarLike(w, 0.0),
580 CreateScalarAddComputation(shape.element_type(), builder),
581 {num_dims - 2, num_dims - 1}));
582 auto diag = GetMatrixDiagonal(w);
583 auto diag_square =
584 Reduce(Square(diag), ScalarLike(w, 0.0),
585 CreateScalarAddComputation(shape.element_type(), builder),
586 {num_dims - 2});
587
588 FrobeniusNorms frobenius_norms;
589
590 frobenius_norms.off_diagonal_norm =
591 Sqrt(Max(Square(frobenius_norm) - diag_square, ScalarLike(w, 0.0)));
592 frobenius_norms.total_norm = frobenius_norm;
593
594 return frobenius_norms;
595 }
596
597 // Main boby of One-sided Jacobi Method.
WhileLoopFn(absl::Span<const XlaOp> initial_values,int matrix_dimension,int max_sweep_updates,absl::string_view name,XlaBuilder * builder)598 StatusOr<std::vector<XlaOp>> WhileLoopFn(
599 absl::Span<const XlaOp> initial_values, //
600 int matrix_dimension, //
601 int max_sweep_updates, //
602 absl::string_view name, //
603 XlaBuilder* builder) {
604 auto while_cond_fn = [&](absl::Span<const XlaOp> values,
605 XlaBuilder* cond_builder) -> StatusOr<XlaOp> {
606 auto k = values[0];
607 auto max_sweeps = ScalarLike(k, max_sweep_updates);
608 auto sweep_update_cond = Gt(max_sweeps, k);
609
610 auto norms = ComputeFrobeniusNorms(values[3]).ValueOrDie();
611 auto tol = norms.total_norm * values[4];
612 auto tol_cond = ReduceAll(Lt(tol, norms.off_diagonal_norm),
613 xla::ConstantR0<bool>(cond_builder, false),
614 CreateScalarOrComputation(PRED, cond_builder));
615
616 return And(sweep_update_cond, tol_cond);
617 };
618
619 auto while_body_fn =
620 [&](absl::Span<const XlaOp> values,
621 XlaBuilder* body_builder) -> StatusOr<std::vector<XlaOp>> {
622 auto while_cond_fn_inner =
623 [&](absl::Span<const XlaOp> values_inner,
624 XlaBuilder* inner_cond_builder) -> StatusOr<XlaOp> {
625 auto p = values_inner[0];
626 return Lt(p, ScalarLike(p, matrix_dimension - 1));
627 };
628
629 auto while_body_fn_inner =
630 [&](absl::Span<const XlaOp> values_inner,
631 XlaBuilder* inner_body_builder) -> StatusOr<std::vector<XlaOp>> {
632 auto while_cond_fn_innermost =
633 [&](absl::Span<const XlaOp> values_innermost,
634 XlaBuilder* innermost_cond_builder) -> StatusOr<XlaOp> {
635 auto q = values_innermost[1];
636 return Lt(q, ScalarLike(q, matrix_dimension));
637 };
638 auto while_body_fn_innermost =
639 [&](absl::Span<const XlaOp> values_innermost,
640 XlaBuilder* innermost_body_builder)
641 -> StatusOr<std::vector<XlaOp>> {
642 auto p = values_innermost[0];
643 auto q = values_innermost[1];
644
645 SVDResult onesided_jacobi_update;
646 onesided_jacobi_update.u = values_innermost[2];
647 onesided_jacobi_update.v = values_innermost[3];
648 onesided_jacobi_update.d = values_innermost[4];
649
650 auto eps = values_innermost[5];
651
652 TF_ASSIGN_OR_RETURN(
653 onesided_jacobi_update,
654 OneSidedJacobiUpdate(onesided_jacobi_update, p, q, eps));
655
656 std::vector<XlaOp> updated_values_innermost;
657 updated_values_innermost.reserve(values_innermost.size());
658
659 updated_values_innermost.push_back(p);
660 updated_values_innermost.push_back(q + ScalarLike(q, 1));
661 updated_values_innermost.push_back(onesided_jacobi_update.u);
662 updated_values_innermost.push_back(onesided_jacobi_update.v);
663 updated_values_innermost.push_back(onesided_jacobi_update.d);
664 updated_values_innermost.push_back(eps);
665
666 return updated_values_innermost;
667 };
668
669 std::vector<XlaOp> values_innermost(6);
670 auto p = values_inner[0];
671 auto q = p + ScalarLike(p, 1);
672 values_innermost[0] = p; // index p.
673 values_innermost[1] = q; // index q.
674 values_innermost[2] = values_inner[1]; // u.
675 values_innermost[3] = values_inner[2]; // v.
676 values_innermost[4] = values_inner[3]; // d.
677 values_innermost[5] = values_inner[4]; // eps.
678 TF_ASSIGN_OR_RETURN(
679 values_innermost,
680 WhileLoopHelper(while_cond_fn_innermost, while_body_fn_innermost,
681 values_innermost, absl::StrCat(name, "-Innermost"),
682 inner_body_builder));
683
684 std::vector<XlaOp> updated_values_inner;
685 updated_values_inner.reserve(values_inner.size());
686
687 updated_values_inner.push_back(p + ScalarLike(p, 1));
688 updated_values_inner.push_back(values_innermost[2]);
689 updated_values_inner.push_back(values_innermost[3]);
690 updated_values_inner.push_back(values_innermost[4]);
691 updated_values_inner.push_back(values_innermost[5]);
692 return updated_values_inner;
693 };
694 // Indexes.
695 XlaOp k = values[0];
696
697 std::vector<XlaOp> values_inner(5);
698 values_inner[0] = ScalarLike(k, 0); // index p.
699 values_inner[1] = values[1]; // u.
700 values_inner[2] = values[2]; // v.
701 values_inner[3] = values[3]; // d.
702 values_inner[4] = values[4]; // eps.
703 TF_ASSIGN_OR_RETURN(
704 values_inner,
705 WhileLoopHelper(while_cond_fn_inner, while_body_fn_inner, values_inner,
706 absl::StrCat(name, "-Inner"), body_builder));
707
708 std::vector<XlaOp> updated_values;
709 updated_values.reserve(values_inner.size());
710
711 updated_values.push_back(k + ScalarLike(k, 1));
712 updated_values.push_back(values_inner[1]);
713 updated_values.push_back(values_inner[2]);
714 updated_values.push_back(values_inner[3]);
715 updated_values.push_back(values_inner[4]);
716
717 return updated_values;
718 };
719 std::vector<XlaOp> values;
720 TF_ASSIGN_OR_RETURN(values, WhileLoopHelper(while_cond_fn, while_body_fn,
721 initial_values, name, builder));
722
723 return values;
724 }
725
726 // Sort singular values in decending order, and make sure they are non-negative
727 // by flipping the signs of negative diagonal values and transferring the signs
728 // to V. And for numeric stability, renormalize U and V.
SortBySingularValuesAndPostProcessing(SVDResult result)729 StatusOr<SVDResult> SortBySingularValuesAndPostProcessing(SVDResult result) {
730 XlaBuilder* builder = result.d.builder();
731 TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(result.d));
732 const int64 num_dims = shape.rank();
733 auto dimensions = shape.dimensions();
734 const int64 m = ShapeUtil::GetDimension(shape, -2);
735 const int64 n = ShapeUtil::GetDimension(shape, -1);
736
737 std::vector<int64> broadcast_dims(num_dims - 1);
738 std::iota(broadcast_dims.begin(), broadcast_dims.end(), 0);
739 broadcast_dims[num_dims - 2] = num_dims - 1;
740
741 auto d = GetMatrixDiagonal(result.d);
742
743 auto zeros = ZerosLike(d);
744 auto one = ScalarLike(d, 1.0);
745
746 // Make all the singular values to be non-negative by transferring the signs
747 // to V.
748 auto sign = Select(Ge(d, zeros), zeros + one, zeros - one);
749 d = Select(Ge(d, zeros), d, -d);
750 result.v = Mul(result.v, sign, broadcast_dims);
751
752 d = BroadcastInDim(d, dimensions, broadcast_dims);
753
754 // As m >= n, only first m columns vectors are needed to be permuted, and the
755 // rest of m - n vectors are appended after the sorting is done.
756 XlaOp sort_u_result =
757 Sort({-d, SliceInMinorDims(result.u, {0, 0}, {m, n})},
758 CreateScalarLtComputation(
759 {shape.element_type(), shape.element_type()}, builder),
760 num_dims - 1);
761
762 // TODO(kuny): using CreateScalarGtComputation after b/124862300 is fixed.
763 XlaOp sort_v_result =
764 Sort({SliceInMinorDims(-d, {0, 0}, {n, n}), result.v},
765 CreateScalarLtComputation(
766 {shape.element_type(), shape.element_type()}, builder),
767 num_dims - 1);
768 // Make sure all the signular values are non-negative.
769 result.d = Max(-GetMatrixDiagonal(GetTupleElement(sort_v_result, 0)),
770 ScalarLike(d, 0.0));
771
772 result.v = GetTupleElement(sort_v_result, 1);
773 result.v = Mul(
774 result.v,
775 Rsqrt(Reduce(Square(result.v), ScalarLike(d, 0.0),
776 CreateScalarAddComputation(shape.element_type(), builder),
777 {num_dims - 2})),
778 broadcast_dims);
779
780 // Append the rest of m - n vectors.
781 result.u = ConcatInDim(builder,
782 {GetTupleElement(sort_u_result, 1),
783 SliceInMinorDims(result.u, {0, n}, {m, m})},
784 num_dims - 1);
785 result.u = Mul(
786 result.u,
787 Rsqrt(Reduce(Square(result.u), ScalarLike(d, 0.0),
788 CreateScalarAddComputation(shape.element_type(), builder),
789 {num_dims - 2})),
790 broadcast_dims);
791
792 return result;
793 }
794
795 } // namespace
796
797 // def jacobi_svd(A):
798 // U, D, V = house_bidiag(A)
799 // m, n = D.shape
800 // iter, max_iter = 0, 100
801 // frobenius_norm = np.linalg.norm(D)
802 // diag_norm = np.linalg.norm(np.diag(D))
803 // off_diag_norm = np.sqrt(
804 // frobenius_norm - diag_norm) * np.sqrt(frobenius_norm + diag_norm)
805 // while off_diag_norm > 1e-6 * frobenius_norm and iter < max_iter:
806 // iter += 1
807 // for p in range(m - 1):
808 // for q in range(p + 1, n):
809 // rot_l, rot_r = jacobi_rot(D[p][p], D[p][q], D[q][p], D[q][q])
810 // D[[p, q], :] = np.matmul(rot_l.T, D[[p, q], :])
811 // D[:, [p, q]] = np.matmul(D[:, [p, q]], rot_r)
812 // U[:, [p, q]] = np.matmul(U[:, [p, q]], rot_l)
813 // V[:, [p, q]] = np.matmul(V[:, [p, q]], rot_r)
814 // frobenius_norm = np.linalg.norm(D)
815 // diag_norm = np.linalg.norm(np.diag(D))
816 // off_diag_norm = np.sqrt(
817 // frobenius_norm - diag_norm) * np.sqrt(frobenius_norm + diag_norm)
818 //
819 // return U, np.diag(D), V
820 //
SVD(XlaOp a,int64 max_iter,float epsilon,PrecisionConfig::Precision precision)821 SVDResult SVD(XlaOp a, int64 max_iter, float epsilon,
822 PrecisionConfig::Precision precision) {
823 XlaBuilder* builder = a.builder();
824 auto return_error = [&](const Status& status) {
825 SVDResult result;
826 result.u = builder->ReportError(status);
827 result.v = builder->ReportError(status);
828 result.d = builder->ReportError(status);
829 return result;
830 };
831 auto shape_with_status = builder->GetShape(a);
832 if (!shape_with_status.status().ok()) {
833 return return_error(shape_with_status.status());
834 }
835 Shape a_shape = shape_with_status.ValueOrDie();
836 const int64 num_dims = a_shape.rank();
837 const int64 num_batch_dims = num_dims - 2;
838 std::vector<int64> batch_dims(num_batch_dims);
839 for (int i = 0; i < num_batch_dims; ++i) {
840 batch_dims[i] = ShapeUtil::GetDimension(a_shape, i);
841 }
842 int64 m = ShapeUtil::GetDimension(a_shape, -2);
843 int64 n = ShapeUtil::GetDimension(a_shape, -1);
844 bool maybe_transpose = m < n;
845
846 if (maybe_transpose) {
847 a = TransposeInMinorDims(a);
848 std::swap(m, n);
849 }
850
851 auto eps = ScalarLike(a, epsilon);
852
853 SVDResult svd_result =
854 HouseHolderBidiagonalization(a, eps, precision).ValueOrDie();
855
856 auto output_with_status = WhileLoopFn(
857 {
858 Zero(builder, S32), // k
859 svd_result.u, // u
860 svd_result.v, // v
861 svd_result.d, // d
862 eps, // epsilon
863 }, //
864 n, //
865 max_iter, //
866 "CyclicOneSidedJacobi", //
867 builder);
868 if (!output_with_status.status().ok()) {
869 return return_error(output_with_status.status());
870 }
871
872 auto output = output_with_status.ValueOrDie();
873
874 svd_result.u = output[1];
875 svd_result.v = output[2];
876 svd_result.d = output[3];
877 svd_result = SortBySingularValuesAndPostProcessing(svd_result).ValueOrDie();
878 if (maybe_transpose) {
879 std::swap(svd_result.u, svd_result.v);
880 }
881 return svd_result;
882 }
883
884 } // namespace xla
885