1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_CORE_KERNELS_LINALG_LINALG_OPS_COMMON_H_ 16 #define TENSORFLOW_CORE_KERNELS_LINALG_LINALG_OPS_COMMON_H_ 17 18 // Classes to support linear algebra functionality, similar to the numpy.linalg 19 // module. Supports batch computation on several matrices at once, sharding the 20 // computations across different threads if necessary. 21 #include <algorithm> 22 23 #include "third_party/eigen3/Eigen/Core" 24 #include "tensorflow/core/framework/kernel_def_builder.h" 25 #include "tensorflow/core/framework/op_kernel.h" 26 #include "tensorflow/core/framework/tensor.h" 27 #include "tensorflow/core/framework/tensor_shape.h" 28 #include "tensorflow/core/framework/tensor_types.h" 29 #include "tensorflow/core/framework/types.h" 30 #include "tensorflow/core/lib/core/errors.h" 31 #include "tensorflow/core/lib/gtl/inlined_vector.h" 32 #include "tensorflow/core/platform/types.h" 33 #include "tensorflow/core/util/work_sharder.h" 34 35 namespace tensorflow { 36 37 // Base class for linear algebra operators. 38 template <class InputScalar, class OutputScalar = InputScalar> 39 class LinearAlgebraOp : public OpKernel { 40 public: LinearAlgebraOp(OpKernelConstruction * context)41 explicit LinearAlgebraOp(OpKernelConstruction* context) : OpKernel(context) {} 42 43 void Compute(OpKernelContext* context) override; 44 45 protected: 46 using TensorShapes = gtl::InlinedVector<TensorShape, 4>; 47 // Returns the number of leading inputs that are to be treated as matrix 48 // inputs. By default this is all the inputs. Derived classes can override 49 // this to tell the base class to ignore one or more trailing inputs. NumMatrixInputs(const OpKernelContext * context)50 virtual int NumMatrixInputs(const OpKernelContext* context) const { 51 return context->num_inputs(); 52 } 53 54 // Returns true if the number of inputs and their shapes are as expected. 55 // Many ops take a single square input matrix, so we provide that as a default 56 // implementation for convenience. ValidateInputMatrixShapes(OpKernelContext * context,const TensorShapes & input_matrix_shapes)57 virtual void ValidateInputMatrixShapes( 58 OpKernelContext* context, const TensorShapes& input_matrix_shapes) const { 59 ValidateSingleSquareMatrix(context, input_matrix_shapes); 60 } 61 62 // Convenience validators for common cases: 63 // 64 // Validate op taking a single matrix A. 65 static void ValidateSingleMatrix(OpKernelContext* context, 66 const TensorShapes& input_matrix_shapes); 67 // Validate op taking a single square matrix A. 68 static void ValidateSingleSquareMatrix( 69 OpKernelContext* context, const TensorShapes& input_matrix_shapes); 70 // Validate op taking two matrices A and B that have the same number of rows. 71 static void ValidateSolver(OpKernelContext* context, 72 const TensorShapes& input_matrix_shapes); 73 // Validate op taking two matrices A and B that have the same number of rows 74 // and A is square. 75 static void ValidateSquareSolver(OpKernelContext* context, 76 const TensorShapes& input_matrix_shapes); 77 78 // Returns the output shapes of each individual matrix operation. Output 79 // matrices shapes must be rank 0, 1, or 2. Scalar outputs are rank 0. 80 // 81 // The derived class may return a number of shapes (N) less than 82 // context->num_outputs() (M) to indicate that a only leading subset of 83 // the outputs will be populated. In this case, a dummy scalar tensor with 84 // value zero will be return for the last M-N outputs. 85 // 86 // For many ops, the output dimensions are the same as the input dimensions, 87 // so we provide that as a default implementation for convenience. GetOutputMatrixShapes(const TensorShapes & input_matrix_shapes)88 virtual TensorShapes GetOutputMatrixShapes( 89 const TensorShapes& input_matrix_shapes) const { 90 return input_matrix_shapes; 91 } 92 93 // Returns the cost per matrix operation. This is used to determine the 94 // number of threads to use for parallelizing calls to ComputeMatrix in 95 // batch mode. Cost per unit is assumed to be roughly 1ns, based on comments 96 // in core/util/work_sharder.cc. Many linear algebra ops take roughly max(m,n) 97 // * min(m,n)^2, where the first input matrix is m-by-n. We provide that as a 98 // default implementation for convenience. GetCostPerUnit(const TensorShapes & input_matrix_shapes)99 virtual int64 GetCostPerUnit(const TensorShapes& input_matrix_shapes) const { 100 double m = static_cast<double>(input_matrix_shapes[0].dim_size(0)); 101 double n = static_cast<double>(input_matrix_shapes[0].dim_size(1)); 102 double cost = std::max(m, n) * std::min(m, n) * std::min(m, n); 103 return cost >= static_cast<double>(kint64max) ? kint64max 104 : static_cast<int64>(cost); 105 } 106 107 // Returns true if it is safe to forward (alias) input to output buffer 108 // and expect the kernel to perform the computation inplace. EnableInputForwarding()109 virtual bool EnableInputForwarding() const { return true; } 110 111 using InputMatrix = Eigen::Matrix<InputScalar, Eigen::Dynamic, Eigen::Dynamic, 112 Eigen::RowMajor>; 113 using InputConstMatrixMap = Eigen::Map<const InputMatrix>; 114 using InputMatrixMap = Eigen::Map<InputMatrix>; 115 using InputConstVectorMap = 116 Eigen::Map<const Eigen::Matrix<InputScalar, 1, Eigen::Dynamic>>; 117 using InputConstMatrixMaps = gtl::InlinedVector<InputConstMatrixMap, 4>; 118 using InputMatrixMaps = gtl::InlinedVector<InputMatrixMap, 4>; 119 using InputRealScalar = typename Eigen::NumTraits<InputScalar>::Real; 120 121 using OutputMatrix = Eigen::Matrix<OutputScalar, Eigen::Dynamic, 122 Eigen::Dynamic, Eigen::RowMajor>; 123 using OutputConstMatrixMap = Eigen::Map<const OutputMatrix>; 124 using OutputMatrixMap = Eigen::Map<OutputMatrix>; 125 using OutputConstVectorMap = 126 Eigen::Map<const Eigen::Matrix<OutputScalar, 1, Eigen::Dynamic>>; 127 using OutputConstMatrixMaps = gtl::InlinedVector<OutputConstMatrixMap, 4>; 128 using OutputMatrixMaps = gtl::InlinedVector<OutputMatrixMap, 4>; 129 using OutputRealScalar = typename Eigen::NumTraits<OutputScalar>::Real; 130 131 // backward compatibility 132 using Scalar = OutputScalar; 133 using Matrix = 134 Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>; 135 using ConstMatrixMap = Eigen::Map<const Matrix>; 136 using MatrixMap = Eigen::Map<Matrix>; 137 using ConstVectorMap = 138 Eigen::Map<const Eigen::Matrix<Scalar, 1, Eigen::Dynamic>>; 139 using ConstMatrixMaps = gtl::InlinedVector<ConstMatrixMap, 4>; 140 using MatrixMaps = gtl::InlinedVector<MatrixMap, 4>; 141 using RealScalar = typename Eigen::NumTraits<Scalar>::Real; 142 143 // Performs a single matrix computation given input matrices, and 144 // stores the result in outputs. For batch operations, this will be called 145 // repeatedly for a single call to Compute() when multiple matrices exist in 146 // input Tensors with rank > 2. In this case the calls to ComputeMatrix are 147 // parallelized. The number of threads used is determined by a cost model from 148 // the value returned by GetCostPerUnit(). 149 virtual void ComputeMatrix(OpKernelContext* context, 150 const InputConstMatrixMaps& inputs, 151 OutputMatrixMaps* outputs) = 0; 152 153 private: 154 using TensorInputs = gtl::InlinedVector<const Tensor*, 4>; 155 using TensorOutputs = gtl::InlinedVector<Tensor*, 4>; 156 // This function maps 2-d slices (matrices) of the input and output tensors 157 // using Eigen::Map and calls ComputeMatrix implemented in terms of the 158 // Eigen::MatrixBase API by the derived class. 159 // 160 // The 'matrix_index' parameter specifies the index of the matrix to be used 161 // from each input tensor, and the index of the matrix to be written to each 162 // output tensor. The input matrices are in row major order, and located at 163 // the memory addresses 164 // inputs[i].flat<Scalar>().data() + 165 // matrix_index * input_matrix_shapes[i].num_elements() 166 // for i in 0...inputs.size()-1. 167 // The output matrices are in row major order, and located at the memory 168 // address 169 // outputs[i]->flat<Scalar>().data() + 170 // matrix_index * output_matrix_shapes[i].num_elements(). 171 // for i in 0...outputs.size()-1. 172 // 173 void ComputeTensorSlice(OpKernelContext* context, int64 matrix_index, 174 const TensorInputs& inputs, 175 const TensorShapes& input_matrix_shapes, 176 const TensorOutputs& outputs, 177 const TensorShapes& output_matrix_shapes); 178 179 void AnalyzeInputs(OpKernelContext* context, TensorInputs* inputs, 180 TensorShapes* input_matrix_shapes, 181 TensorShape* batch_shape); 182 183 void PrepareOutputs(OpKernelContext* context, 184 const TensorShapes& input_matrix_shapes, 185 const TensorShape& batch_shape, TensorOutputs* outputs, 186 TensorShapes* output_matrix_shapes); 187 }; 188 189 // Declare LinearAlgebraOp, which is explicitly instantiated in 190 // linalg_ops_common.cc for float, double, complex64, and complex128. 191 extern template class LinearAlgebraOp<float>; 192 extern template class LinearAlgebraOp<double>; 193 extern template class LinearAlgebraOp<complex64>; 194 extern template class LinearAlgebraOp<complex128>; 195 196 } // namespace tensorflow 197 198 #define INHERIT_LINALG_TYPEDEFS(Scalar) \ 199 typedef LinearAlgebraOp<Scalar> Base; \ 200 using RealScalar = typename Eigen::NumTraits<Scalar>::Real; \ 201 using Matrix = typename Base::Matrix; \ 202 using MatrixMap = typename Base::MatrixMap; \ 203 using MatrixMaps = typename Base::MatrixMaps; \ 204 using ConstMatrixMap = typename Base::ConstMatrixMap; \ 205 using ConstMatrixMaps = typename Base::ConstMatrixMaps; \ 206 using ConstVectorMap = typename Base::ConstVectorMap; \ 207 using TensorShapes = typename Base::TensorShapes; 208 209 #define REGISTER_LINALG_OP_CPU(OpName, OpClass, Scalar) \ 210 REGISTER_KERNEL_BUILDER( \ 211 Name(OpName).Device(DEVICE_CPU).TypeConstraint<Scalar>("T"), OpClass) 212 213 #define REGISTER_LINALG_OP_GPU(OpName, OpClass, Scalar) \ 214 REGISTER_KERNEL_BUILDER( \ 215 Name(OpName).Device(DEVICE_GPU).TypeConstraint<Scalar>("T"), OpClass) 216 217 // Deprecated, use one of the device-specific macros above. 218 #define REGISTER_LINALG_OP(OpName, OpClass, Scalar) \ 219 REGISTER_LINALG_OP_CPU(OpName, OpClass, Scalar) 220 221 #endif // TENSORFLOW_CORE_KERNELS_LINALG_LINALG_OPS_COMMON_H_ 222