• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
17 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
18 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
19 #include "tensorflow/compiler/xla/client/lib/constants.h"
20 #include "tensorflow/compiler/xla/client/lib/matrix.h"
21 #include "tensorflow/compiler/xla/client/xla_builder.h"
22 #include "tensorflow/compiler/xla/primitive_util.h"
23 #include "tensorflow/core/framework/op_kernel.h"
24 #include "tensorflow/core/lib/core/errors.h"
25 
26 namespace tensorflow {
27 namespace {
28 
29 // Calculates the diagonal length of a diagonal.
ComputeDiagLen(int diag_index,int num_rows,int num_cols)30 static inline int ComputeDiagLen(int diag_index, int num_rows, int num_cols) {
31   return std::min(num_rows + std::min(0, diag_index),
32                   num_cols - std::max(0, diag_index));
33 }
34 
35 // Checks if a diagonal is to be aligned left or right.
IsLeftAligned(int diag_index,bool left_align_superdiagonal,bool left_align_subdiagonal)36 static inline bool IsLeftAligned(int diag_index, bool left_align_superdiagonal,
37                                  bool left_align_subdiagonal) {
38   return (diag_index >= 0 && left_align_superdiagonal) ||
39          (diag_index <= 0 && left_align_subdiagonal);
40 }
41 
42 // Reads the diagonal packing alignment.
ReadAlignment(OpKernelConstruction * context,bool * left_align_superdiagonal,bool * left_align_subdiagonal)43 void ReadAlignment(OpKernelConstruction* context,
44                    bool* left_align_superdiagonal,
45                    bool* left_align_subdiagonal) {
46   string align;
47   OP_REQUIRES_OK(context, context->GetAttr("align", &align));
48 
49   *left_align_superdiagonal = align == "LEFT_LEFT" || align == "LEFT_RIGHT";
50   *left_align_subdiagonal = align == "LEFT_LEFT" || align == "RIGHT_LEFT";
51 }
52 
53 // Reads or infers lower_diag_index and upper_diag_index from kernel's input
54 // parameter "k". Also validates their values.
ProcessDiagIndex(XlaOpKernelContext * context)55 std::pair<int64, int64> ProcessDiagIndex(XlaOpKernelContext* context) {
56   int64 lower_diag_index = 0;
57   int64 upper_diag_index = 0;
58   TensorShape diag_index_shape = context->InputShape("k");
59 
60   // Wrapping OP_REQUIRES* macros with a function because they can "return;"
61   // early (without values) which contradicts ProcessDiagIndex's signature.
62   auto validate_diag_indices = [&]() {
63     if (diag_index_shape.dims() == 0) {
64       OP_REQUIRES_OK(context,
65                      context->ConstantInputAsIntScalar("k", &lower_diag_index));
66       upper_diag_index = lower_diag_index;
67     } else {
68       std::vector<int64> diag_index;
69       OP_REQUIRES_OK(context,
70                      context->ConstantInputAsIntVector("k", &diag_index));
71       OP_REQUIRES(
72           context, !diag_index.empty() && diag_index.size() <= 2,
73           errors::InvalidArgument(
74               "diag_index must have only one or two elements, received ",
75               diag_index.size(), " elements."));
76       lower_diag_index = diag_index[0];
77       upper_diag_index =
78           (diag_index.size() > 1) ? diag_index[1] : lower_diag_index;
79     }
80     OP_REQUIRES(
81         context, lower_diag_index <= upper_diag_index,
82         errors::InvalidArgument(
83             "lower_diag_index must not be larger than upper_diag_index: ",
84             lower_diag_index, " > ", upper_diag_index));
85   };
86   validate_diag_indices();
87   return {lower_diag_index, upper_diag_index};
88 }
89 
90 // Makes sure lower_diag_index and upper_diag_index are consistent with the
91 // input matrix size.
ValidateDiagIndexWithOutputMatrixSize(XlaOpKernelContext * context,const int64 lower_diag_index,const int64 upper_diag_index,const int64 num_rows,const int64 num_cols)92 void ValidateDiagIndexWithOutputMatrixSize(XlaOpKernelContext* context,
93                                            const int64 lower_diag_index,
94                                            const int64 upper_diag_index,
95                                            const int64 num_rows,
96                                            const int64 num_cols) {
97   // `lower_diag_index == 0` condition is added to handle matrix shape = 0.
98   OP_REQUIRES(context,
99               (-num_rows < lower_diag_index && lower_diag_index < num_cols) ||
100                   lower_diag_index == 0,
101               errors::InvalidArgument(
102                   "lower_diag_index is out of bound: ", lower_diag_index,
103                   " It must be between ", -num_rows, " and ", num_cols));
104   OP_REQUIRES(context,
105               (-num_rows < upper_diag_index && upper_diag_index < num_cols) ||
106                   upper_diag_index == 0,
107               errors::InvalidArgument(
108                   "upper_diag_index is out of bound: ", upper_diag_index,
109                   " It must be between ", -num_rows, " and ", num_cols));
110   OP_REQUIRES(context, lower_diag_index <= upper_diag_index,
111               errors::InvalidArgument(
112                   "lower_diag_index must not be larger than upper_diag_index: ",
113                   lower_diag_index, " > ", upper_diag_index));
114 }
115 
116 // Kernel to set matrix diagonals.
SetMatrixDiag(const xla::XlaOp input,const xla::XlaOp diag,const TensorShape & input_shape,const int64 diag_rank,const int64 num_diags,const int64 lower_diag_index,const int64 upper_diag_index,const int64 max_diag_len,const int64 num_rows,const int64 num_cols,const bool left_align_superdiagonal,const bool left_align_subdiagonal)117 xla::XlaOp SetMatrixDiag(const xla::XlaOp input, const xla::XlaOp diag,
118                          const TensorShape& input_shape, const int64 diag_rank,
119                          const int64 num_diags, const int64 lower_diag_index,
120                          const int64 upper_diag_index, const int64 max_diag_len,
121                          const int64 num_rows, const int64 num_cols,
122                          const bool left_align_superdiagonal,
123                          const bool left_align_subdiagonal) {
124   // Creates a padding config.
125   const int input_rank = input_shape.dims();
126   xla::PaddingConfig padding_config;
127   padding_config = xla::MakeNoPaddingConfig(input_rank - 1);
128 
129   // Processes one diagonal at a time:
130   // 1) Extracts a single diagonal (diag_slice).
131   // 2) Broadcasts its contents to fill the whole matrix (diag_broadcast).
132   // 3) Masks diag_broadcast to get the right diagonal shape.
133   //
134   // XLA can fuse multiple Broadcasts and Selects so this shouldn't be slow.
135   //
136   // For example,
137   //   diag = [[0, 2, 3], k = (-1, 1), num_cols = 4, and align="RIGHT_LEFT".
138   //           [4, 5, 6],
139   //           [7, 8, 9]]
140   // The expected output is [[7, 4, 2, 0],
141   //                         [0, 8, 5, 3],
142   //                         [0, 0, 9, 6]].
143   // The 1st diagonal is created by:
144   // 1) Extracting diag_slice = [0, 2, 3] which is right-aligned.
145   // 2) Padding the vector (in the same direction) to be as long as num_cols,
146   //      diag_slice = [0, 0, 2, 3],
147   //    then broadcasting diag_slice column-wise to a full matrix,
148   //      diag_broadcast = [[0, 0, 2, 3],
149   //                        [0, 0, 2, 3],
150   //                        [0, 0, 2, 3]].
151   //    The padding value can be anything because it will not appear in the
152   //    results after masking. Here, we use zero.
153   // 3) Masking diag_broadcast with a mask of the shape of the 1st diagonal.
154   //      mask = [[0, 0, 1, 0],  -->  output = [[x, x, 2, x],
155   //              [0, 0, 0, 1],                 [x, x, x, 3],
156   //              [0, 0, 0, 0]]                 [x, x, x, x]],
157   //    where x denotes the existing input contents.
158   std::vector<int64> broadcast_dimensions(input_rank - 1);
159   absl::c_iota(broadcast_dimensions, 0);
160   auto output = input;
161   for (int64 diag_index = lower_diag_index; diag_index <= upper_diag_index;
162        ++diag_index) {
163     // Extracts a single diagonal.
164     auto diag_slice = diag;
165     if (num_diags > 1) {
166       // The result of SliceInDim has dims: [<batch_dim>, 1, max_diag_len].
167       // We call Collapse to make the dims: [<batch_dim>, max_diag_len].
168       const int64 mapped_diag_index = upper_diag_index - diag_index;
169       diag_slice = xla::Collapse(
170           xla::SliceInDim(diag, mapped_diag_index, mapped_diag_index + 1, 1,
171                           diag_rank - 2),
172           {diag_rank - 2, diag_rank - 1});
173     }
174 
175     // Pad if necessary.
176     // - If the diagonal has the longest length, i.e., min(num_rows, num_cols),
177     //   no padding is necessary. It is broadcast column-wise if it is a sub-
178     //   diagonal, row-wise if superdiagonal.
179     // - Otherwise, pad and keep the old alignment (shorter diagonals in the
180     //   input come pre-padded). max_len in the table refers to max_diag_len.
181     //   -------------------------------------------------------------------
182     //   | Diag  | Align | Broadcast   |   padding_low   |   padding_high  |
183     //   -------------------------------------------------------------------
184     //   | Super | Left  | Row-wise    |        0        | #rows - max_len |
185     //   |       | Right | Column-wise | #cols - max_len |        0        |
186     //   -------------------------------------------------------------------
187     //   | Sub   | Left  | Column-wise |        0        | #cols - max_len |
188     //   |       | Right | Row-wise    | #rows - max_len |        0        |
189     //   -------------------------------------------------------------------
190     if (num_cols - num_rows <= diag_index && diag_index <= 0) {
191       broadcast_dimensions.back() = input_rank - 1;  // Column-wise.
192     } else if (0 <= diag_index && diag_index <= num_cols - num_rows) {
193       broadcast_dimensions.back() = input_rank - 2;  // Row-wise.
194     } else {
195       int length_to_pad_to;
196       if ((diag_index > 0 && left_align_superdiagonal) ||
197           (diag_index < 0 && !left_align_subdiagonal)) {
198         length_to_pad_to = num_rows;
199         broadcast_dimensions.back() = input_rank - 2;  // Row-wise.
200       } else {
201         length_to_pad_to = num_cols;
202         broadcast_dimensions.back() = input_rank - 1;  // Column-wise.
203       }
204       int padding_low = length_to_pad_to - max_diag_len;
205       int padding_high = 0;
206       if (IsLeftAligned(diag_index, left_align_superdiagonal,
207                         left_align_subdiagonal)) {
208         std::swap(padding_low, padding_high);
209       }
210       padding_config.mutable_dimensions(input_rank - 2)
211           ->set_edge_padding_low(padding_low);
212       padding_config.mutable_dimensions(input_rank - 2)
213           ->set_edge_padding_high(padding_high);
214 
215       const xla::XlaOp zero = xla::ScalarLike(input, 0);
216       diag_slice = xla::Pad(diag_slice, zero, padding_config);
217     }
218 
219     // Broadcast and mask.
220     xla::XlaOp diag_broadcast = xla::BroadcastInDim(
221         diag_slice, input_shape.dim_sizes(), broadcast_dimensions);
222     const auto mask = xla::GetDiagonalMask(output, diag_index);
223     output = xla::Select(mask, diag_broadcast, output);
224   }
225   return output;
226 }
227 
228 }  // namespace
229 
230 class MatrixDiagOp : public XlaOpKernel {
231  public:
MatrixDiagOp(OpKernelConstruction * context)232   explicit MatrixDiagOp(OpKernelConstruction* context) : XlaOpKernel(context) {
233     // MatrixDiagV3-specific.
234     if (context->HasAttr("align")) {
235       ReadAlignment(context, &left_align_superdiagonal_,
236                     &left_align_subdiagonal_);
237     }
238   }
239 
Compile(XlaOpKernelContext * context)240   void Compile(XlaOpKernelContext* context) override {
241     OP_REQUIRES(
242         context, context->num_inputs() >= kNumV1Inputs,
243         errors::InvalidArgument("MatrixDiag op must have at least one input"));
244     const TensorShape diag_shape = context->InputShape(0);
245     OP_REQUIRES(context, TensorShapeUtils::IsVectorOrHigher(diag_shape),
246                 errors::InvalidArgument("Expected >= 1 dims, got shape ",
247                                         diag_shape.DebugString()));
248 
249     const DataType dtype = context->expected_output_dtype(0);
250     const xla::XlaOp zero = XlaHelpers::Zero(context->builder(), dtype);
251 
252     // Initializes MatrixDiagV2-specific variables.
253     // Input arguments providing the values of num_rows and num_cols can be
254     // absent (-1) and will be inferred later.
255     int64 lower_diag_index = 0;
256     int64 upper_diag_index = 0;
257     int64 num_rows = -1;
258     int64 num_cols = -1;
259     xla::XlaOp padding_value = zero;
260 
261     // MatrixDiag and MatrixDiagV2 both use this OpKernel. MatrixDiag only has
262     // one input, so we have to check the number of inputs before reading
263     // additional parameters for MatrixDiagV2.
264     if (context->num_inputs() > kNumV1Inputs) {
265       std::tie(lower_diag_index, upper_diag_index) = ProcessDiagIndex(context);
266       OP_REQUIRES_OK(context, context->ConstantInputAsIntScalar(2, &num_rows));
267       OP_REQUIRES_OK(context, context->ConstantInputAsIntScalar(3, &num_cols));
268       padding_value = context->Input(4);
269     }
270 
271     // More size validations.
272     const int64 diag_rank = diag_shape.dims();
273     const int64 max_diag_len = diag_shape.dim_size(diag_rank - 1);
274     const int64 num_diags = upper_diag_index - lower_diag_index + 1;
275     OP_REQUIRES(
276         context,
277         num_diags == 1 || num_diags == diag_shape.dim_size(diag_rank - 2),
278         errors::InvalidArgument(
279             "The number of diagonals provided in the input does not "
280             "match the lower_diag_index and upper_diag_index range."));
281     const int64 min_num_rows =
282         max_diag_len - std::min(upper_diag_index, int64{0});
283     const int64 min_num_cols =
284         max_diag_len + std::max(lower_diag_index, int64{0});
285     OP_REQUIRES(context, num_rows == -1 || num_rows >= min_num_rows,
286                 errors::InvalidArgument("The number of rows is too small."));
287     OP_REQUIRES(context, num_cols == -1 || num_cols >= min_num_cols,
288                 errors::InvalidArgument("The number of columns is too small."));
289 
290     // Infers num_rows and num_cols. If both are unknown, assume that the output
291     // is square. Otherwise, use smallest possible values.
292     if (num_rows == -1 && num_cols == -1) {
293       num_rows = std::max(min_num_rows, min_num_cols);
294       num_cols = num_rows;
295     } else if (num_rows == -1) {
296       num_rows = min_num_rows;
297     } else if (num_cols == -1) {
298       num_cols = min_num_cols;
299     }
300 
301     // At least one of num_rows and num_cols must match its minimum length.
302     // Otherwise, we'll have some incomplete diagonals.
303     OP_REQUIRES(context, num_rows == min_num_rows || num_cols == min_num_cols,
304                 errors::InvalidArgument(
305                     "The number of rows or columns is not consistent with "
306                     "the specified d_lower, d_upper, and diagonal."));
307 
308     // Actual processing.
309     // Initializes the output tensor with padding_value.
310     TensorShape output_shape = diag_shape;
311     output_shape.RemoveLastDims((num_diags == 1) ? 1 : 2);
312     output_shape.AddDim(num_rows);
313     output_shape.AddDim(num_cols);
314     xla::XlaOp output = xla::Broadcast(padding_value, output_shape.dim_sizes());
315     xla::XlaOp diag = context->Input(0);
316     context->SetOutput(
317         0, SetMatrixDiag(output, diag, output_shape, diag_rank, num_diags,
318                          lower_diag_index, upper_diag_index, max_diag_len,
319                          num_rows, num_cols, left_align_superdiagonal_,
320                          left_align_subdiagonal_));
321   }
322 
323  private:
324   bool left_align_superdiagonal_ = true;
325   bool left_align_subdiagonal_ = true;
326   static constexpr int kNumV1Inputs = 1;
327 };
328 
329 REGISTER_XLA_OP(Name("MatrixDiag"), MatrixDiagOp);
330 REGISTER_XLA_OP(Name("MatrixDiagV2")
331                     .CompileTimeConstantInput("k")
332                     .CompileTimeConstantInput("num_rows")
333                     .CompileTimeConstantInput("num_cols")
334                     .CompileTimeConstantInput("padding_value"),
335                 MatrixDiagOp);
336 REGISTER_XLA_OP(Name("MatrixDiagV3")
337                     .CompileTimeConstantInput("k")
338                     .CompileTimeConstantInput("num_rows")
339                     .CompileTimeConstantInput("num_cols")
340                     .CompileTimeConstantInput("padding_value"),
341                 MatrixDiagOp);
342 
343 class MatrixDiagPartOp : public XlaOpKernel {
344  public:
MatrixDiagPartOp(OpKernelConstruction * context)345   explicit MatrixDiagPartOp(OpKernelConstruction* context)
346       : XlaOpKernel(context),
347         is_gpu_(context->device_type().type_string() == DEVICE_GPU_XLA_JIT) {
348     // MatrixDiagPartV3-specific.
349     if (context->HasAttr("align")) {
350       ReadAlignment(context, &left_align_superdiagonal_,
351                     &left_align_subdiagonal_);
352     }
353   }
354 
Compile(XlaOpKernelContext * context)355   void Compile(XlaOpKernelContext* context) override {
356     const TensorShape input_shape = context->InputShape(0);
357     const int input_rank = input_shape.dims();
358 
359     OP_REQUIRES(context, TensorShapeUtils::IsMatrixOrHigher(input_shape),
360                 errors::InvalidArgument(
361                     "input must be at least 2-dim, received shape: ",
362                     input_shape.DebugString()));
363 
364     const DataType dtype = context->expected_output_dtype(0);
365     const xla::XlaOp zero = XlaHelpers::Zero(context->builder(), dtype);
366 
367     // Initializes MatrixDiagPartV2-specific variables.
368     int64 lower_diag_index = 0;
369     int64 upper_diag_index = 0;
370     xla::XlaOp padding_value = zero;
371 
372     // MatrixDiagPart and MatrixDiagPartV2 both use this OpKernel.
373     // MatrixDiagPart only has one input, so we have to check the number of
374     // inputs before reading additional parameters in MatrixDiagV2.
375     if (context->num_inputs() > kNumV1Inputs) {
376       std::tie(lower_diag_index, upper_diag_index) = ProcessDiagIndex(context);
377       padding_value = context->Input(2);
378     }
379 
380     // Checks if diag sizes are consistent with input.
381     const int64 num_rows = input_shape.dim_size(input_rank - 2);
382     const int64 num_cols = input_shape.dim_size(input_rank - 1);
383     ValidateDiagIndexWithOutputMatrixSize(context, lower_diag_index,
384                                           upper_diag_index, num_rows, num_cols);
385 
386     // Creates output shape.
387     TensorShape output_shape = input_shape;
388     output_shape.RemoveLastDims(2);
389     const int num_diags = upper_diag_index - lower_diag_index + 1;
390     if (num_diags > 1) output_shape.AddDim(num_diags);
391     const int32 max_diag_len =
392         std::min(num_rows + std::min(upper_diag_index, int64{0}),
393                  num_cols - std::max(lower_diag_index, int64{0}));
394     output_shape.AddDim(max_diag_len);
395 
396     // Computes output.
397     xla::XlaOp input = context->Input(0);
398     std::vector<xla::XlaOp> diag_list;
399     xla::PaddingConfig padding_config =
400         xla::MakeNoPaddingConfig(input_rank - 1);
401     if (num_diags == 1) {
402       context->SetOutput(
403           0, is_gpu_ ? xla::GetMatrixDiagonalViaGather(input, upper_diag_index)
404                      : xla::GetMatrixDiagonal(input, upper_diag_index));
405       return;
406     }
407     for (int diag_index = upper_diag_index; diag_index >= lower_diag_index;
408          --diag_index) {
409       xla::XlaOp single_diag =
410           is_gpu_ ? xla::GetMatrixDiagonalViaGather(input, diag_index)
411                   : xla::GetMatrixDiagonal(input, diag_index);
412       const int64 diag_len = ComputeDiagLen(diag_index, num_rows, num_cols);
413       const int64 padding_len = max_diag_len - diag_len;
414       if (padding_len > 0) {
415         if (IsLeftAligned(diag_index, left_align_superdiagonal_,
416                           left_align_subdiagonal_)) {
417           padding_config.mutable_dimensions(input_rank - 2)
418               ->set_edge_padding_low(0);
419           padding_config.mutable_dimensions(input_rank - 2)
420               ->set_edge_padding_high(padding_len);
421         } else {
422           padding_config.mutable_dimensions(input_rank - 2)
423               ->set_edge_padding_low(padding_len);
424           padding_config.mutable_dimensions(input_rank - 2)
425               ->set_edge_padding_high(0);
426         }
427         single_diag = xla::Pad(single_diag, padding_value, padding_config);
428       }
429       diag_list.emplace_back(single_diag);
430     }
431     auto concat =
432         xla::ConcatInDim(context->builder(), diag_list, input_rank - 2);
433     context->SetOutput(0, xla::Reshape(concat, output_shape.dim_sizes()));
434   }
435 
436  private:
437   const bool is_gpu_;
438   bool left_align_superdiagonal_ = true;
439   bool left_align_subdiagonal_ = true;
440   static constexpr int kNumV1Inputs = 1;
441 };
442 
443 REGISTER_XLA_OP(Name("MatrixDiagPart"), MatrixDiagPartOp);
444 REGISTER_XLA_OP(Name("MatrixDiagPartV2")
445                     .CompileTimeConstantInput("k")
446                     .CompileTimeConstantInput("padding_value"),
447                 MatrixDiagPartOp);
448 REGISTER_XLA_OP(Name("MatrixDiagPartV3")
449                     .CompileTimeConstantInput("k")
450                     .CompileTimeConstantInput("padding_value"),
451                 MatrixDiagPartOp);
452 
453 class MatrixSetDiagOp : public XlaOpKernel {
454  public:
MatrixSetDiagOp(OpKernelConstruction * context)455   explicit MatrixSetDiagOp(OpKernelConstruction* context)
456       : XlaOpKernel(context) {
457     // MatrixSetDiagV3-specific.
458     if (context->HasAttr("align")) {
459       ReadAlignment(context, &left_align_superdiagonal_,
460                     &left_align_subdiagonal_);
461     }
462   }
463 
Compile(XlaOpKernelContext * context)464   void Compile(XlaOpKernelContext* context) override {
465     const TensorShape input_shape = context->InputShape(0);
466     const TensorShape diag_shape = context->InputShape(1);
467     const int input_rank = input_shape.dims();
468     const int diag_rank = diag_shape.dims();
469 
470     // Preliminary validation of sizes.
471     OP_REQUIRES(context, TensorShapeUtils::IsMatrixOrHigher(input_shape),
472                 errors::InvalidArgument(
473                     "input must be at least 2-dim, received shape: ",
474                     input_shape.DebugString()));
475     OP_REQUIRES(context, TensorShapeUtils::IsVectorOrHigher(diag_shape),
476                 errors::InvalidArgument(
477                     "diagonal must be at least 1-dim, received shape: ",
478                     diag_shape.DebugString()));
479 
480     // MatrixSetDiag and MatrixSetDiagV2 both use this OpKernel. MatrixSetDiag
481     // only has two inputs, so we have to check the number of inputs before
482     // reading additional parameters in MatrixSetDiagV2.
483     int64 lower_diag_index = 0;
484     int64 upper_diag_index = 0;
485     if (context->num_inputs() > kNumV1Inputs) {
486       std::tie(lower_diag_index, upper_diag_index) = ProcessDiagIndex(context);
487     }
488 
489     // Checks if diag sizes are consistent with input.
490     const int64 num_rows = input_shape.dim_size(input_rank - 2);
491     const int64 num_cols = input_shape.dim_size(input_rank - 1);
492     ValidateDiagIndexWithOutputMatrixSize(context, lower_diag_index,
493                                           upper_diag_index, num_rows, num_cols);
494     const Eigen::Index num_diags = upper_diag_index - lower_diag_index + 1;
495     OP_REQUIRES(
496         context,
497         lower_diag_index == upper_diag_index ||
498             (diag_shape.dim_size(input_rank - 2) == num_diags),
499         errors::InvalidArgument("The number of diagonals provided in `diag` "
500                                 "is not consistent with `lower_diag_index` and "
501                                 "`upper_diag_index`"));
502 
503     TensorShape expected_diag_shape = input_shape;
504     expected_diag_shape.RemoveLastDims(2);
505     if (num_diags > 1) expected_diag_shape.AddDim(num_diags);
506     const int32 max_diag_len =
507         std::min(num_rows + std::min(upper_diag_index, int64{0}),
508                  num_cols - std::max(lower_diag_index, int64{0}));
509     expected_diag_shape.AddDim(max_diag_len);
510     OP_REQUIRES(
511         context, expected_diag_shape == diag_shape,
512         errors::InvalidArgument(
513             "Either first dimensions of diagonal don't match input.shape[:-2], "
514             "or diagonal.shape[:-1] is not equal to the longests diagonal in "
515             "range [lower_diag_index:upper_diag_index].\nInput shape: ",
516             input_shape.DebugString(),
517             "\nDiagonal shape: ", diag_shape.DebugString(),
518             "\nExpected diagonal shape: ", expected_diag_shape.DebugString()));
519 
520     // Actual processing.
521     xla::XlaOp input = context->Input(0);
522     xla::XlaOp diag = context->Input(1);
523     context->SetOutput(
524         0, SetMatrixDiag(input, diag, input_shape, diag_rank, num_diags,
525                          lower_diag_index, upper_diag_index, max_diag_len,
526                          num_rows, num_cols, left_align_superdiagonal_,
527                          left_align_subdiagonal_));
528   }
529 
530  private:
531   bool left_align_superdiagonal_ = true;
532   bool left_align_subdiagonal_ = true;
533   static constexpr int kNumV1Inputs = 2;
534   TF_DISALLOW_COPY_AND_ASSIGN(MatrixSetDiagOp);
535 };
536 
537 REGISTER_XLA_OP(Name("MatrixSetDiag"), MatrixSetDiagOp);
538 REGISTER_XLA_OP(Name("MatrixSetDiagV2").CompileTimeConstantInput("k"),
539                 MatrixSetDiagOp);
540 REGISTER_XLA_OP(Name("MatrixSetDiagV3").CompileTimeConstantInput("k"),
541                 MatrixSetDiagOp);
542 
543 }  // namespace tensorflow
544