1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/client/lib/qr.h"
17
18 #include <memory>
19 #include <vector>
20
21 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
22 #include "tensorflow/compiler/xla/client/lib/constants.h"
23 #include "tensorflow/compiler/xla/client/lib/loops.h"
24 #include "tensorflow/compiler/xla/client/lib/math.h"
25 #include "tensorflow/compiler/xla/client/lib/matrix.h"
26 #include "tensorflow/compiler/xla/client/lib/slicing.h"
27 #include "tensorflow/compiler/xla/client/xla_builder.h"
28 #include "tensorflow/compiler/xla/literal_util.h"
29 #include "tensorflow/compiler/xla/shape_util.h"
30 #include "tensorflow/compiler/xla/status_macros.h"
31 #include "tensorflow/compiler/xla/statusor.h"
32 #include "tensorflow/core/lib/core/errors.h"
33
34 namespace xla {
35
36 namespace {
37
ConcatVectors(absl::Span<const int64> xs,absl::Span<const int64> ys)38 std::vector<int64> ConcatVectors(absl::Span<const int64> xs,
39 absl::Span<const int64> ys) {
40 std::vector<int64> output(xs.size() + ys.size());
41 std::copy(xs.begin(), xs.end(), output.begin());
42 std::copy(ys.begin(), ys.end(), output.begin() + xs.size());
43 return output;
44 }
45
46 // Computes a Householder reflection of the form:
47 // H = I - tau v v.T.
48 // such that
49 // H . ( x1 ) = ( x1 )
50 // ( x2 ) = ( x2 )
51 // ( ... ) = ( ... )
52 // ( xk ) = ( beta )
53 // ( ... ) ( 0 )
54 // ( ... ) ( 0 )
55 // Unlike the usual formulation, we allow the caller to supply 'k' rather than
56 // only providing the relevant part of 'x' to maintain XLA's static shape
57 // invariant. In addition, the implementation supports batching.
58 // Pseudo-code, without batching:
59 // alpha = x[k]
60 // x_copy = np.copy(x)
61 // x_copy[:k+1] = 0
62 // xnorm = norm2(x_copy)
63 // if xnorm == 0:
64 // beta = alpha
65 // tau = 0
66 // v = np.zeros_like(x)
67 // else:
68 // beta = - np.sign(alpha) * dlapy2(alpha, xnorm)
69 // tau = (beta - alpha) / beta
70 // v = x / (alpha - beta)
71 // v[k] = 1
72 // return (v, tau, beta)
73 // TODO(phawkins): LAPACK's xLARFG implementation has code for handling
74 // overflows in the norm/beta calculations. Perhaps do the same here.
House(XlaOp x,XlaOp k,absl::Span<const int64> batch_dims,const int64 m,XlaOp * v,XlaOp * tau,XlaOp * beta)75 Status House(XlaOp x, XlaOp k, absl::Span<const int64> batch_dims,
76 const int64 m, XlaOp* v, XlaOp* tau, XlaOp* beta) {
77 XlaBuilder* const builder = x.builder();
78 TF_ASSIGN_OR_RETURN(Shape x_shape, builder->GetShape(x));
79 const PrimitiveType type = x_shape.element_type();
80
81 std::vector<int64> batch_dim_ids(batch_dims.size());
82 std::iota(batch_dim_ids.begin(), batch_dim_ids.end(), 0);
83 const int64 minor_dim = batch_dims.size();
84
85 XlaOp zero = ScalarLike(x, 0.0);
86 XlaOp one = ScalarLike(x, 1.0);
87
88 // alpha = x[k]
89 XlaOp alpha = Reshape(DynamicSliceInMinorDims(x, {k}, {1}), batch_dims);
90
91 // Compute x[k+1:] (padded with zeros in elements 0..k)
92 XlaOp iota = Iota(builder, S32, m);
93 XlaOp x_after_k = Mul(x, ConvertElementType(Gt(iota, k), type),
94 /*broadcast_dimensions=*/{minor_dim});
95
96 // sigma = np.dot(x[k+1:], x[k+1:])
97 auto sigma = Reduce(x_after_k * x_after_k, zero,
98 CreateScalarAddComputation(type, builder), {minor_dim});
99 // mu = np.sqrt(x[k]*x[k] + sigma)
100 auto mu = Sqrt(Square(alpha) + sigma);
101
102 auto sigma_is_zero = Eq(sigma, zero);
103
104 *beta = Select(sigma_is_zero, alpha, -Sign(alpha) * mu);
105 *tau = Select(sigma_is_zero, Broadcast(zero, batch_dims),
106 (*beta - alpha) / *beta);
107 auto divisor =
108 Select(sigma_is_zero, Broadcast(one, batch_dims), alpha - *beta);
109
110 auto e_k = Broadcast(ConvertElementType(Eq(iota, k), type),
111 std::vector<int64>(batch_dims.size(), 1));
112
113 // Form v as [0, 0, ..., 1] ++ x[k+1:] / divisor
114 // If sigma is zero, x[k+1:] is zero, so use any non-zero divisor.
115 *v = e_k + Div(x_after_k, divisor, /*broadcast_dimensions=*/batch_dim_ids);
116 return Status::OK();
117 }
118
119 // Householder QR decomposition. Algorithm 5.2.1 from Golub and Van
120 // Loan "Matrix Computations", 4th Edition. This is an unblocked implementation
121 // used as an inner routine of the blocked implementation.
122 // Algorithm is adapted slightly so the shapes inside the loop are static, at
123 // the cost of some redundant computation. Since this is used as an inner block
124 // kernel, accumulates the Householder transformations (vs, taus) rather than
125 // the matrix q.
126 // Equivalent Python code, without batching:
127 // def qr(a):
128 // m = a.shape[0]
129 // n = a.shape[1]
130 // vs = np.zeros([m, n])
131 // taus = np.zeros([n])
132 // for j in xrange(min(m, n)):
133 // v, tau, beta = house(a[:, j], j)
134 // # Unusually, we apply the Householder transformation to the entirety of
135 // # a, wasting FLOPs to maintain the static shape invariant that XLA
136 // # requires. For columns that precede j this has no effect.
137 // a[:, :] -= tau * np.dot(v[:, np.newaxis],
138 // np.dot(v[np.newaxis, :], a[:, :]))
139 // # Form column j explicitly rather than relying on the precision of the
140 // # Householder update.
141 // a[j, j] = beta
142 // a[j+1:, j] = np.zeros([m - j - 1], dtype=a.dtype)
143 // vs[:, j] = v
144 // taus[j] = tau
145 // return (q, vs, taus)
146 struct QRBlockResult {
147 // The factored R value
148 XlaOp r;
149
150 // Representation of the Householder matrices I - beta v v.T
151 XlaOp taus; // Shape: [..., n]
152 XlaOp vs; // Shape: [..., m, n]
153 };
QRBlock(XlaOp a,PrecisionConfig::Precision precision)154 StatusOr<QRBlockResult> QRBlock(XlaOp a, PrecisionConfig::Precision precision) {
155 XlaBuilder* builder = a.builder();
156 TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
157 const int num_dims = a_shape.rank();
158 if (num_dims < 2) {
159 return InvalidArgument("Argument to QR must have rank >= 2; got shape %s",
160 a_shape.ToString());
161 }
162 PrimitiveType type = a_shape.element_type();
163
164 const int64 m = ShapeUtil::GetDimension(a_shape, -2);
165 const int64 n = ShapeUtil::GetDimension(a_shape, -1);
166
167 const int64 num_batch_dims = num_dims - 2;
168 std::vector<int64> batch_dims(num_batch_dims);
169 for (int i = 0; i < num_batch_dims; ++i) {
170 batch_dims[i] = ShapeUtil::GetDimension(a_shape, i);
171 }
172
173 std::vector<int64> batch_dim_indices(num_batch_dims);
174 std::iota(batch_dim_indices.begin(), batch_dim_indices.end(), 0);
175
176 auto qr_body_fn = [&](XlaOp j, absl::Span<const XlaOp> values,
177 XlaBuilder* builder) -> StatusOr<std::vector<XlaOp>> {
178 auto a = values[0];
179 auto vs = values[1];
180 auto taus = values[2];
181
182 // v, beta = house(a[:, j], j)
183 auto x = DynamicSliceInMinorDims(a, {j}, {1});
184 XlaOp v, tau, beta;
185 TF_RETURN_IF_ERROR(House(Collapse(x, {num_dims - 2, num_dims - 1}), j,
186 batch_dims, m, &v, &tau, &beta));
187
188 std::vector<int64> shape = batch_dims;
189 shape.push_back(1);
190 shape.push_back(m);
191 auto v_broadcast = Reshape(v, shape);
192 // a[:, :] -= tau * np.dot(v[:, np.newaxis],
193 // np.dot(v[np.newaxis, :], a[:, :]))
194 auto vva = BatchDot(v_broadcast, a, precision);
195 vva = BatchDot(TransposeInMinorDims(v_broadcast), vva, precision);
196 a = a - Mul(tau, vva,
197 /*broadcast_dimensions=*/batch_dim_indices);
198
199 // It is more precise to populate column 'k' explicitly, rather than
200 // computing it implicitly by applying the Householder transformation.
201 // a[k,k] = beta
202 // a[k+1:,k] = np.zeros([m-k-1], dtype=a.dtype)
203 auto iota = Reshape(Iota(a.builder(), S32, m), {m, 1});
204 auto predecessor_mask = ConvertElementType(Lt(iota, j), type);
205 auto mask = Broadcast(ConvertElementType(Eq(iota, j), type),
206 std::vector<int64>(batch_dims.size(), 1));
207 auto new_x = Mul(x, predecessor_mask,
208 /*broadcast_dimensions=*/{num_dims - 2, num_dims - 1}) +
209 Mul(beta, mask, /*broadcast_dimensions=*/batch_dim_indices);
210 a = DynamicUpdateSliceInMinorDims(a, new_x, {j});
211
212 // vs[:, j] = v
213 vs = DynamicUpdateSliceInMinorDims(
214 vs, Reshape(v, ConcatVectors(batch_dims, {m, 1})), {j});
215 // taus[j] = tau
216 taus = DynamicUpdateSliceInMinorDims(
217 taus, Reshape(tau, ConcatVectors(batch_dims, {1})), {j});
218 return std::vector<XlaOp>{a, vs, taus};
219 };
220
221 auto vs = Zeros(
222 builder, ShapeUtil::MakeShape(type, ConcatVectors(batch_dims, {m, n})));
223 auto taus = Zeros(builder,
224 ShapeUtil::MakeShape(type, ConcatVectors(batch_dims, {n})));
225
226 TF_ASSIGN_OR_RETURN(auto values, ForEachIndex(std::min(m, n), S32, qr_body_fn,
227 {a, vs, taus}, "qr", builder));
228
229 QRBlockResult result;
230 result.r = values[0];
231 result.vs = values[1];
232 result.taus = values[2];
233 return result;
234 }
235
236 // Computes W and Y such that I-WY is equivalent to the sequence of Householder
237 // transformations given by vs and taus.
238 // Golub and van Loan, "Matrix Computations", algorithm 5.1.2.
239 // Y = np.zeros([m, n])
240 // W = np.zeros([m, n])
241 // Y[:, 0] = vs[:, 0]
242 // W[:, 0] = -taus[0] * vs[:, 0]
243 // for j in xrange(1, n):
244 // v = vs[:, j]
245 // z = -taus[j] * v - taus[j] * np.dot(W, np.dot(Y.T, v))
246 // W[:, j] = z
247 // Y[:, j] = v
248 // return W
249 // There is no need to return Y since at termination of the loop it is equal to
250 // vs.
ComputeWYRepresentation(PrimitiveType type,absl::Span<const int64> batch_dims,XlaOp vs,XlaOp taus,int64 m,int64 n,PrecisionConfig::Precision precision)251 StatusOr<XlaOp> ComputeWYRepresentation(PrimitiveType type,
252 absl::Span<const int64> batch_dims,
253 XlaOp vs, XlaOp taus, int64 m, int64 n,
254 PrecisionConfig::Precision precision) {
255 std::vector<int64> batch_dim_indices(batch_dims.size());
256 std::iota(batch_dim_indices.begin(), batch_dim_indices.end(), 0);
257 int64 n_index = batch_dims.size() + 1;
258
259 auto body_fn = [&](XlaOp j, absl::Span<const XlaOp> values,
260 XlaBuilder* builder) -> StatusOr<std::vector<XlaOp>> {
261 auto w = values[0];
262 auto y = values[1];
263 const auto vs = values[2];
264 const auto taus = values[3];
265
266 // Want j values in range [1, ... n).
267 j = j + ConstantR0<int32>(builder, 1);
268 // vs has shape [..., m, 1]
269 auto v = DynamicSliceInMinorDims(vs, {j}, {1});
270 // beta has shape [..., 1]
271 auto beta = DynamicSliceInMinorDims(taus, {j}, {1});
272
273 // yv has shape [..., n, 1]
274 auto yv = BatchDot(TransposeInMinorDims(y), v, precision);
275 // wyv has shape [..., m, 1]
276 auto wyv = BatchDot(w, yv, precision);
277
278 auto z = Mul(
279 -beta, v + wyv,
280 /*broadcast_dimensions=*/ConcatVectors(batch_dim_indices, {n_index}));
281
282 w = DynamicUpdateSliceInMinorDims(w, z, {j});
283 y = DynamicUpdateSliceInMinorDims(y, v, {j});
284
285 return std::vector<XlaOp>{w, y, vs, taus};
286 };
287
288 XlaBuilder* builder = vs.builder();
289 auto w = Zeros(builder,
290 ShapeUtil::MakeShape(type, ConcatVectors(batch_dims, {m, n})));
291 auto y = w;
292 auto v = SliceInMinorDims(vs, {0}, {1});
293 auto beta = SliceInMinorDims(taus, {0}, {1});
294 y = UpdateSliceInMinorDims(y, v, {0});
295 auto bv =
296 Mul(-beta, v,
297 /*broadcast_dimensions=*/ConcatVectors(batch_dim_indices, {n_index}));
298 w = UpdateSliceInMinorDims(w, bv, {0});
299
300 TF_ASSIGN_OR_RETURN(
301 auto values,
302 ForEachIndex(n - 1, S32, body_fn, {w, y, vs, taus}, "wy", builder));
303 return values[0];
304 }
305
306 } // namespace
307
308 // Block Householder QR Factorization. Algorithm 5.2.2 of Golub and van Loan.
309 // def qr_blocked(a, block_size):
310 // m = a.shape[0]
311 // n = a.shape[1]
312 // q = np.eye(m)
313 // for i in xrange(0, min(m, n), block_size):
314 // k = min(block_size, min(m, n) - s)
315 // (a, vs, taus) = qr(a[i:, i:i+k])
316 // y = vs
317 // w = ComputeWYRepresentation(vs, taus, m-i, k)
318 // a[i:, i+r:] += np.dot(y, np.dot(w.T, a[i:, i+k:]))
319 // q[:, i:] += np.dot(q[:, i:], np.dot(w, y.T))
320 // return (q, a)
321 // TODO(phawkins): consider using UT transformations (in the form I - V U V')
322 // rather than WY transformations.
QRDecomposition(XlaOp a,bool full_matrices,int64 block_size,PrecisionConfig::Precision precision)323 StatusOr<QRDecompositionResult> QRDecomposition(
324 XlaOp a, bool full_matrices, int64 block_size,
325 PrecisionConfig::Precision precision) {
326 XlaBuilder* builder = a.builder();
327 TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
328 const int num_dims = a_shape.rank();
329 if (num_dims < 2) {
330 return InvalidArgument("Arguments to QR must have rank >= 2: got shape %s",
331 a_shape.ToString());
332 }
333 PrimitiveType type = a_shape.element_type();
334
335 const int64 m = ShapeUtil::GetDimension(a_shape, -2);
336 const int64 n = ShapeUtil::GetDimension(a_shape, -1);
337 const int64 p = std::min(m, n);
338
339 if (block_size < 1) {
340 return InvalidArgument("block_size argument to QR must be >= 1; got %d",
341 block_size);
342 }
343
344 const int64 num_batch_dims = num_dims - 2;
345 std::vector<int64> batch_dims(num_batch_dims);
346 for (int i = 0; i < num_batch_dims; ++i) {
347 batch_dims[i] = ShapeUtil::GetDimension(a_shape, i);
348 }
349
350 auto q = Broadcast(IdentityMatrix(builder, type, m, m), batch_dims);
351 for (int64 i = 0; i < p; i += block_size) {
352 int64 k = std::min(block_size, p - i);
353
354 auto a_block = SliceInMinorDims(a, {i, i}, {m, i + k});
355 TF_ASSIGN_OR_RETURN(auto qr_block, QRBlock(a_block, precision));
356
357 a = UpdateSliceInMinorDims(a, qr_block.r, {i, i});
358
359 // Compute the I-WY block representation of a product of Householder
360 // matrices.
361 TF_ASSIGN_OR_RETURN(
362 auto w, ComputeWYRepresentation(type, batch_dims, qr_block.vs,
363 qr_block.taus, m - i, k, precision));
364 auto y = qr_block.vs;
365
366 // a[i:, i+k:] += np.dot(Y, np.dot(W.T, a[i:, i+k:]))
367 auto a_panel = SliceInMinorDims(a, {i, i + k}, {m, n});
368 auto a_update = BatchDot(TransposeInMinorDims(w), a_panel, precision);
369 a_update = BatchDot(y, a_update, precision);
370 a_panel = a_panel + a_update;
371 a = UpdateSliceInMinorDims(a, a_panel, {i, i + k});
372
373 // q[:, i:] += np.dot(np.dot(q[:, i:], W), Y.T))
374 auto q_panel = SliceInMinorDims(q, {0, i}, {m, m});
375 auto q_update = BatchDot(q_panel, w, precision);
376 q_update = BatchDot(q_update, TransposeInMinorDims(y), precision);
377 q_panel = q_panel + q_update;
378 q = UpdateSliceInMinorDims(q, q_panel, {0, i});
379 }
380 QRDecompositionResult result;
381
382 // full_matrices is false when only a partial result in needed. Slice to the
383 // needed dimensions here.
384 if (!full_matrices) {
385 q = SliceInMinorDims(q, {0, 0}, {m, p});
386 a = SliceInMinorDims(a, {0, 0}, {p, n});
387 }
388 result.q = q;
389 result.r = a;
390 return result;
391 }
392
393 } // namespace xla
394