• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/compiler/xla/client/lib/svd.h"
16 
17 #include <memory>
18 #include <numeric>
19 #include <utility>
20 #include <vector>
21 
22 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
23 #include "tensorflow/compiler/xla/client/lib/comparators.h"
24 #include "tensorflow/compiler/xla/client/lib/constants.h"
25 #include "tensorflow/compiler/xla/client/lib/loops.h"
26 #include "tensorflow/compiler/xla/client/lib/math.h"
27 #include "tensorflow/compiler/xla/client/lib/matrix.h"
28 #include "tensorflow/compiler/xla/client/lib/slicing.h"
29 #include "tensorflow/compiler/xla/client/xla_builder.h"
30 #include "tensorflow/compiler/xla/literal_util.h"
31 #include "tensorflow/compiler/xla/shape_util.h"
32 #include "tensorflow/compiler/xla/status_macros.h"
33 #include "tensorflow/compiler/xla/statusor.h"
34 #include "tensorflow/core/lib/core/errors.h"
35 
36 namespace xla {
37 
38 namespace {
39 
40 // Given a matrix A, define H,
41 //   H = A * (I - beta * v_T * v) if v is a row vector, or
42 //   H = (I - beta * v * v_T) if v is column vector.
43 // A * H or H * A zeros out trailing part of some row or column of A.
44 //
45 // [x0, ..., x_{k-1}, xk, x_{k+1}, ..., x_{n-1}] * H
46 //       = [x0, ..., x_{k-1}, xnorm, 0, ..., 0]
47 //
48 // Here xnorm = norm([x_k, x_{k+1}, ..., x_{n - 1}])
49 struct HouseHolderResult {
50   XlaOp v;
51   XlaOp beta;
52   XlaOp a;
53 };
54 
55 // Jacobi rotation (also known as Givens rotation):
56 // G = [[ c, s],
57 //      [-s, c]]
58 // matmul(G_T, G) = I
59 struct JacobiRotation {
60   XlaOp c;  // cosine.
61   XlaOp s;  // sine.
62 };
63 
64 // JacobiUpdate holds the intermediate orthogonal matrix, Jacobi-rotated matrix.
65 struct JacobiUpdate {
66   XlaOp v;
67   XlaOp w;
68 };
69 
70 // OneSidedJacobiRotation holds the left and right Jacobi rotations. Refer to
71 // GetOneSidedJacobiRotation for the effect of applying OneSidedJacobiRotation
72 // to a matrix.
73 struct OneSidedJacobiRotation {
74   JacobiRotation rot_l;
75   JacobiRotation rot_r;
76 };
77 
78 struct FrobeniusNorms {
79   XlaOp off_diagonal_norm;
80   XlaOp total_norm;
81 };
82 
83 // Householder reflection on the trailing elements of a vector.
84 //
85 // H = I - beta * [1, v]' * [1, v]
86 //
87 // H * x = [..., xnorm, 0, ..., 0]
88 //          ..., j, j + 1, ..., n
89 //
90 // def house(x, j, eps):
91 //    sigma = np.linalg.norm(x[(j + 1):])
92 //    v = np.zeros_like(x)
93 //    v[(j + 1):] = x[(j + 1):]
94 //    if sigma < eps:
95 //        beta = 0
96 //    else:
97 //        mu = sigma * np.sqrt((x[j]/sigma)**2 + 1)
98 //        if x[j] <= 0:
99 //            v[j] = x[j] - mu
100 //        else:
101 //            v[j] = -sigma / (x[j] + mu) * sigma
102 //        beta = 2 / ((sigma / v[j])**2 + 1)
103 //        v = v / v[j]
104 //    v[j] = 1
105 //    return v, beta
106 //
107 // Householder reflection on the trailing elements of a row of a matrix. After
108 // applying it on the matrix, all elements in [i, (j+1):] become zeros, i.e.,
109 //
110 // H = I - beta * [1, v]' * [1, v], then,
111 //
112 // A[i, j:] * H = [sigma, 0, 0, ..., 0]
113 //
HouseRow(XlaOp a,XlaOp i,XlaOp j,XlaOp eps,PrecisionConfig::Precision precision)114 StatusOr<HouseHolderResult> HouseRow(XlaOp a, XlaOp i, XlaOp j, XlaOp eps,
115                                      PrecisionConfig::Precision precision) {
116   XlaBuilder* builder = a.builder();
117   TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
118   const int64 num_dims = a_shape.rank();
119   const int64 n = ShapeUtil::GetDimension(a_shape, -1);
120   XlaOp zero = ScalarLike(i, 0);
121   XlaOp x = DynamicSliceInMinorDims(a, {i, zero}, {1, n});
122 
123   const int64 num_batch_dims = num_dims - 2;
124   std::vector<int64> batch_dims(num_batch_dims);
125   for (int k = 0; k < num_batch_dims; ++k) {
126     batch_dims[k] = ShapeUtil::GetDimension(a_shape, k);
127   }
128 
129   TF_ASSIGN_OR_RETURN(Shape x_shape, builder->GetShape(x));
130   auto idx = Iota(builder, ShapeUtil::MakeShape(S32, x_shape.dimensions()),
131                   num_dims - 1);
132   auto zeros = ZerosLike(x);
133   auto v = Select(Gt(idx, j), x, zeros);
134 
135   auto one = ScalarLike(v, 1.0);
136 
137   auto sigma =
138       Sqrt(Reduce(Square(v), ScalarLike(v, 0.0),
139                   CreateScalarAddComputation(x_shape.element_type(), builder),
140                   {num_dims - 1}));
141 
142   std::vector<int64> broadcast_dims(num_dims - 1);
143   std::iota(broadcast_dims.begin(), broadcast_dims.end(), 0);
144   auto x_0j = DynamicSliceInMinorDims(x, {zero, j}, {1, 1});
145   auto mu = Mul(sigma, Sqrt(Square(Div(x_0j, sigma, broadcast_dims)) + one),
146                 broadcast_dims);
147 
148   auto v_0j = Select(
149       Le(x_0j, ScalarLike(x_0j, 0.0)), Sub(x_0j, mu),
150       -Mul(sigma, Div(sigma, Add(x_0j, mu), broadcast_dims), broadcast_dims));
151 
152   auto beta = Div(ScalarLike(v_0j, 2.0),
153                   (Square(Div(sigma, v_0j, broadcast_dims)) + one));
154 
155   v = Select(
156       BroadcastInDim(Lt(sigma, eps), x_shape.dimensions(), broadcast_dims), v,
157       v / v_0j);
158   v = Select(Eq(idx, j), zeros + one, v);
159 
160   beta = Select(Lt(Add(sigma, ZerosLike(beta), broadcast_dims), eps),
161                 ZerosLike(beta), beta);
162 
163   HouseHolderResult result;
164   result.v = v;
165   result.beta = beta;
166   result.a =
167       Sub(a, Mul(beta, BatchDot(BatchDot(a, TransposeInMinorDims(v), precision),
168                                 v, precision)));
169 
170   return result;
171 }
172 
173 // Householder reflection on the trailing elements of a col of a matrix. After
174 // applying it on the matrix, all elements in [(i+1):, j] become zeros, i.e.,
175 //
176 // H = I - beta * [1; v] * [1; v]', then,
177 //
178 // H * A[i:, j] = [xnorm, 0, 0, ..., 0]
179 //
HouseCol(XlaOp a,XlaOp i,XlaOp j,XlaOp eps,PrecisionConfig::Precision precision)180 StatusOr<HouseHolderResult> HouseCol(XlaOp a, XlaOp i, XlaOp j, XlaOp eps,
181                                      PrecisionConfig::Precision precision) {
182   XlaBuilder* builder = a.builder();
183   TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
184   const int64 num_dims = a_shape.rank();
185   const int64 m = ShapeUtil::GetDimension(a_shape, -2);
186   XlaOp zero = ScalarLike(i, 0);
187   XlaOp x = DynamicSliceInMinorDims(a, {zero, j}, {m, 1});
188 
189   const int64 num_batch_dims = num_dims - 2;
190   std::vector<int64> batch_dims(num_batch_dims);
191   for (int k = 0; k < num_batch_dims; ++k) {
192     batch_dims[k] = ShapeUtil::GetDimension(a_shape, k);
193   }
194 
195   TF_ASSIGN_OR_RETURN(Shape x_shape, builder->GetShape(x));
196   auto idx = Iota(builder, ShapeUtil::MakeShape(S32, x_shape.dimensions()),
197                   num_dims - 2);
198   auto zeros = ZerosLike(x);
199   auto v = Select(Gt(idx, i), x, zeros);
200 
201   auto one = ScalarLike(v, 1.0);
202 
203   auto sigma =
204       Sqrt(Reduce(Square(v), ScalarLike(v, 0.0),
205                   CreateScalarAddComputation(x_shape.element_type(), builder),
206                   {num_dims - 2}));
207 
208   std::vector<int64> broadcast_dims(num_dims - 1);
209   std::iota(broadcast_dims.begin(), broadcast_dims.end(), 0);
210   broadcast_dims[num_dims - 2] = num_dims - 1;
211   auto x_0i = DynamicSliceInMinorDims(x, {i, zero}, {1, 1});
212   auto mu = Mul(sigma, Sqrt(Square(Div(x_0i, sigma, broadcast_dims)) + one),
213                 broadcast_dims);
214 
215   auto v_0i = Select(
216       Le(x_0i, ScalarLike(x_0i, 0.0)), Sub(x_0i, mu),
217       -Mul(sigma, Div(sigma, Add(x_0i, mu), broadcast_dims), broadcast_dims));
218 
219   auto beta = Div(ScalarLike(v_0i, 2.0),
220                   (Square(Div(sigma, v_0i, broadcast_dims)) + one));
221 
222   v = Select(
223       BroadcastInDim(Lt(sigma, eps), x_shape.dimensions(), broadcast_dims), v,
224       v / v_0i);
225   v = Select(Eq(idx, i), zeros + one, v);
226 
227   beta = Select(Lt(Add(sigma, ZerosLike(beta), broadcast_dims), eps),
228                 ZerosLike(beta), beta);
229 
230   HouseHolderResult result;
231   result.v = v;
232   result.beta = beta;
233   result.a = Sub(
234       a, Mul(beta, BatchDot(v, BatchDot(TransposeInMinorDims(v), a, precision),
235                             precision)));
236 
237   return result;
238 }
239 
240 // Apply column and row householder reflections for bidiagonalization.
241 //
242 // def house_bidiag(A):
243 //    xz, yz = A.shape
244 //    LL = np.eye(xz)
245 //    RR = np.eye(yz)
246 //    for i in range(yz - 1):
247 //        v, beta = house_col(A, i, i, 1e-8)
248 //        L = np.eye(xz) - beta * np.outer(v, v)
249 //        LL = np.matmul(LL, L)
250 //        A = np.matmul(L, A)
251 //        if i < yz - 2:
252 //            v, beta = house_row(A, i, i + 1, 1e-8)
253 //            R = np.eye(yz) - beta * np.outer(v, v)
254 //            RR = np.matmul(RR, R)
255 //            A = np.matmul(A, R)
256 //    return LL, A, RR
257 //
HouseHolderBidiagonalization(XlaOp a,XlaOp eps,PrecisionConfig::Precision precision)258 StatusOr<SVDResult> HouseHolderBidiagonalization(
259     XlaOp a, XlaOp eps, PrecisionConfig::Precision precision) {
260   XlaBuilder* builder = a.builder();
261   TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
262   const int64 num_dims = a_shape.rank();
263   const int64 num_batch_dims = num_dims - 2;
264   std::vector<int64> batch_dims(num_batch_dims);
265   for (int i = 0; i < num_batch_dims; ++i) {
266     batch_dims[i] = ShapeUtil::GetDimension(a_shape, i);
267   }
268   const int64 m = ShapeUtil::GetDimension(a_shape, -2);
269   const int64 n = ShapeUtil::GetDimension(a_shape, -1);
270   XlaOp u_init = Broadcast(
271       IdentityMatrix(builder, a_shape.element_type(), m, m), batch_dims);
272   XlaOp v_init = Broadcast(
273       IdentityMatrix(builder, a_shape.element_type(), n, n), batch_dims);
274 
275   auto while_cond_fn = [&](absl::Span<const XlaOp> values,
276                            XlaBuilder* cond_builder) -> StatusOr<XlaOp> {
277     auto i = values[0];
278     return Lt(i, ScalarLike(i, n - 2));
279   };
280   auto while_body_fn =
281       [&](absl::Span<const XlaOp> values,
282           XlaBuilder* body_builder) -> StatusOr<std::vector<XlaOp>> {
283     auto i = values[0];
284     auto one = ScalarLike(i, 1);
285 
286     auto u = values[1];
287     auto v = values[2];
288     auto a = values[3];
289     auto eps = values[4];
290 
291     TF_ASSIGN_OR_RETURN(HouseHolderResult house_col,
292                         HouseCol(a, i, i, eps, precision));
293     u = Sub(u, Mul(house_col.beta,
294                    BatchDot(BatchDot(u, house_col.v, precision),
295                             TransposeInMinorDims(house_col.v), precision)));
296     a = house_col.a;
297 
298     TF_ASSIGN_OR_RETURN(HouseHolderResult house_row,
299                         HouseRow(a, i, i + one, eps, precision));
300     v = Sub(
301         v,
302         Mul(house_row.beta,
303             BatchDot(BatchDot(v, TransposeInMinorDims(house_row.v), precision),
304                      house_row.v, precision)));
305     a = house_row.a;
306 
307     std::vector<XlaOp> updated_values;
308     updated_values.reserve(values.size());
309 
310     updated_values.push_back(i + one);
311     updated_values.push_back(u);
312     updated_values.push_back(v);
313     updated_values.push_back(a);
314     updated_values.push_back(eps);
315     return updated_values;
316   };
317 
318   std::vector<XlaOp> values(5);
319   values[0] = Zero(builder, S32);
320   values[1] = u_init;
321   values[2] = v_init;
322   values[3] = a;
323   values[4] = eps;
324 
325   TF_ASSIGN_OR_RETURN(values,
326                       WhileLoopHelper(while_cond_fn, while_body_fn, values,
327                                       "HouseHolderBidiagonalization", builder));
328 
329   for (int k = 2; k > 0; --k) {
330     if (n - k >= 0) {
331       XlaOp index = ScalarLike(values[0], n - k);
332       TF_ASSIGN_OR_RETURN(HouseHolderResult house_col,
333                           HouseCol(values[3], index, index, eps, precision));
334       values[1] =
335           Sub(values[1],
336               Mul(house_col.beta,
337                   BatchDot(BatchDot(values[1], house_col.v, precision),
338                            TransposeInMinorDims(house_col.v), precision)));
339       values[3] = house_col.a;
340     }
341   }
342 
343   SVDResult result;
344   result.u = values[1];
345   result.v = values[2];
346   result.d = values[3];
347   return result;
348 }
349 
350 // MakeJacobi computes a rotation matrix G = [[c, s], [-s, c]], such that
351 //                        G_T * [[ps, pqs], [pqs, qs]] * G
352 // is diagonalized.
353 //
354 //  def make_jacobi(ps, qs, pqs, eps):
355 //     if np.abs(a_pq) > eps:
356 //         tau = (a_qq - a_pp) / (2 * a_pq)
357 //         if tau >= 0:
358 //             t = 1.0 / (tau + np.sqrt(1 + tau ** 2))
359 //         else:
360 //             t = -1.0 / (-tau + np.sqrt(1 + tau ** 2))
361 //         c = 1.0 / np.sqrt(1.0 + t ** 2)
362 //         s = t * c
363 //     else:
364 //         c = 1.0
365 //         s = 0.0
366 //     return c, s
367 //
MakeJacobi(XlaOp ps,XlaOp qs,XlaOp pqs,XlaOp eps)368 StatusOr<JacobiRotation> MakeJacobi(XlaOp ps, XlaOp qs, XlaOp pqs, XlaOp eps) {
369   auto zero = ScalarLike(ps, 0.0);
370   auto one = ScalarLike(ps, 1.0);
371   auto two = ScalarLike(ps, 2.0);
372 
373   auto tau = (qs - ps) / (pqs * two);
374   auto t_pos = one / (tau + Sqrt(one + Square(tau)));
375   auto t_neg = -one / (-tau + Sqrt(one + Square(tau)));
376   auto t = Select(Ge(tau, zero), t_pos, t_neg);
377 
378   auto c_temp = Rsqrt(one + Square(t));
379   auto s_temp = t * c_temp;
380 
381   auto c = Select(Ge(Abs(pqs), eps), c_temp, ZerosLike(c_temp) + one);
382   auto s = Select(Ge(Abs(pqs), eps), s_temp, ZerosLike(s_temp));
383   // Renormalize c and s to compensate for low precision arithmetic, this step
384   // is redundant if high precision float is used, like float64.
385   auto rnorm = Rsqrt(Square(c) + Square(s));
386 
387   JacobiRotation rot;
388 
389   rot.c = c * rnorm;
390   rot.s = s * rnorm;
391 
392   return rot;
393 }
394 
395 // One sided Jacobi rotations. For a matrix,
396 //  [a_pp, a_pq]
397 //  [a_qp, a_qq]
398 // After applying Jacobi rotations on both sides, the matrix is diagonalized.
399 //  [b_pp, 0]
400 //  [0, b_qq]
401 //
402 // def jacobi_rot(a, p, q, eps):
403 //     t = a[p, p] + a[q, q]
404 //     d = a[q, p] - a[p, q]
405 //
406 //     if np.abs(d) < eps:
407 //         s = 0.0
408 //         c = 1.0
409 //     else:
410 //         u = t / d
411 //         tmp = np.sqrt(1.0 + u**2)
412 //         s = -1.0 / tmp
413 //         c = u / tmp
414 //
415 //     rot = np.array([[c, s], [-s, c]])
416 //     m_tmp = rot.T @ a[[p, q], [p, q]]
417 //     c_r, s_r = make_jacobi(m_tmp[0, 0], m_tmp[1, 1], m_tmp[0, 1])
418 //     rot_r = np.array([[c_r, s_r], [-s_r, c_r]])
419 //     rot_l = rot @ rot_r
420 //    return rot_l, rot_r
421 //
GetOneSidedJacobiRotation(XlaOp a,XlaOp p,XlaOp q,XlaOp eps)422 StatusOr<OneSidedJacobiRotation> GetOneSidedJacobiRotation(XlaOp a, XlaOp p,
423                                                            XlaOp q, XlaOp eps) {
424   XlaOp a_pp = DynamicSliceInMinorDims(a, {p, p}, {1, 1});
425   XlaOp a_pq = DynamicSliceInMinorDims(a, {p, q}, {1, 1});
426   XlaOp a_qp = DynamicSliceInMinorDims(a, {q, p}, {1, 1});
427   XlaOp a_qq = DynamicSliceInMinorDims(a, {q, q}, {1, 1});
428 
429   XlaOp one = ScalarLike(a, 1.0);
430 
431   XlaOp t = a_pp + a_qq;
432   XlaOp d = a_qp - a_pq;
433 
434   XlaOp u = Div(t, d);
435   XlaOp tmp = Rsqrt(one + Square(u));
436 
437   JacobiRotation rot;
438 
439   XlaOp zeros = ZerosLike(tmp);
440   XlaOp ones = zeros + one;
441 
442   rot.s = Select(Lt(Abs(d), eps), zeros, -tmp);
443   rot.c = Select(Lt(Abs(d), eps), ones, Mul(u, tmp));
444 
445   XlaOp a_pp_new = rot.c * a_pp - rot.s * a_qp;
446   XlaOp a_pq_new = rot.c * a_pq - rot.s * a_qq;
447   XlaOp a_qq_new = rot.s * a_pq + rot.c * a_qq;
448 
449   OneSidedJacobiRotation rots;
450   TF_ASSIGN_OR_RETURN(rots.rot_r,
451                       MakeJacobi(a_pp_new, a_qq_new, a_pq_new, eps));
452 
453   rots.rot_l.c = rot.c * rots.rot_r.c - rot.s * rots.rot_r.s;
454   rots.rot_l.s = rot.s * rots.rot_r.c + rot.c * rots.rot_r.s;
455 
456   return rots;
457 }
458 
459 // Apply one-sided Jacobi on elements at indices pp, pq, qp, qq.
OneSidedJacobiUpdate(SVDResult svd_result,XlaOp p,XlaOp q,XlaOp eps)460 StatusOr<SVDResult> OneSidedJacobiUpdate(SVDResult svd_result, XlaOp p, XlaOp q,
461                                          XlaOp eps) {
462   XlaOp u = svd_result.u;
463   XlaOp v = svd_result.v;
464   XlaOp d = svd_result.d;
465   XlaBuilder* builder = d.builder();
466   TF_ASSIGN_OR_RETURN(Shape d_shape, builder->GetShape(d));
467   const int64 num_dims = d_shape.rank();
468   const int64 num_batch_dims = num_dims - 2;
469   std::vector<int64> batch_dims(num_batch_dims);
470   for (int i = 0; i < num_batch_dims; ++i) {
471     batch_dims[i] = ShapeUtil::GetDimension(d_shape, i);
472   }
473   const int64 m = ShapeUtil::GetDimension(d_shape, -2);
474   const int64 n = ShapeUtil::GetDimension(d_shape, -1);
475 
476   TF_ASSIGN_OR_RETURN(OneSidedJacobiRotation onesided_jacobi,
477                       GetOneSidedJacobiRotation(d, p, q, eps));
478 
479   auto zero = ScalarLike(p, 0);
480 
481   // Zero out a_{pq} explicitly.
482   std::vector<int64> pq_dims(batch_dims.begin(), batch_dims.end());
483   pq_dims.push_back(1);
484   pq_dims.push_back(1);
485   auto pq_zero = ScalarLike(d, 0.0);
486   auto pq_zeros = Broadcast(pq_zero, pq_dims);
487 
488   std::vector<int64> broadcast_dims(batch_dims.size());
489   std::iota(broadcast_dims.begin(), broadcast_dims.end(), 0);
490   broadcast_dims.push_back(num_dims - 1);
491 
492   // Apply Jacobi Rotation on the left.
493   auto slice_p = DynamicSliceInMinorDims(d, {p, zero}, {1, n});
494   auto slice_q = DynamicSliceInMinorDims(d, {q, zero}, {1, n});
495   auto slice_p_new =
496       onesided_jacobi.rot_l.c * slice_p - onesided_jacobi.rot_l.s * slice_q;
497   auto slice_q_new =
498       onesided_jacobi.rot_l.s * slice_p + onesided_jacobi.rot_l.c * slice_q;
499   d = DynamicUpdateSliceInMinorDims(d, slice_p_new, {p, zero});
500   d = DynamicUpdateSliceInMinorDims(d, slice_q_new, {q, zero});
501 
502   // Apply Jacobi Rotation on the right.
503   slice_p = DynamicSliceInMinorDims(d, {zero, p}, {m, 1});
504   slice_q = DynamicSliceInMinorDims(d, {zero, q}, {m, 1});
505   slice_p_new =
506       onesided_jacobi.rot_r.c * slice_p - onesided_jacobi.rot_r.s * slice_q;
507   slice_q_new =
508       onesided_jacobi.rot_r.s * slice_p + onesided_jacobi.rot_r.c * slice_q;
509   d = DynamicUpdateSliceInMinorDims(d, slice_p_new, {zero, p});
510   d = DynamicUpdateSliceInMinorDims(d, slice_q_new, {zero, q});
511 
512   d = DynamicUpdateSliceInMinorDims(d, pq_zeros, {p, q});
513   d = DynamicUpdateSliceInMinorDims(d, pq_zeros, {q, p});
514 
515   // Apply left Jacobi Rotation on U.
516   slice_p = DynamicSliceInMinorDims(u, {zero, p}, {m, 1});
517   slice_q = DynamicSliceInMinorDims(u, {zero, q}, {m, 1});
518   slice_p_new =
519       onesided_jacobi.rot_l.c * slice_p - onesided_jacobi.rot_l.s * slice_q;
520 
521   slice_p_new = Mul(
522       slice_p_new,
523       Rsqrt(Reduce(Square(slice_p_new), pq_zero,
524                    CreateScalarAddComputation(d_shape.element_type(), builder),
525                    {num_dims - 2})),
526       broadcast_dims);
527 
528   slice_q_new =
529       onesided_jacobi.rot_l.s * slice_p + onesided_jacobi.rot_l.c * slice_q;
530 
531   slice_q_new = Mul(
532       slice_q_new,
533       Rsqrt(Reduce(Square(slice_q_new), pq_zero,
534                    CreateScalarAddComputation(d_shape.element_type(), builder),
535                    {num_dims - 2})),
536       broadcast_dims);
537 
538   u = DynamicUpdateSliceInMinorDims(u, slice_p_new, {zero, p});
539   u = DynamicUpdateSliceInMinorDims(u, slice_q_new, {zero, q});
540 
541   // Apply right Jacobi Rotation on V.
542   slice_p = DynamicSliceInMinorDims(v, {zero, p}, {n, 1});
543   slice_q = DynamicSliceInMinorDims(v, {zero, q}, {n, 1});
544   slice_p_new =
545       onesided_jacobi.rot_r.c * slice_p - onesided_jacobi.rot_r.s * slice_q;
546 
547   slice_p_new = Mul(
548       slice_p_new,
549       Rsqrt(Reduce(Square(slice_p_new), pq_zero,
550                    CreateScalarAddComputation(d_shape.element_type(), builder),
551                    {num_dims - 2})),
552       broadcast_dims);
553 
554   slice_q_new =
555       onesided_jacobi.rot_r.s * slice_p + onesided_jacobi.rot_r.c * slice_q;
556 
557   slice_q_new = Mul(
558       slice_q_new,
559       Rsqrt(Reduce(Square(slice_q_new), pq_zero,
560                    CreateScalarAddComputation(d_shape.element_type(), builder),
561                    {num_dims - 2})),
562       broadcast_dims);
563 
564   v = DynamicUpdateSliceInMinorDims(v, slice_p_new, {zero, p});
565   v = DynamicUpdateSliceInMinorDims(v, slice_q_new, {zero, q});
566 
567   svd_result.d = d;
568   svd_result.u = u;
569   svd_result.v = v;
570 
571   return svd_result;
572 }
573 
ComputeFrobeniusNorms(XlaOp w)574 StatusOr<FrobeniusNorms> ComputeFrobeniusNorms(XlaOp w) {
575   XlaBuilder* builder = w.builder();
576   TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(w));
577   const int64 num_dims = shape.rank();
578   auto frobenius_norm =
579       Sqrt(Reduce(Square(w), ScalarLike(w, 0.0),
580                   CreateScalarAddComputation(shape.element_type(), builder),
581                   {num_dims - 2, num_dims - 1}));
582   auto diag = GetMatrixDiagonal(w);
583   auto diag_square =
584       Reduce(Square(diag), ScalarLike(w, 0.0),
585              CreateScalarAddComputation(shape.element_type(), builder),
586              {num_dims - 2});
587 
588   FrobeniusNorms frobenius_norms;
589 
590   frobenius_norms.off_diagonal_norm =
591       Sqrt(Max(Square(frobenius_norm) - diag_square, ScalarLike(w, 0.0)));
592   frobenius_norms.total_norm = frobenius_norm;
593 
594   return frobenius_norms;
595 }
596 
597 // Main boby of One-sided Jacobi Method.
WhileLoopFn(absl::Span<const XlaOp> initial_values,int matrix_dimension,int max_sweep_updates,absl::string_view name,XlaBuilder * builder)598 StatusOr<std::vector<XlaOp>> WhileLoopFn(
599     absl::Span<const XlaOp> initial_values,  //
600     int matrix_dimension,                    //
601     int max_sweep_updates,                   //
602     absl::string_view name,                  //
603     XlaBuilder* builder) {
604   auto while_cond_fn = [&](absl::Span<const XlaOp> values,
605                            XlaBuilder* cond_builder) -> StatusOr<XlaOp> {
606     auto k = values[0];
607     auto max_sweeps = ScalarLike(k, max_sweep_updates);
608     auto sweep_update_cond = Gt(max_sweeps, k);
609 
610     auto norms = ComputeFrobeniusNorms(values[3]).ValueOrDie();
611     auto tol = norms.total_norm * values[4];
612     auto tol_cond = ReduceAll(Lt(tol, norms.off_diagonal_norm),
613                               xla::ConstantR0<bool>(cond_builder, false),
614                               CreateScalarOrComputation(PRED, cond_builder));
615 
616     return And(sweep_update_cond, tol_cond);
617   };
618 
619   auto while_body_fn =
620       [&](absl::Span<const XlaOp> values,
621           XlaBuilder* body_builder) -> StatusOr<std::vector<XlaOp>> {
622     auto while_cond_fn_inner =
623         [&](absl::Span<const XlaOp> values_inner,
624             XlaBuilder* inner_cond_builder) -> StatusOr<XlaOp> {
625       auto p = values_inner[0];
626       return Lt(p, ScalarLike(p, matrix_dimension - 1));
627     };
628 
629     auto while_body_fn_inner =
630         [&](absl::Span<const XlaOp> values_inner,
631             XlaBuilder* inner_body_builder) -> StatusOr<std::vector<XlaOp>> {
632       auto while_cond_fn_innermost =
633           [&](absl::Span<const XlaOp> values_innermost,
634               XlaBuilder* innermost_cond_builder) -> StatusOr<XlaOp> {
635         auto q = values_innermost[1];
636         return Lt(q, ScalarLike(q, matrix_dimension));
637       };
638       auto while_body_fn_innermost =
639           [&](absl::Span<const XlaOp> values_innermost,
640               XlaBuilder* innermost_body_builder)
641           -> StatusOr<std::vector<XlaOp>> {
642         auto p = values_innermost[0];
643         auto q = values_innermost[1];
644 
645         SVDResult onesided_jacobi_update;
646         onesided_jacobi_update.u = values_innermost[2];
647         onesided_jacobi_update.v = values_innermost[3];
648         onesided_jacobi_update.d = values_innermost[4];
649 
650         auto eps = values_innermost[5];
651 
652         TF_ASSIGN_OR_RETURN(
653             onesided_jacobi_update,
654             OneSidedJacobiUpdate(onesided_jacobi_update, p, q, eps));
655 
656         std::vector<XlaOp> updated_values_innermost;
657         updated_values_innermost.reserve(values_innermost.size());
658 
659         updated_values_innermost.push_back(p);
660         updated_values_innermost.push_back(q + ScalarLike(q, 1));
661         updated_values_innermost.push_back(onesided_jacobi_update.u);
662         updated_values_innermost.push_back(onesided_jacobi_update.v);
663         updated_values_innermost.push_back(onesided_jacobi_update.d);
664         updated_values_innermost.push_back(eps);
665 
666         return updated_values_innermost;
667       };
668 
669       std::vector<XlaOp> values_innermost(6);
670       auto p = values_inner[0];
671       auto q = p + ScalarLike(p, 1);
672       values_innermost[0] = p;                // index p.
673       values_innermost[1] = q;                // index q.
674       values_innermost[2] = values_inner[1];  // u.
675       values_innermost[3] = values_inner[2];  // v.
676       values_innermost[4] = values_inner[3];  // d.
677       values_innermost[5] = values_inner[4];  // eps.
678       TF_ASSIGN_OR_RETURN(
679           values_innermost,
680           WhileLoopHelper(while_cond_fn_innermost, while_body_fn_innermost,
681                           values_innermost, absl::StrCat(name, "-Innermost"),
682                           inner_body_builder));
683 
684       std::vector<XlaOp> updated_values_inner;
685       updated_values_inner.reserve(values_inner.size());
686 
687       updated_values_inner.push_back(p + ScalarLike(p, 1));
688       updated_values_inner.push_back(values_innermost[2]);
689       updated_values_inner.push_back(values_innermost[3]);
690       updated_values_inner.push_back(values_innermost[4]);
691       updated_values_inner.push_back(values_innermost[5]);
692       return updated_values_inner;
693     };
694     // Indexes.
695     XlaOp k = values[0];
696 
697     std::vector<XlaOp> values_inner(5);
698     values_inner[0] = ScalarLike(k, 0);  // index p.
699     values_inner[1] = values[1];         // u.
700     values_inner[2] = values[2];         // v.
701     values_inner[3] = values[3];         // d.
702     values_inner[4] = values[4];         // eps.
703     TF_ASSIGN_OR_RETURN(
704         values_inner,
705         WhileLoopHelper(while_cond_fn_inner, while_body_fn_inner, values_inner,
706                         absl::StrCat(name, "-Inner"), body_builder));
707 
708     std::vector<XlaOp> updated_values;
709     updated_values.reserve(values_inner.size());
710 
711     updated_values.push_back(k + ScalarLike(k, 1));
712     updated_values.push_back(values_inner[1]);
713     updated_values.push_back(values_inner[2]);
714     updated_values.push_back(values_inner[3]);
715     updated_values.push_back(values_inner[4]);
716 
717     return updated_values;
718   };
719   std::vector<XlaOp> values;
720   TF_ASSIGN_OR_RETURN(values, WhileLoopHelper(while_cond_fn, while_body_fn,
721                                               initial_values, name, builder));
722 
723   return values;
724 }
725 
726 // Sort singular values in decending order, and make sure they are non-negative
727 // by flipping the signs of negative diagonal values and transferring the signs
728 // to V. And for numeric stability, renormalize U and V.
SortBySingularValuesAndPostProcessing(SVDResult result)729 StatusOr<SVDResult> SortBySingularValuesAndPostProcessing(SVDResult result) {
730   XlaBuilder* builder = result.d.builder();
731   TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(result.d));
732   const int64 num_dims = shape.rank();
733   auto dimensions = shape.dimensions();
734   const int64 m = ShapeUtil::GetDimension(shape, -2);
735   const int64 n = ShapeUtil::GetDimension(shape, -1);
736 
737   std::vector<int64> broadcast_dims(num_dims - 1);
738   std::iota(broadcast_dims.begin(), broadcast_dims.end(), 0);
739   broadcast_dims[num_dims - 2] = num_dims - 1;
740 
741   auto d = GetMatrixDiagonal(result.d);
742 
743   auto zeros = ZerosLike(d);
744   auto one = ScalarLike(d, 1.0);
745 
746   // Make all the singular values to be non-negative by transferring the signs
747   // to V.
748   auto sign = Select(Ge(d, zeros), zeros + one, zeros - one);
749   d = Select(Ge(d, zeros), d, -d);
750   result.v = Mul(result.v, sign, broadcast_dims);
751 
752   d = BroadcastInDim(d, dimensions, broadcast_dims);
753 
754   // As m >= n, only first m columns vectors are needed to be permuted, and the
755   // rest of m - n vectors are appended after the sorting is done.
756   XlaOp sort_u_result =
757       Sort({-d, SliceInMinorDims(result.u, {0, 0}, {m, n})},
758            CreateScalarLtComputation(
759                {shape.element_type(), shape.element_type()}, builder),
760            num_dims - 1);
761 
762   // TODO(kuny): using CreateScalarGtComputation after b/124862300 is fixed.
763   XlaOp sort_v_result =
764       Sort({SliceInMinorDims(-d, {0, 0}, {n, n}), result.v},
765            CreateScalarLtComputation(
766                {shape.element_type(), shape.element_type()}, builder),
767            num_dims - 1);
768   // Make sure all the signular values are non-negative.
769   result.d = Max(-GetMatrixDiagonal(GetTupleElement(sort_v_result, 0)),
770                  ScalarLike(d, 0.0));
771 
772   result.v = GetTupleElement(sort_v_result, 1);
773   result.v = Mul(
774       result.v,
775       Rsqrt(Reduce(Square(result.v), ScalarLike(d, 0.0),
776                    CreateScalarAddComputation(shape.element_type(), builder),
777                    {num_dims - 2})),
778       broadcast_dims);
779 
780   // Append the rest of m - n vectors.
781   result.u = ConcatInDim(builder,
782                          {GetTupleElement(sort_u_result, 1),
783                           SliceInMinorDims(result.u, {0, n}, {m, m})},
784                          num_dims - 1);
785   result.u = Mul(
786       result.u,
787       Rsqrt(Reduce(Square(result.u), ScalarLike(d, 0.0),
788                    CreateScalarAddComputation(shape.element_type(), builder),
789                    {num_dims - 2})),
790       broadcast_dims);
791 
792   return result;
793 }
794 
795 }  // namespace
796 
797 // def jacobi_svd(A):
798 //    U, D, V = house_bidiag(A)
799 //    m, n = D.shape
800 //    iter, max_iter = 0, 100
801 //    frobenius_norm = np.linalg.norm(D)
802 //    diag_norm = np.linalg.norm(np.diag(D))
803 //    off_diag_norm = np.sqrt(
804 //        frobenius_norm - diag_norm) * np.sqrt(frobenius_norm + diag_norm)
805 //    while off_diag_norm > 1e-6 * frobenius_norm and iter < max_iter:
806 //        iter += 1
807 //        for p in range(m - 1):
808 //            for q in range(p + 1, n):
809 //                rot_l, rot_r = jacobi_rot(D[p][p], D[p][q], D[q][p], D[q][q])
810 //                D[[p, q], :] = np.matmul(rot_l.T, D[[p, q], :])
811 //                D[:, [p, q]] = np.matmul(D[:, [p, q]], rot_r)
812 //                U[:, [p, q]] = np.matmul(U[:, [p, q]], rot_l)
813 //                V[:, [p, q]] = np.matmul(V[:, [p, q]], rot_r)
814 //        frobenius_norm = np.linalg.norm(D)
815 //        diag_norm = np.linalg.norm(np.diag(D))
816 //        off_diag_norm = np.sqrt(
817 //            frobenius_norm - diag_norm) * np.sqrt(frobenius_norm + diag_norm)
818 //
819 //    return U, np.diag(D), V
820 //
SVD(XlaOp a,int64 max_iter,float epsilon,PrecisionConfig::Precision precision)821 SVDResult SVD(XlaOp a, int64 max_iter, float epsilon,
822               PrecisionConfig::Precision precision) {
823   XlaBuilder* builder = a.builder();
824   auto return_error = [&](const Status& status) {
825     SVDResult result;
826     result.u = builder->ReportError(status);
827     result.v = builder->ReportError(status);
828     result.d = builder->ReportError(status);
829     return result;
830   };
831   auto shape_with_status = builder->GetShape(a);
832   if (!shape_with_status.status().ok()) {
833     return return_error(shape_with_status.status());
834   }
835   Shape a_shape = shape_with_status.ValueOrDie();
836   const int64 num_dims = a_shape.rank();
837   const int64 num_batch_dims = num_dims - 2;
838   std::vector<int64> batch_dims(num_batch_dims);
839   for (int i = 0; i < num_batch_dims; ++i) {
840     batch_dims[i] = ShapeUtil::GetDimension(a_shape, i);
841   }
842   int64 m = ShapeUtil::GetDimension(a_shape, -2);
843   int64 n = ShapeUtil::GetDimension(a_shape, -1);
844   bool maybe_transpose = m < n;
845 
846   if (maybe_transpose) {
847     a = TransposeInMinorDims(a);
848     std::swap(m, n);
849   }
850 
851   auto eps = ScalarLike(a, epsilon);
852 
853   SVDResult svd_result =
854       HouseHolderBidiagonalization(a, eps, precision).ValueOrDie();
855 
856   auto output_with_status = WhileLoopFn(
857       {
858           Zero(builder, S32),  // k
859           svd_result.u,        // u
860           svd_result.v,        // v
861           svd_result.d,        // d
862           eps,                 // epsilon
863       },                       //
864       n,                       //
865       max_iter,                //
866       "CyclicOneSidedJacobi",  //
867       builder);
868   if (!output_with_status.status().ok()) {
869     return return_error(output_with_status.status());
870   }
871 
872   auto output = output_with_status.ValueOrDie();
873 
874   svd_result.u = output[1];
875   svd_result.v = output[2];
876   svd_result.d = output[3];
877   svd_result = SortBySingularValuesAndPostProcessing(svd_result).ValueOrDie();
878   if (maybe_transpose) {
879     std::swap(svd_result.u, svd_result.v);
880   }
881   return svd_result;
882 }
883 
884 }  // namespace xla
885