1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/kernels/linalg/linalg_ops_common.h"
17
18 #include <utility>
19
20 #include "third_party/eigen3/Eigen/Core"
21 #include "tensorflow/core/framework/device_base.h"
22 #include "tensorflow/core/framework/kernel_def_builder.h"
23 #include "tensorflow/core/framework/op_kernel.h"
24 #include "tensorflow/core/framework/tensor_shape.h"
25 #include "tensorflow/core/lib/core/errors.h"
26 #include "tensorflow/core/platform/logging.h"
27 #include "tensorflow/core/platform/types.h"
28
29 namespace tensorflow {
30
31 // static
32 template <class InputScalar, class OutputScalar>
ValidateSingleMatrix(OpKernelContext * context,const TensorShapes & input_matrix_shapes)33 void LinearAlgebraOp<InputScalar, OutputScalar>::ValidateSingleMatrix(
34 OpKernelContext* context, const TensorShapes& input_matrix_shapes) {
35 OP_REQUIRES(context, input_matrix_shapes.size() == 1,
36 errors::InvalidArgument("Expected a single input matrix, got %d.",
37 input_matrix_shapes.size()));
38 OP_REQUIRES(context, TensorShapeUtils::IsMatrix(input_matrix_shapes[0]),
39 errors::InvalidArgument("Input must be a matrix."));
40 }
41
42 // static
43 template <class InputScalar, class OutputScalar>
ValidateSingleSquareMatrix(OpKernelContext * context,const TensorShapes & input_matrix_shapes)44 void LinearAlgebraOp<InputScalar, OutputScalar>::ValidateSingleSquareMatrix(
45 OpKernelContext* context, const TensorShapes& input_matrix_shapes) {
46 OP_REQUIRES(context, input_matrix_shapes.size() == 1,
47 errors::InvalidArgument("Expected a single input matrix, got %d.",
48 input_matrix_shapes.size()));
49 OP_REQUIRES(context, TensorShapeUtils::IsSquareMatrix(input_matrix_shapes[0]),
50 errors::InvalidArgument("Input matrix must be square."));
51 }
52
53 // static
54 template <class InputScalar, class OutputScalar>
ValidateSolver(OpKernelContext * context,const TensorShapes & input_matrix_shapes)55 void LinearAlgebraOp<InputScalar, OutputScalar>::ValidateSolver(
56 OpKernelContext* context, const TensorShapes& input_matrix_shapes) {
57 OP_REQUIRES(context, input_matrix_shapes.size() == 2,
58 errors::InvalidArgument("Expected two input matrices, got %d.",
59 input_matrix_shapes.size()));
60 OP_REQUIRES(context, TensorShapeUtils::IsMatrix(input_matrix_shapes[0]),
61 errors::InvalidArgument("First input (lhs) must be a matrix."));
62 OP_REQUIRES(context, TensorShapeUtils::IsMatrix(input_matrix_shapes[1]),
63 errors::InvalidArgument("Second input (rhs) must be a matrix."));
64 OP_REQUIRES(
65 context,
66 input_matrix_shapes[0].dim_size(0) == input_matrix_shapes[1].dim_size(0),
67 errors::InvalidArgument("Input matrix and rhs are incompatible."));
68 }
69
70 // static
71 template <class InputScalar, class OutputScalar>
ValidateSquareSolver(OpKernelContext * context,const TensorShapes & input_matrix_shapes)72 void LinearAlgebraOp<InputScalar, OutputScalar>::ValidateSquareSolver(
73 OpKernelContext* context, const TensorShapes& input_matrix_shapes) {
74 OP_REQUIRES(context, input_matrix_shapes.size() == 2,
75 errors::InvalidArgument("Expected two input matrices, got %d.",
76 input_matrix_shapes.size()));
77 OP_REQUIRES(
78 context, TensorShapeUtils::IsSquareMatrix(input_matrix_shapes[0]),
79 errors::InvalidArgument("First input (lhs) must be a square matrix."));
80 OP_REQUIRES(context, TensorShapeUtils::IsMatrix(input_matrix_shapes[1]),
81 errors::InvalidArgument("Second input (rhs) must be a matrix."));
82 OP_REQUIRES(
83 context,
84 input_matrix_shapes[0].dim_size(0) == input_matrix_shapes[1].dim_size(0),
85 errors::InvalidArgument("Input matrix and rhs are incompatible."));
86 }
87
88 template <class InputScalar, class OutputScalar>
Compute(OpKernelContext * context)89 void LinearAlgebraOp<InputScalar, OutputScalar>::Compute(
90 OpKernelContext* context) {
91 TensorInputs inputs;
92 TensorShapes input_matrix_shapes;
93 TensorShape batch_shape;
94 AnalyzeInputs(context, &inputs, &input_matrix_shapes, &batch_shape);
95 if (!context->status().ok()) return;
96
97 TensorShapes output_matrix_shapes;
98 TensorOutputs outputs;
99 PrepareOutputs(context, input_matrix_shapes, batch_shape, &outputs,
100 &output_matrix_shapes);
101 if (!context->status().ok()) return;
102
103 // Process the individual matrix problems in parallel using a threadpool.
104 auto shard = [this, &inputs, &input_matrix_shapes, &outputs,
105 &output_matrix_shapes, context](int64 begin, int64 end) {
106 for (int64 i = begin; i < end; ++i) {
107 ComputeTensorSlice(context, i, inputs, input_matrix_shapes, outputs,
108 output_matrix_shapes);
109 }
110 };
111 auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
112 Shard(worker_threads.num_threads, worker_threads.workers,
113 batch_shape.num_elements(), GetCostPerUnit(input_matrix_shapes), shard);
114 }
115
116 template <class InputScalar, class OutputScalar>
AnalyzeInputs(OpKernelContext * context,TensorInputs * inputs,TensorShapes * input_matrix_shapes,TensorShape * batch_shape)117 void LinearAlgebraOp<InputScalar, OutputScalar>::AnalyzeInputs(
118 OpKernelContext* context, TensorInputs* inputs,
119 TensorShapes* input_matrix_shapes, TensorShape* batch_shape) {
120 int input_rank = -1;
121 for (int i = 0; i < NumMatrixInputs(context); ++i) {
122 const Tensor& in = context->input(i);
123 if (i == 0) {
124 input_rank = in.dims();
125 OP_REQUIRES(
126 context, input_rank >= 2,
127 errors::InvalidArgument("Input tensor ", i,
128 " must have rank >= 2, got ", input_rank));
129 // If the tensor rank is greater than 2, we consider the inner-most
130 // dimensions as matrices, and loop over all the other outer ("batch")
131 // dimensions to compute the results.
132 for (int dim = 0; dim < input_rank - 2; ++dim) {
133 batch_shape->AddDim(in.dim_size(dim));
134 }
135 } else {
136 // Make sure that all inputs have the same rank and outer dimensions.
137 OP_REQUIRES(context, input_rank == in.dims(),
138 errors::InvalidArgument(
139 "All input tensors must have the same rank."));
140 for (int dim = 0; dim < input_rank - 2; ++dim) {
141 OP_REQUIRES(
142 context, in.dim_size(dim) == batch_shape->dim_size(dim),
143 errors::InvalidArgument(
144 "All input tensors must have the same outer dimensions."));
145 }
146 }
147
148 const int row_dimension = input_rank - 2;
149 const int col_dimension = input_rank - 1;
150 const int64 num_rows = in.dim_size(row_dimension);
151 const int64 num_cols = in.dim_size(col_dimension);
152 input_matrix_shapes->emplace_back(
153 std::initializer_list<int64>({num_rows, num_cols}));
154 inputs->emplace_back(&in);
155 }
156 // Have the derived class validate that the inputs are as expected.
157 ValidateInputMatrixShapes(context, *input_matrix_shapes);
158 }
159
160 template <class InputScalar, class OutputScalar>
PrepareOutputs(OpKernelContext * context,const TensorShapes & input_matrix_shapes,const TensorShape & batch_shape,TensorOutputs * outputs,TensorShapes * output_matrix_shapes)161 void LinearAlgebraOp<InputScalar, OutputScalar>::PrepareOutputs(
162 OpKernelContext* context, const TensorShapes& input_matrix_shapes,
163 const TensorShape& batch_shape, TensorOutputs* outputs,
164 TensorShapes* output_matrix_shapes) {
165 // Get shape for each of the matrix outputs produced by the derived class.
166 *output_matrix_shapes = GetOutputMatrixShapes(input_matrix_shapes);
167 const int num_outputs = output_matrix_shapes->size();
168
169 // Make sure the number of op outputs is what the derived class expects.
170 OP_REQUIRES(
171 context, num_outputs <= context->num_outputs(),
172 errors::Internal(
173 "Derived class expected more outputs (%d) that the op has (%d).",
174 num_outputs, context->num_outputs()));
175
176 // Allocate outputs.
177 std::set<int> unused_inputs;
178 for (int input_idx = 0; input_idx < context->num_inputs(); ++input_idx) {
179 unused_inputs.insert(input_idx);
180 }
181 for (int output_idx = 0; output_idx < context->num_outputs(); ++output_idx) {
182 TensorShape output_tensor_shape({});
183 if (output_idx < num_outputs) {
184 // This output is used, set up output shape and allocate it.
185 const TensorShape& output_matrix_shape =
186 output_matrix_shapes->at(output_idx);
187 OP_REQUIRES(context, output_matrix_shape.dims() <= 2,
188 errors::InvalidArgument(
189 "Rank of matrix output no. %d must be 0, 1 or 2, got %d.",
190 output_idx, output_matrix_shape.dims()));
191
192 // The final output has the shape of the outer batch dimensions
193 // concatenated with the output_matrix_shape (if the output is not
194 // scalar).
195 output_tensor_shape = batch_shape;
196 output_tensor_shape.AppendShape(output_matrix_shape);
197 }
198 Tensor* out = nullptr;
199 // See if there is an input buffer we can reuse for this output.
200 bool reused_input = false;
201 if (EnableInputForwarding()) {
202 for (int input_idx : unused_inputs) {
203 if (context->forward_input_to_output_with_shape(
204 input_idx, output_idx, output_tensor_shape, &out)) {
205 reused_input = true;
206 unused_inputs.erase(input_idx);
207 break;
208 }
209 }
210 }
211 if (!reused_input) {
212 OP_REQUIRES_OK(context, context->allocate_output(
213 output_idx, output_tensor_shape, &out));
214 }
215 outputs->emplace_back(out);
216 }
217 }
218
219 template <class InputScalar, class OutputScalar>
ComputeTensorSlice(OpKernelContext * context,int64 matrix_index,const TensorInputs & inputs,const TensorShapes & input_matrix_shapes,const TensorOutputs & outputs,const TensorShapes & output_matrix_shapes)220 void LinearAlgebraOp<InputScalar, OutputScalar>::ComputeTensorSlice(
221 OpKernelContext* context, int64 matrix_index, const TensorInputs& inputs,
222 const TensorShapes& input_matrix_shapes, const TensorOutputs& outputs,
223 const TensorShapes& output_matrix_shapes) {
224 InputConstMatrixMaps matrix_inputs;
225 for (size_t i = 0; i < inputs.size(); ++i) {
226 // TODO(kalakris): Handle alignment if possible. Eigen::Map is
227 // unaligned by default.
228 matrix_inputs.emplace_back(
229 inputs[i]->flat<InputScalar>().data() +
230 matrix_index * input_matrix_shapes[i].num_elements(),
231 input_matrix_shapes[i].dim_size(0), input_matrix_shapes[i].dim_size(1));
232 }
233
234 OutputMatrixMaps matrix_outputs;
235 for (size_t i = 0; i < output_matrix_shapes.size(); ++i) {
236 // The output matrix shape may not be a matrix.
237 int num_output_rows = output_matrix_shapes[i].dims() >= 1
238 ? output_matrix_shapes[i].dim_size(0)
239 : 1;
240 int num_output_cols = output_matrix_shapes[i].dims() == 2
241 ? output_matrix_shapes[i].dim_size(1)
242 : 1;
243 matrix_outputs.emplace_back(
244 outputs[i]->flat<OutputScalar>().data() +
245 matrix_index * output_matrix_shapes[i].num_elements(),
246 num_output_rows, num_output_cols);
247 }
248 ComputeMatrix(context, matrix_inputs, &matrix_outputs);
249 }
250
251 // Explicitly instantiate LinearAlgebraOp for the scalar types we expect to use.
252 template class LinearAlgebraOp<float>;
253 template class LinearAlgebraOp<double>;
254 template class LinearAlgebraOp<complex64>;
255 template class LinearAlgebraOp<complex128>;
256 template class LinearAlgebraOp<float, complex64>;
257 template class LinearAlgebraOp<double, complex128>;
258
259 } // namespace tensorflow
260