• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/kernels/linalg/linalg_ops_common.h"
17 
18 #include <utility>
19 
20 #include "third_party/eigen3/Eigen/Core"
21 #include "tensorflow/core/framework/device_base.h"
22 #include "tensorflow/core/framework/kernel_def_builder.h"
23 #include "tensorflow/core/framework/op_kernel.h"
24 #include "tensorflow/core/framework/tensor_shape.h"
25 #include "tensorflow/core/lib/core/errors.h"
26 #include "tensorflow/core/platform/logging.h"
27 #include "tensorflow/core/platform/types.h"
28 
29 namespace tensorflow {
30 
31 // static
32 template <class InputScalar, class OutputScalar>
ValidateSingleMatrix(OpKernelContext * context,const TensorShapes & input_matrix_shapes)33 void LinearAlgebraOp<InputScalar, OutputScalar>::ValidateSingleMatrix(
34     OpKernelContext* context, const TensorShapes& input_matrix_shapes) {
35   OP_REQUIRES(context, input_matrix_shapes.size() == 1,
36               errors::InvalidArgument("Expected a single input matrix, got %d.",
37                                       input_matrix_shapes.size()));
38   OP_REQUIRES(context, TensorShapeUtils::IsMatrix(input_matrix_shapes[0]),
39               errors::InvalidArgument("Input must be a matrix."));
40 }
41 
42 // static
43 template <class InputScalar, class OutputScalar>
ValidateSingleSquareMatrix(OpKernelContext * context,const TensorShapes & input_matrix_shapes)44 void LinearAlgebraOp<InputScalar, OutputScalar>::ValidateSingleSquareMatrix(
45     OpKernelContext* context, const TensorShapes& input_matrix_shapes) {
46   OP_REQUIRES(context, input_matrix_shapes.size() == 1,
47               errors::InvalidArgument("Expected a single input matrix, got %d.",
48                                       input_matrix_shapes.size()));
49   OP_REQUIRES(context, TensorShapeUtils::IsSquareMatrix(input_matrix_shapes[0]),
50               errors::InvalidArgument("Input matrix must be square."));
51 }
52 
53 // static
54 template <class InputScalar, class OutputScalar>
ValidateSolver(OpKernelContext * context,const TensorShapes & input_matrix_shapes)55 void LinearAlgebraOp<InputScalar, OutputScalar>::ValidateSolver(
56     OpKernelContext* context, const TensorShapes& input_matrix_shapes) {
57   OP_REQUIRES(context, input_matrix_shapes.size() == 2,
58               errors::InvalidArgument("Expected two input matrices, got %d.",
59                                       input_matrix_shapes.size()));
60   OP_REQUIRES(context, TensorShapeUtils::IsMatrix(input_matrix_shapes[0]),
61               errors::InvalidArgument("First input (lhs) must be a matrix."));
62   OP_REQUIRES(context, TensorShapeUtils::IsMatrix(input_matrix_shapes[1]),
63               errors::InvalidArgument("Second input (rhs) must be a matrix."));
64   OP_REQUIRES(
65       context,
66       input_matrix_shapes[0].dim_size(0) == input_matrix_shapes[1].dim_size(0),
67       errors::InvalidArgument("Input matrix and rhs are incompatible."));
68 }
69 
70 // static
71 template <class InputScalar, class OutputScalar>
ValidateSquareSolver(OpKernelContext * context,const TensorShapes & input_matrix_shapes)72 void LinearAlgebraOp<InputScalar, OutputScalar>::ValidateSquareSolver(
73     OpKernelContext* context, const TensorShapes& input_matrix_shapes) {
74   OP_REQUIRES(context, input_matrix_shapes.size() == 2,
75               errors::InvalidArgument("Expected two input matrices, got %d.",
76                                       input_matrix_shapes.size()));
77   OP_REQUIRES(
78       context, TensorShapeUtils::IsSquareMatrix(input_matrix_shapes[0]),
79       errors::InvalidArgument("First input (lhs) must be a square matrix."));
80   OP_REQUIRES(context, TensorShapeUtils::IsMatrix(input_matrix_shapes[1]),
81               errors::InvalidArgument("Second input (rhs) must be a matrix."));
82   OP_REQUIRES(
83       context,
84       input_matrix_shapes[0].dim_size(0) == input_matrix_shapes[1].dim_size(0),
85       errors::InvalidArgument("Input matrix and rhs are incompatible."));
86 }
87 
88 template <class InputScalar, class OutputScalar>
Compute(OpKernelContext * context)89 void LinearAlgebraOp<InputScalar, OutputScalar>::Compute(
90     OpKernelContext* context) {
91   TensorInputs inputs;
92   TensorShapes input_matrix_shapes;
93   TensorShape batch_shape;
94   AnalyzeInputs(context, &inputs, &input_matrix_shapes, &batch_shape);
95   if (!context->status().ok()) return;
96 
97   TensorShapes output_matrix_shapes;
98   TensorOutputs outputs;
99   PrepareOutputs(context, input_matrix_shapes, batch_shape, &outputs,
100                  &output_matrix_shapes);
101   if (!context->status().ok()) return;
102 
103   // Process the individual matrix problems in parallel using a threadpool.
104   auto shard = [this, &inputs, &input_matrix_shapes, &outputs,
105                 &output_matrix_shapes, context](int64 begin, int64 end) {
106     for (int64 i = begin; i < end; ++i) {
107       ComputeTensorSlice(context, i, inputs, input_matrix_shapes, outputs,
108                          output_matrix_shapes);
109     }
110   };
111   auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
112   Shard(worker_threads.num_threads, worker_threads.workers,
113         batch_shape.num_elements(), GetCostPerUnit(input_matrix_shapes), shard);
114 }
115 
116 template <class InputScalar, class OutputScalar>
AnalyzeInputs(OpKernelContext * context,TensorInputs * inputs,TensorShapes * input_matrix_shapes,TensorShape * batch_shape)117 void LinearAlgebraOp<InputScalar, OutputScalar>::AnalyzeInputs(
118     OpKernelContext* context, TensorInputs* inputs,
119     TensorShapes* input_matrix_shapes, TensorShape* batch_shape) {
120   int input_rank = -1;
121   for (int i = 0; i < NumMatrixInputs(context); ++i) {
122     const Tensor& in = context->input(i);
123     if (i == 0) {
124       input_rank = in.dims();
125       OP_REQUIRES(
126           context, input_rank >= 2,
127           errors::InvalidArgument("Input tensor ", i,
128                                   " must have rank >= 2, got ", input_rank));
129       // If the tensor rank is greater than 2, we consider the inner-most
130       // dimensions as matrices, and loop over all the other outer ("batch")
131       // dimensions to compute the results.
132       for (int dim = 0; dim < input_rank - 2; ++dim) {
133         batch_shape->AddDim(in.dim_size(dim));
134       }
135     } else {
136       // Make sure that all inputs have the same rank and outer dimensions.
137       OP_REQUIRES(context, input_rank == in.dims(),
138                   errors::InvalidArgument(
139                       "All input tensors must have the same rank."));
140       for (int dim = 0; dim < input_rank - 2; ++dim) {
141         OP_REQUIRES(
142             context, in.dim_size(dim) == batch_shape->dim_size(dim),
143             errors::InvalidArgument(
144                 "All input tensors must have the same outer dimensions."));
145       }
146     }
147 
148     const int row_dimension = input_rank - 2;
149     const int col_dimension = input_rank - 1;
150     const int64 num_rows = in.dim_size(row_dimension);
151     const int64 num_cols = in.dim_size(col_dimension);
152     input_matrix_shapes->emplace_back(
153         std::initializer_list<int64>({num_rows, num_cols}));
154     inputs->emplace_back(&in);
155   }
156   // Have the derived class validate that the inputs are as expected.
157   ValidateInputMatrixShapes(context, *input_matrix_shapes);
158 }
159 
160 template <class InputScalar, class OutputScalar>
PrepareOutputs(OpKernelContext * context,const TensorShapes & input_matrix_shapes,const TensorShape & batch_shape,TensorOutputs * outputs,TensorShapes * output_matrix_shapes)161 void LinearAlgebraOp<InputScalar, OutputScalar>::PrepareOutputs(
162     OpKernelContext* context, const TensorShapes& input_matrix_shapes,
163     const TensorShape& batch_shape, TensorOutputs* outputs,
164     TensorShapes* output_matrix_shapes) {
165   // Get shape for each of the matrix outputs produced by the derived class.
166   *output_matrix_shapes = GetOutputMatrixShapes(input_matrix_shapes);
167   const int num_outputs = output_matrix_shapes->size();
168 
169   // Make sure the number of op outputs is what the derived class expects.
170   OP_REQUIRES(
171       context, num_outputs <= context->num_outputs(),
172       errors::Internal(
173           "Derived class expected more outputs (%d) that the op has (%d).",
174           num_outputs, context->num_outputs()));
175 
176   // Allocate outputs.
177   std::set<int> unused_inputs;
178   for (int input_idx = 0; input_idx < context->num_inputs(); ++input_idx) {
179     unused_inputs.insert(input_idx);
180   }
181   for (int output_idx = 0; output_idx < context->num_outputs(); ++output_idx) {
182     TensorShape output_tensor_shape({});
183     if (output_idx < num_outputs) {
184       // This output is used, set up output shape and allocate it.
185       const TensorShape& output_matrix_shape =
186           output_matrix_shapes->at(output_idx);
187       OP_REQUIRES(context, output_matrix_shape.dims() <= 2,
188                   errors::InvalidArgument(
189                       "Rank of matrix output no. %d must be 0, 1 or 2, got %d.",
190                       output_idx, output_matrix_shape.dims()));
191 
192       // The final output has the shape of the outer batch dimensions
193       // concatenated with the output_matrix_shape (if the output is not
194       // scalar).
195       output_tensor_shape = batch_shape;
196       output_tensor_shape.AppendShape(output_matrix_shape);
197     }
198     Tensor* out = nullptr;
199     // See if there is an input buffer we can reuse for this output.
200     bool reused_input = false;
201     if (EnableInputForwarding()) {
202       for (int input_idx : unused_inputs) {
203         if (context->forward_input_to_output_with_shape(
204                 input_idx, output_idx, output_tensor_shape, &out)) {
205           reused_input = true;
206           unused_inputs.erase(input_idx);
207           break;
208         }
209       }
210     }
211     if (!reused_input) {
212       OP_REQUIRES_OK(context, context->allocate_output(
213                                   output_idx, output_tensor_shape, &out));
214     }
215     outputs->emplace_back(out);
216   }
217 }
218 
219 template <class InputScalar, class OutputScalar>
ComputeTensorSlice(OpKernelContext * context,int64 matrix_index,const TensorInputs & inputs,const TensorShapes & input_matrix_shapes,const TensorOutputs & outputs,const TensorShapes & output_matrix_shapes)220 void LinearAlgebraOp<InputScalar, OutputScalar>::ComputeTensorSlice(
221     OpKernelContext* context, int64 matrix_index, const TensorInputs& inputs,
222     const TensorShapes& input_matrix_shapes, const TensorOutputs& outputs,
223     const TensorShapes& output_matrix_shapes) {
224   InputConstMatrixMaps matrix_inputs;
225   for (size_t i = 0; i < inputs.size(); ++i) {
226     // TODO(kalakris): Handle alignment if possible. Eigen::Map is
227     // unaligned by default.
228     matrix_inputs.emplace_back(
229         inputs[i]->flat<InputScalar>().data() +
230             matrix_index * input_matrix_shapes[i].num_elements(),
231         input_matrix_shapes[i].dim_size(0), input_matrix_shapes[i].dim_size(1));
232   }
233 
234   OutputMatrixMaps matrix_outputs;
235   for (size_t i = 0; i < output_matrix_shapes.size(); ++i) {
236     // The output matrix shape may not be a matrix.
237     int num_output_rows = output_matrix_shapes[i].dims() >= 1
238                               ? output_matrix_shapes[i].dim_size(0)
239                               : 1;
240     int num_output_cols = output_matrix_shapes[i].dims() == 2
241                               ? output_matrix_shapes[i].dim_size(1)
242                               : 1;
243     matrix_outputs.emplace_back(
244         outputs[i]->flat<OutputScalar>().data() +
245             matrix_index * output_matrix_shapes[i].num_elements(),
246         num_output_rows, num_output_cols);
247   }
248   ComputeMatrix(context, matrix_inputs, &matrix_outputs);
249 }
250 
251 // Explicitly instantiate LinearAlgebraOp for the scalar types we expect to use.
252 template class LinearAlgebraOp<float>;
253 template class LinearAlgebraOp<double>;
254 template class LinearAlgebraOp<complex64>;
255 template class LinearAlgebraOp<complex128>;
256 template class LinearAlgebraOp<float, complex64>;
257 template class LinearAlgebraOp<double, complex128>;
258 
259 }  // namespace tensorflow
260