1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_KERNELS_LINALG_MATRIX_DIAG_OP_H_ 17 #define TENSORFLOW_CORE_KERNELS_LINALG_MATRIX_DIAG_OP_H_ 18 19 // Generator definition for MatrixDiagOp, must be compilable by nvcc. 20 21 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 22 #include "tensorflow/core/framework/op_kernel.h" 23 #include "tensorflow/core/framework/tensor.h" 24 #include "tensorflow/core/framework/tensor_types.h" 25 #include "tensorflow/core/platform/types.h" 26 27 namespace tensorflow { 28 namespace functor { 29 30 // Reads the diagonal packing alignment. 31 void ReadAlignment(OpKernelConstruction* context, 32 bool* left_align_superdiagonal, 33 bool* left_align_subdiagonal); 34 35 // Calculates diagonal length and content offset (from aligning) of a diagonal. 36 // Returns a pair of integers {diag_len, content_offset}: 37 // - diag_len: The length of the diag_index-th diagonal. 38 // - content_offset: Each diagonal is stored as a row in the compact format. 39 // If the diagonal is shorter than max_diag_len, its content is aligned 40 // either to the left or right. content_offset is the index in the row 41 // where the first element of the diag-index-th diagonal is stored. It is 42 // always zero when the diagonal is left-aligned. 43 std::pair<int, int> ComputeDiagLenAndContentOffset( 44 int diag_index, int max_diag_len, int num_rows, int num_cols, 45 bool left_align_superdiagonal, bool left_align_subdiagonal); 46 47 template <typename Device, typename T> 48 struct MatrixDiagPart { 49 EIGEN_ALWAYS_INLINE static void Compute( 50 OpKernelContext* context, const Device& device, 51 typename TTypes<T, 3>::ConstTensor& input, 52 typename TTypes<T>::Tensor& output, const Eigen::Index lower_diag_index, 53 const Eigen::Index upper_diag_index, const Eigen::Index max_diag_len, 54 const T padding_value, const bool left_align_superdiagonal, 55 const bool left_align_subdiagonal); 56 }; 57 58 template <typename Device, typename T> 59 struct MatrixDiag { 60 EIGEN_ALWAYS_INLINE static void Compute( 61 OpKernelContext* context, const Device& device, 62 typename TTypes<T>::ConstTensor& diag, 63 typename TTypes<T, 3>::Tensor& output, 64 const Eigen::Index lower_diag_index, const Eigen::Index upper_diag_index, 65 const Eigen::Index max_diag_len, const T padding_value, 66 const bool left_align_superdiagonal, const bool left_align_subdiagonal); 67 }; 68 69 } // namespace functor 70 } // namespace tensorflow 71 72 #endif // TENSORFLOW_CORE_KERNELS_LINALG_MATRIX_DIAG_OP_H_ 73