1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
17 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
18 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
19 #include "tensorflow/compiler/xla/client/lib/constants.h"
20 #include "tensorflow/compiler/xla/client/lib/matrix.h"
21 #include "tensorflow/compiler/xla/client/xla_builder.h"
22 #include "tensorflow/compiler/xla/primitive_util.h"
23 #include "tensorflow/core/framework/op_kernel.h"
24 #include "tensorflow/core/lib/core/errors.h"
25
26 namespace tensorflow {
27 namespace {
28
29 // Calculates the diagonal length of a diagonal.
ComputeDiagLen(int diag_index,int num_rows,int num_cols)30 static inline int ComputeDiagLen(int diag_index, int num_rows, int num_cols) {
31 return std::min(num_rows + std::min(0, diag_index),
32 num_cols - std::max(0, diag_index));
33 }
34
35 // Checks if a diagonal is to be aligned left or right.
IsLeftAligned(int diag_index,bool left_align_superdiagonal,bool left_align_subdiagonal)36 static inline bool IsLeftAligned(int diag_index, bool left_align_superdiagonal,
37 bool left_align_subdiagonal) {
38 return (diag_index >= 0 && left_align_superdiagonal) ||
39 (diag_index <= 0 && left_align_subdiagonal);
40 }
41
42 // Reads the diagonal packing alignment.
ReadAlignment(OpKernelConstruction * context,bool * left_align_superdiagonal,bool * left_align_subdiagonal)43 void ReadAlignment(OpKernelConstruction* context,
44 bool* left_align_superdiagonal,
45 bool* left_align_subdiagonal) {
46 string align;
47 OP_REQUIRES_OK(context, context->GetAttr("align", &align));
48
49 *left_align_superdiagonal = align == "LEFT_LEFT" || align == "LEFT_RIGHT";
50 *left_align_subdiagonal = align == "LEFT_LEFT" || align == "RIGHT_LEFT";
51 }
52
53 // Reads or infers lower_diag_index and upper_diag_index from kernel's input
54 // parameter "k". Also validates their values.
ProcessDiagIndex(XlaOpKernelContext * context)55 std::pair<int64, int64> ProcessDiagIndex(XlaOpKernelContext* context) {
56 int64 lower_diag_index = 0;
57 int64 upper_diag_index = 0;
58 TensorShape diag_index_shape = context->InputShape("k");
59
60 // Wrapping OP_REQUIRES* macros with a function because they can "return;"
61 // early (without values) which contradicts ProcessDiagIndex's signature.
62 auto validate_diag_indices = [&]() {
63 if (diag_index_shape.dims() == 0) {
64 OP_REQUIRES_OK(context,
65 context->ConstantInputAsIntScalar("k", &lower_diag_index));
66 upper_diag_index = lower_diag_index;
67 } else {
68 std::vector<int64> diag_index;
69 OP_REQUIRES_OK(context,
70 context->ConstantInputAsIntVector("k", &diag_index));
71 OP_REQUIRES(
72 context, !diag_index.empty() && diag_index.size() <= 2,
73 errors::InvalidArgument(
74 "diag_index must have only one or two elements, received ",
75 diag_index.size(), " elements."));
76 lower_diag_index = diag_index[0];
77 upper_diag_index =
78 (diag_index.size() > 1) ? diag_index[1] : lower_diag_index;
79 }
80 OP_REQUIRES(
81 context, lower_diag_index <= upper_diag_index,
82 errors::InvalidArgument(
83 "lower_diag_index must not be larger than upper_diag_index: ",
84 lower_diag_index, " > ", upper_diag_index));
85 };
86 validate_diag_indices();
87 return {lower_diag_index, upper_diag_index};
88 }
89
90 // Makes sure lower_diag_index and upper_diag_index are consistent with the
91 // input matrix size.
ValidateDiagIndexWithOutputMatrixSize(XlaOpKernelContext * context,const int64 lower_diag_index,const int64 upper_diag_index,const int64 num_rows,const int64 num_cols)92 void ValidateDiagIndexWithOutputMatrixSize(XlaOpKernelContext* context,
93 const int64 lower_diag_index,
94 const int64 upper_diag_index,
95 const int64 num_rows,
96 const int64 num_cols) {
97 // `lower_diag_index == 0` condition is added to handle matrix shape = 0.
98 OP_REQUIRES(context,
99 (-num_rows < lower_diag_index && lower_diag_index < num_cols) ||
100 lower_diag_index == 0,
101 errors::InvalidArgument(
102 "lower_diag_index is out of bound: ", lower_diag_index,
103 " It must be between ", -num_rows, " and ", num_cols));
104 OP_REQUIRES(context,
105 (-num_rows < upper_diag_index && upper_diag_index < num_cols) ||
106 upper_diag_index == 0,
107 errors::InvalidArgument(
108 "upper_diag_index is out of bound: ", upper_diag_index,
109 " It must be between ", -num_rows, " and ", num_cols));
110 OP_REQUIRES(context, lower_diag_index <= upper_diag_index,
111 errors::InvalidArgument(
112 "lower_diag_index must not be larger than upper_diag_index: ",
113 lower_diag_index, " > ", upper_diag_index));
114 }
115
116 // Kernel to set matrix diagonals.
SetMatrixDiag(const xla::XlaOp input,const xla::XlaOp diag,const TensorShape & input_shape,const int64 diag_rank,const int64 num_diags,const int64 lower_diag_index,const int64 upper_diag_index,const int64 max_diag_len,const int64 num_rows,const int64 num_cols,const bool left_align_superdiagonal,const bool left_align_subdiagonal)117 xla::XlaOp SetMatrixDiag(const xla::XlaOp input, const xla::XlaOp diag,
118 const TensorShape& input_shape, const int64 diag_rank,
119 const int64 num_diags, const int64 lower_diag_index,
120 const int64 upper_diag_index, const int64 max_diag_len,
121 const int64 num_rows, const int64 num_cols,
122 const bool left_align_superdiagonal,
123 const bool left_align_subdiagonal) {
124 // Creates a padding config.
125 const int input_rank = input_shape.dims();
126 xla::PaddingConfig padding_config;
127 padding_config = xla::MakeNoPaddingConfig(input_rank - 1);
128
129 // Processes one diagonal at a time:
130 // 1) Extracts a single diagonal (diag_slice).
131 // 2) Broadcasts its contents to fill the whole matrix (diag_broadcast).
132 // 3) Masks diag_broadcast to get the right diagonal shape.
133 //
134 // XLA can fuse multiple Broadcasts and Selects so this shouldn't be slow.
135 //
136 // For example,
137 // diag = [[0, 2, 3], k = (-1, 1), num_cols = 4, and align="RIGHT_LEFT".
138 // [4, 5, 6],
139 // [7, 8, 9]]
140 // The expected output is [[7, 4, 2, 0],
141 // [0, 8, 5, 3],
142 // [0, 0, 9, 6]].
143 // The 1st diagonal is created by:
144 // 1) Extracting diag_slice = [0, 2, 3] which is right-aligned.
145 // 2) Padding the vector (in the same direction) to be as long as num_cols,
146 // diag_slice = [0, 0, 2, 3],
147 // then broadcasting diag_slice column-wise to a full matrix,
148 // diag_broadcast = [[0, 0, 2, 3],
149 // [0, 0, 2, 3],
150 // [0, 0, 2, 3]].
151 // The padding value can be anything because it will not appear in the
152 // results after masking. Here, we use zero.
153 // 3) Masking diag_broadcast with a mask of the shape of the 1st diagonal.
154 // mask = [[0, 0, 1, 0], --> output = [[x, x, 2, x],
155 // [0, 0, 0, 1], [x, x, x, 3],
156 // [0, 0, 0, 0]] [x, x, x, x]],
157 // where x denotes the existing input contents.
158 std::vector<int64> broadcast_dimensions(input_rank - 1);
159 absl::c_iota(broadcast_dimensions, 0);
160 auto output = input;
161 for (int64 diag_index = lower_diag_index; diag_index <= upper_diag_index;
162 ++diag_index) {
163 // Extracts a single diagonal.
164 auto diag_slice = diag;
165 if (num_diags > 1) {
166 // The result of SliceInDim has dims: [<batch_dim>, 1, max_diag_len].
167 // We call Collapse to make the dims: [<batch_dim>, max_diag_len].
168 const int64 mapped_diag_index = upper_diag_index - diag_index;
169 diag_slice = xla::Collapse(
170 xla::SliceInDim(diag, mapped_diag_index, mapped_diag_index + 1, 1,
171 diag_rank - 2),
172 {diag_rank - 2, diag_rank - 1});
173 }
174
175 // Pad if necessary.
176 // - If the diagonal has the longest length, i.e., min(num_rows, num_cols),
177 // no padding is necessary. It is broadcast column-wise if it is a sub-
178 // diagonal, row-wise if superdiagonal.
179 // - Otherwise, pad and keep the old alignment (shorter diagonals in the
180 // input come pre-padded). max_len in the table refers to max_diag_len.
181 // -------------------------------------------------------------------
182 // | Diag | Align | Broadcast | padding_low | padding_high |
183 // -------------------------------------------------------------------
184 // | Super | Left | Row-wise | 0 | #rows - max_len |
185 // | | Right | Column-wise | #cols - max_len | 0 |
186 // -------------------------------------------------------------------
187 // | Sub | Left | Column-wise | 0 | #cols - max_len |
188 // | | Right | Row-wise | #rows - max_len | 0 |
189 // -------------------------------------------------------------------
190 if (num_cols - num_rows <= diag_index && diag_index <= 0) {
191 broadcast_dimensions.back() = input_rank - 1; // Column-wise.
192 } else if (0 <= diag_index && diag_index <= num_cols - num_rows) {
193 broadcast_dimensions.back() = input_rank - 2; // Row-wise.
194 } else {
195 int length_to_pad_to;
196 if ((diag_index > 0 && left_align_superdiagonal) ||
197 (diag_index < 0 && !left_align_subdiagonal)) {
198 length_to_pad_to = num_rows;
199 broadcast_dimensions.back() = input_rank - 2; // Row-wise.
200 } else {
201 length_to_pad_to = num_cols;
202 broadcast_dimensions.back() = input_rank - 1; // Column-wise.
203 }
204 int padding_low = length_to_pad_to - max_diag_len;
205 int padding_high = 0;
206 if (IsLeftAligned(diag_index, left_align_superdiagonal,
207 left_align_subdiagonal)) {
208 std::swap(padding_low, padding_high);
209 }
210 padding_config.mutable_dimensions(input_rank - 2)
211 ->set_edge_padding_low(padding_low);
212 padding_config.mutable_dimensions(input_rank - 2)
213 ->set_edge_padding_high(padding_high);
214
215 const xla::XlaOp zero = xla::ScalarLike(input, 0);
216 diag_slice = xla::Pad(diag_slice, zero, padding_config);
217 }
218
219 // Broadcast and mask.
220 xla::XlaOp diag_broadcast = xla::BroadcastInDim(
221 diag_slice, input_shape.dim_sizes(), broadcast_dimensions);
222 const auto mask = xla::GetDiagonalMask(output, diag_index);
223 output = xla::Select(mask, diag_broadcast, output);
224 }
225 return output;
226 }
227
228 } // namespace
229
230 class MatrixDiagOp : public XlaOpKernel {
231 public:
MatrixDiagOp(OpKernelConstruction * context)232 explicit MatrixDiagOp(OpKernelConstruction* context) : XlaOpKernel(context) {
233 // MatrixDiagV3-specific.
234 if (context->HasAttr("align")) {
235 ReadAlignment(context, &left_align_superdiagonal_,
236 &left_align_subdiagonal_);
237 }
238 }
239
Compile(XlaOpKernelContext * context)240 void Compile(XlaOpKernelContext* context) override {
241 OP_REQUIRES(
242 context, context->num_inputs() >= kNumV1Inputs,
243 errors::InvalidArgument("MatrixDiag op must have at least one input"));
244 const TensorShape diag_shape = context->InputShape(0);
245 OP_REQUIRES(context, TensorShapeUtils::IsVectorOrHigher(diag_shape),
246 errors::InvalidArgument("Expected >= 1 dims, got shape ",
247 diag_shape.DebugString()));
248
249 const DataType dtype = context->expected_output_dtype(0);
250 const xla::XlaOp zero = XlaHelpers::Zero(context->builder(), dtype);
251
252 // Initializes MatrixDiagV2-specific variables.
253 // Input arguments providing the values of num_rows and num_cols can be
254 // absent (-1) and will be inferred later.
255 int64 lower_diag_index = 0;
256 int64 upper_diag_index = 0;
257 int64 num_rows = -1;
258 int64 num_cols = -1;
259 xla::XlaOp padding_value = zero;
260
261 // MatrixDiag and MatrixDiagV2 both use this OpKernel. MatrixDiag only has
262 // one input, so we have to check the number of inputs before reading
263 // additional parameters for MatrixDiagV2.
264 if (context->num_inputs() > kNumV1Inputs) {
265 std::tie(lower_diag_index, upper_diag_index) = ProcessDiagIndex(context);
266 OP_REQUIRES_OK(context, context->ConstantInputAsIntScalar(2, &num_rows));
267 OP_REQUIRES_OK(context, context->ConstantInputAsIntScalar(3, &num_cols));
268 padding_value = context->Input(4);
269 }
270
271 // More size validations.
272 const int64 diag_rank = diag_shape.dims();
273 const int64 max_diag_len = diag_shape.dim_size(diag_rank - 1);
274 const int64 num_diags = upper_diag_index - lower_diag_index + 1;
275 OP_REQUIRES(
276 context,
277 num_diags == 1 || num_diags == diag_shape.dim_size(diag_rank - 2),
278 errors::InvalidArgument(
279 "The number of diagonals provided in the input does not "
280 "match the lower_diag_index and upper_diag_index range."));
281 const int64 min_num_rows =
282 max_diag_len - std::min(upper_diag_index, int64{0});
283 const int64 min_num_cols =
284 max_diag_len + std::max(lower_diag_index, int64{0});
285 OP_REQUIRES(context, num_rows == -1 || num_rows >= min_num_rows,
286 errors::InvalidArgument("The number of rows is too small."));
287 OP_REQUIRES(context, num_cols == -1 || num_cols >= min_num_cols,
288 errors::InvalidArgument("The number of columns is too small."));
289
290 // Infers num_rows and num_cols. If both are unknown, assume that the output
291 // is square. Otherwise, use smallest possible values.
292 if (num_rows == -1 && num_cols == -1) {
293 num_rows = std::max(min_num_rows, min_num_cols);
294 num_cols = num_rows;
295 } else if (num_rows == -1) {
296 num_rows = min_num_rows;
297 } else if (num_cols == -1) {
298 num_cols = min_num_cols;
299 }
300
301 // At least one of num_rows and num_cols must match its minimum length.
302 // Otherwise, we'll have some incomplete diagonals.
303 OP_REQUIRES(context, num_rows == min_num_rows || num_cols == min_num_cols,
304 errors::InvalidArgument(
305 "The number of rows or columns is not consistent with "
306 "the specified d_lower, d_upper, and diagonal."));
307
308 // Actual processing.
309 // Initializes the output tensor with padding_value.
310 TensorShape output_shape = diag_shape;
311 output_shape.RemoveLastDims((num_diags == 1) ? 1 : 2);
312 output_shape.AddDim(num_rows);
313 output_shape.AddDim(num_cols);
314 xla::XlaOp output = xla::Broadcast(padding_value, output_shape.dim_sizes());
315 xla::XlaOp diag = context->Input(0);
316 context->SetOutput(
317 0, SetMatrixDiag(output, diag, output_shape, diag_rank, num_diags,
318 lower_diag_index, upper_diag_index, max_diag_len,
319 num_rows, num_cols, left_align_superdiagonal_,
320 left_align_subdiagonal_));
321 }
322
323 private:
324 bool left_align_superdiagonal_ = true;
325 bool left_align_subdiagonal_ = true;
326 static constexpr int kNumV1Inputs = 1;
327 };
328
329 REGISTER_XLA_OP(Name("MatrixDiag"), MatrixDiagOp);
330 REGISTER_XLA_OP(Name("MatrixDiagV2")
331 .CompileTimeConstantInput("k")
332 .CompileTimeConstantInput("num_rows")
333 .CompileTimeConstantInput("num_cols")
334 .CompileTimeConstantInput("padding_value"),
335 MatrixDiagOp);
336 REGISTER_XLA_OP(Name("MatrixDiagV3")
337 .CompileTimeConstantInput("k")
338 .CompileTimeConstantInput("num_rows")
339 .CompileTimeConstantInput("num_cols")
340 .CompileTimeConstantInput("padding_value"),
341 MatrixDiagOp);
342
343 class MatrixDiagPartOp : public XlaOpKernel {
344 public:
MatrixDiagPartOp(OpKernelConstruction * context)345 explicit MatrixDiagPartOp(OpKernelConstruction* context)
346 : XlaOpKernel(context),
347 is_gpu_(context->device_type().type_string() == DEVICE_GPU_XLA_JIT) {
348 // MatrixDiagPartV3-specific.
349 if (context->HasAttr("align")) {
350 ReadAlignment(context, &left_align_superdiagonal_,
351 &left_align_subdiagonal_);
352 }
353 }
354
Compile(XlaOpKernelContext * context)355 void Compile(XlaOpKernelContext* context) override {
356 const TensorShape input_shape = context->InputShape(0);
357 const int input_rank = input_shape.dims();
358
359 OP_REQUIRES(context, TensorShapeUtils::IsMatrixOrHigher(input_shape),
360 errors::InvalidArgument(
361 "input must be at least 2-dim, received shape: ",
362 input_shape.DebugString()));
363
364 const DataType dtype = context->expected_output_dtype(0);
365 const xla::XlaOp zero = XlaHelpers::Zero(context->builder(), dtype);
366
367 // Initializes MatrixDiagPartV2-specific variables.
368 int64 lower_diag_index = 0;
369 int64 upper_diag_index = 0;
370 xla::XlaOp padding_value = zero;
371
372 // MatrixDiagPart and MatrixDiagPartV2 both use this OpKernel.
373 // MatrixDiagPart only has one input, so we have to check the number of
374 // inputs before reading additional parameters in MatrixDiagV2.
375 if (context->num_inputs() > kNumV1Inputs) {
376 std::tie(lower_diag_index, upper_diag_index) = ProcessDiagIndex(context);
377 padding_value = context->Input(2);
378 }
379
380 // Checks if diag sizes are consistent with input.
381 const int64 num_rows = input_shape.dim_size(input_rank - 2);
382 const int64 num_cols = input_shape.dim_size(input_rank - 1);
383 ValidateDiagIndexWithOutputMatrixSize(context, lower_diag_index,
384 upper_diag_index, num_rows, num_cols);
385
386 // Creates output shape.
387 TensorShape output_shape = input_shape;
388 output_shape.RemoveLastDims(2);
389 const int num_diags = upper_diag_index - lower_diag_index + 1;
390 if (num_diags > 1) output_shape.AddDim(num_diags);
391 const int32 max_diag_len =
392 std::min(num_rows + std::min(upper_diag_index, int64{0}),
393 num_cols - std::max(lower_diag_index, int64{0}));
394 output_shape.AddDim(max_diag_len);
395
396 // Computes output.
397 xla::XlaOp input = context->Input(0);
398 std::vector<xla::XlaOp> diag_list;
399 xla::PaddingConfig padding_config =
400 xla::MakeNoPaddingConfig(input_rank - 1);
401 if (num_diags == 1) {
402 context->SetOutput(
403 0, is_gpu_ ? xla::GetMatrixDiagonalViaGather(input, upper_diag_index)
404 : xla::GetMatrixDiagonal(input, upper_diag_index));
405 return;
406 }
407 for (int diag_index = upper_diag_index; diag_index >= lower_diag_index;
408 --diag_index) {
409 xla::XlaOp single_diag =
410 is_gpu_ ? xla::GetMatrixDiagonalViaGather(input, diag_index)
411 : xla::GetMatrixDiagonal(input, diag_index);
412 const int64 diag_len = ComputeDiagLen(diag_index, num_rows, num_cols);
413 const int64 padding_len = max_diag_len - diag_len;
414 if (padding_len > 0) {
415 if (IsLeftAligned(diag_index, left_align_superdiagonal_,
416 left_align_subdiagonal_)) {
417 padding_config.mutable_dimensions(input_rank - 2)
418 ->set_edge_padding_low(0);
419 padding_config.mutable_dimensions(input_rank - 2)
420 ->set_edge_padding_high(padding_len);
421 } else {
422 padding_config.mutable_dimensions(input_rank - 2)
423 ->set_edge_padding_low(padding_len);
424 padding_config.mutable_dimensions(input_rank - 2)
425 ->set_edge_padding_high(0);
426 }
427 single_diag = xla::Pad(single_diag, padding_value, padding_config);
428 }
429 diag_list.emplace_back(single_diag);
430 }
431 auto concat =
432 xla::ConcatInDim(context->builder(), diag_list, input_rank - 2);
433 context->SetOutput(0, xla::Reshape(concat, output_shape.dim_sizes()));
434 }
435
436 private:
437 const bool is_gpu_;
438 bool left_align_superdiagonal_ = true;
439 bool left_align_subdiagonal_ = true;
440 static constexpr int kNumV1Inputs = 1;
441 };
442
443 REGISTER_XLA_OP(Name("MatrixDiagPart"), MatrixDiagPartOp);
444 REGISTER_XLA_OP(Name("MatrixDiagPartV2")
445 .CompileTimeConstantInput("k")
446 .CompileTimeConstantInput("padding_value"),
447 MatrixDiagPartOp);
448 REGISTER_XLA_OP(Name("MatrixDiagPartV3")
449 .CompileTimeConstantInput("k")
450 .CompileTimeConstantInput("padding_value"),
451 MatrixDiagPartOp);
452
453 class MatrixSetDiagOp : public XlaOpKernel {
454 public:
MatrixSetDiagOp(OpKernelConstruction * context)455 explicit MatrixSetDiagOp(OpKernelConstruction* context)
456 : XlaOpKernel(context) {
457 // MatrixSetDiagV3-specific.
458 if (context->HasAttr("align")) {
459 ReadAlignment(context, &left_align_superdiagonal_,
460 &left_align_subdiagonal_);
461 }
462 }
463
Compile(XlaOpKernelContext * context)464 void Compile(XlaOpKernelContext* context) override {
465 const TensorShape input_shape = context->InputShape(0);
466 const TensorShape diag_shape = context->InputShape(1);
467 const int input_rank = input_shape.dims();
468 const int diag_rank = diag_shape.dims();
469
470 // Preliminary validation of sizes.
471 OP_REQUIRES(context, TensorShapeUtils::IsMatrixOrHigher(input_shape),
472 errors::InvalidArgument(
473 "input must be at least 2-dim, received shape: ",
474 input_shape.DebugString()));
475 OP_REQUIRES(context, TensorShapeUtils::IsVectorOrHigher(diag_shape),
476 errors::InvalidArgument(
477 "diagonal must be at least 1-dim, received shape: ",
478 diag_shape.DebugString()));
479
480 // MatrixSetDiag and MatrixSetDiagV2 both use this OpKernel. MatrixSetDiag
481 // only has two inputs, so we have to check the number of inputs before
482 // reading additional parameters in MatrixSetDiagV2.
483 int64 lower_diag_index = 0;
484 int64 upper_diag_index = 0;
485 if (context->num_inputs() > kNumV1Inputs) {
486 std::tie(lower_diag_index, upper_diag_index) = ProcessDiagIndex(context);
487 }
488
489 // Checks if diag sizes are consistent with input.
490 const int64 num_rows = input_shape.dim_size(input_rank - 2);
491 const int64 num_cols = input_shape.dim_size(input_rank - 1);
492 ValidateDiagIndexWithOutputMatrixSize(context, lower_diag_index,
493 upper_diag_index, num_rows, num_cols);
494 const Eigen::Index num_diags = upper_diag_index - lower_diag_index + 1;
495 OP_REQUIRES(
496 context,
497 lower_diag_index == upper_diag_index ||
498 (diag_shape.dim_size(input_rank - 2) == num_diags),
499 errors::InvalidArgument("The number of diagonals provided in `diag` "
500 "is not consistent with `lower_diag_index` and "
501 "`upper_diag_index`"));
502
503 TensorShape expected_diag_shape = input_shape;
504 expected_diag_shape.RemoveLastDims(2);
505 if (num_diags > 1) expected_diag_shape.AddDim(num_diags);
506 const int32 max_diag_len =
507 std::min(num_rows + std::min(upper_diag_index, int64{0}),
508 num_cols - std::max(lower_diag_index, int64{0}));
509 expected_diag_shape.AddDim(max_diag_len);
510 OP_REQUIRES(
511 context, expected_diag_shape == diag_shape,
512 errors::InvalidArgument(
513 "Either first dimensions of diagonal don't match input.shape[:-2], "
514 "or diagonal.shape[:-1] is not equal to the longests diagonal in "
515 "range [lower_diag_index:upper_diag_index].\nInput shape: ",
516 input_shape.DebugString(),
517 "\nDiagonal shape: ", diag_shape.DebugString(),
518 "\nExpected diagonal shape: ", expected_diag_shape.DebugString()));
519
520 // Actual processing.
521 xla::XlaOp input = context->Input(0);
522 xla::XlaOp diag = context->Input(1);
523 context->SetOutput(
524 0, SetMatrixDiag(input, diag, input_shape, diag_rank, num_diags,
525 lower_diag_index, upper_diag_index, max_diag_len,
526 num_rows, num_cols, left_align_superdiagonal_,
527 left_align_subdiagonal_));
528 }
529
530 private:
531 bool left_align_superdiagonal_ = true;
532 bool left_align_subdiagonal_ = true;
533 static constexpr int kNumV1Inputs = 2;
534 TF_DISALLOW_COPY_AND_ASSIGN(MatrixSetDiagOp);
535 };
536
537 REGISTER_XLA_OP(Name("MatrixSetDiag"), MatrixSetDiagOp);
538 REGISTER_XLA_OP(Name("MatrixSetDiagV2").CompileTimeConstantInput("k"),
539 MatrixSetDiagOp);
540 REGISTER_XLA_OP(Name("MatrixSetDiagV3").CompileTimeConstantInput("k"),
541 MatrixSetDiagOp);
542
543 } // namespace tensorflow
544