• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/array_ops.cc.
17 #define EIGEN_USE_THREADS
18 
19 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
20 #define EIGEN_USE_GPU
21 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
22 
23 #include "tensorflow/core/kernels/linalg/matrix_diag_op.h"
24 
25 #include <algorithm>
26 #include <memory>
27 #include <vector>
28 
29 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
30 #include "tensorflow/core/framework/op_kernel.h"
31 #include "tensorflow/core/framework/register_types.h"
32 #include "tensorflow/core/framework/tensor.h"
33 #include "tensorflow/core/framework/tensor_shape.h"
34 #include "tensorflow/core/framework/tensor_types.h"
35 #include "tensorflow/core/framework/types.h"
36 #include "tensorflow/core/lib/core/errors.h"
37 #include "tensorflow/core/lib/core/threadpool.h"
38 #include "tensorflow/core/platform/logging.h"
39 #include "tensorflow/core/platform/macros.h"
40 
41 namespace tensorflow {
42 
43 typedef Eigen::ThreadPoolDevice CPUDevice;
44 typedef Eigen::GpuDevice GPUDevice;
45 
46 template <typename Device, typename T>
47 class MatrixDiagPartOp : public OpKernel {
48  public:
MatrixDiagPartOp(OpKernelConstruction * context)49   explicit MatrixDiagPartOp(OpKernelConstruction* context) : OpKernel(context) {
50     // MatrixDiagPartV3-specific.
51     if (context->HasAttr("align")) {
52       functor::ReadAlignment(context, &left_align_superdiagonal_,
53                              &left_align_subdiagonal_);
54     }
55   }
56 
Compute(OpKernelContext * context)57   void Compute(OpKernelContext* context) override {
58     const Tensor& input = context->input(0);
59 
60     // MatrixDiagPart and MatrixDiagPartV2 both use this OpKernel.
61     // MatrixDiagPart only has one input, so we have to check the number of
62     // inputs before reading additional parameters in MatrixDiagV2.
63     int32 lower_diag_index = 0;
64     int32 upper_diag_index = 0;
65     T padding_value(0);
66 
67     // MatrixDiagPartV2-specific.
68     if (context->num_inputs() > kNumV1Inputs) {
69       auto& diag_index = context->input(1);
70       OP_REQUIRES(context,
71                   TensorShapeUtils::IsScalar(diag_index.shape()) ||
72                       TensorShapeUtils::IsVector(diag_index.shape()),
73                   errors::InvalidArgument(
74                       "diag_index must be a scalar or vector, received shape: ",
75                       diag_index.shape().DebugString()));
76       lower_diag_index = diag_index.flat<int32>()(0);
77       upper_diag_index = lower_diag_index;
78       if (TensorShapeUtils::IsVector(diag_index.shape())) {
79         auto diag_index_size = diag_index.dim_size(0);
80         OP_REQUIRES(
81             context, 0 < diag_index_size && diag_index_size <= 2,
82             errors::InvalidArgument(
83                 "diag_index must have only one or two elements, received ",
84                 diag_index_size, " elements."));
85         if (diag_index_size > 1) {
86           upper_diag_index = diag_index.flat<int32>()(1);
87         }
88       }
89       padding_value = context->input(2).flat<T>()(0);
90     }
91     const TensorShape& input_shape = input.shape();
92 
93     // Preliminary validation of sizes.
94     OP_REQUIRES(context, TensorShapeUtils::IsMatrixOrHigher(input_shape),
95                 errors::InvalidArgument(
96                     "input must be at least 2-dim, received shape: ",
97                     input.shape().DebugString()));
98 
99     // Make sure lower_diag_index and upper_diag_index is valid.
100     const int rank = input_shape.dims();
101     const Eigen::Index num_rows = input_shape.dim_size(rank - 2);
102     const Eigen::Index num_cols = input_shape.dim_size(rank - 1);
103     OP_REQUIRES(  // Checks lower_diag_index == 0 for when matrix shape = 0.
104         context,
105         (-num_rows < lower_diag_index && lower_diag_index < num_cols) ||
106             lower_diag_index == 0,
107         errors::InvalidArgument(
108             "lower_diag_index is out of bound: ", lower_diag_index,
109             ". It must be between ", -num_rows, " and ", num_cols));
110     OP_REQUIRES(context,
111                 (-num_rows < upper_diag_index && upper_diag_index < num_cols) ||
112                     upper_diag_index == 0,
113                 errors::InvalidArgument(
114                     "upper_diag_index is out of bound: ", upper_diag_index,
115                     " It must be between ", -num_rows, " and ", num_cols));
116     OP_REQUIRES(
117         context, lower_diag_index <= upper_diag_index,
118         errors::InvalidArgument(
119             "lower_diag_index must not be larger than upper_diag_index: ",
120             lower_diag_index, " > ", upper_diag_index));
121 
122     TensorShape output_shape;
123     for (int i = 0; i < rank - 2; ++i) {
124       output_shape.AddDim(input_shape.dim_size(i));
125     }
126     const Eigen::Index num_diags = upper_diag_index - lower_diag_index + 1;
127     if (num_diags > 1) output_shape.AddDim(num_diags);
128     const int32 max_diag_len =
129         std::min(num_rows + std::min(upper_diag_index, 0),
130                  num_cols - std::max(lower_diag_index, 0));
131     output_shape.AddDim(max_diag_len);
132 
133     Tensor* output = nullptr;
134     OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
135     auto output_reshaped = output->flat<T>();
136     auto input_reshaped = input.flat_inner_dims<T, 3>();
137     functor::MatrixDiagPart<Device, T>::Compute(
138         context, context->eigen_device<Device>(), input_reshaped,
139         output_reshaped, lower_diag_index, upper_diag_index, max_diag_len,
140         padding_value, left_align_superdiagonal_, left_align_subdiagonal_);
141   }
142 
143  private:
144   bool left_align_superdiagonal_ = true;
145   bool left_align_subdiagonal_ = true;
146   static constexpr int kNumV1Inputs = 1;
147   TF_DISALLOW_COPY_AND_ASSIGN(MatrixDiagPartOp);
148 };
149 
150 template <typename Device, typename T>
151 class MatrixDiagOp : public OpKernel {
152  public:
MatrixDiagOp(OpKernelConstruction * context)153   explicit MatrixDiagOp(OpKernelConstruction* context) : OpKernel(context) {
154     // MatrixDiagV3-specific.
155     if (context->HasAttr("align")) {
156       functor::ReadAlignment(context, &left_align_superdiagonal_,
157                              &left_align_subdiagonal_);
158     }
159   }
160 
Compute(OpKernelContext * context)161   void Compute(OpKernelContext* context) override {
162     const Tensor& diagonal = context->input(0);
163 
164     // MatrixDiag and MatrixDiagV2 both use this OpKernel. MatrixDiag only has
165     // one input, so we have to check the number of inputs before reading
166     // additional parameters in MatrixDiagV2.
167     int32 lower_diag_index = 0;
168     int32 upper_diag_index = 0;
169     int32 num_rows = -1;
170     int32 num_cols = -1;
171     T padding_value(0);
172 
173     // MatrixDiagOpV2-specific.
174     if (context->num_inputs() > kNumV1Inputs) {
175       auto& diag_index = context->input(1);
176       OP_REQUIRES(context,
177                   TensorShapeUtils::IsScalar(diag_index.shape()) ||
178                       TensorShapeUtils::IsVector(diag_index.shape()),
179                   errors::InvalidArgument(
180                       "diag_index must be a scalar or vector, received shape: ",
181                       diag_index.shape().DebugString()));
182       lower_diag_index = diag_index.flat<int32>()(0);
183       upper_diag_index = lower_diag_index;
184       if (TensorShapeUtils::IsVector(diag_index.shape())) {
185         auto diag_index_size = diag_index.dim_size(0);
186         OP_REQUIRES(
187             context, 0 < diag_index_size && diag_index_size <= 2,
188             errors::InvalidArgument(
189                 "diag_index must have only one or two elements, received ",
190                 diag_index_size, " elements."));
191         if (diag_index_size > 1) {
192           upper_diag_index = diag_index.flat<int32>()(1);
193         }
194       }
195       num_rows = context->input(2).flat<int32>()(0);
196       num_cols = context->input(3).flat<int32>()(0);
197       padding_value = context->input(4).flat<T>()(0);
198     }
199 
200     // Size validations.
201     const TensorShape& diagonal_shape = diagonal.shape();
202     const int diag_rank = diagonal_shape.dims();
203     const Eigen::Index num_diags = upper_diag_index - lower_diag_index + 1;
204     OP_REQUIRES(context, TensorShapeUtils::IsVectorOrHigher(diagonal_shape),
205                 errors::InvalidArgument(
206                     "diagonal must be at least 1-dim, received shape: ",
207                     diagonal.shape().DebugString()));
208     OP_REQUIRES(
209         context, lower_diag_index <= upper_diag_index,
210         errors::InvalidArgument(
211             "lower_diag_index must not be larger than upper_diag_index: ",
212             lower_diag_index, " > ", upper_diag_index));
213     OP_REQUIRES(context,
214                 lower_diag_index == upper_diag_index ||
215                     diagonal_shape.dim_size(diag_rank - 2) == num_diags,
216                 errors::InvalidArgument(
217                     "The number of diagonals provided in the input does not "
218                     "match the lower_diag_index and upper_diag_index range."));
219 
220     const Eigen::Index max_diag_len = diagonal_shape.dim_size(diag_rank - 1);
221     const int32 min_num_rows = max_diag_len - std::min(upper_diag_index, 0);
222     const int32 min_num_cols = max_diag_len + std::max(lower_diag_index, 0);
223     OP_REQUIRES(context, num_rows == -1 || num_rows >= min_num_rows,
224                 errors::InvalidArgument("The number of rows is too small."));
225     OP_REQUIRES(context, num_cols == -1 || num_cols >= min_num_cols,
226                 errors::InvalidArgument("The number of columns is too small."));
227 
228     // If both num_rows and num_cols are unknown, assume that output is square.
229     // Otherwise, use smallest possible values.
230     if (num_rows == -1 && num_cols == -1) {
231       num_rows = std::max(min_num_rows, min_num_cols);
232       num_cols = num_rows;
233     } else if (num_rows == -1) {
234       num_rows = min_num_rows;
235     } else if (num_cols == -1) {
236       num_cols = min_num_cols;
237     }
238     OP_REQUIRES(context, num_rows == min_num_rows || num_cols == min_num_cols,
239                 errors::InvalidArgument(
240                     "The number of rows or columns is not consistent with "
241                     "the specified d_lower, d_upper, and diagonal."));
242 
243     TensorShape output_shape = diagonal_shape;
244     if (num_diags == 1) {  // Output has rank `rank+1`.
245       output_shape.set_dim(diag_rank - 1, num_rows);
246       output_shape.AddDim(num_cols);
247     } else {  // Output has rank `rank`.
248       output_shape.set_dim(diag_rank - 2, num_rows);
249       output_shape.set_dim(diag_rank - 1, num_cols);
250     }
251 
252     Tensor* output = nullptr;
253     OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
254     auto output_reshaped = output->flat_inner_dims<T, 3>();
255     auto diag_reshaped = diagonal.flat<T>();
256     functor::MatrixDiag<Device, T>::Compute(
257         context, context->eigen_device<Device>(), diag_reshaped,
258         output_reshaped, lower_diag_index, upper_diag_index, max_diag_len,
259         padding_value, left_align_superdiagonal_, left_align_subdiagonal_);
260   }
261 
262  private:
263   bool left_align_superdiagonal_ = true;
264   bool left_align_subdiagonal_ = true;
265   static constexpr int kNumV1Inputs = 1;
266   TF_DISALLOW_COPY_AND_ASSIGN(MatrixDiagOp);
267 };
268 
269 #define REGISTER_MATRIX_DIAG(type)                                           \
270   REGISTER_KERNEL_BUILDER(                                                   \
271       Name("MatrixDiag").Device(DEVICE_CPU).TypeConstraint<type>("T"),       \
272       MatrixDiagOp<CPUDevice, type>);                                        \
273   REGISTER_KERNEL_BUILDER(                                                   \
274       Name("MatrixDiagV2").Device(DEVICE_CPU).TypeConstraint<type>("T"),     \
275       MatrixDiagOp<CPUDevice, type>);                                        \
276   REGISTER_KERNEL_BUILDER(                                                   \
277       Name("MatrixDiagV3").Device(DEVICE_CPU).TypeConstraint<type>("T"),     \
278       MatrixDiagOp<CPUDevice, type>);                                        \
279   REGISTER_KERNEL_BUILDER(                                                   \
280       Name("MatrixDiagPart").Device(DEVICE_CPU).TypeConstraint<type>("T"),   \
281       MatrixDiagPartOp<CPUDevice, type>);                                    \
282   REGISTER_KERNEL_BUILDER(                                                   \
283       Name("MatrixDiagPartV2").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
284       MatrixDiagPartOp<CPUDevice, type>);                                    \
285   REGISTER_KERNEL_BUILDER(                                                   \
286       Name("MatrixDiagPartV3").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
287       MatrixDiagPartOp<CPUDevice, type>);
288 
289 TF_CALL_POD_TYPES(REGISTER_MATRIX_DIAG);
290 #undef REGISTER_MATRIX_DIAG
291 
292 // Registration of the deprecated kernel.
293 // Delete after 10mar2017.
294 #define REGISTER_BATCH_MATRIX_DIAG(type)                                    \
295   REGISTER_KERNEL_BUILDER(                                                  \
296       Name("BatchMatrixDiag").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
297       MatrixDiagOp<CPUDevice, type>);                                       \
298   REGISTER_KERNEL_BUILDER(Name("BatchMatrixDiagPart")                       \
299                               .Device(DEVICE_CPU)                           \
300                               .TypeConstraint<type>("T"),                   \
301                           MatrixDiagPartOp<CPUDevice, type>);
302 TF_CALL_POD_TYPES(REGISTER_BATCH_MATRIX_DIAG);
303 #undef REGISTER_BATCH_MATRIX_DIAG
304 
305 // Implementation of the functor specialization for CPU.
306 namespace functor {
307 
ReadAlignment(OpKernelConstruction * context,bool * left_align_superdiagonal,bool * left_align_subdiagonal)308 void ReadAlignment(OpKernelConstruction* context,
309                    bool* left_align_superdiagonal,
310                    bool* left_align_subdiagonal) {
311   string align;
312   OP_REQUIRES_OK(context, context->GetAttr("align", &align));
313 
314   *left_align_superdiagonal = align == "LEFT_LEFT" || align == "LEFT_RIGHT";
315   *left_align_subdiagonal = align == "LEFT_LEFT" || align == "RIGHT_LEFT";
316 }
317 
ComputeDiagLenAndContentOffset(int diag_index,int max_diag_len,int num_rows,int num_cols,bool left_align_superdiagonal,bool left_align_subdiagonal)318 std::pair<int, int> ComputeDiagLenAndContentOffset(
319     int diag_index, int max_diag_len, int num_rows, int num_cols,
320     bool left_align_superdiagonal, bool left_align_subdiagonal) {
321   const bool left_align = (diag_index >= 0 && left_align_superdiagonal) ||
322                           (diag_index <= 0 && left_align_subdiagonal);
323   const int diag_len = std::min(num_rows + std::min(0, diag_index),
324                                 num_cols - std::max(0, diag_index));
325   const int content_offset = (left_align) ? 0 : (max_diag_len - diag_len);
326   return {diag_len, content_offset};
327 }
328 
329 template <typename T>
330 struct MatrixDiag<CPUDevice, T> {
Computetensorflow::functor::MatrixDiag331   static void Compute(OpKernelContext* context, const CPUDevice& device,
332                       typename TTypes<T>::ConstTensor& diag,
333                       typename TTypes<T, 3>::Tensor& output,
334                       const Eigen::Index lower_diag_index,
335                       const Eigen::Index upper_diag_index,
336                       const Eigen::Index max_diag_len, const T padding_value,
337                       const bool left_align_superdiagonal,
338                       const bool left_align_subdiagonal) {
339     // 10 in cost_per_batch is from existing heuristic.
340     // TODO(penporn): Tune for the best constant in cost_per_batch.
341     const Eigen::Index num_batches = output.dimension(0);
342     const Eigen::Index num_rows = output.dimension(1);
343     const Eigen::Index num_cols = output.dimension(2);
344     const Eigen::Index cost_per_batch = 10 * num_rows * num_cols;
345 
346     auto compute_shard = [&output, &num_rows, &num_cols, &diag,
347                           &lower_diag_index, &upper_diag_index, &max_diag_len,
348                           &padding_value, &left_align_superdiagonal,
349                           &left_align_subdiagonal](Eigen::Index begin,
350                                                    Eigen::Index end) {
351       const int num_diags = upper_diag_index - lower_diag_index + 1;
352       const int diag_elements_in_batch = num_diags * max_diag_len;
353       Eigen::Index diag_batch_base_index = begin * diag_elements_in_batch;
354       for (Eigen::Index batch = begin; batch < end; ++batch) {
355         for (Eigen::Index i = 0; i < output.dimension(1); ++i) {
356           for (Eigen::Index j = 0; j < output.dimension(2); ++j) {
357             const int diag_index = j - i;
358             const int diag_index_in_input = upper_diag_index - diag_index;
359             int diag_len, content_offset;
360             std::tie(diag_len, content_offset) = ComputeDiagLenAndContentOffset(
361                 diag_index, max_diag_len, num_rows, num_cols,
362                 left_align_superdiagonal, left_align_subdiagonal);
363             const int index_in_the_diagonal =
364                 j - std::max<Eigen::Index>(diag_index, 0) + content_offset;
365             if (lower_diag_index <= diag_index &&
366                 diag_index <= upper_diag_index) {
367               output(batch, i, j) = diag(diag_batch_base_index +
368                                          diag_index_in_input * max_diag_len +
369                                          index_in_the_diagonal);
370             } else {
371               output(batch, i, j) = padding_value;
372             }
373           }
374         }
375         diag_batch_base_index += diag_elements_in_batch;
376       }
377     };
378     auto thread_pool =
379         context->device()->tensorflow_cpu_worker_threads()->workers;
380     thread_pool->ParallelFor(num_batches, cost_per_batch,
381                              std::move(compute_shard));
382   }
383 };
384 
385 template <typename T>
386 struct MatrixDiagPart<CPUDevice, T> {
Computetensorflow::functor::MatrixDiagPart387   static void Compute(OpKernelContext* context, const CPUDevice& device,
388                       typename TTypes<T, 3>::ConstTensor& input,
389                       typename TTypes<T>::Tensor& output,
390                       const Eigen::Index lower_diag_index,
391                       const Eigen::Index upper_diag_index,
392                       const Eigen::Index max_diag_len, const T padding_value,
393                       const bool left_align_superdiagonal,
394                       const bool left_align_subdiagonal) {
395     // 10 in cost_per_batch is from existing heuristic.
396     // TODO(penporn): Tune for the best constant in cost_per_batch.
397     const Eigen::Index num_diags = upper_diag_index - lower_diag_index + 1;
398     const Eigen::Index output_elements_in_batch = num_diags * max_diag_len;
399     const Eigen::Index cost_per_batch = 10 * output_elements_in_batch;
400     const Eigen::Index num_batches = input.dimension(0);
401     const Eigen::Index num_rows = input.dimension(1);
402     const Eigen::Index num_cols = input.dimension(2);
403 
404     auto compute_shard = [&output, &input, &num_rows, &num_cols,
405                           &upper_diag_index, &max_diag_len, &num_diags,
406                           &output_elements_in_batch, &padding_value,
407                           &left_align_superdiagonal, &left_align_subdiagonal](
408                              Eigen::Index begin, Eigen::Index end) {
409       Eigen::Index output_base_index = begin * output_elements_in_batch;
410       for (Eigen::Index batch = begin; batch < end; ++batch) {
411         for (Eigen::Index m = 0; m < num_diags; ++m) {
412           const Eigen::Index diag_index = upper_diag_index - m;
413           Eigen::Index y_offset = std::max<Eigen::Index>(0, -diag_index);
414           Eigen::Index x_offset = std::max<Eigen::Index>(0, diag_index);
415           int diag_len, content_offset;
416           std::tie(diag_len, content_offset) = ComputeDiagLenAndContentOffset(
417               diag_index, max_diag_len, num_rows, num_cols,
418               left_align_superdiagonal, left_align_subdiagonal);
419 
420           // Fills the diagonal.
421           for (Eigen::Index n = 0; n < diag_len; ++n) {
422             output(output_base_index + content_offset + n) =
423                 input(batch, n + y_offset, n + x_offset);
424           }
425 
426           // Padding.
427           const bool left_align = (content_offset == 0);
428           const Eigen::Index padding_start = (left_align) ? diag_len : 0;
429           const Eigen::Index padding_end =
430               (left_align) ? max_diag_len : content_offset;
431           for (Eigen::Index n = padding_start; n < padding_end; ++n) {
432             output(output_base_index + n) = padding_value;
433           }
434           output_base_index += max_diag_len;
435         }
436       }
437     };
438     auto thread_pool =
439         context->device()->tensorflow_cpu_worker_threads()->workers;
440     thread_pool->ParallelFor(num_batches, cost_per_batch,
441                              std::move(compute_shard));
442   }
443 };
444 
445 }  // namespace functor
446 
447 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
448 
449 // Forward declarations of the functor specializations for GPU.
450 namespace functor {
451 #define DECLARE_GPU_SPEC(T)                                                    \
452   template <>                                                                  \
453   void MatrixDiag<GPUDevice, T>::Compute(                                      \
454       OpKernelContext* context, const GPUDevice& device,                       \
455       typename TTypes<T>::ConstTensor& diag,                                   \
456       typename TTypes<T, 3>::Tensor& output,                                   \
457       const Eigen::Index lower_diag_index,                                     \
458       const Eigen::Index upper_diag_index, const Eigen::Index max_diag_len,    \
459       const T padding_value, const bool left_align_superdiagonal,              \
460       const bool left_align_subdiagonal);                                      \
461   extern template struct MatrixDiag<GPUDevice, T>;                             \
462   template <>                                                                  \
463   void MatrixDiagPart<GPUDevice, T>::Compute(                                  \
464       OpKernelContext* context, const GPUDevice& device,                       \
465       typename TTypes<T, 3>::ConstTensor& input,                               \
466       typename TTypes<T>::Tensor& output, const Eigen::Index lower_diag_index, \
467       const Eigen::Index upper_diag_index, const Eigen::Index max_diag_len,    \
468       const T padding_value, const bool left_align_superdiagonal,              \
469       const bool left_align_subdiagonal);                                      \
470   extern template struct MatrixDiagPart<GPUDevice, T>;
471 
472 TF_CALL_GPU_ALL_TYPES(DECLARE_GPU_SPEC);
473 
474 }  // namespace functor
475 
476 // Registration of the GPU implementations.
477 #define REGISTER_MATRIX_DIAG_GPU(type)                                     \
478   REGISTER_KERNEL_BUILDER(                                                 \
479       Name("MatrixDiag").Device(DEVICE_GPU).TypeConstraint<type>("T"),     \
480       MatrixDiagOp<GPUDevice, type>);                                      \
481   REGISTER_KERNEL_BUILDER(Name("MatrixDiagV2")                             \
482                               .Device(DEVICE_GPU)                          \
483                               .TypeConstraint<type>("T")                   \
484                               .HostMemory("k")                             \
485                               .HostMemory("num_rows")                      \
486                               .HostMemory("num_cols")                      \
487                               .HostMemory("padding_value"),                \
488                           MatrixDiagOp<GPUDevice, type>);                  \
489   REGISTER_KERNEL_BUILDER(Name("MatrixDiagV3")                             \
490                               .Device(DEVICE_GPU)                          \
491                               .TypeConstraint<type>("T")                   \
492                               .HostMemory("k")                             \
493                               .HostMemory("num_rows")                      \
494                               .HostMemory("num_cols")                      \
495                               .HostMemory("padding_value"),                \
496                           MatrixDiagOp<GPUDevice, type>);                  \
497   REGISTER_KERNEL_BUILDER(                                                 \
498       Name("MatrixDiagPart").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
499       MatrixDiagPartOp<GPUDevice, type>);                                  \
500   REGISTER_KERNEL_BUILDER(Name("MatrixDiagPartV2")                         \
501                               .Device(DEVICE_GPU)                          \
502                               .TypeConstraint<type>("T")                   \
503                               .HostMemory("k")                             \
504                               .HostMemory("padding_value"),                \
505                           MatrixDiagPartOp<GPUDevice, type>);              \
506   REGISTER_KERNEL_BUILDER(Name("MatrixDiagPartV3")                         \
507                               .Device(DEVICE_GPU)                          \
508                               .TypeConstraint<type>("T")                   \
509                               .HostMemory("k")                             \
510                               .HostMemory("padding_value"),                \
511                           MatrixDiagPartOp<GPUDevice, type>);
512 
513 TF_CALL_GPU_ALL_TYPES(REGISTER_MATRIX_DIAG_GPU);
514 #undef REGISTER_MATRIX_DIAG_GPU
515 
516 // Registration of the deprecated kernel.
517 // Delete after 10mar2017.
518 #define REGISTER_BATCH_MATRIX_DIAG_GPU(type)                                \
519   REGISTER_KERNEL_BUILDER(                                                  \
520       Name("BatchMatrixDiag").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
521       MatrixDiagOp<GPUDevice, type>);                                       \
522   REGISTER_KERNEL_BUILDER(Name("BatchMatrixDiagPart")                       \
523                               .Device(DEVICE_GPU)                           \
524                               .TypeConstraint<type>("T"),                   \
525                           MatrixDiagPartOp<GPUDevice, type>);
526 TF_CALL_GPU_NUMBER_TYPES(REGISTER_BATCH_MATRIX_DIAG_GPU);
527 #undef REGISTER_BATCH_MATRIX_DIAG_GPU
528 
529 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
530 
531 }  // namespace tensorflow
532