1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // See docs in ../ops/array_ops.cc.
17 #define EIGEN_USE_THREADS
18
19 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
20 #define EIGEN_USE_GPU
21 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
22
23 #include "tensorflow/core/kernels/linalg/matrix_diag_op.h"
24
25 #include <algorithm>
26 #include <memory>
27 #include <vector>
28
29 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
30 #include "tensorflow/core/framework/op_kernel.h"
31 #include "tensorflow/core/framework/register_types.h"
32 #include "tensorflow/core/framework/tensor.h"
33 #include "tensorflow/core/framework/tensor_shape.h"
34 #include "tensorflow/core/framework/tensor_types.h"
35 #include "tensorflow/core/framework/types.h"
36 #include "tensorflow/core/lib/core/errors.h"
37 #include "tensorflow/core/lib/core/threadpool.h"
38 #include "tensorflow/core/platform/logging.h"
39 #include "tensorflow/core/platform/macros.h"
40
41 namespace tensorflow {
42
43 typedef Eigen::ThreadPoolDevice CPUDevice;
44 typedef Eigen::GpuDevice GPUDevice;
45
46 template <typename Device, typename T>
47 class MatrixDiagPartOp : public OpKernel {
48 public:
MatrixDiagPartOp(OpKernelConstruction * context)49 explicit MatrixDiagPartOp(OpKernelConstruction* context) : OpKernel(context) {
50 // MatrixDiagPartV3-specific.
51 if (context->HasAttr("align")) {
52 functor::ReadAlignment(context, &left_align_superdiagonal_,
53 &left_align_subdiagonal_);
54 }
55 }
56
Compute(OpKernelContext * context)57 void Compute(OpKernelContext* context) override {
58 const Tensor& input = context->input(0);
59
60 // MatrixDiagPart and MatrixDiagPartV2 both use this OpKernel.
61 // MatrixDiagPart only has one input, so we have to check the number of
62 // inputs before reading additional parameters in MatrixDiagV2.
63 int32 lower_diag_index = 0;
64 int32 upper_diag_index = 0;
65 T padding_value(0);
66
67 // MatrixDiagPartV2-specific.
68 if (context->num_inputs() > kNumV1Inputs) {
69 auto& diag_index = context->input(1);
70 OP_REQUIRES(context,
71 TensorShapeUtils::IsScalar(diag_index.shape()) ||
72 TensorShapeUtils::IsVector(diag_index.shape()),
73 errors::InvalidArgument(
74 "diag_index must be a scalar or vector, received shape: ",
75 diag_index.shape().DebugString()));
76 lower_diag_index = diag_index.flat<int32>()(0);
77 upper_diag_index = lower_diag_index;
78 if (TensorShapeUtils::IsVector(diag_index.shape())) {
79 auto diag_index_size = diag_index.dim_size(0);
80 OP_REQUIRES(
81 context, 0 < diag_index_size && diag_index_size <= 2,
82 errors::InvalidArgument(
83 "diag_index must have only one or two elements, received ",
84 diag_index_size, " elements."));
85 if (diag_index_size > 1) {
86 upper_diag_index = diag_index.flat<int32>()(1);
87 }
88 }
89 padding_value = context->input(2).flat<T>()(0);
90 }
91 const TensorShape& input_shape = input.shape();
92
93 // Preliminary validation of sizes.
94 OP_REQUIRES(context, TensorShapeUtils::IsMatrixOrHigher(input_shape),
95 errors::InvalidArgument(
96 "input must be at least 2-dim, received shape: ",
97 input.shape().DebugString()));
98
99 // Make sure lower_diag_index and upper_diag_index is valid.
100 const int rank = input_shape.dims();
101 const Eigen::Index num_rows = input_shape.dim_size(rank - 2);
102 const Eigen::Index num_cols = input_shape.dim_size(rank - 1);
103 OP_REQUIRES( // Checks lower_diag_index == 0 for when matrix shape = 0.
104 context,
105 (-num_rows < lower_diag_index && lower_diag_index < num_cols) ||
106 lower_diag_index == 0,
107 errors::InvalidArgument(
108 "lower_diag_index is out of bound: ", lower_diag_index,
109 ". It must be between ", -num_rows, " and ", num_cols));
110 OP_REQUIRES(context,
111 (-num_rows < upper_diag_index && upper_diag_index < num_cols) ||
112 upper_diag_index == 0,
113 errors::InvalidArgument(
114 "upper_diag_index is out of bound: ", upper_diag_index,
115 " It must be between ", -num_rows, " and ", num_cols));
116 OP_REQUIRES(
117 context, lower_diag_index <= upper_diag_index,
118 errors::InvalidArgument(
119 "lower_diag_index must not be larger than upper_diag_index: ",
120 lower_diag_index, " > ", upper_diag_index));
121
122 TensorShape output_shape;
123 for (int i = 0; i < rank - 2; ++i) {
124 output_shape.AddDim(input_shape.dim_size(i));
125 }
126 const Eigen::Index num_diags = upper_diag_index - lower_diag_index + 1;
127 if (num_diags > 1) output_shape.AddDim(num_diags);
128 const int32 max_diag_len =
129 std::min(num_rows + std::min(upper_diag_index, 0),
130 num_cols - std::max(lower_diag_index, 0));
131 output_shape.AddDim(max_diag_len);
132
133 Tensor* output = nullptr;
134 OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
135 auto output_reshaped = output->flat<T>();
136 auto input_reshaped = input.flat_inner_dims<T, 3>();
137 functor::MatrixDiagPart<Device, T>::Compute(
138 context, context->eigen_device<Device>(), input_reshaped,
139 output_reshaped, lower_diag_index, upper_diag_index, max_diag_len,
140 padding_value, left_align_superdiagonal_, left_align_subdiagonal_);
141 }
142
143 private:
144 bool left_align_superdiagonal_ = true;
145 bool left_align_subdiagonal_ = true;
146 static constexpr int kNumV1Inputs = 1;
147 TF_DISALLOW_COPY_AND_ASSIGN(MatrixDiagPartOp);
148 };
149
150 template <typename Device, typename T>
151 class MatrixDiagOp : public OpKernel {
152 public:
MatrixDiagOp(OpKernelConstruction * context)153 explicit MatrixDiagOp(OpKernelConstruction* context) : OpKernel(context) {
154 // MatrixDiagV3-specific.
155 if (context->HasAttr("align")) {
156 functor::ReadAlignment(context, &left_align_superdiagonal_,
157 &left_align_subdiagonal_);
158 }
159 }
160
Compute(OpKernelContext * context)161 void Compute(OpKernelContext* context) override {
162 const Tensor& diagonal = context->input(0);
163
164 // MatrixDiag and MatrixDiagV2 both use this OpKernel. MatrixDiag only has
165 // one input, so we have to check the number of inputs before reading
166 // additional parameters in MatrixDiagV2.
167 int32 lower_diag_index = 0;
168 int32 upper_diag_index = 0;
169 int32 num_rows = -1;
170 int32 num_cols = -1;
171 T padding_value(0);
172
173 // MatrixDiagOpV2-specific.
174 if (context->num_inputs() > kNumV1Inputs) {
175 auto& diag_index = context->input(1);
176 OP_REQUIRES(context,
177 TensorShapeUtils::IsScalar(diag_index.shape()) ||
178 TensorShapeUtils::IsVector(diag_index.shape()),
179 errors::InvalidArgument(
180 "diag_index must be a scalar or vector, received shape: ",
181 diag_index.shape().DebugString()));
182 lower_diag_index = diag_index.flat<int32>()(0);
183 upper_diag_index = lower_diag_index;
184 if (TensorShapeUtils::IsVector(diag_index.shape())) {
185 auto diag_index_size = diag_index.dim_size(0);
186 OP_REQUIRES(
187 context, 0 < diag_index_size && diag_index_size <= 2,
188 errors::InvalidArgument(
189 "diag_index must have only one or two elements, received ",
190 diag_index_size, " elements."));
191 if (diag_index_size > 1) {
192 upper_diag_index = diag_index.flat<int32>()(1);
193 }
194 }
195 num_rows = context->input(2).flat<int32>()(0);
196 num_cols = context->input(3).flat<int32>()(0);
197 padding_value = context->input(4).flat<T>()(0);
198 }
199
200 // Size validations.
201 const TensorShape& diagonal_shape = diagonal.shape();
202 const int diag_rank = diagonal_shape.dims();
203 const Eigen::Index num_diags = upper_diag_index - lower_diag_index + 1;
204 OP_REQUIRES(context, TensorShapeUtils::IsVectorOrHigher(diagonal_shape),
205 errors::InvalidArgument(
206 "diagonal must be at least 1-dim, received shape: ",
207 diagonal.shape().DebugString()));
208 OP_REQUIRES(
209 context, lower_diag_index <= upper_diag_index,
210 errors::InvalidArgument(
211 "lower_diag_index must not be larger than upper_diag_index: ",
212 lower_diag_index, " > ", upper_diag_index));
213 OP_REQUIRES(context,
214 lower_diag_index == upper_diag_index ||
215 diagonal_shape.dim_size(diag_rank - 2) == num_diags,
216 errors::InvalidArgument(
217 "The number of diagonals provided in the input does not "
218 "match the lower_diag_index and upper_diag_index range."));
219
220 const Eigen::Index max_diag_len = diagonal_shape.dim_size(diag_rank - 1);
221 const int32 min_num_rows = max_diag_len - std::min(upper_diag_index, 0);
222 const int32 min_num_cols = max_diag_len + std::max(lower_diag_index, 0);
223 OP_REQUIRES(context, num_rows == -1 || num_rows >= min_num_rows,
224 errors::InvalidArgument("The number of rows is too small."));
225 OP_REQUIRES(context, num_cols == -1 || num_cols >= min_num_cols,
226 errors::InvalidArgument("The number of columns is too small."));
227
228 // If both num_rows and num_cols are unknown, assume that output is square.
229 // Otherwise, use smallest possible values.
230 if (num_rows == -1 && num_cols == -1) {
231 num_rows = std::max(min_num_rows, min_num_cols);
232 num_cols = num_rows;
233 } else if (num_rows == -1) {
234 num_rows = min_num_rows;
235 } else if (num_cols == -1) {
236 num_cols = min_num_cols;
237 }
238 OP_REQUIRES(context, num_rows == min_num_rows || num_cols == min_num_cols,
239 errors::InvalidArgument(
240 "The number of rows or columns is not consistent with "
241 "the specified d_lower, d_upper, and diagonal."));
242
243 TensorShape output_shape = diagonal_shape;
244 if (num_diags == 1) { // Output has rank `rank+1`.
245 output_shape.set_dim(diag_rank - 1, num_rows);
246 output_shape.AddDim(num_cols);
247 } else { // Output has rank `rank`.
248 output_shape.set_dim(diag_rank - 2, num_rows);
249 output_shape.set_dim(diag_rank - 1, num_cols);
250 }
251
252 Tensor* output = nullptr;
253 OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
254 auto output_reshaped = output->flat_inner_dims<T, 3>();
255 auto diag_reshaped = diagonal.flat<T>();
256 functor::MatrixDiag<Device, T>::Compute(
257 context, context->eigen_device<Device>(), diag_reshaped,
258 output_reshaped, lower_diag_index, upper_diag_index, max_diag_len,
259 padding_value, left_align_superdiagonal_, left_align_subdiagonal_);
260 }
261
262 private:
263 bool left_align_superdiagonal_ = true;
264 bool left_align_subdiagonal_ = true;
265 static constexpr int kNumV1Inputs = 1;
266 TF_DISALLOW_COPY_AND_ASSIGN(MatrixDiagOp);
267 };
268
269 #define REGISTER_MATRIX_DIAG(type) \
270 REGISTER_KERNEL_BUILDER( \
271 Name("MatrixDiag").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
272 MatrixDiagOp<CPUDevice, type>); \
273 REGISTER_KERNEL_BUILDER( \
274 Name("MatrixDiagV2").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
275 MatrixDiagOp<CPUDevice, type>); \
276 REGISTER_KERNEL_BUILDER( \
277 Name("MatrixDiagV3").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
278 MatrixDiagOp<CPUDevice, type>); \
279 REGISTER_KERNEL_BUILDER( \
280 Name("MatrixDiagPart").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
281 MatrixDiagPartOp<CPUDevice, type>); \
282 REGISTER_KERNEL_BUILDER( \
283 Name("MatrixDiagPartV2").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
284 MatrixDiagPartOp<CPUDevice, type>); \
285 REGISTER_KERNEL_BUILDER( \
286 Name("MatrixDiagPartV3").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
287 MatrixDiagPartOp<CPUDevice, type>);
288
289 TF_CALL_POD_TYPES(REGISTER_MATRIX_DIAG);
290 #undef REGISTER_MATRIX_DIAG
291
292 // Registration of the deprecated kernel.
293 // Delete after 10mar2017.
294 #define REGISTER_BATCH_MATRIX_DIAG(type) \
295 REGISTER_KERNEL_BUILDER( \
296 Name("BatchMatrixDiag").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
297 MatrixDiagOp<CPUDevice, type>); \
298 REGISTER_KERNEL_BUILDER(Name("BatchMatrixDiagPart") \
299 .Device(DEVICE_CPU) \
300 .TypeConstraint<type>("T"), \
301 MatrixDiagPartOp<CPUDevice, type>);
302 TF_CALL_POD_TYPES(REGISTER_BATCH_MATRIX_DIAG);
303 #undef REGISTER_BATCH_MATRIX_DIAG
304
305 // Implementation of the functor specialization for CPU.
306 namespace functor {
307
ReadAlignment(OpKernelConstruction * context,bool * left_align_superdiagonal,bool * left_align_subdiagonal)308 void ReadAlignment(OpKernelConstruction* context,
309 bool* left_align_superdiagonal,
310 bool* left_align_subdiagonal) {
311 string align;
312 OP_REQUIRES_OK(context, context->GetAttr("align", &align));
313
314 *left_align_superdiagonal = align == "LEFT_LEFT" || align == "LEFT_RIGHT";
315 *left_align_subdiagonal = align == "LEFT_LEFT" || align == "RIGHT_LEFT";
316 }
317
ComputeDiagLenAndContentOffset(int diag_index,int max_diag_len,int num_rows,int num_cols,bool left_align_superdiagonal,bool left_align_subdiagonal)318 std::pair<int, int> ComputeDiagLenAndContentOffset(
319 int diag_index, int max_diag_len, int num_rows, int num_cols,
320 bool left_align_superdiagonal, bool left_align_subdiagonal) {
321 const bool left_align = (diag_index >= 0 && left_align_superdiagonal) ||
322 (diag_index <= 0 && left_align_subdiagonal);
323 const int diag_len = std::min(num_rows + std::min(0, diag_index),
324 num_cols - std::max(0, diag_index));
325 const int content_offset = (left_align) ? 0 : (max_diag_len - diag_len);
326 return {diag_len, content_offset};
327 }
328
329 template <typename T>
330 struct MatrixDiag<CPUDevice, T> {
Computetensorflow::functor::MatrixDiag331 static void Compute(OpKernelContext* context, const CPUDevice& device,
332 typename TTypes<T>::ConstTensor& diag,
333 typename TTypes<T, 3>::Tensor& output,
334 const Eigen::Index lower_diag_index,
335 const Eigen::Index upper_diag_index,
336 const Eigen::Index max_diag_len, const T padding_value,
337 const bool left_align_superdiagonal,
338 const bool left_align_subdiagonal) {
339 // 10 in cost_per_batch is from existing heuristic.
340 // TODO(penporn): Tune for the best constant in cost_per_batch.
341 const Eigen::Index num_batches = output.dimension(0);
342 const Eigen::Index num_rows = output.dimension(1);
343 const Eigen::Index num_cols = output.dimension(2);
344 const Eigen::Index cost_per_batch = 10 * num_rows * num_cols;
345
346 auto compute_shard = [&output, &num_rows, &num_cols, &diag,
347 &lower_diag_index, &upper_diag_index, &max_diag_len,
348 &padding_value, &left_align_superdiagonal,
349 &left_align_subdiagonal](Eigen::Index begin,
350 Eigen::Index end) {
351 const int num_diags = upper_diag_index - lower_diag_index + 1;
352 const int diag_elements_in_batch = num_diags * max_diag_len;
353 Eigen::Index diag_batch_base_index = begin * diag_elements_in_batch;
354 for (Eigen::Index batch = begin; batch < end; ++batch) {
355 for (Eigen::Index i = 0; i < output.dimension(1); ++i) {
356 for (Eigen::Index j = 0; j < output.dimension(2); ++j) {
357 const int diag_index = j - i;
358 const int diag_index_in_input = upper_diag_index - diag_index;
359 int diag_len, content_offset;
360 std::tie(diag_len, content_offset) = ComputeDiagLenAndContentOffset(
361 diag_index, max_diag_len, num_rows, num_cols,
362 left_align_superdiagonal, left_align_subdiagonal);
363 const int index_in_the_diagonal =
364 j - std::max<Eigen::Index>(diag_index, 0) + content_offset;
365 if (lower_diag_index <= diag_index &&
366 diag_index <= upper_diag_index) {
367 output(batch, i, j) = diag(diag_batch_base_index +
368 diag_index_in_input * max_diag_len +
369 index_in_the_diagonal);
370 } else {
371 output(batch, i, j) = padding_value;
372 }
373 }
374 }
375 diag_batch_base_index += diag_elements_in_batch;
376 }
377 };
378 auto thread_pool =
379 context->device()->tensorflow_cpu_worker_threads()->workers;
380 thread_pool->ParallelFor(num_batches, cost_per_batch,
381 std::move(compute_shard));
382 }
383 };
384
385 template <typename T>
386 struct MatrixDiagPart<CPUDevice, T> {
Computetensorflow::functor::MatrixDiagPart387 static void Compute(OpKernelContext* context, const CPUDevice& device,
388 typename TTypes<T, 3>::ConstTensor& input,
389 typename TTypes<T>::Tensor& output,
390 const Eigen::Index lower_diag_index,
391 const Eigen::Index upper_diag_index,
392 const Eigen::Index max_diag_len, const T padding_value,
393 const bool left_align_superdiagonal,
394 const bool left_align_subdiagonal) {
395 // 10 in cost_per_batch is from existing heuristic.
396 // TODO(penporn): Tune for the best constant in cost_per_batch.
397 const Eigen::Index num_diags = upper_diag_index - lower_diag_index + 1;
398 const Eigen::Index output_elements_in_batch = num_diags * max_diag_len;
399 const Eigen::Index cost_per_batch = 10 * output_elements_in_batch;
400 const Eigen::Index num_batches = input.dimension(0);
401 const Eigen::Index num_rows = input.dimension(1);
402 const Eigen::Index num_cols = input.dimension(2);
403
404 auto compute_shard = [&output, &input, &num_rows, &num_cols,
405 &upper_diag_index, &max_diag_len, &num_diags,
406 &output_elements_in_batch, &padding_value,
407 &left_align_superdiagonal, &left_align_subdiagonal](
408 Eigen::Index begin, Eigen::Index end) {
409 Eigen::Index output_base_index = begin * output_elements_in_batch;
410 for (Eigen::Index batch = begin; batch < end; ++batch) {
411 for (Eigen::Index m = 0; m < num_diags; ++m) {
412 const Eigen::Index diag_index = upper_diag_index - m;
413 Eigen::Index y_offset = std::max<Eigen::Index>(0, -diag_index);
414 Eigen::Index x_offset = std::max<Eigen::Index>(0, diag_index);
415 int diag_len, content_offset;
416 std::tie(diag_len, content_offset) = ComputeDiagLenAndContentOffset(
417 diag_index, max_diag_len, num_rows, num_cols,
418 left_align_superdiagonal, left_align_subdiagonal);
419
420 // Fills the diagonal.
421 for (Eigen::Index n = 0; n < diag_len; ++n) {
422 output(output_base_index + content_offset + n) =
423 input(batch, n + y_offset, n + x_offset);
424 }
425
426 // Padding.
427 const bool left_align = (content_offset == 0);
428 const Eigen::Index padding_start = (left_align) ? diag_len : 0;
429 const Eigen::Index padding_end =
430 (left_align) ? max_diag_len : content_offset;
431 for (Eigen::Index n = padding_start; n < padding_end; ++n) {
432 output(output_base_index + n) = padding_value;
433 }
434 output_base_index += max_diag_len;
435 }
436 }
437 };
438 auto thread_pool =
439 context->device()->tensorflow_cpu_worker_threads()->workers;
440 thread_pool->ParallelFor(num_batches, cost_per_batch,
441 std::move(compute_shard));
442 }
443 };
444
445 } // namespace functor
446
447 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
448
449 // Forward declarations of the functor specializations for GPU.
450 namespace functor {
451 #define DECLARE_GPU_SPEC(T) \
452 template <> \
453 void MatrixDiag<GPUDevice, T>::Compute( \
454 OpKernelContext* context, const GPUDevice& device, \
455 typename TTypes<T>::ConstTensor& diag, \
456 typename TTypes<T, 3>::Tensor& output, \
457 const Eigen::Index lower_diag_index, \
458 const Eigen::Index upper_diag_index, const Eigen::Index max_diag_len, \
459 const T padding_value, const bool left_align_superdiagonal, \
460 const bool left_align_subdiagonal); \
461 extern template struct MatrixDiag<GPUDevice, T>; \
462 template <> \
463 void MatrixDiagPart<GPUDevice, T>::Compute( \
464 OpKernelContext* context, const GPUDevice& device, \
465 typename TTypes<T, 3>::ConstTensor& input, \
466 typename TTypes<T>::Tensor& output, const Eigen::Index lower_diag_index, \
467 const Eigen::Index upper_diag_index, const Eigen::Index max_diag_len, \
468 const T padding_value, const bool left_align_superdiagonal, \
469 const bool left_align_subdiagonal); \
470 extern template struct MatrixDiagPart<GPUDevice, T>;
471
472 TF_CALL_GPU_ALL_TYPES(DECLARE_GPU_SPEC);
473
474 } // namespace functor
475
476 // Registration of the GPU implementations.
477 #define REGISTER_MATRIX_DIAG_GPU(type) \
478 REGISTER_KERNEL_BUILDER( \
479 Name("MatrixDiag").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
480 MatrixDiagOp<GPUDevice, type>); \
481 REGISTER_KERNEL_BUILDER(Name("MatrixDiagV2") \
482 .Device(DEVICE_GPU) \
483 .TypeConstraint<type>("T") \
484 .HostMemory("k") \
485 .HostMemory("num_rows") \
486 .HostMemory("num_cols") \
487 .HostMemory("padding_value"), \
488 MatrixDiagOp<GPUDevice, type>); \
489 REGISTER_KERNEL_BUILDER(Name("MatrixDiagV3") \
490 .Device(DEVICE_GPU) \
491 .TypeConstraint<type>("T") \
492 .HostMemory("k") \
493 .HostMemory("num_rows") \
494 .HostMemory("num_cols") \
495 .HostMemory("padding_value"), \
496 MatrixDiagOp<GPUDevice, type>); \
497 REGISTER_KERNEL_BUILDER( \
498 Name("MatrixDiagPart").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
499 MatrixDiagPartOp<GPUDevice, type>); \
500 REGISTER_KERNEL_BUILDER(Name("MatrixDiagPartV2") \
501 .Device(DEVICE_GPU) \
502 .TypeConstraint<type>("T") \
503 .HostMemory("k") \
504 .HostMemory("padding_value"), \
505 MatrixDiagPartOp<GPUDevice, type>); \
506 REGISTER_KERNEL_BUILDER(Name("MatrixDiagPartV3") \
507 .Device(DEVICE_GPU) \
508 .TypeConstraint<type>("T") \
509 .HostMemory("k") \
510 .HostMemory("padding_value"), \
511 MatrixDiagPartOp<GPUDevice, type>);
512
513 TF_CALL_GPU_ALL_TYPES(REGISTER_MATRIX_DIAG_GPU);
514 #undef REGISTER_MATRIX_DIAG_GPU
515
516 // Registration of the deprecated kernel.
517 // Delete after 10mar2017.
518 #define REGISTER_BATCH_MATRIX_DIAG_GPU(type) \
519 REGISTER_KERNEL_BUILDER( \
520 Name("BatchMatrixDiag").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
521 MatrixDiagOp<GPUDevice, type>); \
522 REGISTER_KERNEL_BUILDER(Name("BatchMatrixDiagPart") \
523 .Device(DEVICE_GPU) \
524 .TypeConstraint<type>("T"), \
525 MatrixDiagPartOp<GPUDevice, type>);
526 TF_CALL_GPU_NUMBER_TYPES(REGISTER_BATCH_MATRIX_DIAG_GPU);
527 #undef REGISTER_BATCH_MATRIX_DIAG_GPU
528
529 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
530
531 } // namespace tensorflow
532