1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/tf2xla/mlir_xla_op_kernel.h"
17 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
18 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
19 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
20 #include "tensorflow/compiler/xla/client/lib/constants.h"
21 #include "tensorflow/compiler/xla/client/lib/matrix.h"
22 #include "tensorflow/compiler/xla/client/xla_builder.h"
23 #include "tensorflow/compiler/xla/primitive_util.h"
24 #include "tensorflow/core/framework/op_kernel.h"
25 #include "tensorflow/core/lib/core/errors.h"
26
27 namespace tensorflow {
28 namespace {
29
30 // Calculates the diagonal length of a diagonal.
ComputeDiagLen(int diag_index,int num_rows,int num_cols)31 static inline int ComputeDiagLen(int diag_index, int num_rows, int num_cols) {
32 return std::min(num_rows + std::min(0, diag_index),
33 num_cols - std::max(0, diag_index));
34 }
35
36 // Checks if a diagonal is to be aligned left or right.
IsLeftAligned(int diag_index,bool left_align_superdiagonal,bool left_align_subdiagonal)37 static inline bool IsLeftAligned(int diag_index, bool left_align_superdiagonal,
38 bool left_align_subdiagonal) {
39 return (diag_index >= 0 && left_align_superdiagonal) ||
40 (diag_index <= 0 && left_align_subdiagonal);
41 }
42
43 // Reads the diagonal packing alignment.
ReadAlignment(OpKernelConstruction * context,bool * left_align_superdiagonal,bool * left_align_subdiagonal)44 void ReadAlignment(OpKernelConstruction* context,
45 bool* left_align_superdiagonal,
46 bool* left_align_subdiagonal) {
47 string align;
48 OP_REQUIRES_OK(context, context->GetAttr("align", &align));
49
50 *left_align_superdiagonal = align == "LEFT_LEFT" || align == "LEFT_RIGHT";
51 *left_align_subdiagonal = align == "LEFT_LEFT" || align == "RIGHT_LEFT";
52 }
53
54 // Reads or infers lower_diag_index and upper_diag_index from kernel's input
55 // parameter "k". Also validates their values.
ProcessDiagIndex(XlaOpKernelContext * context)56 std::pair<int64_t, int64_t> ProcessDiagIndex(XlaOpKernelContext* context) {
57 int64_t lower_diag_index = 0;
58 int64_t upper_diag_index = 0;
59 TensorShape diag_index_shape = context->InputShape("k");
60
61 // Wrapping OP_REQUIRES* macros with a function because they can "return;"
62 // early (without values) which contradicts ProcessDiagIndex's signature.
63 auto validate_diag_indices = [&]() {
64 if (diag_index_shape.dims() == 0) {
65 OP_REQUIRES_OK(context,
66 context->ConstantInputAsIntScalar("k", &lower_diag_index));
67 upper_diag_index = lower_diag_index;
68 } else {
69 std::vector<int64_t> diag_index;
70 OP_REQUIRES_OK(context,
71 context->ConstantInputAsIntVector("k", &diag_index));
72 OP_REQUIRES(
73 context, !diag_index.empty() && diag_index.size() <= 2,
74 errors::InvalidArgument(
75 "diag_index must have only one or two elements, received ",
76 diag_index.size(), " elements."));
77 lower_diag_index = diag_index[0];
78 upper_diag_index =
79 (diag_index.size() > 1) ? diag_index[1] : lower_diag_index;
80 }
81 OP_REQUIRES(
82 context, lower_diag_index <= upper_diag_index,
83 errors::InvalidArgument(
84 "lower_diag_index must not be larger than upper_diag_index: ",
85 lower_diag_index, " > ", upper_diag_index));
86 };
87 validate_diag_indices();
88 return {lower_diag_index, upper_diag_index};
89 }
90
91 // Makes sure lower_diag_index and upper_diag_index are consistent with the
92 // input matrix size.
ValidateDiagIndexWithOutputMatrixSize(XlaOpKernelContext * context,const int64_t lower_diag_index,const int64_t upper_diag_index,const int64_t num_rows,const int64_t num_cols)93 void ValidateDiagIndexWithOutputMatrixSize(XlaOpKernelContext* context,
94 const int64_t lower_diag_index,
95 const int64_t upper_diag_index,
96 const int64_t num_rows,
97 const int64_t num_cols) {
98 // `lower_diag_index == 0` condition is added to handle matrix shape = 0.
99 OP_REQUIRES(context,
100 (-num_rows < lower_diag_index && lower_diag_index < num_cols) ||
101 lower_diag_index == 0,
102 errors::InvalidArgument(
103 "lower_diag_index is out of bound: ", lower_diag_index,
104 " It must be between ", -num_rows, " and ", num_cols));
105 OP_REQUIRES(context,
106 (-num_rows < upper_diag_index && upper_diag_index < num_cols) ||
107 upper_diag_index == 0,
108 errors::InvalidArgument(
109 "upper_diag_index is out of bound: ", upper_diag_index,
110 " It must be between ", -num_rows, " and ", num_cols));
111 OP_REQUIRES(context, lower_diag_index <= upper_diag_index,
112 errors::InvalidArgument(
113 "lower_diag_index must not be larger than upper_diag_index: ",
114 lower_diag_index, " > ", upper_diag_index));
115 }
116
117 // Kernel to set matrix diagonals.
SetMatrixDiag(const xla::XlaOp input,const xla::XlaOp diag,const TensorShape & input_shape,const int64_t diag_rank,const int64_t num_diags,const int64_t lower_diag_index,const int64_t upper_diag_index,const int64_t max_diag_len,const int64_t num_rows,const int64_t num_cols,const bool left_align_superdiagonal,const bool left_align_subdiagonal)118 xla::XlaOp SetMatrixDiag(const xla::XlaOp input, const xla::XlaOp diag,
119 const TensorShape& input_shape,
120 const int64_t diag_rank, const int64_t num_diags,
121 const int64_t lower_diag_index,
122 const int64_t upper_diag_index,
123 const int64_t max_diag_len, const int64_t num_rows,
124 const int64_t num_cols,
125 const bool left_align_superdiagonal,
126 const bool left_align_subdiagonal) {
127 // Creates a padding config.
128 const int input_rank = input_shape.dims();
129 xla::PaddingConfig padding_config;
130 padding_config = xla::MakeNoPaddingConfig(input_rank - 1);
131
132 // Processes one diagonal at a time:
133 // 1) Extracts a single diagonal (diag_slice).
134 // 2) Broadcasts its contents to fill the whole matrix (diag_broadcast).
135 // 3) Masks diag_broadcast to get the right diagonal shape.
136 //
137 // XLA can fuse multiple Broadcasts and Selects so this shouldn't be slow.
138 //
139 // For example,
140 // diag = [[0, 2, 3], k = (-1, 1), num_cols = 4, and align="RIGHT_LEFT".
141 // [4, 5, 6],
142 // [7, 8, 9]]
143 // The expected output is [[7, 4, 2, 0],
144 // [0, 8, 5, 3],
145 // [0, 0, 9, 6]].
146 // The 1st diagonal is created by:
147 // 1) Extracting diag_slice = [0, 2, 3] which is right-aligned.
148 // 2) Padding the vector (in the same direction) to be as long as num_cols,
149 // diag_slice = [0, 0, 2, 3],
150 // then broadcasting diag_slice column-wise to a full matrix,
151 // diag_broadcast = [[0, 0, 2, 3],
152 // [0, 0, 2, 3],
153 // [0, 0, 2, 3]].
154 // The padding value can be anything because it will not appear in the
155 // results after masking. Here, we use zero.
156 // 3) Masking diag_broadcast with a mask of the shape of the 1st diagonal.
157 // mask = [[0, 0, 1, 0], --> output = [[x, x, 2, x],
158 // [0, 0, 0, 1], [x, x, x, 3],
159 // [0, 0, 0, 0]] [x, x, x, x]],
160 // where x denotes the existing input contents.
161 std::vector<int64_t> broadcast_dimensions(input_rank - 1);
162 absl::c_iota(broadcast_dimensions, 0);
163 auto output = input;
164 for (int64_t diag_index = lower_diag_index; diag_index <= upper_diag_index;
165 ++diag_index) {
166 // Extracts a single diagonal.
167 auto diag_slice = diag;
168 if (num_diags > 1) {
169 // The result of SliceInDim has dims: [<batch_dim>, 1, max_diag_len].
170 // We call Collapse to make the dims: [<batch_dim>, max_diag_len].
171 const int64_t mapped_diag_index = upper_diag_index - diag_index;
172 diag_slice = xla::Collapse(
173 xla::SliceInDim(diag, mapped_diag_index, mapped_diag_index + 1, 1,
174 diag_rank - 2),
175 {diag_rank - 2, diag_rank - 1});
176 }
177
178 // Pad if necessary.
179 // - If the diagonal has the longest length, i.e., min(num_rows, num_cols),
180 // no padding is necessary. It is broadcast column-wise if it is a sub-
181 // diagonal, row-wise if superdiagonal.
182 // - Otherwise, pad and keep the old alignment (shorter diagonals in the
183 // input come pre-padded). max_len in the table refers to max_diag_len.
184 // -------------------------------------------------------------------
185 // | Diag | Align | Broadcast | padding_low | padding_high |
186 // -------------------------------------------------------------------
187 // | Super | Left | Row-wise | 0 | #rows - max_len |
188 // | | Right | Column-wise | #cols - max_len | 0 |
189 // -------------------------------------------------------------------
190 // | Sub | Left | Column-wise | 0 | #cols - max_len |
191 // | | Right | Row-wise | #rows - max_len | 0 |
192 // -------------------------------------------------------------------
193 if (num_cols - num_rows <= diag_index && diag_index <= 0) {
194 broadcast_dimensions.back() = input_rank - 1; // Column-wise.
195 } else if (0 <= diag_index && diag_index <= num_cols - num_rows) {
196 broadcast_dimensions.back() = input_rank - 2; // Row-wise.
197 } else {
198 int length_to_pad_to;
199 if ((diag_index > 0 && left_align_superdiagonal) ||
200 (diag_index < 0 && !left_align_subdiagonal)) {
201 length_to_pad_to = num_rows;
202 broadcast_dimensions.back() = input_rank - 2; // Row-wise.
203 } else {
204 length_to_pad_to = num_cols;
205 broadcast_dimensions.back() = input_rank - 1; // Column-wise.
206 }
207 int padding_low = length_to_pad_to - max_diag_len;
208 int padding_high = 0;
209 if (IsLeftAligned(diag_index, left_align_superdiagonal,
210 left_align_subdiagonal)) {
211 std::swap(padding_low, padding_high);
212 }
213 padding_config.mutable_dimensions(input_rank - 2)
214 ->set_edge_padding_low(padding_low);
215 padding_config.mutable_dimensions(input_rank - 2)
216 ->set_edge_padding_high(padding_high);
217
218 const xla::XlaOp zero = xla::ScalarLike(input, 0);
219 diag_slice = xla::Pad(diag_slice, zero, padding_config);
220 }
221
222 // Broadcast and mask.
223 xla::XlaOp diag_broadcast = xla::BroadcastInDim(
224 diag_slice, input_shape.dim_sizes(), broadcast_dimensions);
225 const auto mask = xla::GetDiagonalMask(output, diag_index);
226 output = xla::Select(mask, diag_broadcast, output);
227 }
228 return output;
229 }
230
231 } // namespace
232
233 class MatrixDiagOp : public XlaOpKernel {
234 public:
MatrixDiagOp(OpKernelConstruction * context)235 explicit MatrixDiagOp(OpKernelConstruction* context) : XlaOpKernel(context) {
236 // MatrixDiagV3-specific.
237 if (context->HasAttr("align")) {
238 ReadAlignment(context, &left_align_superdiagonal_,
239 &left_align_subdiagonal_);
240 }
241 }
242
Compile(XlaOpKernelContext * context)243 void Compile(XlaOpKernelContext* context) override {
244 OP_REQUIRES(
245 context, context->num_inputs() >= kNumV1Inputs,
246 errors::InvalidArgument("MatrixDiag op must have at least one input"));
247 const TensorShape diag_shape = context->InputShape(0);
248 OP_REQUIRES(context, TensorShapeUtils::IsVectorOrHigher(diag_shape),
249 errors::InvalidArgument(
250 "diagonal must be at least 1-dim, received shape: ",
251 diag_shape.DebugString()));
252
253 const DataType dtype = context->expected_output_dtype(0);
254 const xla::XlaOp zero = XlaHelpers::Zero(context->builder(), dtype);
255
256 // Initializes MatrixDiagV2-specific variables.
257 // Input arguments providing the values of num_rows and num_cols can be
258 // absent (-1) and will be inferred later.
259 int64_t lower_diag_index = 0;
260 int64_t upper_diag_index = 0;
261 int64_t num_rows = -1;
262 int64_t num_cols = -1;
263 xla::XlaOp padding_value = zero;
264
265 // MatrixDiag and MatrixDiagV2 both use this OpKernel. MatrixDiag only has
266 // one input, so we have to check the number of inputs before reading
267 // additional parameters for MatrixDiagV2.
268 if (context->num_inputs() > kNumV1Inputs) {
269 std::tie(lower_diag_index, upper_diag_index) = ProcessDiagIndex(context);
270 OP_REQUIRES_OK(context, context->ConstantInputAsIntScalar(2, &num_rows));
271 OP_REQUIRES_OK(context, context->ConstantInputAsIntScalar(3, &num_cols));
272 padding_value = context->Input(4);
273 }
274
275 // More size validations.
276 const int64_t diag_rank = diag_shape.dims();
277 const int64_t max_diag_len = diag_shape.dim_size(diag_rank - 1);
278 const int64_t num_diags = upper_diag_index - lower_diag_index + 1;
279 OP_REQUIRES(
280 context,
281 num_diags == 1 || num_diags == diag_shape.dim_size(diag_rank - 2),
282 errors::InvalidArgument(
283 "The number of diagonals provided in the input does not "
284 "match the lower_diag_index and upper_diag_index range."));
285 const int64_t min_num_rows =
286 max_diag_len - std::min(upper_diag_index, int64_t{0});
287 const int64_t min_num_cols =
288 max_diag_len + std::max(lower_diag_index, int64_t{0});
289 OP_REQUIRES(context, num_rows == -1 || num_rows >= min_num_rows,
290 errors::InvalidArgument("The number of rows is too small."));
291 OP_REQUIRES(context, num_cols == -1 || num_cols >= min_num_cols,
292 errors::InvalidArgument("The number of columns is too small."));
293
294 // Infers num_rows and num_cols. If both are unknown, assume that the output
295 // is square. Otherwise, use smallest possible values.
296 if (num_rows == -1 && num_cols == -1) {
297 num_rows = std::max(min_num_rows, min_num_cols);
298 num_cols = num_rows;
299 } else if (num_rows == -1) {
300 num_rows = min_num_rows;
301 } else if (num_cols == -1) {
302 num_cols = min_num_cols;
303 }
304
305 // At least one of num_rows and num_cols must match its minimum length.
306 // Otherwise, we'll have some incomplete diagonals.
307 OP_REQUIRES(context, num_rows == min_num_rows || num_cols == min_num_cols,
308 errors::InvalidArgument(
309 "The number of rows or columns is not consistent with "
310 "the specified d_lower, d_upper, and diagonal."));
311
312 // Actual processing.
313 // Initializes the output tensor with padding_value.
314 TensorShape output_shape = diag_shape;
315 output_shape.RemoveLastDims((num_diags == 1) ? 1 : 2);
316 output_shape.AddDim(num_rows);
317 output_shape.AddDim(num_cols);
318 xla::XlaOp output = xla::Broadcast(padding_value, output_shape.dim_sizes());
319 xla::XlaOp diag = context->Input(0);
320 context->SetOutput(
321 0, SetMatrixDiag(output, diag, output_shape, diag_rank, num_diags,
322 lower_diag_index, upper_diag_index, max_diag_len,
323 num_rows, num_cols, left_align_superdiagonal_,
324 left_align_subdiagonal_));
325 }
326
327 private:
328 bool left_align_superdiagonal_ = true;
329 bool left_align_subdiagonal_ = true;
330 static constexpr int kNumV1Inputs = 1;
331 };
332
333 REGISTER_XLA_OP(Name("MatrixDiag"), MlirXlaOpKernel);
334 REGISTER_XLA_OP(Name("MatrixDiagV2")
335 .CompileTimeConstantInput("k")
336 .CompileTimeConstantInput("num_rows")
337 .CompileTimeConstantInput("num_cols")
338 .CompileTimeConstantInput("padding_value"),
339 MatrixDiagOp);
340 REGISTER_XLA_OP(Name("MatrixDiagV3")
341 .CompileTimeConstantInput("k")
342 .CompileTimeConstantInput("num_rows")
343 .CompileTimeConstantInput("num_cols")
344 .CompileTimeConstantInput("padding_value"),
345 MatrixDiagOp);
346
347 class MatrixDiagPartOp : public XlaOpKernel {
348 public:
MatrixDiagPartOp(OpKernelConstruction * context)349 explicit MatrixDiagPartOp(OpKernelConstruction* context)
350 : XlaOpKernel(context),
351 is_gpu_(context->device_type().type_string() == DEVICE_GPU_XLA_JIT) {
352 // MatrixDiagPartV3-specific.
353 if (context->HasAttr("align")) {
354 ReadAlignment(context, &left_align_superdiagonal_,
355 &left_align_subdiagonal_);
356 }
357 }
358
Compile(XlaOpKernelContext * context)359 void Compile(XlaOpKernelContext* context) override {
360 const TensorShape input_shape = context->InputShape(0);
361 const int input_rank = input_shape.dims();
362
363 OP_REQUIRES(context, TensorShapeUtils::IsMatrixOrHigher(input_shape),
364 errors::InvalidArgument(
365 "input must be at least 2-dim, received shape: ",
366 input_shape.DebugString()));
367
368 const DataType dtype = context->expected_output_dtype(0);
369 const xla::XlaOp zero = XlaHelpers::Zero(context->builder(), dtype);
370
371 // Initializes MatrixDiagPartV2-specific variables.
372 int64_t lower_diag_index = 0;
373 int64_t upper_diag_index = 0;
374 xla::XlaOp padding_value = zero;
375
376 // MatrixDiagPart and MatrixDiagPartV2 both use this OpKernel.
377 // MatrixDiagPart only has one input, so we have to check the number of
378 // inputs before reading additional parameters in MatrixDiagV2.
379 if (context->num_inputs() > kNumV1Inputs) {
380 std::tie(lower_diag_index, upper_diag_index) = ProcessDiagIndex(context);
381 padding_value = context->Input(2);
382 }
383
384 // Checks if diag sizes are consistent with input.
385 const int64_t num_rows = input_shape.dim_size(input_rank - 2);
386 const int64_t num_cols = input_shape.dim_size(input_rank - 1);
387 ValidateDiagIndexWithOutputMatrixSize(context, lower_diag_index,
388 upper_diag_index, num_rows, num_cols);
389
390 // Creates output shape.
391 TensorShape output_shape = input_shape;
392 output_shape.RemoveLastDims(2);
393 const int num_diags = upper_diag_index - lower_diag_index + 1;
394 if (num_diags > 1) output_shape.AddDim(num_diags);
395 const int32_t max_diag_len =
396 std::min(num_rows + std::min(upper_diag_index, int64_t{0}),
397 num_cols - std::max(lower_diag_index, int64_t{0}));
398 output_shape.AddDim(max_diag_len);
399
400 // Computes output.
401 xla::XlaOp input = context->Input(0);
402 std::vector<xla::XlaOp> diag_list;
403 xla::PaddingConfig padding_config =
404 xla::MakeNoPaddingConfig(input_rank - 1);
405 if (num_diags == 1) {
406 context->SetOutput(
407 0, is_gpu_ ? xla::GetMatrixDiagonalViaGather(input, upper_diag_index)
408 : xla::GetMatrixDiagonal(input, upper_diag_index));
409 return;
410 }
411 for (int diag_index = upper_diag_index; diag_index >= lower_diag_index;
412 --diag_index) {
413 xla::XlaOp single_diag =
414 is_gpu_ ? xla::GetMatrixDiagonalViaGather(input, diag_index)
415 : xla::GetMatrixDiagonal(input, diag_index);
416 const int64_t diag_len = ComputeDiagLen(diag_index, num_rows, num_cols);
417 const int64_t padding_len = max_diag_len - diag_len;
418 if (padding_len > 0) {
419 if (IsLeftAligned(diag_index, left_align_superdiagonal_,
420 left_align_subdiagonal_)) {
421 padding_config.mutable_dimensions(input_rank - 2)
422 ->set_edge_padding_low(0);
423 padding_config.mutable_dimensions(input_rank - 2)
424 ->set_edge_padding_high(padding_len);
425 } else {
426 padding_config.mutable_dimensions(input_rank - 2)
427 ->set_edge_padding_low(padding_len);
428 padding_config.mutable_dimensions(input_rank - 2)
429 ->set_edge_padding_high(0);
430 }
431 single_diag = xla::Pad(single_diag, padding_value, padding_config);
432 }
433 diag_list.emplace_back(single_diag);
434 }
435 auto concat =
436 xla::ConcatInDim(context->builder(), diag_list, input_rank - 2);
437 context->SetOutput(0, xla::Reshape(concat, output_shape.dim_sizes()));
438 }
439
440 private:
441 const bool is_gpu_;
442 bool left_align_superdiagonal_ = true;
443 bool left_align_subdiagonal_ = true;
444 static constexpr int kNumV1Inputs = 1;
445 };
446
447 REGISTER_XLA_OP(Name("MatrixDiagPart"), MatrixDiagPartOp);
448 REGISTER_XLA_OP(Name("MatrixDiagPartV2")
449 .CompileTimeConstantInput("k")
450 .CompileTimeConstantInput("padding_value"),
451 MatrixDiagPartOp);
452 REGISTER_XLA_OP(Name("MatrixDiagPartV3")
453 .CompileTimeConstantInput("k")
454 .CompileTimeConstantInput("padding_value"),
455 MatrixDiagPartOp);
456
457 class MatrixSetDiagOp : public XlaOpKernel {
458 public:
MatrixSetDiagOp(OpKernelConstruction * context)459 explicit MatrixSetDiagOp(OpKernelConstruction* context)
460 : XlaOpKernel(context) {
461 // MatrixSetDiagV3-specific.
462 if (context->HasAttr("align")) {
463 ReadAlignment(context, &left_align_superdiagonal_,
464 &left_align_subdiagonal_);
465 }
466 }
467
Compile(XlaOpKernelContext * context)468 void Compile(XlaOpKernelContext* context) override {
469 const TensorShape input_shape = context->InputShape(0);
470 const TensorShape diag_shape = context->InputShape(1);
471 const int input_rank = input_shape.dims();
472 const int diag_rank = diag_shape.dims();
473
474 // Preliminary validation of sizes.
475 OP_REQUIRES(context, TensorShapeUtils::IsMatrixOrHigher(input_shape),
476 errors::InvalidArgument(
477 "input must be at least 2-dim, received shape: ",
478 input_shape.DebugString()));
479 OP_REQUIRES(context, TensorShapeUtils::IsVectorOrHigher(diag_shape),
480 errors::InvalidArgument(
481 "diagonal must be at least 1-dim, received shape: ",
482 diag_shape.DebugString()));
483
484 // MatrixSetDiag and MatrixSetDiagV2 both use this OpKernel. MatrixSetDiag
485 // only has two inputs, so we have to check the number of inputs before
486 // reading additional parameters in MatrixSetDiagV2.
487 int64_t lower_diag_index = 0;
488 int64_t upper_diag_index = 0;
489 if (context->num_inputs() > kNumV1Inputs) {
490 std::tie(lower_diag_index, upper_diag_index) = ProcessDiagIndex(context);
491 }
492
493 // Checks if diag sizes are consistent with input.
494 const int64_t num_rows = input_shape.dim_size(input_rank - 2);
495 const int64_t num_cols = input_shape.dim_size(input_rank - 1);
496 ValidateDiagIndexWithOutputMatrixSize(context, lower_diag_index,
497 upper_diag_index, num_rows, num_cols);
498 const Eigen::Index num_diags = upper_diag_index - lower_diag_index + 1;
499 OP_REQUIRES(
500 context,
501 lower_diag_index == upper_diag_index ||
502 (diag_shape.dim_size(input_rank - 2) == num_diags),
503 errors::InvalidArgument("The number of diagonals provided in `diag` "
504 "is not consistent with `lower_diag_index` and "
505 "`upper_diag_index`"));
506
507 TensorShape expected_diag_shape = input_shape;
508 expected_diag_shape.RemoveLastDims(2);
509 if (num_diags > 1) expected_diag_shape.AddDim(num_diags);
510 const int32_t max_diag_len =
511 std::min(num_rows + std::min(upper_diag_index, int64_t{0}),
512 num_cols - std::max(lower_diag_index, int64_t{0}));
513 expected_diag_shape.AddDim(max_diag_len);
514 OP_REQUIRES(
515 context, expected_diag_shape == diag_shape,
516 errors::InvalidArgument(
517 "Either first dimensions of diagonal don't match input.shape[:-2], "
518 "or diagonal.shape[:-1] is not equal to the longests diagonal in "
519 "range [lower_diag_index:upper_diag_index].\nInput shape: ",
520 input_shape.DebugString(),
521 "\nDiagonal shape: ", diag_shape.DebugString(),
522 "\nExpected diagonal shape: ", expected_diag_shape.DebugString()));
523
524 // Actual processing.
525 xla::XlaOp input = context->Input(0);
526 xla::XlaOp diag = context->Input(1);
527 context->SetOutput(
528 0, SetMatrixDiag(input, diag, input_shape, diag_rank, num_diags,
529 lower_diag_index, upper_diag_index, max_diag_len,
530 num_rows, num_cols, left_align_superdiagonal_,
531 left_align_subdiagonal_));
532 }
533
534 private:
535 bool left_align_superdiagonal_ = true;
536 bool left_align_subdiagonal_ = true;
537 static constexpr int kNumV1Inputs = 2;
538 TF_DISALLOW_COPY_AND_ASSIGN(MatrixSetDiagOp);
539 };
540
541 REGISTER_XLA_OP(Name("MatrixSetDiag"), MatrixSetDiagOp);
542 REGISTER_XLA_OP(Name("MatrixSetDiagV2").CompileTimeConstantInput("k"),
543 MatrixSetDiagOp);
544 REGISTER_XLA_OP(Name("MatrixSetDiagV3").CompileTimeConstantInput("k"),
545 MatrixSetDiagOp);
546
547 } // namespace tensorflow
548