• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_KERNELS_LINALG_LINALG_OPS_COMMON_H_
16 #define TENSORFLOW_CORE_KERNELS_LINALG_LINALG_OPS_COMMON_H_
17 
18 // Classes to support linear algebra functionality, similar to the numpy.linalg
19 // module. Supports batch computation on several matrices at once, sharding the
20 // computations across different threads if necessary.
21 #include <algorithm>
22 
23 #include "third_party/eigen3/Eigen/Core"
24 #include "tensorflow/core/framework/kernel_def_builder.h"
25 #include "tensorflow/core/framework/op_kernel.h"
26 #include "tensorflow/core/framework/tensor.h"
27 #include "tensorflow/core/framework/tensor_shape.h"
28 #include "tensorflow/core/framework/tensor_types.h"
29 #include "tensorflow/core/framework/types.h"
30 #include "tensorflow/core/lib/core/errors.h"
31 #include "tensorflow/core/lib/gtl/inlined_vector.h"
32 #include "tensorflow/core/platform/types.h"
33 #include "tensorflow/core/util/work_sharder.h"
34 
35 namespace tensorflow {
36 
37 // Base class for linear algebra operators.
38 template <class InputScalar, class OutputScalar = InputScalar>
39 class LinearAlgebraOp : public OpKernel {
40  public:
LinearAlgebraOp(OpKernelConstruction * context)41   explicit LinearAlgebraOp(OpKernelConstruction* context) : OpKernel(context) {}
42 
43   void Compute(OpKernelContext* context) override;
44 
45  protected:
46   using TensorShapes = gtl::InlinedVector<TensorShape, 4>;
47   // Returns the number of leading inputs that are to be treated as matrix
48   // inputs. By default this is all the inputs. Derived classes can override
49   // this to tell the base class to ignore one or more trailing inputs.
NumMatrixInputs(const OpKernelContext * context)50   virtual int NumMatrixInputs(const OpKernelContext* context) const {
51     return context->num_inputs();
52   }
53 
54   // Returns true if the number of inputs and their shapes are as expected.
55   // Many ops take a single square input matrix, so we provide that as a default
56   // implementation for convenience.
ValidateInputMatrixShapes(OpKernelContext * context,const TensorShapes & input_matrix_shapes)57   virtual void ValidateInputMatrixShapes(
58       OpKernelContext* context, const TensorShapes& input_matrix_shapes) const {
59     ValidateSingleSquareMatrix(context, input_matrix_shapes);
60   }
61 
62   // Convenience validators for common cases:
63   //
64   // Validate op taking a single matrix A.
65   static void ValidateSingleMatrix(OpKernelContext* context,
66                                    const TensorShapes& input_matrix_shapes);
67   // Validate op taking a single square matrix A.
68   static void ValidateSingleSquareMatrix(
69       OpKernelContext* context, const TensorShapes& input_matrix_shapes);
70   // Validate op taking two matrices A and B that have the same number of rows.
71   static void ValidateSolver(OpKernelContext* context,
72                              const TensorShapes& input_matrix_shapes);
73   // Validate op taking two matrices A and B that have the same number of rows
74   // and A is square.
75   static void ValidateSquareSolver(OpKernelContext* context,
76                                    const TensorShapes& input_matrix_shapes);
77 
78   // Returns the output shapes of each individual matrix operation. Output
79   // matrices shapes must be rank 0, 1, or 2. Scalar outputs are rank 0.
80   //
81   // The derived class may return a number of shapes (N) less than
82   // context->num_outputs() (M) to indicate that a only leading subset of
83   // the outputs will be populated. In this case, a dummy scalar tensor with
84   // value zero will be return for the last M-N outputs.
85   //
86   // For many ops, the output dimensions are the same as the input dimensions,
87   // so we provide that as a default implementation for convenience.
GetOutputMatrixShapes(const TensorShapes & input_matrix_shapes)88   virtual TensorShapes GetOutputMatrixShapes(
89       const TensorShapes& input_matrix_shapes) const {
90     return input_matrix_shapes;
91   }
92 
93   // Returns the cost per matrix operation. This is used to determine the
94   // number of threads to use for parallelizing calls to ComputeMatrix in
95   // batch mode. Cost per unit is assumed to be roughly 1ns, based on comments
96   // in core/util/work_sharder.cc. Many linear algebra ops take roughly max(m,n)
97   // * min(m,n)^2, where the first input matrix is m-by-n. We provide that as a
98   // default implementation for convenience.
GetCostPerUnit(const TensorShapes & input_matrix_shapes)99   virtual int64 GetCostPerUnit(const TensorShapes& input_matrix_shapes) const {
100     double m = static_cast<double>(input_matrix_shapes[0].dim_size(0));
101     double n = static_cast<double>(input_matrix_shapes[0].dim_size(1));
102     double cost = std::max(m, n) * std::min(m, n) * std::min(m, n);
103     return cost >= static_cast<double>(kint64max) ? kint64max
104                                                   : static_cast<int64>(cost);
105   }
106 
107   // Returns true if it is safe to forward (alias) input to output buffer
108   // and expect the kernel to perform the computation inplace.
EnableInputForwarding()109   virtual bool EnableInputForwarding() const { return true; }
110 
111   using InputMatrix = Eigen::Matrix<InputScalar, Eigen::Dynamic, Eigen::Dynamic,
112                                     Eigen::RowMajor>;
113   using InputConstMatrixMap = Eigen::Map<const InputMatrix>;
114   using InputMatrixMap = Eigen::Map<InputMatrix>;
115   using InputConstVectorMap =
116       Eigen::Map<const Eigen::Matrix<InputScalar, 1, Eigen::Dynamic>>;
117   using InputConstMatrixMaps = gtl::InlinedVector<InputConstMatrixMap, 4>;
118   using InputMatrixMaps = gtl::InlinedVector<InputMatrixMap, 4>;
119   using InputRealScalar = typename Eigen::NumTraits<InputScalar>::Real;
120 
121   using OutputMatrix = Eigen::Matrix<OutputScalar, Eigen::Dynamic,
122                                      Eigen::Dynamic, Eigen::RowMajor>;
123   using OutputConstMatrixMap = Eigen::Map<const OutputMatrix>;
124   using OutputMatrixMap = Eigen::Map<OutputMatrix>;
125   using OutputConstVectorMap =
126       Eigen::Map<const Eigen::Matrix<OutputScalar, 1, Eigen::Dynamic>>;
127   using OutputConstMatrixMaps = gtl::InlinedVector<OutputConstMatrixMap, 4>;
128   using OutputMatrixMaps = gtl::InlinedVector<OutputMatrixMap, 4>;
129   using OutputRealScalar = typename Eigen::NumTraits<OutputScalar>::Real;
130 
131   // backward compatibility
132   using Scalar = OutputScalar;
133   using Matrix =
134       Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
135   using ConstMatrixMap = Eigen::Map<const Matrix>;
136   using MatrixMap = Eigen::Map<Matrix>;
137   using ConstVectorMap =
138       Eigen::Map<const Eigen::Matrix<Scalar, 1, Eigen::Dynamic>>;
139   using ConstMatrixMaps = gtl::InlinedVector<ConstMatrixMap, 4>;
140   using MatrixMaps = gtl::InlinedVector<MatrixMap, 4>;
141   using RealScalar = typename Eigen::NumTraits<Scalar>::Real;
142 
143   // Performs a single matrix computation given input matrices, and
144   // stores the result in outputs. For batch operations, this will be called
145   // repeatedly for a single call to Compute() when multiple matrices exist in
146   // input Tensors with rank > 2. In this case the calls to ComputeMatrix are
147   // parallelized. The number of threads used is determined by a cost model from
148   // the value returned by GetCostPerUnit().
149   virtual void ComputeMatrix(OpKernelContext* context,
150                              const InputConstMatrixMaps& inputs,
151                              OutputMatrixMaps* outputs) = 0;
152 
153  private:
154   using TensorInputs = gtl::InlinedVector<const Tensor*, 4>;
155   using TensorOutputs = gtl::InlinedVector<Tensor*, 4>;
156   // This function maps 2-d slices (matrices) of the input and output tensors
157   // using Eigen::Map and calls ComputeMatrix implemented in terms of the
158   // Eigen::MatrixBase API by the derived class.
159   //
160   // The 'matrix_index' parameter specifies the index of the matrix to be used
161   // from each input tensor, and the index of the matrix to be written to each
162   // output tensor. The input matrices are in row major order, and located at
163   // the memory addresses
164   //   inputs[i].flat<Scalar>().data() +
165   //   matrix_index * input_matrix_shapes[i].num_elements()
166   // for i in 0...inputs.size()-1.
167   // The output matrices are in row major order, and located at the memory
168   // address
169   //   outputs[i]->flat<Scalar>().data() +
170   //   matrix_index * output_matrix_shapes[i].num_elements().
171   // for i in 0...outputs.size()-1.
172   //
173   void ComputeTensorSlice(OpKernelContext* context, int64 matrix_index,
174                           const TensorInputs& inputs,
175                           const TensorShapes& input_matrix_shapes,
176                           const TensorOutputs& outputs,
177                           const TensorShapes& output_matrix_shapes);
178 
179   void AnalyzeInputs(OpKernelContext* context, TensorInputs* inputs,
180                      TensorShapes* input_matrix_shapes,
181                      TensorShape* batch_shape);
182 
183   void PrepareOutputs(OpKernelContext* context,
184                       const TensorShapes& input_matrix_shapes,
185                       const TensorShape& batch_shape, TensorOutputs* outputs,
186                       TensorShapes* output_matrix_shapes);
187 };
188 
189 // Declare LinearAlgebraOp, which is explicitly instantiated in
190 // linalg_ops_common.cc for float, double, complex64, and complex128.
191 extern template class LinearAlgebraOp<float>;
192 extern template class LinearAlgebraOp<double>;
193 extern template class LinearAlgebraOp<complex64>;
194 extern template class LinearAlgebraOp<complex128>;
195 
196 }  // namespace tensorflow
197 
198 #define INHERIT_LINALG_TYPEDEFS(Scalar)                       \
199   typedef LinearAlgebraOp<Scalar> Base;                       \
200   using RealScalar = typename Eigen::NumTraits<Scalar>::Real; \
201   using Matrix = typename Base::Matrix;                       \
202   using MatrixMap = typename Base::MatrixMap;                 \
203   using MatrixMaps = typename Base::MatrixMaps;               \
204   using ConstMatrixMap = typename Base::ConstMatrixMap;       \
205   using ConstMatrixMaps = typename Base::ConstMatrixMaps;     \
206   using ConstVectorMap = typename Base::ConstVectorMap;       \
207   using TensorShapes = typename Base::TensorShapes;
208 
209 #define REGISTER_LINALG_OP_CPU(OpName, OpClass, Scalar) \
210   REGISTER_KERNEL_BUILDER(                              \
211       Name(OpName).Device(DEVICE_CPU).TypeConstraint<Scalar>("T"), OpClass)
212 
213 #define REGISTER_LINALG_OP_GPU(OpName, OpClass, Scalar) \
214   REGISTER_KERNEL_BUILDER(                              \
215       Name(OpName).Device(DEVICE_GPU).TypeConstraint<Scalar>("T"), OpClass)
216 
217 // Deprecated, use one of the device-specific macros above.
218 #define REGISTER_LINALG_OP(OpName, OpClass, Scalar) \
219   REGISTER_LINALG_OP_CPU(OpName, OpClass, Scalar)
220 
221 #endif  // TENSORFLOW_CORE_KERNELS_LINALG_LINALG_OPS_COMMON_H_
222