• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/qr_expander.h"
17 
18 #include <memory>
19 #include <vector>
20 
21 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
22 #include "tensorflow/compiler/xla/client/lib/constants.h"
23 #include "tensorflow/compiler/xla/client/lib/loops.h"
24 #include "tensorflow/compiler/xla/client/lib/math.h"
25 #include "tensorflow/compiler/xla/client/lib/matrix.h"
26 #include "tensorflow/compiler/xla/client/lib/slicing.h"
27 #include "tensorflow/compiler/xla/client/xla_builder.h"
28 #include "tensorflow/compiler/xla/literal.h"
29 #include "tensorflow/compiler/xla/primitive_util.h"
30 #include "tensorflow/compiler/xla/shape_util.h"
31 #include "tensorflow/compiler/xla/status_macros.h"
32 #include "tensorflow/compiler/xla/statusor.h"
33 #include "tensorflow/compiler/xla/util.h"
34 #include "tensorflow/core/lib/core/errors.h"
35 
36 namespace xla {
37 
38 namespace {
39 
ConcatVectors(absl::Span<const int64> xs,absl::Span<const int64> ys)40 std::vector<int64> ConcatVectors(absl::Span<const int64> xs,
41                                  absl::Span<const int64> ys) {
42   std::vector<int64> output;
43   output.reserve(xs.size() + ys.size());
44   std::copy(xs.begin(), xs.end(), std::back_inserter(output));
45   std::copy(ys.begin(), ys.end(), std::back_inserter(output));
46   return output;
47 }
48 
49 // Computes a Householder reflection of the form:
50 // H = I - tau v v.T.
51 // such that
52 // H . ( x1  ) = ( x1   )
53 //     ( x2  ) = ( x2   )
54 //     ( ... ) = ( ...  )
55 //     ( xk  ) = ( beta )
56 //     ( ... )   ( 0    )
57 //     ( ... )   ( 0    )
58 // Unlike the usual formulation, we allow the caller to supply 'k' rather than
59 // only providing the relevant part of 'x' to maintain XLA's static shape
60 // invariant. In addition, the implementation supports batching.
61 // Pseudo-code, without batching:
62 //   alpha = x[k]
63 //   x_copy = np.copy(x)
64 //   x_copy[:k+1] = 0
65 //   xnorm = norm2(x_copy)
66 //   if xnorm == 0 and np.imag(alpha) == 0:
67 //     beta = alpha
68 //     tau = 0
69 //     v = np.zeros_like(x)
70 //   else:
71 //     beta = -np.sign(np.real(alpha)) * np.sqrt(alpha * np.conj(alpha) + xnorm)
72 //     if np.issubdtype(x.dtype, np.complexfloating):
73 //       tau = (beta - alpha) / beta
74 //     else:
75 //       tau = (beta - np.real(alpha) / beta) + (-np.imag(alpha) / beta) * 1j
76 //     v = x / (alpha - beta)
77 //   v[k] = 1
78 //   return (v, tau, beta)
79 // TODO(phawkins): LAPACK's xLARFG implementation has code for handling
80 // overflows in the norm/beta calculations. Perhaps do the same here.
House(XlaOp x,XlaOp k,absl::Span<const int64> batch_dims,const int64 m,XlaOp * v,XlaOp * tau,XlaOp * beta)81 Status House(XlaOp x, XlaOp k, absl::Span<const int64> batch_dims,
82              const int64 m, XlaOp* v, XlaOp* tau, XlaOp* beta) {
83   XlaBuilder* const builder = x.builder();
84   TF_ASSIGN_OR_RETURN(Shape x_shape, builder->GetShape(x));
85   const PrimitiveType type = x_shape.element_type();
86 
87   std::vector<int64> batch_dim_ids(batch_dims.size());
88   std::iota(batch_dim_ids.begin(), batch_dim_ids.end(), 0);
89   const int64 minor_dim = batch_dims.size();
90 
91   XlaOp zero = ScalarLike(x, 0.0);
92 
93   // alpha = x[k]
94   XlaOp alpha = Reshape(DynamicSliceInMinorDims(x, {k}, {1}), batch_dims);
95 
96   // Compute x[k+1:] (padded with zeros in elements 0..k)
97   XlaOp iota = Iota(builder, S32, m);
98   XlaOp x_after_k = Mul(x, ConvertElementType(Gt(iota, k), type),
99                         /*broadcast_dimensions=*/{minor_dim});
100 
101   XlaOp sigma_is_zero;
102   if (primitive_util::IsComplexType(type)) {
103     // sigma = np.dot(x[k+1:], np.conj(x[k+1:]))
104     // TODO(phawkins): this calculation may be numerically unstable.
105     auto x_squared = Real(x_after_k * Conj(x_after_k));
106     auto sigma =
107         Reduce(x_squared, ScalarLike(x_squared, 0.0),
108                CreateScalarAddComputation(
109                    primitive_util::ComplexComponentType(type), builder),
110                {minor_dim});
111     // mu = np.sqrt(x[k]*np.con(x[k]) + sigma)
112     auto mu = Sqrt(Real(alpha * Conj(alpha)) + sigma);
113 
114     sigma_is_zero = Eq(sigma, ScalarLike(sigma, 0));
115     sigma_is_zero = And(sigma_is_zero, Eq(Imag(alpha), ScalarLike(sigma, 0)));
116 
117     *beta = Select(Lt(Real(alpha), ScalarLike(sigma, 0)), ScalarLike(mu, 1),
118                    ScalarLike(mu, -1)) *
119             mu;
120     *beta = Select(sigma_is_zero, Real(alpha), *beta);
121     *tau = Complex((*beta - Real(alpha)) / *beta, -Imag(alpha) / *beta);
122   } else {
123     // sigma = np.dot(x[k+1:], x[k+1:])
124     // TODO(phawkins): this calculation may be numerically unstable.
125     auto sigma = Reduce(x_after_k * x_after_k, zero,
126                         CreateScalarAddComputation(type, builder), {minor_dim});
127     // mu = np.sqrt(x[k]*x[k] + sigma)
128     auto mu = Sqrt(Square(alpha) + sigma);
129     sigma_is_zero = Eq(sigma, zero);
130 
131     XlaOp one = ScalarLike(x, 1.0);
132     *beta = Select(Lt(alpha, zero), one, -one) * mu;
133     *beta = Select(sigma_is_zero, alpha, *beta);
134     *tau = (*beta - alpha) / *beta;
135   }
136   *tau = Select(sigma_is_zero, ZerosLike(*tau), *tau);
137 
138   auto divisor =
139       Select(sigma_is_zero, Broadcast(ScalarLike(alpha, 1), batch_dims),
140              alpha - ConvertElementType(*beta, type));
141 
142   auto e_k = Broadcast(ConvertElementType(Eq(iota, k), type),
143                        std::vector<int64>(batch_dims.size(), 1));
144 
145   // Form v as [0, 0, ..., 1] ++ x[k+1:] / divisor
146   // If sigma is zero, x[k+1:] is zero, so use any non-zero divisor.
147   *v = e_k + Div(x_after_k, divisor, /*broadcast_dimensions=*/batch_dim_ids);
148   return Status::OK();
149 }
150 
151 }  // namespace
152 
153 // Householder QR decomposition. Algorithm 5.2.1 from Golub and Van
154 // Loan "Matrix Computations", 4th Edition. This is an unblocked implementation
155 // used as an inner routine of the blocked implementation.
156 // Algorithm is adapted slightly so the shapes inside the loop are static, at
157 // the cost of some redundant computation. Since this is used as an inner block
158 // kernel, accumulates the Householder transformations (vs, taus) rather than
159 // the matrix q.
160 // Equivalent Python code, without batching:
161 // def qr(a):
162 //   m = a.shape[0]
163 //   n = a.shape[1]
164 //   taus = np.zeros([n])
165 //   for j in xrange(min(m, n)):
166 //     v, tau, beta = house(a[:, j], j)
167 //     a[:, j+1:] -= np.conj(tau) * np.dot(v[:, np.newaxis],
168 //                                np.dot(np.conj(v[np.newaxis, :]), a[:, j+1:]))
169 //     # Form column j explicitly rather than relying on the precision of the
170 //     # Householder update.
171 //     a[j, j] = beta
172 //     a[j+1:, j] = v[j+1:]
173 //     taus[j] = tau
174 //   return (a, taus)
QrBlock(XlaOp a,PrecisionConfig::Precision precision)175 StatusOr<QrExpander::QrResult> QrExpander::QrBlock(
176     XlaOp a, PrecisionConfig::Precision precision) {
177   XlaBuilder* builder = a.builder();
178   TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
179   const int num_dims = a_shape.rank();
180   if (num_dims < 2) {
181     return InvalidArgument("Argument to QR must have rank >= 2; got shape %s",
182                            a_shape.ToString());
183   }
184   PrimitiveType type = a_shape.element_type();
185 
186   const int64 m = ShapeUtil::GetDimension(a_shape, -2);
187   const int64 n = ShapeUtil::GetDimension(a_shape, -1);
188 
189   const int64 num_batch_dims = num_dims - 2;
190   std::vector<int64> batch_dims(num_batch_dims);
191   for (int i = 0; i < num_batch_dims; ++i) {
192     batch_dims[i] = ShapeUtil::GetDimension(a_shape, i);
193   }
194 
195   std::vector<int64> batch_dim_indices(num_batch_dims);
196   std::iota(batch_dim_indices.begin(), batch_dim_indices.end(), 0);
197 
198   auto qr_body_fn = [&](XlaOp j, absl::Span<const XlaOp> values,
199                         XlaBuilder* builder) -> StatusOr<std::vector<XlaOp>> {
200     auto a = values[0];
201     auto taus = values[1];
202 
203     // v, tau, beta = house(a[:, j], j)
204     auto x = DynamicSliceInMinorDims(a, {j}, {1});
205     XlaOp v, tau, beta;
206     TF_RETURN_IF_ERROR(House(Collapse(x, {num_dims - 2, num_dims - 1}), j,
207                              batch_dims, m, &v, &tau, &beta));
208 
209     const int64 minor_dim = batch_dims.size();
210     auto iota_mn = Iota(
211         builder, ShapeUtil::MakeShape(S32, ConcatVectors(batch_dims, {m, n})),
212         minor_dim + 1);
213 
214     std::vector<int64> shape = batch_dims;
215     shape.push_back(1);
216     shape.push_back(m);
217     auto v_broadcast = Reshape(v, shape);
218     // a[:, j+1:] -= np.conj(tau) * (v[:, np.newaxis] @
219     //     (np.conj(v[np.newaxis, :]) @ a[:, j+1:]))
220     // We use masking rather than a loop-variant shape to handle the j+1:
221     // indexing.
222     auto vva = BatchDot(MaybeConjugate(v_broadcast, true),
223                         Select(Lt(j, iota_mn), a, ZerosLike(a)), precision);
224     vva = BatchDot(v_broadcast, true, vva, false, precision);
225     a = a - Mul(MaybeConjugate(tau, true), vva,
226                 /*broadcast_dimensions=*/batch_dim_indices);
227 
228     // a[j, j] = beta
229     // a[j+1:,j] = v[j+1:]
230     auto iota = Reshape(Iota(a.builder(), S32, m), {m, 1});
231     auto predecessor_mask = ConvertElementType(Lt(iota, j), type);
232     auto mask = Broadcast(ConvertElementType(Eq(iota, j), type),
233                           std::vector<int64>(batch_dims.size(), 1));
234     auto successor_mask = Gt(Iota(a.builder(), S32, m), j);
235     auto new_x = Mul(x, predecessor_mask,
236                      /*broadcast_dimensions=*/{num_dims - 2, num_dims - 1}) +
237                  Mul(ConvertElementType(beta, type), mask,
238                      /*broadcast_dimensions=*/batch_dim_indices);
239     new_x = Add(
240         new_x, Select(Broadcast(successor_mask, batch_dims), v, ZerosLike(v)),
241         /*broadcast_dimensions=*/ConcatVectors(batch_dim_indices, {minor_dim}));
242     // Update a[:,j]
243     std::vector<int64> dim_ids(num_dims);
244     std::iota(dim_ids.begin(), dim_ids.end(), 0);
245     new_x = BroadcastInDim(new_x, ConcatVectors(batch_dims, {m, n}),
246                            /*broadcast_dimensions=*/dim_ids);
247     a = Select(Eq(iota_mn, j), new_x, a);
248 
249     // taus[j] = tau
250     std::vector<int64> tau_broadcast_dims(batch_dims.size());
251     std::iota(tau_broadcast_dims.begin(), tau_broadcast_dims.end(), 0);
252 
253     auto iota_n =
254         Iota(builder, ShapeUtil::MakeShape(S32, ConcatVectors(batch_dims, {n})),
255              minor_dim);
256     auto taus_zeros = ZerosLike(taus);
257     auto taus_update = Select(
258         Eq(iota_n, j),
259         Add(taus_zeros, tau, /*broadcast_dimensions=*/tau_broadcast_dims),
260         taus_zeros);
261     taus = taus + taus_update;
262     return std::vector<XlaOp>{a, taus};
263   };
264 
265   auto taus = Zeros(
266       builder,
267       ShapeUtil::MakeShape(type, ConcatVectors(batch_dims, {std::min(m, n)})));
268 
269   TF_ASSIGN_OR_RETURN(auto values, ForEachIndex(std::min(m, n), S32, qr_body_fn,
270                                                 {a, taus}, "qr", builder));
271 
272   QrResult result;
273   result.a = values[0];
274   result.taus = values[1];
275   return result;
276 }
277 
278 // Computes an upper triangular matrix T such that (I - Y @ T @ Y^t) is a
279 // product of the elementary Householder reflectors given by `vs` and `taus`.
280 //
281 // Schreiber, Robert, and Charles Van Loan. "A storage-efficient WY
282 // representation for products of Householder transformations." SIAM Journal on
283 // Scientific and Statistical Computing 10.1 (1989): 53-57.
284 //
285 // def compact_wy(vs, taus):
286 //   m, n = vs.shape[-2:]
287 //   t = np.eye(n) * -taus
288 //   # We premultiply Y.T @ vs, since we would prefer to compute a single matrix
289 //   # multiplication to many matrix-vector products.
290 //   vtv = -taus[None, :] * np.triu(np.conj(vs.T) @ vs, 1) + np.eye(n)
291 //   for i in range(1, n):
292 //     t[:, i] = scipy.linalg.blas.strmm(t, vtv[:, i])
293 //   return t
CompactWYRepresentation(PrimitiveType type,absl::Span<const int64> batch_dims,XlaOp vs,XlaOp taus,int64 m,int64 n,PrecisionConfig::Precision precision)294 StatusOr<XlaOp> QrExpander::CompactWYRepresentation(
295     PrimitiveType type, absl::Span<const int64> batch_dims, XlaOp vs,
296     XlaOp taus, int64 m, int64 n, PrecisionConfig::Precision precision) {
297   XlaBuilder* builder = vs.builder();
298 
299   std::vector<int64> batch_dim_indices(batch_dims.size());
300   std::iota(batch_dim_indices.begin(), batch_dim_indices.end(), 0);
301   int64 n_index = batch_dims.size() + 1;
302 
303   auto body_fn = [&](XlaOp j, absl::Span<const XlaOp> values,
304                      XlaBuilder* builder) -> StatusOr<std::vector<XlaOp>> {
305     // w has shape [..., m, n]
306     auto t = values[0];
307     const auto vtv = values[1];
308 
309     // yv has shape [..., n, 1]
310     auto yv = DynamicSliceInMinorDims(vtv, {j}, {1});
311 
312     // z has shape [..., n, 1]
313     auto z = BatchDot(t, yv, precision);
314 
315     t = DynamicUpdateSliceInMinorDims(t, z, {j});
316 
317     return std::vector<XlaOp>{t, vtv};
318   };
319 
320   auto tau_scale = BroadcastInDim(-taus, ConcatVectors(batch_dims, {1, n}),
321                                   ConcatVectors(batch_dim_indices, {n_index}));
322 
323   auto eye = Broadcast(IdentityMatrix(builder, type, n, n), batch_dims);
324   auto t = eye;
325 
326   auto vtv = BatchDot(MaybeConjugate(vs, true), /*transpose_x=*/true, vs,
327                       /*transpose_y=*/false, precision);
328   vtv = Select(TriangleMask(vtv, 0), ZerosLike(vtv), vtv);
329   vtv = (vtv + eye) * tau_scale;
330 
331   TF_ASSIGN_OR_RETURN(auto values,
332                       ForEachIndex(n, S32, body_fn, {t, vtv}, "wy", builder));
333   return values[0];
334 }
335 
336 // Block Householder QR Factorization. Algorithm 5.2.2 of Golub and van Loan.
337 // def qr_blocked(a, block_size):
338 //   m = a.shape[0]
339 //   n = a.shape[1]
340 //   q = np.eye(m)
341 //   for i in xrange(0, min(m, n), block_size):
342 //     k = min(block_size, min(m, n) - s)
343 //     (a, taus) = qr(a[i:, i:i+k])
344 //     y = np.eye(m, n) + np.tril(a, -1)
345 //     t = CompactWYRepresentation(vs, taus, m-i, k)
346 //     a[i:, i+k:] += (y @ np.conj(t.T)) @ (np.conj(y.T) @ a[i:, i+k:])
347 //     q[:, i:] += (q[:, i:] @ y) @ np.conj((y @ np.conj(t.T)).T)
348 //   return (q, a)
BuildQrDecomposition(XlaOp a,int64 block_size,PrecisionConfig::Precision precision)349 StatusOr<XlaOp> QrExpander::BuildQrDecomposition(
350     XlaOp a, int64 block_size, PrecisionConfig::Precision precision) {
351   XlaBuilder* builder = a.builder();
352   TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
353   const int num_dims = a_shape.rank();
354   if (num_dims < 2) {
355     return InvalidArgument("Arguments to QR must have rank >= 2: got shape %s",
356                            a_shape.ToString());
357   }
358   PrimitiveType type = a_shape.element_type();
359 
360   const int64 m = ShapeUtil::GetDimension(a_shape, -2);
361   const int64 n = ShapeUtil::GetDimension(a_shape, -1);
362   const int64 p = std::min(m, n);
363 
364   if (block_size < 1) {
365     return InvalidArgument("block_size argument to QR must be >= 1; got %d",
366                            block_size);
367   }
368 
369   const int64 num_batch_dims = num_dims - 2;
370   std::vector<int64> batch_dims(num_batch_dims);
371   for (int i = 0; i < num_batch_dims; ++i) {
372     batch_dims[i] = ShapeUtil::GetDimension(a_shape, i);
373   }
374 
375   auto q = Broadcast(IdentityMatrix(builder, type, m, m), batch_dims);
376   for (int64 i = 0; i < p; i += block_size) {
377     int64 k = std::min(block_size, p - i);
378 
379     auto a_block = SliceInMinorDims(a, {i, i}, {m, i + k});
380     TF_ASSIGN_OR_RETURN(auto qr_block, QrBlock(a_block, precision));
381     auto y = Add(
382         IdentityMatrix(builder, type, m - i, k),
383         Select(TriangleMask(qr_block.a, -1), qr_block.a, ZerosLike(qr_block.a)),
384         /*broadcast_dimensions=*/{num_dims - 2, num_dims - 1});
385 
386     a = UpdateSliceInMinorDims(a, qr_block.a, {i, i});
387 
388     // Compute the I + Y @ T @ Y^t block representation of a product of
389     // Householder matrices.
390     TF_ASSIGN_OR_RETURN(
391         auto t, CompactWYRepresentation(type, batch_dims, y, qr_block.taus,
392                                         m - i, k, precision));
393 
394     // a[i:, i+k:] += (y @ np.conj(t.T)) @ (np.conj(y.T) @ a[i:, i+k:])
395     auto yt = BatchDot(y, /*transpose_x=*/false, MaybeConjugate(t, true),
396                        /*transpose_y=*/true, precision);
397     auto a_panel = SliceInMinorDims(a, {i, i + k}, {m, n});
398     auto a_update =
399         BatchDot(MaybeConjugate(y, true), /*transpose_x=*/true, a_panel,
400                  /*transpose_y=*/false, precision);
401     a_update = BatchDot(yt, a_update, precision);
402     a_panel = a_panel + a_update;
403     a = UpdateSliceInMinorDims(a, a_panel, {i, i + k});
404 
405     // q[:, i:] += (q[:, i:] @ y) @ np.conj((y @ np.conj(t.T)).T)
406     auto q_panel = SliceInMinorDims(q, {0, i}, {m, m});
407     auto q_update = BatchDot(q_panel, y, precision);
408     q_update =
409         BatchDot(q_update, /*transpose_x=*/false, MaybeConjugate(yt, true),
410                  /*transpose_y=*/true, precision);
411     q_panel = q_panel + q_update;
412     q = UpdateSliceInMinorDims(q, q_panel, {0, i});
413   }
414 
415   return Tuple(builder, {q, UpperTriangle(a)});
416 }
417 
InstructionMatchesPattern(HloInstruction * instruction)418 bool QrExpander::InstructionMatchesPattern(HloInstruction* instruction) {
419   return instruction->opcode() == HloOpcode::kCustomCall &&
420          instruction->custom_call_target() == "QrDecomposition";
421 }
422 
ExpandInstruction(HloInstruction * instruction)423 StatusOr<HloInstruction*> QrExpander::ExpandInstruction(
424     HloInstruction* instruction) {
425   const string name =
426       absl::StrFormat("xla.qr_%s", instruction->operand(0)->shape().ToString());
427 
428   HloModule* module = instruction->parent()->parent();
429 
430   HloComputation*& computation =
431       computation_cache_.emplace(name, nullptr).first->second;
432   if (!computation) {
433     // Builds a new expansion.
434     //
435     // TODO(b/62327888): We do something unusual here: we build the computation
436     // using the XlaBuilder API, which is nominally an XLA client API. We do
437     // this because the external APIs for building complicated computations
438     // (XlaBuilder) are much more ergonomic than the internal ones. As it turns
439     // out, XlaBuilder isn't really a client API—what it does is build a
440     // HloModuleProto protocol buffer, that we can then deserialize and clone
441     // into our HloModule. Ideally we would avoid the protocol buffer step;
442     // that is left as an exercise for future work.
443     XlaBuilder builder(name);
444     XlaOp a = Parameter(&builder, 0, instruction->operand(0)->shape(), "a");
445     TF_ASSIGN_OR_RETURN(
446         XlaOp l, BuildQrDecomposition(a,
447                                       /*block_size=*/128,
448                                       /*precision=*/PrecisionConfig::HIGHEST));
449 
450     TF_ASSIGN_OR_RETURN(XlaComputation xla_computation, builder.Build(l));
451 
452     TF_ASSIGN_OR_RETURN(ProgramShape program_shape,
453                         xla_computation.GetProgramShape());
454     HloModuleConfig config(program_shape);
455     TF_ASSIGN_OR_RETURN(auto new_module, HloModule::CreateFromProto(
456                                              xla_computation.proto(), config));
457     HloCloneContext context(module);
458     computation =
459         module->DeepCloneComputation(new_module->entry_computation(), &context);
460   }
461 
462   return instruction->parent()->AddInstruction(HloInstruction::CreateCall(
463       instruction->shape(), instruction->operands(), computation));
464 }
465 
466 }  // namespace xla
467