• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/client/lib/qr.h"
17 
18 #include <memory>
19 #include <vector>
20 
21 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
22 #include "tensorflow/compiler/xla/client/lib/constants.h"
23 #include "tensorflow/compiler/xla/client/lib/loops.h"
24 #include "tensorflow/compiler/xla/client/lib/math.h"
25 #include "tensorflow/compiler/xla/client/lib/matrix.h"
26 #include "tensorflow/compiler/xla/client/lib/slicing.h"
27 #include "tensorflow/compiler/xla/client/xla_builder.h"
28 #include "tensorflow/compiler/xla/literal_util.h"
29 #include "tensorflow/compiler/xla/shape_util.h"
30 #include "tensorflow/compiler/xla/status_macros.h"
31 #include "tensorflow/compiler/xla/statusor.h"
32 #include "tensorflow/core/lib/core/errors.h"
33 
34 namespace xla {
35 
36 namespace {
37 
ConcatVectors(absl::Span<const int64> xs,absl::Span<const int64> ys)38 std::vector<int64> ConcatVectors(absl::Span<const int64> xs,
39                                  absl::Span<const int64> ys) {
40   std::vector<int64> output(xs.size() + ys.size());
41   std::copy(xs.begin(), xs.end(), output.begin());
42   std::copy(ys.begin(), ys.end(), output.begin() + xs.size());
43   return output;
44 }
45 
46 // Computes a Householder reflection of the form:
47 // H = I - tau v v.T.
48 // such that
49 // H . ( x1  ) = ( x1   )
50 //     ( x2  ) = ( x2   )
51 //     ( ... ) = ( ...  )
52 //     ( xk  ) = ( beta )
53 //     ( ... )   ( 0    )
54 //     ( ... )   ( 0    )
55 // Unlike the usual formulation, we allow the caller to supply 'k' rather than
56 // only providing the relevant part of 'x' to maintain XLA's static shape
57 // invariant. In addition, the implementation supports batching.
58 // Pseudo-code, without batching:
59 //   alpha = x[k]
60 //   x_copy = np.copy(x)
61 //   x_copy[:k+1] = 0
62 //   xnorm = norm2(x_copy)
63 //   if xnorm == 0:
64 //     beta = alpha
65 //     tau = 0
66 //     v = np.zeros_like(x)
67 //   else:
68 //     beta = - np.sign(alpha) * dlapy2(alpha, xnorm)
69 //     tau = (beta - alpha) / beta
70 //     v = x / (alpha - beta)
71 //   v[k] = 1
72 //   return (v, tau, beta)
73 // TODO(phawkins): LAPACK's xLARFG implementation has code for handling
74 // overflows in the norm/beta calculations. Perhaps do the same here.
House(XlaOp x,XlaOp k,absl::Span<const int64> batch_dims,const int64 m,XlaOp * v,XlaOp * tau,XlaOp * beta)75 Status House(XlaOp x, XlaOp k, absl::Span<const int64> batch_dims,
76              const int64 m, XlaOp* v, XlaOp* tau, XlaOp* beta) {
77   XlaBuilder* const builder = x.builder();
78   TF_ASSIGN_OR_RETURN(Shape x_shape, builder->GetShape(x));
79   const PrimitiveType type = x_shape.element_type();
80 
81   std::vector<int64> batch_dim_ids(batch_dims.size());
82   std::iota(batch_dim_ids.begin(), batch_dim_ids.end(), 0);
83   const int64 minor_dim = batch_dims.size();
84 
85   XlaOp zero = ScalarLike(x, 0.0);
86   XlaOp one = ScalarLike(x, 1.0);
87 
88   // alpha = x[k]
89   XlaOp alpha = Reshape(DynamicSliceInMinorDims(x, {k}, {1}), batch_dims);
90 
91   // Compute x[k+1:] (padded with zeros in elements 0..k)
92   XlaOp iota = Iota(builder, S32, m);
93   XlaOp x_after_k = Mul(x, ConvertElementType(Gt(iota, k), type),
94                         /*broadcast_dimensions=*/{minor_dim});
95 
96   // sigma = np.dot(x[k+1:], x[k+1:])
97   auto sigma = Reduce(x_after_k * x_after_k, zero,
98                       CreateScalarAddComputation(type, builder), {minor_dim});
99   // mu = np.sqrt(x[k]*x[k] + sigma)
100   auto mu = Sqrt(Square(alpha) + sigma);
101 
102   auto sigma_is_zero = Eq(sigma, zero);
103 
104   *beta = Select(sigma_is_zero, alpha, -Sign(alpha) * mu);
105   *tau = Select(sigma_is_zero, Broadcast(zero, batch_dims),
106                 (*beta - alpha) / *beta);
107   auto divisor =
108       Select(sigma_is_zero, Broadcast(one, batch_dims), alpha - *beta);
109 
110   auto e_k = Broadcast(ConvertElementType(Eq(iota, k), type),
111                        std::vector<int64>(batch_dims.size(), 1));
112 
113   // Form v as [0, 0, ..., 1] ++ x[k+1:] / divisor
114   // If sigma is zero, x[k+1:] is zero, so use any non-zero divisor.
115   *v = e_k + Div(x_after_k, divisor, /*broadcast_dimensions=*/batch_dim_ids);
116   return Status::OK();
117 }
118 
119 // Householder QR decomposition. Algorithm 5.2.1 from Golub and Van
120 // Loan "Matrix Computations", 4th Edition. This is an unblocked implementation
121 // used as an inner routine of the blocked implementation.
122 // Algorithm is adapted slightly so the shapes inside the loop are static, at
123 // the cost of some redundant computation. Since this is used as an inner block
124 // kernel, accumulates the Householder transformations (vs, taus) rather than
125 // the matrix q.
126 // Equivalent Python code, without batching:
127 // def qr(a):
128 //   m = a.shape[0]
129 //   n = a.shape[1]
130 //   vs = np.zeros([m, n])
131 //   taus = np.zeros([n])
132 //   for j in xrange(min(m, n)):
133 //     v, tau, beta = house(a[:, j], j)
134 //     # Unusually, we apply the Householder transformation to the entirety of
135 //     # a, wasting FLOPs to maintain the static shape invariant that XLA
136 //     # requires. For columns that precede j this has no effect.
137 //     a[:, :] -= tau * np.dot(v[:, np.newaxis],
138 //                              np.dot(v[np.newaxis, :], a[:, :]))
139 //     # Form column j explicitly rather than relying on the precision of the
140 //     # Householder update.
141 //     a[j, j] = beta
142 //     a[j+1:, j] = np.zeros([m - j - 1], dtype=a.dtype)
143 //     vs[:, j] = v
144 //     taus[j] = tau
145 //   return (q, vs, taus)
146 struct QRBlockResult {
147   // The factored R value
148   XlaOp r;
149 
150   // Representation of the Householder matrices I - beta v v.T
151   XlaOp taus;  // Shape: [..., n]
152   XlaOp vs;    // Shape: [..., m, n]
153 };
QRBlock(XlaOp a,PrecisionConfig::Precision precision)154 StatusOr<QRBlockResult> QRBlock(XlaOp a, PrecisionConfig::Precision precision) {
155   XlaBuilder* builder = a.builder();
156   TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
157   const int num_dims = a_shape.rank();
158   if (num_dims < 2) {
159     return InvalidArgument("Argument to QR must have rank >= 2; got shape %s",
160                            a_shape.ToString());
161   }
162   PrimitiveType type = a_shape.element_type();
163 
164   const int64 m = ShapeUtil::GetDimension(a_shape, -2);
165   const int64 n = ShapeUtil::GetDimension(a_shape, -1);
166 
167   const int64 num_batch_dims = num_dims - 2;
168   std::vector<int64> batch_dims(num_batch_dims);
169   for (int i = 0; i < num_batch_dims; ++i) {
170     batch_dims[i] = ShapeUtil::GetDimension(a_shape, i);
171   }
172 
173   std::vector<int64> batch_dim_indices(num_batch_dims);
174   std::iota(batch_dim_indices.begin(), batch_dim_indices.end(), 0);
175 
176   auto qr_body_fn = [&](XlaOp j, absl::Span<const XlaOp> values,
177                         XlaBuilder* builder) -> StatusOr<std::vector<XlaOp>> {
178     auto a = values[0];
179     auto vs = values[1];
180     auto taus = values[2];
181 
182     // v, beta = house(a[:, j], j)
183     auto x = DynamicSliceInMinorDims(a, {j}, {1});
184     XlaOp v, tau, beta;
185     TF_RETURN_IF_ERROR(House(Collapse(x, {num_dims - 2, num_dims - 1}), j,
186                              batch_dims, m, &v, &tau, &beta));
187 
188     std::vector<int64> shape = batch_dims;
189     shape.push_back(1);
190     shape.push_back(m);
191     auto v_broadcast = Reshape(v, shape);
192     // a[:, :] -= tau * np.dot(v[:, np.newaxis],
193     //                          np.dot(v[np.newaxis, :], a[:, :]))
194     auto vva = BatchDot(v_broadcast, a, precision);
195     vva = BatchDot(TransposeInMinorDims(v_broadcast), vva, precision);
196     a = a - Mul(tau, vva,
197                 /*broadcast_dimensions=*/batch_dim_indices);
198 
199     // It is more precise to populate column 'k' explicitly, rather than
200     // computing it implicitly by applying the Householder transformation.
201     // a[k,k] = beta
202     // a[k+1:,k] = np.zeros([m-k-1], dtype=a.dtype)
203     auto iota = Reshape(Iota(a.builder(), S32, m), {m, 1});
204     auto predecessor_mask = ConvertElementType(Lt(iota, j), type);
205     auto mask = Broadcast(ConvertElementType(Eq(iota, j), type),
206                           std::vector<int64>(batch_dims.size(), 1));
207     auto new_x = Mul(x, predecessor_mask,
208                      /*broadcast_dimensions=*/{num_dims - 2, num_dims - 1}) +
209                  Mul(beta, mask, /*broadcast_dimensions=*/batch_dim_indices);
210     a = DynamicUpdateSliceInMinorDims(a, new_x, {j});
211 
212     // vs[:, j] = v
213     vs = DynamicUpdateSliceInMinorDims(
214         vs, Reshape(v, ConcatVectors(batch_dims, {m, 1})), {j});
215     // taus[j] = tau
216     taus = DynamicUpdateSliceInMinorDims(
217         taus, Reshape(tau, ConcatVectors(batch_dims, {1})), {j});
218     return std::vector<XlaOp>{a, vs, taus};
219   };
220 
221   auto vs = Zeros(
222       builder, ShapeUtil::MakeShape(type, ConcatVectors(batch_dims, {m, n})));
223   auto taus = Zeros(builder,
224                     ShapeUtil::MakeShape(type, ConcatVectors(batch_dims, {n})));
225 
226   TF_ASSIGN_OR_RETURN(auto values, ForEachIndex(std::min(m, n), S32, qr_body_fn,
227                                                 {a, vs, taus}, "qr", builder));
228 
229   QRBlockResult result;
230   result.r = values[0];
231   result.vs = values[1];
232   result.taus = values[2];
233   return result;
234 }
235 
236 // Computes W and Y such that I-WY is equivalent to the sequence of Householder
237 // transformations given by vs and taus.
238 // Golub and van Loan, "Matrix Computations", algorithm 5.1.2.
239 // Y = np.zeros([m, n])
240 // W = np.zeros([m, n])
241 // Y[:, 0] = vs[:, 0]
242 // W[:, 0] = -taus[0] * vs[:, 0]
243 // for j in xrange(1, n):
244 //   v = vs[:, j]
245 //   z = -taus[j] * v - taus[j] * np.dot(W, np.dot(Y.T, v))
246 //   W[:, j] = z
247 //   Y[:, j] = v
248 // return W
249 // There is no need to return Y since at termination of the loop it is equal to
250 // vs.
ComputeWYRepresentation(PrimitiveType type,absl::Span<const int64> batch_dims,XlaOp vs,XlaOp taus,int64 m,int64 n,PrecisionConfig::Precision precision)251 StatusOr<XlaOp> ComputeWYRepresentation(PrimitiveType type,
252                                         absl::Span<const int64> batch_dims,
253                                         XlaOp vs, XlaOp taus, int64 m, int64 n,
254                                         PrecisionConfig::Precision precision) {
255   std::vector<int64> batch_dim_indices(batch_dims.size());
256   std::iota(batch_dim_indices.begin(), batch_dim_indices.end(), 0);
257   int64 n_index = batch_dims.size() + 1;
258 
259   auto body_fn = [&](XlaOp j, absl::Span<const XlaOp> values,
260                      XlaBuilder* builder) -> StatusOr<std::vector<XlaOp>> {
261     auto w = values[0];
262     auto y = values[1];
263     const auto vs = values[2];
264     const auto taus = values[3];
265 
266     // Want j values in range [1, ... n).
267     j = j + ConstantR0<int32>(builder, 1);
268     // vs has shape [..., m, 1]
269     auto v = DynamicSliceInMinorDims(vs, {j}, {1});
270     // beta has shape [..., 1]
271     auto beta = DynamicSliceInMinorDims(taus, {j}, {1});
272 
273     // yv has shape [..., n, 1]
274     auto yv = BatchDot(TransposeInMinorDims(y), v, precision);
275     // wyv has shape [..., m, 1]
276     auto wyv = BatchDot(w, yv, precision);
277 
278     auto z = Mul(
279         -beta, v + wyv,
280         /*broadcast_dimensions=*/ConcatVectors(batch_dim_indices, {n_index}));
281 
282     w = DynamicUpdateSliceInMinorDims(w, z, {j});
283     y = DynamicUpdateSliceInMinorDims(y, v, {j});
284 
285     return std::vector<XlaOp>{w, y, vs, taus};
286   };
287 
288   XlaBuilder* builder = vs.builder();
289   auto w = Zeros(builder,
290                  ShapeUtil::MakeShape(type, ConcatVectors(batch_dims, {m, n})));
291   auto y = w;
292   auto v = SliceInMinorDims(vs, {0}, {1});
293   auto beta = SliceInMinorDims(taus, {0}, {1});
294   y = UpdateSliceInMinorDims(y, v, {0});
295   auto bv =
296       Mul(-beta, v,
297           /*broadcast_dimensions=*/ConcatVectors(batch_dim_indices, {n_index}));
298   w = UpdateSliceInMinorDims(w, bv, {0});
299 
300   TF_ASSIGN_OR_RETURN(
301       auto values,
302       ForEachIndex(n - 1, S32, body_fn, {w, y, vs, taus}, "wy", builder));
303   return values[0];
304 }
305 
306 }  // namespace
307 
308 // Block Householder QR Factorization. Algorithm 5.2.2 of Golub and van Loan.
309 // def qr_blocked(a, block_size):
310 //   m = a.shape[0]
311 //   n = a.shape[1]
312 //   q = np.eye(m)
313 //   for i in xrange(0, min(m, n), block_size):
314 //     k = min(block_size, min(m, n) - s)
315 //     (a, vs, taus) = qr(a[i:, i:i+k])
316 //     y = vs
317 //     w = ComputeWYRepresentation(vs, taus, m-i, k)
318 //     a[i:, i+r:] += np.dot(y, np.dot(w.T, a[i:, i+k:]))
319 //     q[:, i:] += np.dot(q[:, i:], np.dot(w, y.T))
320 //   return (q, a)
321 // TODO(phawkins): consider using UT transformations (in the form I - V U V')
322 // rather than WY transformations.
QRDecomposition(XlaOp a,bool full_matrices,int64 block_size,PrecisionConfig::Precision precision)323 StatusOr<QRDecompositionResult> QRDecomposition(
324     XlaOp a, bool full_matrices, int64 block_size,
325     PrecisionConfig::Precision precision) {
326   XlaBuilder* builder = a.builder();
327   TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
328   const int num_dims = a_shape.rank();
329   if (num_dims < 2) {
330     return InvalidArgument("Arguments to QR must have rank >= 2: got shape %s",
331                            a_shape.ToString());
332   }
333   PrimitiveType type = a_shape.element_type();
334 
335   const int64 m = ShapeUtil::GetDimension(a_shape, -2);
336   const int64 n = ShapeUtil::GetDimension(a_shape, -1);
337   const int64 p = std::min(m, n);
338 
339   if (block_size < 1) {
340     return InvalidArgument("block_size argument to QR must be >= 1; got %d",
341                            block_size);
342   }
343 
344   const int64 num_batch_dims = num_dims - 2;
345   std::vector<int64> batch_dims(num_batch_dims);
346   for (int i = 0; i < num_batch_dims; ++i) {
347     batch_dims[i] = ShapeUtil::GetDimension(a_shape, i);
348   }
349 
350   auto q = Broadcast(IdentityMatrix(builder, type, m, m), batch_dims);
351   for (int64 i = 0; i < p; i += block_size) {
352     int64 k = std::min(block_size, p - i);
353 
354     auto a_block = SliceInMinorDims(a, {i, i}, {m, i + k});
355     TF_ASSIGN_OR_RETURN(auto qr_block, QRBlock(a_block, precision));
356 
357     a = UpdateSliceInMinorDims(a, qr_block.r, {i, i});
358 
359     // Compute the I-WY block representation of a product of Householder
360     // matrices.
361     TF_ASSIGN_OR_RETURN(
362         auto w, ComputeWYRepresentation(type, batch_dims, qr_block.vs,
363                                         qr_block.taus, m - i, k, precision));
364     auto y = qr_block.vs;
365 
366     // a[i:, i+k:] += np.dot(Y, np.dot(W.T, a[i:, i+k:]))
367     auto a_panel = SliceInMinorDims(a, {i, i + k}, {m, n});
368     auto a_update = BatchDot(TransposeInMinorDims(w), a_panel, precision);
369     a_update = BatchDot(y, a_update, precision);
370     a_panel = a_panel + a_update;
371     a = UpdateSliceInMinorDims(a, a_panel, {i, i + k});
372 
373     // q[:, i:] += np.dot(np.dot(q[:, i:], W), Y.T))
374     auto q_panel = SliceInMinorDims(q, {0, i}, {m, m});
375     auto q_update = BatchDot(q_panel, w, precision);
376     q_update = BatchDot(q_update, TransposeInMinorDims(y), precision);
377     q_panel = q_panel + q_update;
378     q = UpdateSliceInMinorDims(q, q_panel, {0, i});
379   }
380   QRDecompositionResult result;
381 
382   // full_matrices is false when only a partial result in needed. Slice to the
383   // needed dimensions here.
384   if (!full_matrices) {
385     q = SliceInMinorDims(q, {0, 0}, {m, p});
386     a = SliceInMinorDims(a, {0, 0}, {p, n});
387   }
388   result.q = q;
389   result.r = a;
390   return result;
391 }
392 
393 }  // namespace xla
394