1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
7 http://www.apache.org/licenses/LICENSE-2.0
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
16 #include "tensorflow/compiler/xla/service/qr_expander.h"
18 #include <memory>
19 #include <vector>
21 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
22 #include "tensorflow/compiler/xla/client/lib/constants.h"
23 #include "tensorflow/compiler/xla/client/lib/loops.h"
24 #include "tensorflow/compiler/xla/client/lib/math.h"
25 #include "tensorflow/compiler/xla/client/lib/matrix.h"
26 #include "tensorflow/compiler/xla/client/lib/slicing.h"
27 #include "tensorflow/compiler/xla/client/xla_builder.h"
28 #include "tensorflow/compiler/xla/literal.h"
29 #include "tensorflow/compiler/xla/primitive_util.h"
30 #include "tensorflow/compiler/xla/shape_util.h"
31 #include "tensorflow/compiler/xla/status_macros.h"
32 #include "tensorflow/compiler/xla/statusor.h"
33 #include "tensorflow/compiler/xla/util.h"
34 #include "tensorflow/core/lib/core/errors.h"
36 namespace xla {
38 namespace {
ConcatVectors(absl::Span<const int64> xs,absl::Span<const int64> ys)40 std::vector<int64> ConcatVectors(absl::Span<const int64> xs,
41 absl::Span<const int64> ys) {
42 std::vector<int64> output;
43 output.reserve(xs.size() + ys.size());
44 std::copy(xs.begin(), xs.end(), std::back_inserter(output));
45 std::copy(ys.begin(), ys.end(), std::back_inserter(output));
46 return output;
47 }
49 // Computes a Householder reflection of the form:
50 // H = I - tau v v.T.
51 // such that
52 // H . ( x1 ) = ( x1 )
53 // ( x2 ) = ( x2 )
54 // ( ... ) = ( ... )
55 // ( xk ) = ( beta )
56 // ( ... ) ( 0 )
57 // ( ... ) ( 0 )
58 // Unlike the usual formulation, we allow the caller to supply 'k' rather than
59 // only providing the relevant part of 'x' to maintain XLA's static shape
60 // invariant. In addition, the implementation supports batching.
61 // Pseudo-code, without batching:
62 // alpha = x[k]
63 // x_copy = np.copy(x)
64 // x_copy[:k+1] = 0
65 // xnorm = norm2(x_copy)
66 // if xnorm == 0 and np.imag(alpha) == 0:
67 // beta = alpha
68 // tau = 0
69 // v = np.zeros_like(x)
70 // else:
71 // beta = -np.sign(np.real(alpha)) * np.sqrt(alpha * np.conj(alpha) + xnorm)
72 // if np.issubdtype(x.dtype, np.complexfloating):
73 // tau = (beta - alpha) / beta
74 // else:
75 // tau = (beta - np.real(alpha) / beta) + (-np.imag(alpha) / beta) * 1j
76 // v = x / (alpha - beta)
77 // v[k] = 1
78 // return (v, tau, beta)
79 // TODO(phawkins): LAPACK's xLARFG implementation has code for handling
80 // overflows in the norm/beta calculations. Perhaps do the same here.
House(XlaOp x,XlaOp k,absl::Span<const int64> batch_dims,const int64 m,XlaOp * v,XlaOp * tau,XlaOp * beta)81 Status House(XlaOp x, XlaOp k, absl::Span<const int64> batch_dims,
82 const int64 m, XlaOp* v, XlaOp* tau, XlaOp* beta) {
83 XlaBuilder* const builder = x.builder();
84 TF_ASSIGN_OR_RETURN(Shape x_shape, builder->GetShape(x));
85 const PrimitiveType type = x_shape.element_type();
87 std::vector<int64> batch_dim_ids(batch_dims.size());
88 std::iota(batch_dim_ids.begin(), batch_dim_ids.end(), 0);
89 const int64 minor_dim = batch_dims.size();
91 XlaOp zero = ScalarLike(x, 0.0);
93 // alpha = x[k]
94 XlaOp alpha = Reshape(DynamicSliceInMinorDims(x, {k}, {1}), batch_dims);
96 // Compute x[k+1:] (padded with zeros in elements 0..k)
97 XlaOp iota = Iota(builder, S32, m);
98 XlaOp x_after_k = Mul(x, ConvertElementType(Gt(iota, k), type),
99 /*broadcast_dimensions=*/{minor_dim});
101 XlaOp sigma_is_zero;
102 if (primitive_util::IsComplexType(type)) {
103 // sigma = np.dot(x[k+1:], np.conj(x[k+1:]))
104 // TODO(phawkins): this calculation may be numerically unstable.
105 auto x_squared = Real(x_after_k * Conj(x_after_k));
106 auto sigma =
107 Reduce(x_squared, ScalarLike(x_squared, 0.0),
108 CreateScalarAddComputation(
109 primitive_util::ComplexComponentType(type), builder),
110 {minor_dim});
111 // mu = np.sqrt(x[k]*np.con(x[k]) + sigma)
112 auto mu = Sqrt(Real(alpha * Conj(alpha)) + sigma);
114 sigma_is_zero = Eq(sigma, ScalarLike(sigma, 0));
115 sigma_is_zero = And(sigma_is_zero, Eq(Imag(alpha), ScalarLike(sigma, 0)));
117 *beta = Select(Lt(Real(alpha), ScalarLike(sigma, 0)), ScalarLike(mu, 1),
118 ScalarLike(mu, -1)) *
119 mu;
120 *beta = Select(sigma_is_zero, Real(alpha), *beta);
121 *tau = Complex((*beta - Real(alpha)) / *beta, -Imag(alpha) / *beta);
122 } else {
123 // sigma = np.dot(x[k+1:], x[k+1:])
124 // TODO(phawkins): this calculation may be numerically unstable.
125 auto sigma = Reduce(x_after_k * x_after_k, zero,
126 CreateScalarAddComputation(type, builder), {minor_dim});
127 // mu = np.sqrt(x[k]*x[k] + sigma)
128 auto mu = Sqrt(Square(alpha) + sigma);
129 sigma_is_zero = Eq(sigma, zero);
131 XlaOp one = ScalarLike(x, 1.0);
132 *beta = Select(Lt(alpha, zero), one, -one) * mu;
133 *beta = Select(sigma_is_zero, alpha, *beta);
134 *tau = (*beta - alpha) / *beta;
135 }
136 *tau = Select(sigma_is_zero, ZerosLike(*tau), *tau);
138 auto divisor =
139 Select(sigma_is_zero, Broadcast(ScalarLike(alpha, 1), batch_dims),
140 alpha - ConvertElementType(*beta, type));
142 auto e_k = Broadcast(ConvertElementType(Eq(iota, k), type),
143 std::vector<int64>(batch_dims.size(), 1));
145 // Form v as [0, 0, ..., 1] ++ x[k+1:] / divisor
146 // If sigma is zero, x[k+1:] is zero, so use any non-zero divisor.
147 *v = e_k + Div(x_after_k, divisor, /*broadcast_dimensions=*/batch_dim_ids);
148 return Status::OK();
149 }
151 } // namespace
153 // Householder QR decomposition. Algorithm 5.2.1 from Golub and Van
154 // Loan "Matrix Computations", 4th Edition. This is an unblocked implementation
155 // used as an inner routine of the blocked implementation.
156 // Algorithm is adapted slightly so the shapes inside the loop are static, at
157 // the cost of some redundant computation. Since this is used as an inner block
158 // kernel, accumulates the Householder transformations (vs, taus) rather than
159 // the matrix q.
160 // Equivalent Python code, without batching:
161 // def qr(a):
162 // m = a.shape[0]
163 // n = a.shape[1]
164 // taus = np.zeros([n])
165 // for j in xrange(min(m, n)):
166 // v, tau, beta = house(a[:, j], j)
167 // a[:, j+1:] -= np.conj(tau) * np.dot(v[:, np.newaxis],
168 // np.dot(np.conj(v[np.newaxis, :]), a[:, j+1:]))
169 // # Form column j explicitly rather than relying on the precision of the
170 // # Householder update.
171 // a[j, j] = beta
172 // a[j+1:, j] = v[j+1:]
173 // taus[j] = tau
174 // return (a, taus)
QrBlock(XlaOp a,PrecisionConfig::Precision precision)175 StatusOr<QrExpander::QrResult> QrExpander::QrBlock(
176 XlaOp a, PrecisionConfig::Precision precision) {
177 XlaBuilder* builder = a.builder();
178 TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
179 const int num_dims = a_shape.rank();
180 if (num_dims < 2) {
181 return InvalidArgument("Argument to QR must have rank >= 2; got shape %s",
182 a_shape.ToString());
183 }
184 PrimitiveType type = a_shape.element_type();
186 const int64 m = ShapeUtil::GetDimension(a_shape, -2);
187 const int64 n = ShapeUtil::GetDimension(a_shape, -1);
189 const int64 num_batch_dims = num_dims - 2;
190 std::vector<int64> batch_dims(num_batch_dims);
191 for (int i = 0; i < num_batch_dims; ++i) {
192 batch_dims[i] = ShapeUtil::GetDimension(a_shape, i);
193 }
195 std::vector<int64> batch_dim_indices(num_batch_dims);
196 std::iota(batch_dim_indices.begin(), batch_dim_indices.end(), 0);
198 auto qr_body_fn = [&](XlaOp j, absl::Span<const XlaOp> values,
199 XlaBuilder* builder) -> StatusOr<std::vector<XlaOp>> {
200 auto a = values[0];
201 auto taus = values[1];
203 // v, tau, beta = house(a[:, j], j)
204 auto x = DynamicSliceInMinorDims(a, {j}, {1});
205 XlaOp v, tau, beta;
206 TF_RETURN_IF_ERROR(House(Collapse(x, {num_dims - 2, num_dims - 1}), j,
207 batch_dims, m, &v, &tau, &beta));
209 const int64 minor_dim = batch_dims.size();
210 auto iota_mn = Iota(
211 builder, ShapeUtil::MakeShape(S32, ConcatVectors(batch_dims, {m, n})),
212 minor_dim + 1);
214 std::vector<int64> shape = batch_dims;
215 shape.push_back(1);
216 shape.push_back(m);
217 auto v_broadcast = Reshape(v, shape);
218 // a[:, j+1:] -= np.conj(tau) * (v[:, np.newaxis] @
219 // (np.conj(v[np.newaxis, :]) @ a[:, j+1:]))
220 // We use masking rather than a loop-variant shape to handle the j+1:
221 // indexing.
222 auto vva = BatchDot(MaybeConjugate(v_broadcast, true),
223 Select(Lt(j, iota_mn), a, ZerosLike(a)), precision);
224 vva = BatchDot(v_broadcast, true, vva, false, precision);
225 a = a - Mul(MaybeConjugate(tau, true), vva,
226 /*broadcast_dimensions=*/batch_dim_indices);
228 // a[j, j] = beta
229 // a[j+1:,j] = v[j+1:]
230 auto iota = Reshape(Iota(a.builder(), S32, m), {m, 1});
231 auto predecessor_mask = ConvertElementType(Lt(iota, j), type);
232 auto mask = Broadcast(ConvertElementType(Eq(iota, j), type),
233 std::vector<int64>(batch_dims.size(), 1));
234 auto successor_mask = Gt(Iota(a.builder(), S32, m), j);
235 auto new_x = Mul(x, predecessor_mask,
236 /*broadcast_dimensions=*/{num_dims - 2, num_dims - 1}) +
237 Mul(ConvertElementType(beta, type), mask,
238 /*broadcast_dimensions=*/batch_dim_indices);
239 new_x = Add(
240 new_x, Select(Broadcast(successor_mask, batch_dims), v, ZerosLike(v)),
241 /*broadcast_dimensions=*/ConcatVectors(batch_dim_indices, {minor_dim}));
242 // Update a[:,j]
243 std::vector<int64> dim_ids(num_dims);
244 std::iota(dim_ids.begin(), dim_ids.end(), 0);
245 new_x = BroadcastInDim(new_x, ConcatVectors(batch_dims, {m, n}),
246 /*broadcast_dimensions=*/dim_ids);
247 a = Select(Eq(iota_mn, j), new_x, a);
249 // taus[j] = tau
250 std::vector<int64> tau_broadcast_dims(batch_dims.size());
251 std::iota(tau_broadcast_dims.begin(), tau_broadcast_dims.end(), 0);
253 auto iota_n =
254 Iota(builder, ShapeUtil::MakeShape(S32, ConcatVectors(batch_dims, {n})),
255 minor_dim);
256 auto taus_zeros = ZerosLike(taus);
257 auto taus_update = Select(
258 Eq(iota_n, j),
259 Add(taus_zeros, tau, /*broadcast_dimensions=*/tau_broadcast_dims),
260 taus_zeros);
261 taus = taus + taus_update;
262 return std::vector<XlaOp>{a, taus};
263 };
265 auto taus = Zeros(
266 builder,
267 ShapeUtil::MakeShape(type, ConcatVectors(batch_dims, {std::min(m, n)})));
269 TF_ASSIGN_OR_RETURN(auto values, ForEachIndex(std::min(m, n), S32, qr_body_fn,
270 {a, taus}, "qr", builder));
272 QrResult result;
273 result.a = values[0];
274 result.taus = values[1];
275 return result;
276 }
278 // Computes an upper triangular matrix T such that (I - Y @ T @ Y^t) is a
279 // product of the elementary Householder reflectors given by `vs` and `taus`.
280 //
281 // Schreiber, Robert, and Charles Van Loan. "A storage-efficient WY
282 // representation for products of Householder transformations." SIAM Journal on
283 // Scientific and Statistical Computing 10.1 (1989): 53-57.
284 //
285 // def compact_wy(vs, taus):
286 // m, n = vs.shape[-2:]
287 // t = np.eye(n) * -taus
288 // # We premultiply Y.T @ vs, since we would prefer to compute a single matrix
289 // # multiplication to many matrix-vector products.
290 // vtv = -taus[None, :] * np.triu(np.conj(vs.T) @ vs, 1) + np.eye(n)
291 // for i in range(1, n):
292 // t[:, i] = scipy.linalg.blas.strmm(t, vtv[:, i])
293 // return t
CompactWYRepresentation(PrimitiveType type,absl::Span<const int64> batch_dims,XlaOp vs,XlaOp taus,int64 m,int64 n,PrecisionConfig::Precision precision)294 StatusOr<XlaOp> QrExpander::CompactWYRepresentation(
295 PrimitiveType type, absl::Span<const int64> batch_dims, XlaOp vs,
296 XlaOp taus, int64 m, int64 n, PrecisionConfig::Precision precision) {
297 XlaBuilder* builder = vs.builder();
299 std::vector<int64> batch_dim_indices(batch_dims.size());
300 std::iota(batch_dim_indices.begin(), batch_dim_indices.end(), 0);
301 int64 n_index = batch_dims.size() + 1;
303 auto body_fn = [&](XlaOp j, absl::Span<const XlaOp> values,
304 XlaBuilder* builder) -> StatusOr<std::vector<XlaOp>> {
305 // w has shape [..., m, n]
306 auto t = values[0];
307 const auto vtv = values[1];
309 // yv has shape [..., n, 1]
310 auto yv = DynamicSliceInMinorDims(vtv, {j}, {1});
312 // z has shape [..., n, 1]
313 auto z = BatchDot(t, yv, precision);
315 t = DynamicUpdateSliceInMinorDims(t, z, {j});
317 return std::vector<XlaOp>{t, vtv};
318 };
320 auto tau_scale = BroadcastInDim(-taus, ConcatVectors(batch_dims, {1, n}),
321 ConcatVectors(batch_dim_indices, {n_index}));
323 auto eye = Broadcast(IdentityMatrix(builder, type, n, n), batch_dims);
324 auto t = eye;
326 auto vtv = BatchDot(MaybeConjugate(vs, true), /*transpose_x=*/true, vs,
327 /*transpose_y=*/false, precision);
328 vtv = Select(TriangleMask(vtv, 0), ZerosLike(vtv), vtv);
329 vtv = (vtv + eye) * tau_scale;
331 TF_ASSIGN_OR_RETURN(auto values,
332 ForEachIndex(n, S32, body_fn, {t, vtv}, "wy", builder));
333 return values[0];
334 }
336 // Block Householder QR Factorization. Algorithm 5.2.2 of Golub and van Loan.
337 // def qr_blocked(a, block_size):
338 // m = a.shape[0]
339 // n = a.shape[1]
340 // q = np.eye(m)
341 // for i in xrange(0, min(m, n), block_size):
342 // k = min(block_size, min(m, n) - s)
343 // (a, taus) = qr(a[i:, i:i+k])
344 // y = np.eye(m, n) + np.tril(a, -1)
345 // t = CompactWYRepresentation(vs, taus, m-i, k)
346 // a[i:, i+k:] += (y @ np.conj(t.T)) @ (np.conj(y.T) @ a[i:, i+k:])
347 // q[:, i:] += (q[:, i:] @ y) @ np.conj((y @ np.conj(t.T)).T)
348 // return (q, a)
BuildQrDecomposition(XlaOp a,int64 block_size,PrecisionConfig::Precision precision)349 StatusOr<XlaOp> QrExpander::BuildQrDecomposition(
350 XlaOp a, int64 block_size, PrecisionConfig::Precision precision) {
351 XlaBuilder* builder = a.builder();
352 TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
353 const int num_dims = a_shape.rank();
354 if (num_dims < 2) {
355 return InvalidArgument("Arguments to QR must have rank >= 2: got shape %s",
356 a_shape.ToString());
357 }
358 PrimitiveType type = a_shape.element_type();
360 const int64 m = ShapeUtil::GetDimension(a_shape, -2);
361 const int64 n = ShapeUtil::GetDimension(a_shape, -1);
362 const int64 p = std::min(m, n);
364 if (block_size < 1) {
365 return InvalidArgument("block_size argument to QR must be >= 1; got %d",
366 block_size);
367 }
369 const int64 num_batch_dims = num_dims - 2;
370 std::vector<int64> batch_dims(num_batch_dims);
371 for (int i = 0; i < num_batch_dims; ++i) {
372 batch_dims[i] = ShapeUtil::GetDimension(a_shape, i);
373 }
375 auto q = Broadcast(IdentityMatrix(builder, type, m, m), batch_dims);
376 for (int64 i = 0; i < p; i += block_size) {
377 int64 k = std::min(block_size, p - i);
379 auto a_block = SliceInMinorDims(a, {i, i}, {m, i + k});
380 TF_ASSIGN_OR_RETURN(auto qr_block, QrBlock(a_block, precision));
381 auto y = Add(
382 IdentityMatrix(builder, type, m - i, k),
383 Select(TriangleMask(qr_block.a, -1), qr_block.a, ZerosLike(qr_block.a)),
384 /*broadcast_dimensions=*/{num_dims - 2, num_dims - 1});
386 a = UpdateSliceInMinorDims(a, qr_block.a, {i, i});
388 // Compute the I + Y @ T @ Y^t block representation of a product of
389 // Householder matrices.
391 auto t, CompactWYRepresentation(type, batch_dims, y, qr_block.taus,
392 m - i, k, precision));
394 // a[i:, i+k:] += (y @ np.conj(t.T)) @ (np.conj(y.T) @ a[i:, i+k:])
395 auto yt = BatchDot(y, /*transpose_x=*/false, MaybeConjugate(t, true),
396 /*transpose_y=*/true, precision);
397 auto a_panel = SliceInMinorDims(a, {i, i + k}, {m, n});
398 auto a_update =
399 BatchDot(MaybeConjugate(y, true), /*transpose_x=*/true, a_panel,
400 /*transpose_y=*/false, precision);
401 a_update = BatchDot(yt, a_update, precision);
402 a_panel = a_panel + a_update;
403 a = UpdateSliceInMinorDims(a, a_panel, {i, i + k});
405 // q[:, i:] += (q[:, i:] @ y) @ np.conj((y @ np.conj(t.T)).T)
406 auto q_panel = SliceInMinorDims(q, {0, i}, {m, m});
407 auto q_update = BatchDot(q_panel, y, precision);
408 q_update =
409 BatchDot(q_update, /*transpose_x=*/false, MaybeConjugate(yt, true),
410 /*transpose_y=*/true, precision);
411 q_panel = q_panel + q_update;
412 q = UpdateSliceInMinorDims(q, q_panel, {0, i});
413 }
415 return Tuple(builder, {q, UpperTriangle(a)});
416 }
InstructionMatchesPattern(HloInstruction * instruction)418 bool QrExpander::InstructionMatchesPattern(HloInstruction* instruction) {
419 return instruction->opcode() == HloOpcode::kCustomCall &&
420 instruction->custom_call_target() == "QrDecomposition";
421 }
ExpandInstruction(HloInstruction * instruction)423 StatusOr<HloInstruction*> QrExpander::ExpandInstruction(
424 HloInstruction* instruction) {
425 const string name =
426 absl::StrFormat("xla.qr_%s", instruction->operand(0)->shape().ToString());
428 HloModule* module = instruction->parent()->parent();
430 HloComputation*& computation =
431 computation_cache_.emplace(name, nullptr).first->second;
432 if (!computation) {
433 // Builds a new expansion.
434 //
435 // TODO(b/62327888): We do something unusual here: we build the computation
436 // using the XlaBuilder API, which is nominally an XLA client API. We do
437 // this because the external APIs for building complicated computations
438 // (XlaBuilder) are much more ergonomic than the internal ones. As it turns
439 // out, XlaBuilder isn't really a client API—what it does is build a
440 // HloModuleProto protocol buffer, that we can then deserialize and clone
441 // into our HloModule. Ideally we would avoid the protocol buffer step;
442 // that is left as an exercise for future work.
443 XlaBuilder builder(name);
444 XlaOp a = Parameter(&builder, 0, instruction->operand(0)->shape(), "a");
446 XlaOp l, BuildQrDecomposition(a,
447 /*block_size=*/128,
448 /*precision=*/PrecisionConfig::HIGHEST));
450 TF_ASSIGN_OR_RETURN(XlaComputation xla_computation, builder.Build(l));
452 TF_ASSIGN_OR_RETURN(ProgramShape program_shape,
453 xla_computation.GetProgramShape());
454 HloModuleConfig config(program_shape);
455 TF_ASSIGN_OR_RETURN(auto new_module, HloModule::CreateFromProto(
456 xla_computation.proto(), config));
457 HloCloneContext context(module);
458 computation =
459 module->DeepCloneComputation(new_module->entry_computation(), &context);
460 }
462 return instruction->parent()->AddInstruction(HloInstruction::CreateCall(
463 instruction->shape(), instruction->operands(), computation));
464 }
466 } // namespace xla