1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/qr_expander.h"
17
18 #include <memory>
19 #include <vector>
20
21 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
22 #include "tensorflow/compiler/xla/client/lib/constants.h"
23 #include "tensorflow/compiler/xla/client/lib/loops.h"
24 #include "tensorflow/compiler/xla/client/lib/math.h"
25 #include "tensorflow/compiler/xla/client/lib/matrix.h"
26 #include "tensorflow/compiler/xla/client/lib/slicing.h"
27 #include "tensorflow/compiler/xla/client/xla_builder.h"
28 #include "tensorflow/compiler/xla/literal.h"
29 #include "tensorflow/compiler/xla/primitive_util.h"
30 #include "tensorflow/compiler/xla/shape_util.h"
31 #include "tensorflow/compiler/xla/status_macros.h"
32 #include "tensorflow/compiler/xla/statusor.h"
33 #include "tensorflow/compiler/xla/util.h"
34 #include "tensorflow/core/lib/core/errors.h"
35
36 namespace xla {
37
38 namespace {
39
ConcatVectors(absl::Span<const int64> xs,absl::Span<const int64> ys)40 std::vector<int64> ConcatVectors(absl::Span<const int64> xs,
41 absl::Span<const int64> ys) {
42 std::vector<int64> output;
43 output.reserve(xs.size() + ys.size());
44 std::copy(xs.begin(), xs.end(), std::back_inserter(output));
45 std::copy(ys.begin(), ys.end(), std::back_inserter(output));
46 return output;
47 }
48
49 // Computes a Householder reflection of the form:
50 // H = I - tau v v.T.
51 // such that
52 // H . ( x1 ) = ( x1 )
53 // ( x2 ) = ( x2 )
54 // ( ... ) = ( ... )
55 // ( xk ) = ( beta )
56 // ( ... ) ( 0 )
57 // ( ... ) ( 0 )
58 // Unlike the usual formulation, we allow the caller to supply 'k' rather than
59 // only providing the relevant part of 'x' to maintain XLA's static shape
60 // invariant. In addition, the implementation supports batching.
61 // Pseudo-code, without batching:
62 // alpha = x[k]
63 // x_copy = np.copy(x)
64 // x_copy[:k+1] = 0
65 // xnorm = norm2(x_copy)
66 // if xnorm == 0 and np.imag(alpha) == 0:
67 // beta = alpha
68 // tau = 0
69 // v = np.zeros_like(x)
70 // else:
71 // beta = -np.sign(np.real(alpha)) * np.sqrt(alpha * np.conj(alpha) + xnorm)
72 // if np.issubdtype(x.dtype, np.complexfloating):
73 // tau = (beta - alpha) / beta
74 // else:
75 // tau = (beta - np.real(alpha) / beta) + (-np.imag(alpha) / beta) * 1j
76 // v = x / (alpha - beta)
77 // v[k] = 1
78 // return (v, tau, beta)
79 // TODO(phawkins): LAPACK's xLARFG implementation has code for handling
80 // overflows in the norm/beta calculations. Perhaps do the same here.
House(XlaOp x,XlaOp k,absl::Span<const int64> batch_dims,const int64 m,XlaOp * v,XlaOp * tau,XlaOp * beta)81 Status House(XlaOp x, XlaOp k, absl::Span<const int64> batch_dims,
82 const int64 m, XlaOp* v, XlaOp* tau, XlaOp* beta) {
83 XlaBuilder* const builder = x.builder();
84 TF_ASSIGN_OR_RETURN(Shape x_shape, builder->GetShape(x));
85 const PrimitiveType type = x_shape.element_type();
86
87 std::vector<int64> batch_dim_ids(batch_dims.size());
88 std::iota(batch_dim_ids.begin(), batch_dim_ids.end(), 0);
89 const int64 minor_dim = batch_dims.size();
90
91 XlaOp zero = ScalarLike(x, 0.0);
92
93 // alpha = x[k]
94 XlaOp alpha = Reshape(DynamicSliceInMinorDims(x, {k}, {1}), batch_dims);
95
96 // Compute x[k+1:] (padded with zeros in elements 0..k)
97 XlaOp iota = Iota(builder, S32, m);
98 XlaOp x_after_k = Mul(x, ConvertElementType(Gt(iota, k), type),
99 /*broadcast_dimensions=*/{minor_dim});
100
101 XlaOp sigma_is_zero;
102 if (primitive_util::IsComplexType(type)) {
103 // sigma = np.dot(x[k+1:], np.conj(x[k+1:]))
104 // TODO(phawkins): this calculation may be numerically unstable.
105 auto x_squared = Real(x_after_k * Conj(x_after_k));
106 auto sigma =
107 Reduce(x_squared, ScalarLike(x_squared, 0.0),
108 CreateScalarAddComputation(
109 primitive_util::ComplexComponentType(type), builder),
110 {minor_dim});
111 // mu = np.sqrt(x[k]*np.con(x[k]) + sigma)
112 auto mu = Sqrt(Real(alpha * Conj(alpha)) + sigma);
113
114 sigma_is_zero = Eq(sigma, ScalarLike(sigma, 0));
115 sigma_is_zero = And(sigma_is_zero, Eq(Imag(alpha), ScalarLike(sigma, 0)));
116
117 *beta = Select(Lt(Real(alpha), ScalarLike(sigma, 0)), ScalarLike(mu, 1),
118 ScalarLike(mu, -1)) *
119 mu;
120 *beta = Select(sigma_is_zero, Real(alpha), *beta);
121 *tau = Complex((*beta - Real(alpha)) / *beta, -Imag(alpha) / *beta);
122 } else {
123 // sigma = np.dot(x[k+1:], x[k+1:])
124 // TODO(phawkins): this calculation may be numerically unstable.
125 auto sigma = Reduce(x_after_k * x_after_k, zero,
126 CreateScalarAddComputation(type, builder), {minor_dim});
127 // mu = np.sqrt(x[k]*x[k] + sigma)
128 auto mu = Sqrt(Square(alpha) + sigma);
129 sigma_is_zero = Eq(sigma, zero);
130
131 XlaOp one = ScalarLike(x, 1.0);
132 *beta = Select(Lt(alpha, zero), one, -one) * mu;
133 *beta = Select(sigma_is_zero, alpha, *beta);
134 *tau = (*beta - alpha) / *beta;
135 }
136 *tau = Select(sigma_is_zero, ZerosLike(*tau), *tau);
137
138 auto divisor =
139 Select(sigma_is_zero, Broadcast(ScalarLike(alpha, 1), batch_dims),
140 alpha - ConvertElementType(*beta, type));
141
142 auto e_k = Broadcast(ConvertElementType(Eq(iota, k), type),
143 std::vector<int64>(batch_dims.size(), 1));
144
145 // Form v as [0, 0, ..., 1] ++ x[k+1:] / divisor
146 // If sigma is zero, x[k+1:] is zero, so use any non-zero divisor.
147 *v = e_k + Div(x_after_k, divisor, /*broadcast_dimensions=*/batch_dim_ids);
148 return Status::OK();
149 }
150
151 } // namespace
152
153 // Householder QR decomposition. Algorithm 5.2.1 from Golub and Van
154 // Loan "Matrix Computations", 4th Edition. This is an unblocked implementation
155 // used as an inner routine of the blocked implementation.
156 // Algorithm is adapted slightly so the shapes inside the loop are static, at
157 // the cost of some redundant computation. Since this is used as an inner block
158 // kernel, accumulates the Householder transformations (vs, taus) rather than
159 // the matrix q.
160 // Equivalent Python code, without batching:
161 // def qr(a):
162 // m = a.shape[0]
163 // n = a.shape[1]
164 // taus = np.zeros([n])
165 // for j in xrange(min(m, n)):
166 // v, tau, beta = house(a[:, j], j)
167 // a[:, j+1:] -= np.conj(tau) * np.dot(v[:, np.newaxis],
168 // np.dot(np.conj(v[np.newaxis, :]), a[:, j+1:]))
169 // # Form column j explicitly rather than relying on the precision of the
170 // # Householder update.
171 // a[j, j] = beta
172 // a[j+1:, j] = v[j+1:]
173 // taus[j] = tau
174 // return (a, taus)
QrBlock(XlaOp a,PrecisionConfig::Precision precision)175 StatusOr<QrExpander::QrResult> QrExpander::QrBlock(
176 XlaOp a, PrecisionConfig::Precision precision) {
177 XlaBuilder* builder = a.builder();
178 TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
179 const int num_dims = a_shape.rank();
180 if (num_dims < 2) {
181 return InvalidArgument("Argument to QR must have rank >= 2; got shape %s",
182 a_shape.ToString());
183 }
184 PrimitiveType type = a_shape.element_type();
185
186 const int64 m = ShapeUtil::GetDimension(a_shape, -2);
187 const int64 n = ShapeUtil::GetDimension(a_shape, -1);
188
189 const int64 num_batch_dims = num_dims - 2;
190 std::vector<int64> batch_dims(num_batch_dims);
191 for (int i = 0; i < num_batch_dims; ++i) {
192 batch_dims[i] = ShapeUtil::GetDimension(a_shape, i);
193 }
194
195 std::vector<int64> batch_dim_indices(num_batch_dims);
196 std::iota(batch_dim_indices.begin(), batch_dim_indices.end(), 0);
197
198 auto qr_body_fn = [&](XlaOp j, absl::Span<const XlaOp> values,
199 XlaBuilder* builder) -> StatusOr<std::vector<XlaOp>> {
200 auto a = values[0];
201 auto taus = values[1];
202
203 // v, tau, beta = house(a[:, j], j)
204 auto x = DynamicSliceInMinorDims(a, {j}, {1});
205 XlaOp v, tau, beta;
206 TF_RETURN_IF_ERROR(House(Collapse(x, {num_dims - 2, num_dims - 1}), j,
207 batch_dims, m, &v, &tau, &beta));
208
209 const int64 minor_dim = batch_dims.size();
210 auto iota_mn = Iota(
211 builder, ShapeUtil::MakeShape(S32, ConcatVectors(batch_dims, {m, n})),
212 minor_dim + 1);
213
214 std::vector<int64> shape = batch_dims;
215 shape.push_back(1);
216 shape.push_back(m);
217 auto v_broadcast = Reshape(v, shape);
218 // a[:, j+1:] -= np.conj(tau) * (v[:, np.newaxis] @
219 // (np.conj(v[np.newaxis, :]) @ a[:, j+1:]))
220 // We use masking rather than a loop-variant shape to handle the j+1:
221 // indexing.
222 auto vva = BatchDot(MaybeConjugate(v_broadcast, true),
223 Select(Lt(j, iota_mn), a, ZerosLike(a)), precision);
224 vva = BatchDot(v_broadcast, true, vva, false, precision);
225 a = a - Mul(MaybeConjugate(tau, true), vva,
226 /*broadcast_dimensions=*/batch_dim_indices);
227
228 // a[j, j] = beta
229 // a[j+1:,j] = v[j+1:]
230 auto iota = Reshape(Iota(a.builder(), S32, m), {m, 1});
231 auto predecessor_mask = ConvertElementType(Lt(iota, j), type);
232 auto mask = Broadcast(ConvertElementType(Eq(iota, j), type),
233 std::vector<int64>(batch_dims.size(), 1));
234 auto successor_mask = Gt(Iota(a.builder(), S32, m), j);
235 auto new_x = Mul(x, predecessor_mask,
236 /*broadcast_dimensions=*/{num_dims - 2, num_dims - 1}) +
237 Mul(ConvertElementType(beta, type), mask,
238 /*broadcast_dimensions=*/batch_dim_indices);
239 new_x = Add(
240 new_x, Select(Broadcast(successor_mask, batch_dims), v, ZerosLike(v)),
241 /*broadcast_dimensions=*/ConcatVectors(batch_dim_indices, {minor_dim}));
242 // Update a[:,j]
243 std::vector<int64> dim_ids(num_dims);
244 std::iota(dim_ids.begin(), dim_ids.end(), 0);
245 new_x = BroadcastInDim(new_x, ConcatVectors(batch_dims, {m, n}),
246 /*broadcast_dimensions=*/dim_ids);
247 a = Select(Eq(iota_mn, j), new_x, a);
248
249 // taus[j] = tau
250 std::vector<int64> tau_broadcast_dims(batch_dims.size());
251 std::iota(tau_broadcast_dims.begin(), tau_broadcast_dims.end(), 0);
252
253 auto iota_n =
254 Iota(builder, ShapeUtil::MakeShape(S32, ConcatVectors(batch_dims, {n})),
255 minor_dim);
256 auto taus_zeros = ZerosLike(taus);
257 auto taus_update = Select(
258 Eq(iota_n, j),
259 Add(taus_zeros, tau, /*broadcast_dimensions=*/tau_broadcast_dims),
260 taus_zeros);
261 taus = taus + taus_update;
262 return std::vector<XlaOp>{a, taus};
263 };
264
265 auto taus = Zeros(
266 builder,
267 ShapeUtil::MakeShape(type, ConcatVectors(batch_dims, {std::min(m, n)})));
268
269 TF_ASSIGN_OR_RETURN(auto values, ForEachIndex(std::min(m, n), S32, qr_body_fn,
270 {a, taus}, "qr", builder));
271
272 QrResult result;
273 result.a = values[0];
274 result.taus = values[1];
275 return result;
276 }
277
278 // Computes an upper triangular matrix T such that (I - Y @ T @ Y^t) is a
279 // product of the elementary Householder reflectors given by `vs` and `taus`.
280 //
281 // Schreiber, Robert, and Charles Van Loan. "A storage-efficient WY
282 // representation for products of Householder transformations." SIAM Journal on
283 // Scientific and Statistical Computing 10.1 (1989): 53-57.
284 //
285 // def compact_wy(vs, taus):
286 // m, n = vs.shape[-2:]
287 // t = np.eye(n) * -taus
288 // # We premultiply Y.T @ vs, since we would prefer to compute a single matrix
289 // # multiplication to many matrix-vector products.
290 // vtv = -taus[None, :] * np.triu(np.conj(vs.T) @ vs, 1) + np.eye(n)
291 // for i in range(1, n):
292 // t[:, i] = scipy.linalg.blas.strmm(t, vtv[:, i])
293 // return t
CompactWYRepresentation(PrimitiveType type,absl::Span<const int64> batch_dims,XlaOp vs,XlaOp taus,int64 m,int64 n,PrecisionConfig::Precision precision)294 StatusOr<XlaOp> QrExpander::CompactWYRepresentation(
295 PrimitiveType type, absl::Span<const int64> batch_dims, XlaOp vs,
296 XlaOp taus, int64 m, int64 n, PrecisionConfig::Precision precision) {
297 XlaBuilder* builder = vs.builder();
298
299 std::vector<int64> batch_dim_indices(batch_dims.size());
300 std::iota(batch_dim_indices.begin(), batch_dim_indices.end(), 0);
301 int64 n_index = batch_dims.size() + 1;
302
303 auto body_fn = [&](XlaOp j, absl::Span<const XlaOp> values,
304 XlaBuilder* builder) -> StatusOr<std::vector<XlaOp>> {
305 // w has shape [..., m, n]
306 auto t = values[0];
307 const auto vtv = values[1];
308
309 // yv has shape [..., n, 1]
310 auto yv = DynamicSliceInMinorDims(vtv, {j}, {1});
311
312 // z has shape [..., n, 1]
313 auto z = BatchDot(t, yv, precision);
314
315 t = DynamicUpdateSliceInMinorDims(t, z, {j});
316
317 return std::vector<XlaOp>{t, vtv};
318 };
319
320 auto tau_scale = BroadcastInDim(-taus, ConcatVectors(batch_dims, {1, n}),
321 ConcatVectors(batch_dim_indices, {n_index}));
322
323 auto eye = Broadcast(IdentityMatrix(builder, type, n, n), batch_dims);
324 auto t = eye;
325
326 auto vtv = BatchDot(MaybeConjugate(vs, true), /*transpose_x=*/true, vs,
327 /*transpose_y=*/false, precision);
328 vtv = Select(TriangleMask(vtv, 0), ZerosLike(vtv), vtv);
329 vtv = (vtv + eye) * tau_scale;
330
331 TF_ASSIGN_OR_RETURN(auto values,
332 ForEachIndex(n, S32, body_fn, {t, vtv}, "wy", builder));
333 return values[0];
334 }
335
336 // Block Householder QR Factorization. Algorithm 5.2.2 of Golub and van Loan.
337 // def qr_blocked(a, block_size):
338 // m = a.shape[0]
339 // n = a.shape[1]
340 // q = np.eye(m)
341 // for i in xrange(0, min(m, n), block_size):
342 // k = min(block_size, min(m, n) - s)
343 // (a, taus) = qr(a[i:, i:i+k])
344 // y = np.eye(m, n) + np.tril(a, -1)
345 // t = CompactWYRepresentation(vs, taus, m-i, k)
346 // a[i:, i+k:] += (y @ np.conj(t.T)) @ (np.conj(y.T) @ a[i:, i+k:])
347 // q[:, i:] += (q[:, i:] @ y) @ np.conj((y @ np.conj(t.T)).T)
348 // return (q, a)
BuildQrDecomposition(XlaOp a,int64 block_size,PrecisionConfig::Precision precision)349 StatusOr<XlaOp> QrExpander::BuildQrDecomposition(
350 XlaOp a, int64 block_size, PrecisionConfig::Precision precision) {
351 XlaBuilder* builder = a.builder();
352 TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
353 const int num_dims = a_shape.rank();
354 if (num_dims < 2) {
355 return InvalidArgument("Arguments to QR must have rank >= 2: got shape %s",
356 a_shape.ToString());
357 }
358 PrimitiveType type = a_shape.element_type();
359
360 const int64 m = ShapeUtil::GetDimension(a_shape, -2);
361 const int64 n = ShapeUtil::GetDimension(a_shape, -1);
362 const int64 p = std::min(m, n);
363
364 if (block_size < 1) {
365 return InvalidArgument("block_size argument to QR must be >= 1; got %d",
366 block_size);
367 }
368
369 const int64 num_batch_dims = num_dims - 2;
370 std::vector<int64> batch_dims(num_batch_dims);
371 for (int i = 0; i < num_batch_dims; ++i) {
372 batch_dims[i] = ShapeUtil::GetDimension(a_shape, i);
373 }
374
375 auto q = Broadcast(IdentityMatrix(builder, type, m, m), batch_dims);
376 for (int64 i = 0; i < p; i += block_size) {
377 int64 k = std::min(block_size, p - i);
378
379 auto a_block = SliceInMinorDims(a, {i, i}, {m, i + k});
380 TF_ASSIGN_OR_RETURN(auto qr_block, QrBlock(a_block, precision));
381 auto y = Add(
382 IdentityMatrix(builder, type, m - i, k),
383 Select(TriangleMask(qr_block.a, -1), qr_block.a, ZerosLike(qr_block.a)),
384 /*broadcast_dimensions=*/{num_dims - 2, num_dims - 1});
385
386 a = UpdateSliceInMinorDims(a, qr_block.a, {i, i});
387
388 // Compute the I + Y @ T @ Y^t block representation of a product of
389 // Householder matrices.
390 TF_ASSIGN_OR_RETURN(
391 auto t, CompactWYRepresentation(type, batch_dims, y, qr_block.taus,
392 m - i, k, precision));
393
394 // a[i:, i+k:] += (y @ np.conj(t.T)) @ (np.conj(y.T) @ a[i:, i+k:])
395 auto yt = BatchDot(y, /*transpose_x=*/false, MaybeConjugate(t, true),
396 /*transpose_y=*/true, precision);
397 auto a_panel = SliceInMinorDims(a, {i, i + k}, {m, n});
398 auto a_update =
399 BatchDot(MaybeConjugate(y, true), /*transpose_x=*/true, a_panel,
400 /*transpose_y=*/false, precision);
401 a_update = BatchDot(yt, a_update, precision);
402 a_panel = a_panel + a_update;
403 a = UpdateSliceInMinorDims(a, a_panel, {i, i + k});
404
405 // q[:, i:] += (q[:, i:] @ y) @ np.conj((y @ np.conj(t.T)).T)
406 auto q_panel = SliceInMinorDims(q, {0, i}, {m, m});
407 auto q_update = BatchDot(q_panel, y, precision);
408 q_update =
409 BatchDot(q_update, /*transpose_x=*/false, MaybeConjugate(yt, true),
410 /*transpose_y=*/true, precision);
411 q_panel = q_panel + q_update;
412 q = UpdateSliceInMinorDims(q, q_panel, {0, i});
413 }
414
415 return Tuple(builder, {q, UpperTriangle(a)});
416 }
417
InstructionMatchesPattern(HloInstruction * instruction)418 bool QrExpander::InstructionMatchesPattern(HloInstruction* instruction) {
419 return instruction->opcode() == HloOpcode::kCustomCall &&
420 instruction->custom_call_target() == "QrDecomposition";
421 }
422
ExpandInstruction(HloInstruction * instruction)423 StatusOr<HloInstruction*> QrExpander::ExpandInstruction(
424 HloInstruction* instruction) {
425 const string name =
426 absl::StrFormat("xla.qr_%s", instruction->operand(0)->shape().ToString());
427
428 HloModule* module = instruction->parent()->parent();
429
430 HloComputation*& computation =
431 computation_cache_.emplace(name, nullptr).first->second;
432 if (!computation) {
433 // Builds a new expansion.
434 //
435 // TODO(b/62327888): We do something unusual here: we build the computation
436 // using the XlaBuilder API, which is nominally an XLA client API. We do
437 // this because the external APIs for building complicated computations
438 // (XlaBuilder) are much more ergonomic than the internal ones. As it turns
439 // out, XlaBuilder isn't really a client API—what it does is build a
440 // HloModuleProto protocol buffer, that we can then deserialize and clone
441 // into our HloModule. Ideally we would avoid the protocol buffer step;
442 // that is left as an exercise for future work.
443 XlaBuilder builder(name);
444 XlaOp a = Parameter(&builder, 0, instruction->operand(0)->shape(), "a");
445 TF_ASSIGN_OR_RETURN(
446 XlaOp l, BuildQrDecomposition(a,
447 /*block_size=*/128,
448 /*precision=*/PrecisionConfig::HIGHEST));
449
450 TF_ASSIGN_OR_RETURN(XlaComputation xla_computation, builder.Build(l));
451
452 TF_ASSIGN_OR_RETURN(ProgramShape program_shape,
453 xla_computation.GetProgramShape());
454 HloModuleConfig config(program_shape);
455 TF_ASSIGN_OR_RETURN(auto new_module, HloModule::CreateFromProto(
456 xla_computation.proto(), config));
457 HloCloneContext context(module);
458 computation =
459 module->DeepCloneComputation(new_module->entry_computation(), &context);
460 }
461
462 return instruction->parent()->AddInstruction(HloInstruction::CreateCall(
463 instruction->shape(), instruction->operands(), computation));
464 }
465
466 } // namespace xla
467