• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/tf2xla/mlir_xla_op_kernel.h"
17 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
18 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
19 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
20 #include "tensorflow/compiler/xla/client/lib/constants.h"
21 #include "tensorflow/compiler/xla/client/lib/matrix.h"
22 #include "tensorflow/compiler/xla/client/xla_builder.h"
23 #include "tensorflow/compiler/xla/primitive_util.h"
24 #include "tensorflow/core/framework/op_kernel.h"
25 #include "tensorflow/core/lib/core/errors.h"
26 
27 namespace tensorflow {
28 namespace {
29 
30 // Calculates the diagonal length of a diagonal.
ComputeDiagLen(int diag_index,int num_rows,int num_cols)31 static inline int ComputeDiagLen(int diag_index, int num_rows, int num_cols) {
32   return std::min(num_rows + std::min(0, diag_index),
33                   num_cols - std::max(0, diag_index));
34 }
35 
36 // Checks if a diagonal is to be aligned left or right.
IsLeftAligned(int diag_index,bool left_align_superdiagonal,bool left_align_subdiagonal)37 static inline bool IsLeftAligned(int diag_index, bool left_align_superdiagonal,
38                                  bool left_align_subdiagonal) {
39   return (diag_index >= 0 && left_align_superdiagonal) ||
40          (diag_index <= 0 && left_align_subdiagonal);
41 }
42 
43 // Reads the diagonal packing alignment.
ReadAlignment(OpKernelConstruction * context,bool * left_align_superdiagonal,bool * left_align_subdiagonal)44 void ReadAlignment(OpKernelConstruction* context,
45                    bool* left_align_superdiagonal,
46                    bool* left_align_subdiagonal) {
47   string align;
48   OP_REQUIRES_OK(context, context->GetAttr("align", &align));
49 
50   *left_align_superdiagonal = align == "LEFT_LEFT" || align == "LEFT_RIGHT";
51   *left_align_subdiagonal = align == "LEFT_LEFT" || align == "RIGHT_LEFT";
52 }
53 
54 // Reads or infers lower_diag_index and upper_diag_index from kernel's input
55 // parameter "k". Also validates their values.
ProcessDiagIndex(XlaOpKernelContext * context)56 std::pair<int64_t, int64_t> ProcessDiagIndex(XlaOpKernelContext* context) {
57   int64_t lower_diag_index = 0;
58   int64_t upper_diag_index = 0;
59   TensorShape diag_index_shape = context->InputShape("k");
60 
61   // Wrapping OP_REQUIRES* macros with a function because they can "return;"
62   // early (without values) which contradicts ProcessDiagIndex's signature.
63   auto validate_diag_indices = [&]() {
64     if (diag_index_shape.dims() == 0) {
65       OP_REQUIRES_OK(context,
66                      context->ConstantInputAsIntScalar("k", &lower_diag_index));
67       upper_diag_index = lower_diag_index;
68     } else {
69       std::vector<int64_t> diag_index;
70       OP_REQUIRES_OK(context,
71                      context->ConstantInputAsIntVector("k", &diag_index));
72       OP_REQUIRES(
73           context, !diag_index.empty() && diag_index.size() <= 2,
74           errors::InvalidArgument(
75               "diag_index must have only one or two elements, received ",
76               diag_index.size(), " elements."));
77       lower_diag_index = diag_index[0];
78       upper_diag_index =
79           (diag_index.size() > 1) ? diag_index[1] : lower_diag_index;
80     }
81     OP_REQUIRES(
82         context, lower_diag_index <= upper_diag_index,
83         errors::InvalidArgument(
84             "lower_diag_index must not be larger than upper_diag_index: ",
85             lower_diag_index, " > ", upper_diag_index));
86   };
87   validate_diag_indices();
88   return {lower_diag_index, upper_diag_index};
89 }
90 
91 // Makes sure lower_diag_index and upper_diag_index are consistent with the
92 // input matrix size.
ValidateDiagIndexWithOutputMatrixSize(XlaOpKernelContext * context,const int64_t lower_diag_index,const int64_t upper_diag_index,const int64_t num_rows,const int64_t num_cols)93 void ValidateDiagIndexWithOutputMatrixSize(XlaOpKernelContext* context,
94                                            const int64_t lower_diag_index,
95                                            const int64_t upper_diag_index,
96                                            const int64_t num_rows,
97                                            const int64_t num_cols) {
98   // `lower_diag_index == 0` condition is added to handle matrix shape = 0.
99   OP_REQUIRES(context,
100               (-num_rows < lower_diag_index && lower_diag_index < num_cols) ||
101                   lower_diag_index == 0,
102               errors::InvalidArgument(
103                   "lower_diag_index is out of bound: ", lower_diag_index,
104                   " It must be between ", -num_rows, " and ", num_cols));
105   OP_REQUIRES(context,
106               (-num_rows < upper_diag_index && upper_diag_index < num_cols) ||
107                   upper_diag_index == 0,
108               errors::InvalidArgument(
109                   "upper_diag_index is out of bound: ", upper_diag_index,
110                   " It must be between ", -num_rows, " and ", num_cols));
111   OP_REQUIRES(context, lower_diag_index <= upper_diag_index,
112               errors::InvalidArgument(
113                   "lower_diag_index must not be larger than upper_diag_index: ",
114                   lower_diag_index, " > ", upper_diag_index));
115 }
116 
117 // Kernel to set matrix diagonals.
SetMatrixDiag(const xla::XlaOp input,const xla::XlaOp diag,const TensorShape & input_shape,const int64_t diag_rank,const int64_t num_diags,const int64_t lower_diag_index,const int64_t upper_diag_index,const int64_t max_diag_len,const int64_t num_rows,const int64_t num_cols,const bool left_align_superdiagonal,const bool left_align_subdiagonal)118 xla::XlaOp SetMatrixDiag(const xla::XlaOp input, const xla::XlaOp diag,
119                          const TensorShape& input_shape,
120                          const int64_t diag_rank, const int64_t num_diags,
121                          const int64_t lower_diag_index,
122                          const int64_t upper_diag_index,
123                          const int64_t max_diag_len, const int64_t num_rows,
124                          const int64_t num_cols,
125                          const bool left_align_superdiagonal,
126                          const bool left_align_subdiagonal) {
127   // Creates a padding config.
128   const int input_rank = input_shape.dims();
129   xla::PaddingConfig padding_config;
130   padding_config = xla::MakeNoPaddingConfig(input_rank - 1);
131 
132   // Processes one diagonal at a time:
133   // 1) Extracts a single diagonal (diag_slice).
134   // 2) Broadcasts its contents to fill the whole matrix (diag_broadcast).
135   // 3) Masks diag_broadcast to get the right diagonal shape.
136   //
137   // XLA can fuse multiple Broadcasts and Selects so this shouldn't be slow.
138   //
139   // For example,
140   //   diag = [[0, 2, 3], k = (-1, 1), num_cols = 4, and align="RIGHT_LEFT".
141   //           [4, 5, 6],
142   //           [7, 8, 9]]
143   // The expected output is [[7, 4, 2, 0],
144   //                         [0, 8, 5, 3],
145   //                         [0, 0, 9, 6]].
146   // The 1st diagonal is created by:
147   // 1) Extracting diag_slice = [0, 2, 3] which is right-aligned.
148   // 2) Padding the vector (in the same direction) to be as long as num_cols,
149   //      diag_slice = [0, 0, 2, 3],
150   //    then broadcasting diag_slice column-wise to a full matrix,
151   //      diag_broadcast = [[0, 0, 2, 3],
152   //                        [0, 0, 2, 3],
153   //                        [0, 0, 2, 3]].
154   //    The padding value can be anything because it will not appear in the
155   //    results after masking. Here, we use zero.
156   // 3) Masking diag_broadcast with a mask of the shape of the 1st diagonal.
157   //      mask = [[0, 0, 1, 0],  -->  output = [[x, x, 2, x],
158   //              [0, 0, 0, 1],                 [x, x, x, 3],
159   //              [0, 0, 0, 0]]                 [x, x, x, x]],
160   //    where x denotes the existing input contents.
161   std::vector<int64_t> broadcast_dimensions(input_rank - 1);
162   absl::c_iota(broadcast_dimensions, 0);
163   auto output = input;
164   for (int64_t diag_index = lower_diag_index; diag_index <= upper_diag_index;
165        ++diag_index) {
166     // Extracts a single diagonal.
167     auto diag_slice = diag;
168     if (num_diags > 1) {
169       // The result of SliceInDim has dims: [<batch_dim>, 1, max_diag_len].
170       // We call Collapse to make the dims: [<batch_dim>, max_diag_len].
171       const int64_t mapped_diag_index = upper_diag_index - diag_index;
172       diag_slice = xla::Collapse(
173           xla::SliceInDim(diag, mapped_diag_index, mapped_diag_index + 1, 1,
174                           diag_rank - 2),
175           {diag_rank - 2, diag_rank - 1});
176     }
177 
178     // Pad if necessary.
179     // - If the diagonal has the longest length, i.e., min(num_rows, num_cols),
180     //   no padding is necessary. It is broadcast column-wise if it is a sub-
181     //   diagonal, row-wise if superdiagonal.
182     // - Otherwise, pad and keep the old alignment (shorter diagonals in the
183     //   input come pre-padded). max_len in the table refers to max_diag_len.
184     //   -------------------------------------------------------------------
185     //   | Diag  | Align | Broadcast   |   padding_low   |   padding_high  |
186     //   -------------------------------------------------------------------
187     //   | Super | Left  | Row-wise    |        0        | #rows - max_len |
188     //   |       | Right | Column-wise | #cols - max_len |        0        |
189     //   -------------------------------------------------------------------
190     //   | Sub   | Left  | Column-wise |        0        | #cols - max_len |
191     //   |       | Right | Row-wise    | #rows - max_len |        0        |
192     //   -------------------------------------------------------------------
193     if (num_cols - num_rows <= diag_index && diag_index <= 0) {
194       broadcast_dimensions.back() = input_rank - 1;  // Column-wise.
195     } else if (0 <= diag_index && diag_index <= num_cols - num_rows) {
196       broadcast_dimensions.back() = input_rank - 2;  // Row-wise.
197     } else {
198       int length_to_pad_to;
199       if ((diag_index > 0 && left_align_superdiagonal) ||
200           (diag_index < 0 && !left_align_subdiagonal)) {
201         length_to_pad_to = num_rows;
202         broadcast_dimensions.back() = input_rank - 2;  // Row-wise.
203       } else {
204         length_to_pad_to = num_cols;
205         broadcast_dimensions.back() = input_rank - 1;  // Column-wise.
206       }
207       int padding_low = length_to_pad_to - max_diag_len;
208       int padding_high = 0;
209       if (IsLeftAligned(diag_index, left_align_superdiagonal,
210                         left_align_subdiagonal)) {
211         std::swap(padding_low, padding_high);
212       }
213       padding_config.mutable_dimensions(input_rank - 2)
214           ->set_edge_padding_low(padding_low);
215       padding_config.mutable_dimensions(input_rank - 2)
216           ->set_edge_padding_high(padding_high);
217 
218       const xla::XlaOp zero = xla::ScalarLike(input, 0);
219       diag_slice = xla::Pad(diag_slice, zero, padding_config);
220     }
221 
222     // Broadcast and mask.
223     xla::XlaOp diag_broadcast = xla::BroadcastInDim(
224         diag_slice, input_shape.dim_sizes(), broadcast_dimensions);
225     const auto mask = xla::GetDiagonalMask(output, diag_index);
226     output = xla::Select(mask, diag_broadcast, output);
227   }
228   return output;
229 }
230 
231 }  // namespace
232 
233 class MatrixDiagOp : public XlaOpKernel {
234  public:
MatrixDiagOp(OpKernelConstruction * context)235   explicit MatrixDiagOp(OpKernelConstruction* context) : XlaOpKernel(context) {
236     // MatrixDiagV3-specific.
237     if (context->HasAttr("align")) {
238       ReadAlignment(context, &left_align_superdiagonal_,
239                     &left_align_subdiagonal_);
240     }
241   }
242 
Compile(XlaOpKernelContext * context)243   void Compile(XlaOpKernelContext* context) override {
244     OP_REQUIRES(
245         context, context->num_inputs() >= kNumV1Inputs,
246         errors::InvalidArgument("MatrixDiag op must have at least one input"));
247     const TensorShape diag_shape = context->InputShape(0);
248     OP_REQUIRES(context, TensorShapeUtils::IsVectorOrHigher(diag_shape),
249                 errors::InvalidArgument(
250                     "diagonal must be at least 1-dim, received shape: ",
251                     diag_shape.DebugString()));
252 
253     const DataType dtype = context->expected_output_dtype(0);
254     const xla::XlaOp zero = XlaHelpers::Zero(context->builder(), dtype);
255 
256     // Initializes MatrixDiagV2-specific variables.
257     // Input arguments providing the values of num_rows and num_cols can be
258     // absent (-1) and will be inferred later.
259     int64_t lower_diag_index = 0;
260     int64_t upper_diag_index = 0;
261     int64_t num_rows = -1;
262     int64_t num_cols = -1;
263     xla::XlaOp padding_value = zero;
264 
265     // MatrixDiag and MatrixDiagV2 both use this OpKernel. MatrixDiag only has
266     // one input, so we have to check the number of inputs before reading
267     // additional parameters for MatrixDiagV2.
268     if (context->num_inputs() > kNumV1Inputs) {
269       std::tie(lower_diag_index, upper_diag_index) = ProcessDiagIndex(context);
270       OP_REQUIRES_OK(context, context->ConstantInputAsIntScalar(2, &num_rows));
271       OP_REQUIRES_OK(context, context->ConstantInputAsIntScalar(3, &num_cols));
272       padding_value = context->Input(4);
273     }
274 
275     // More size validations.
276     const int64_t diag_rank = diag_shape.dims();
277     const int64_t max_diag_len = diag_shape.dim_size(diag_rank - 1);
278     const int64_t num_diags = upper_diag_index - lower_diag_index + 1;
279     OP_REQUIRES(
280         context,
281         num_diags == 1 || num_diags == diag_shape.dim_size(diag_rank - 2),
282         errors::InvalidArgument(
283             "The number of diagonals provided in the input does not "
284             "match the lower_diag_index and upper_diag_index range."));
285     const int64_t min_num_rows =
286         max_diag_len - std::min(upper_diag_index, int64_t{0});
287     const int64_t min_num_cols =
288         max_diag_len + std::max(lower_diag_index, int64_t{0});
289     OP_REQUIRES(context, num_rows == -1 || num_rows >= min_num_rows,
290                 errors::InvalidArgument("The number of rows is too small."));
291     OP_REQUIRES(context, num_cols == -1 || num_cols >= min_num_cols,
292                 errors::InvalidArgument("The number of columns is too small."));
293 
294     // Infers num_rows and num_cols. If both are unknown, assume that the output
295     // is square. Otherwise, use smallest possible values.
296     if (num_rows == -1 && num_cols == -1) {
297       num_rows = std::max(min_num_rows, min_num_cols);
298       num_cols = num_rows;
299     } else if (num_rows == -1) {
300       num_rows = min_num_rows;
301     } else if (num_cols == -1) {
302       num_cols = min_num_cols;
303     }
304 
305     // At least one of num_rows and num_cols must match its minimum length.
306     // Otherwise, we'll have some incomplete diagonals.
307     OP_REQUIRES(context, num_rows == min_num_rows || num_cols == min_num_cols,
308                 errors::InvalidArgument(
309                     "The number of rows or columns is not consistent with "
310                     "the specified d_lower, d_upper, and diagonal."));
311 
312     // Actual processing.
313     // Initializes the output tensor with padding_value.
314     TensorShape output_shape = diag_shape;
315     output_shape.RemoveLastDims((num_diags == 1) ? 1 : 2);
316     output_shape.AddDim(num_rows);
317     output_shape.AddDim(num_cols);
318     xla::XlaOp output = xla::Broadcast(padding_value, output_shape.dim_sizes());
319     xla::XlaOp diag = context->Input(0);
320     context->SetOutput(
321         0, SetMatrixDiag(output, diag, output_shape, diag_rank, num_diags,
322                          lower_diag_index, upper_diag_index, max_diag_len,
323                          num_rows, num_cols, left_align_superdiagonal_,
324                          left_align_subdiagonal_));
325   }
326 
327  private:
328   bool left_align_superdiagonal_ = true;
329   bool left_align_subdiagonal_ = true;
330   static constexpr int kNumV1Inputs = 1;
331 };
332 
333 REGISTER_XLA_OP(Name("MatrixDiag"), MlirXlaOpKernel);
334 REGISTER_XLA_OP(Name("MatrixDiagV2")
335                     .CompileTimeConstantInput("k")
336                     .CompileTimeConstantInput("num_rows")
337                     .CompileTimeConstantInput("num_cols")
338                     .CompileTimeConstantInput("padding_value"),
339                 MatrixDiagOp);
340 REGISTER_XLA_OP(Name("MatrixDiagV3")
341                     .CompileTimeConstantInput("k")
342                     .CompileTimeConstantInput("num_rows")
343                     .CompileTimeConstantInput("num_cols")
344                     .CompileTimeConstantInput("padding_value"),
345                 MatrixDiagOp);
346 
347 class MatrixDiagPartOp : public XlaOpKernel {
348  public:
MatrixDiagPartOp(OpKernelConstruction * context)349   explicit MatrixDiagPartOp(OpKernelConstruction* context)
350       : XlaOpKernel(context),
351         is_gpu_(context->device_type().type_string() == DEVICE_GPU_XLA_JIT) {
352     // MatrixDiagPartV3-specific.
353     if (context->HasAttr("align")) {
354       ReadAlignment(context, &left_align_superdiagonal_,
355                     &left_align_subdiagonal_);
356     }
357   }
358 
Compile(XlaOpKernelContext * context)359   void Compile(XlaOpKernelContext* context) override {
360     const TensorShape input_shape = context->InputShape(0);
361     const int input_rank = input_shape.dims();
362 
363     OP_REQUIRES(context, TensorShapeUtils::IsMatrixOrHigher(input_shape),
364                 errors::InvalidArgument(
365                     "input must be at least 2-dim, received shape: ",
366                     input_shape.DebugString()));
367 
368     const DataType dtype = context->expected_output_dtype(0);
369     const xla::XlaOp zero = XlaHelpers::Zero(context->builder(), dtype);
370 
371     // Initializes MatrixDiagPartV2-specific variables.
372     int64_t lower_diag_index = 0;
373     int64_t upper_diag_index = 0;
374     xla::XlaOp padding_value = zero;
375 
376     // MatrixDiagPart and MatrixDiagPartV2 both use this OpKernel.
377     // MatrixDiagPart only has one input, so we have to check the number of
378     // inputs before reading additional parameters in MatrixDiagV2.
379     if (context->num_inputs() > kNumV1Inputs) {
380       std::tie(lower_diag_index, upper_diag_index) = ProcessDiagIndex(context);
381       padding_value = context->Input(2);
382     }
383 
384     // Checks if diag sizes are consistent with input.
385     const int64_t num_rows = input_shape.dim_size(input_rank - 2);
386     const int64_t num_cols = input_shape.dim_size(input_rank - 1);
387     ValidateDiagIndexWithOutputMatrixSize(context, lower_diag_index,
388                                           upper_diag_index, num_rows, num_cols);
389 
390     // Creates output shape.
391     TensorShape output_shape = input_shape;
392     output_shape.RemoveLastDims(2);
393     const int num_diags = upper_diag_index - lower_diag_index + 1;
394     if (num_diags > 1) output_shape.AddDim(num_diags);
395     const int32_t max_diag_len =
396         std::min(num_rows + std::min(upper_diag_index, int64_t{0}),
397                  num_cols - std::max(lower_diag_index, int64_t{0}));
398     output_shape.AddDim(max_diag_len);
399 
400     // Computes output.
401     xla::XlaOp input = context->Input(0);
402     std::vector<xla::XlaOp> diag_list;
403     xla::PaddingConfig padding_config =
404         xla::MakeNoPaddingConfig(input_rank - 1);
405     if (num_diags == 1) {
406       context->SetOutput(
407           0, is_gpu_ ? xla::GetMatrixDiagonalViaGather(input, upper_diag_index)
408                      : xla::GetMatrixDiagonal(input, upper_diag_index));
409       return;
410     }
411     for (int diag_index = upper_diag_index; diag_index >= lower_diag_index;
412          --diag_index) {
413       xla::XlaOp single_diag =
414           is_gpu_ ? xla::GetMatrixDiagonalViaGather(input, diag_index)
415                   : xla::GetMatrixDiagonal(input, diag_index);
416       const int64_t diag_len = ComputeDiagLen(diag_index, num_rows, num_cols);
417       const int64_t padding_len = max_diag_len - diag_len;
418       if (padding_len > 0) {
419         if (IsLeftAligned(diag_index, left_align_superdiagonal_,
420                           left_align_subdiagonal_)) {
421           padding_config.mutable_dimensions(input_rank - 2)
422               ->set_edge_padding_low(0);
423           padding_config.mutable_dimensions(input_rank - 2)
424               ->set_edge_padding_high(padding_len);
425         } else {
426           padding_config.mutable_dimensions(input_rank - 2)
427               ->set_edge_padding_low(padding_len);
428           padding_config.mutable_dimensions(input_rank - 2)
429               ->set_edge_padding_high(0);
430         }
431         single_diag = xla::Pad(single_diag, padding_value, padding_config);
432       }
433       diag_list.emplace_back(single_diag);
434     }
435     auto concat =
436         xla::ConcatInDim(context->builder(), diag_list, input_rank - 2);
437     context->SetOutput(0, xla::Reshape(concat, output_shape.dim_sizes()));
438   }
439 
440  private:
441   const bool is_gpu_;
442   bool left_align_superdiagonal_ = true;
443   bool left_align_subdiagonal_ = true;
444   static constexpr int kNumV1Inputs = 1;
445 };
446 
447 REGISTER_XLA_OP(Name("MatrixDiagPart"), MatrixDiagPartOp);
448 REGISTER_XLA_OP(Name("MatrixDiagPartV2")
449                     .CompileTimeConstantInput("k")
450                     .CompileTimeConstantInput("padding_value"),
451                 MatrixDiagPartOp);
452 REGISTER_XLA_OP(Name("MatrixDiagPartV3")
453                     .CompileTimeConstantInput("k")
454                     .CompileTimeConstantInput("padding_value"),
455                 MatrixDiagPartOp);
456 
457 class MatrixSetDiagOp : public XlaOpKernel {
458  public:
MatrixSetDiagOp(OpKernelConstruction * context)459   explicit MatrixSetDiagOp(OpKernelConstruction* context)
460       : XlaOpKernel(context) {
461     // MatrixSetDiagV3-specific.
462     if (context->HasAttr("align")) {
463       ReadAlignment(context, &left_align_superdiagonal_,
464                     &left_align_subdiagonal_);
465     }
466   }
467 
Compile(XlaOpKernelContext * context)468   void Compile(XlaOpKernelContext* context) override {
469     const TensorShape input_shape = context->InputShape(0);
470     const TensorShape diag_shape = context->InputShape(1);
471     const int input_rank = input_shape.dims();
472     const int diag_rank = diag_shape.dims();
473 
474     // Preliminary validation of sizes.
475     OP_REQUIRES(context, TensorShapeUtils::IsMatrixOrHigher(input_shape),
476                 errors::InvalidArgument(
477                     "input must be at least 2-dim, received shape: ",
478                     input_shape.DebugString()));
479     OP_REQUIRES(context, TensorShapeUtils::IsVectorOrHigher(diag_shape),
480                 errors::InvalidArgument(
481                     "diagonal must be at least 1-dim, received shape: ",
482                     diag_shape.DebugString()));
483 
484     // MatrixSetDiag and MatrixSetDiagV2 both use this OpKernel. MatrixSetDiag
485     // only has two inputs, so we have to check the number of inputs before
486     // reading additional parameters in MatrixSetDiagV2.
487     int64_t lower_diag_index = 0;
488     int64_t upper_diag_index = 0;
489     if (context->num_inputs() > kNumV1Inputs) {
490       std::tie(lower_diag_index, upper_diag_index) = ProcessDiagIndex(context);
491     }
492 
493     // Checks if diag sizes are consistent with input.
494     const int64_t num_rows = input_shape.dim_size(input_rank - 2);
495     const int64_t num_cols = input_shape.dim_size(input_rank - 1);
496     ValidateDiagIndexWithOutputMatrixSize(context, lower_diag_index,
497                                           upper_diag_index, num_rows, num_cols);
498     const Eigen::Index num_diags = upper_diag_index - lower_diag_index + 1;
499     OP_REQUIRES(
500         context,
501         lower_diag_index == upper_diag_index ||
502             (diag_shape.dim_size(input_rank - 2) == num_diags),
503         errors::InvalidArgument("The number of diagonals provided in `diag` "
504                                 "is not consistent with `lower_diag_index` and "
505                                 "`upper_diag_index`"));
506 
507     TensorShape expected_diag_shape = input_shape;
508     expected_diag_shape.RemoveLastDims(2);
509     if (num_diags > 1) expected_diag_shape.AddDim(num_diags);
510     const int32_t max_diag_len =
511         std::min(num_rows + std::min(upper_diag_index, int64_t{0}),
512                  num_cols - std::max(lower_diag_index, int64_t{0}));
513     expected_diag_shape.AddDim(max_diag_len);
514     OP_REQUIRES(
515         context, expected_diag_shape == diag_shape,
516         errors::InvalidArgument(
517             "Either first dimensions of diagonal don't match input.shape[:-2], "
518             "or diagonal.shape[:-1] is not equal to the longests diagonal in "
519             "range [lower_diag_index:upper_diag_index].\nInput shape: ",
520             input_shape.DebugString(),
521             "\nDiagonal shape: ", diag_shape.DebugString(),
522             "\nExpected diagonal shape: ", expected_diag_shape.DebugString()));
523 
524     // Actual processing.
525     xla::XlaOp input = context->Input(0);
526     xla::XlaOp diag = context->Input(1);
527     context->SetOutput(
528         0, SetMatrixDiag(input, diag, input_shape, diag_rank, num_diags,
529                          lower_diag_index, upper_diag_index, max_diag_len,
530                          num_rows, num_cols, left_align_superdiagonal_,
531                          left_align_subdiagonal_));
532   }
533 
534  private:
535   bool left_align_superdiagonal_ = true;
536   bool left_align_subdiagonal_ = true;
537   static constexpr int kNumV1Inputs = 2;
538   TF_DISALLOW_COPY_AND_ASSIGN(MatrixSetDiagOp);
539 };
540 
541 REGISTER_XLA_OP(Name("MatrixSetDiag"), MatrixSetDiagOp);
542 REGISTER_XLA_OP(Name("MatrixSetDiagV2").CompileTimeConstantInput("k"),
543                 MatrixSetDiagOp);
544 REGISTER_XLA_OP(Name("MatrixSetDiagV3").CompileTimeConstantInput("k"),
545                 MatrixSetDiagOp);
546 
547 }  // namespace tensorflow
548