• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #define EIGEN_USE_THREADS
17 
18 #include "tensorflow/core/kernels/sparse_concat_op.h"
19 
20 #include <algorithm>
21 #include <numeric>
22 #include <unordered_map>
23 #include <utility>
24 #include <vector>
25 
26 #include "tensorflow/core/framework/op_kernel.h"
27 #include "tensorflow/core/framework/register_types.h"
28 #include "tensorflow/core/framework/tensor.h"
29 #include "tensorflow/core/framework/tensor_util.h"
30 #include "tensorflow/core/framework/types.h"
31 #include "tensorflow/core/lib/gtl/inlined_vector.h"
32 #include "tensorflow/core/util/overflow.h"
33 #include "tensorflow/core/util/sparse/sparse_tensor.h"
34 
35 namespace tensorflow {
36 
37 typedef Eigen::ThreadPoolDevice CPUDevice;
38 
39 namespace functor {
40 
41 template <typename T>
42 struct SparseConcatFunctor<CPUDevice, T> {
operator ()tensorflow::functor::SparseConcatFunctor43   void operator()(OpKernelContext* context, const OpInputList& inds,
44                   const OpInputList& vals, const OpInputList& shapes,
45                   int concat_dim) {
46     const int N = inds.size();
47     const TensorShape input_shape(shapes[0].vec<int64>());
48     const int input_rank = input_shape.dims();
49 
50     // The input and output sparse tensors are assumed to be ordered along
51     // increasing dimension number. But in order for concat to work properly,
52     // order[0] must be concat_dim. So we will reorder the inputs to the
53     // concat ordering, concatenate, then reorder back to the standard order.
54     // We make a deep copy of the input tensors to ensure that the in-place
55     // reorder doesn't create race conditions for other ops that may be
56     // concurrently reading the indices and values tensors.
57 
58     gtl::InlinedVector<int64, 8> std_order(input_rank);
59     std::iota(std_order.begin(), std_order.end(), 0);
60 
61     std::vector<int64> concat_order;
62     concat_order.reserve(input_rank);
63     concat_order.push_back(concat_dim);
64     for (int j = 0; j < input_rank; ++j) {
65       if (j != concat_dim) {
66         concat_order.push_back(j);
67       }
68     }
69 
70     std::vector<sparse::SparseTensor> sp_inputs;
71     for (int i = 0; i < N; ++i) {
72       const TensorShape current_shape(shapes[i].vec<int64>());
73       sparse::SparseTensor tensor;
74       OP_REQUIRES_OK(context,
75                      sparse::SparseTensor::Create(
76                          tensor::DeepCopy(inds[i]), tensor::DeepCopy(vals[i]),
77                          current_shape, std_order, &tensor));
78       sp_inputs.push_back(std::move(tensor));
79       sp_inputs[i].Reorder<T>(concat_order);
80     }
81 
82     sparse::SparseTensor concat = sparse::SparseTensor::Concat<T>(sp_inputs);
83     concat.Reorder<T>(std_order);
84 
85     context->set_output(0, concat.indices());
86     context->set_output(1, concat.values());
87   }
88 };
89 
90 }  // namespace functor
91 
92 template <typename Device, typename T>
93 class SparseConcatOp : public OpKernel {
94  public:
SparseConcatOp(OpKernelConstruction * context)95   explicit SparseConcatOp(OpKernelConstruction* context) : OpKernel(context) {
96     OP_REQUIRES_OK(context, context->GetAttr("concat_dim", &concat_dim_attr_));
97   }
98 
Compute(OpKernelContext * context)99   void Compute(OpKernelContext* context) override {
100     OpInputList inds;
101     OP_REQUIRES_OK(context, context->input_list("indices", &inds));
102     const int N = inds.size();
103     for (int i = 0; i < N; i++) {
104       OP_REQUIRES(context, TensorShapeUtils::IsMatrix(inds[i].shape()),
105                   errors::InvalidArgument(
106                       "Input indices should be a matrix but received shape ",
107                       inds[i].shape().DebugString(), " at position ", i));
108     }
109 
110     OpInputList vals;
111     OP_REQUIRES_OK(context, context->input_list("values", &vals));
112     OP_REQUIRES(context, vals.size() == N,
113                 errors::InvalidArgument("Expected ", N, " input values, got ",
114                                         vals.size()));
115     for (int i = 0; i < N; i++) {
116       OP_REQUIRES(context, TensorShapeUtils::IsVector(vals[i].shape()),
117                   errors::InvalidArgument(
118                       "Input values should be a vector but received shape ",
119                       vals[i].shape().DebugString(), " at position ", i));
120     }
121 
122     OpInputList shapes;
123     OP_REQUIRES_OK(context, context->input_list("shapes", &shapes));
124     OP_REQUIRES(context, shapes.size() == N,
125                 errors::InvalidArgument("Expected ", N, " input shapes, got ",
126                                         shapes.size()));
127     bool overflow_ocurred = false;
128     for (int i = 0; i < N; i++) {
129       int64_t new_num_elements = 1;
130       OP_REQUIRES(context, TensorShapeUtils::IsVector(shapes[i].shape()),
131                   errors::InvalidArgument(
132                       "Input shapes should be a vector but received shape ",
133                       shapes[i].shape().DebugString(), " at position ", i));
134       auto input_shape_vector = shapes[i].vec<int64>();
135       for (int j = 0; j < input_shape_vector.size(); j++) {
136         new_num_elements =
137             MultiplyWithoutOverflow(new_num_elements, input_shape_vector(j));
138         if (new_num_elements < 0) {
139           overflow_ocurred = true;
140           break;
141         }
142       }
143 
144       if (overflow_ocurred) {
145         break;
146       }
147     }
148 
149     OP_REQUIRES(
150         context, !overflow_ocurred,
151         errors::Internal("Encountered overflow from large input shape."));
152 
153     const TensorShape input_shape(shapes[0].vec<int64>());
154     const int input_rank = input_shape.dims();
155     const int concat_dim = (concat_dim_attr_ < 0)
156                                ? input_rank + concat_dim_attr_
157                                : concat_dim_attr_;
158     OP_REQUIRES(context, concat_dim >= 0 && concat_dim < input_rank,
159                 errors::InvalidArgument("Concat dimension must be in range [",
160                                         -input_rank, ", ", input_rank,
161                                         "), got ", concat_dim_attr_));
162     TensorShape output_shape = input_shape;
163     for (int i = 1; i < N; ++i) {
164       const TensorShape current_shape(shapes[i].vec<int64>());
165       OP_REQUIRES(
166           context, current_shape.dims() == input_rank,
167           errors::InvalidArgument(
168               "Ranks of all input tensors must match: expected ", input_rank,
169               " but got ", current_shape.dims(), " at position ", i));
170       for (int j = 0; j < input_rank; ++j) {
171         if (j != concat_dim) {
172           OP_REQUIRES(
173               context, input_shape.dim_size(j) == current_shape.dim_size(j),
174               errors::InvalidArgument(
175                   "Input shapes must match: expected ", input_shape.dim_size(j),
176                   " for dimension ", j, " but got ", current_shape.dim_size(j),
177                   " at position ", i));
178         } else {
179           output_shape.set_dim(
180               j, output_shape.dim_size(j) + current_shape.dim_size(j));
181         }
182       }
183     }
184 
185     Tensor* output_shape_out = nullptr;
186     OP_REQUIRES_OK(
187         context, context->allocate_output(2, TensorShape({output_shape.dims()}),
188                                           &output_shape_out));
189     auto output_shape_t = output_shape_out->vec<int64>();
190     for (int j = 0; j < output_shape.dims(); ++j) {
191       output_shape_t(j) = output_shape.dim_size(j);
192     }
193 
194     int64 output_nnz = 0;
195     for (int i = 0; i < N; ++i) {
196       output_nnz += inds[i].dim_size(0);
197     }
198     if (output_nnz == 0) {
199       Tensor* output_inds = nullptr;
200       OP_REQUIRES_OK(context,
201                      context->allocate_output(0, TensorShape({0, input_rank}),
202                                               &output_inds));
203       Tensor* output_vals = nullptr;
204       OP_REQUIRES_OK(
205           context, context->allocate_output(1, TensorShape({0}), &output_vals));
206       return;  // No work to do
207     }
208 
209     functor::SparseConcatFunctor<Device, T>()(context, inds, vals, shapes,
210                                               concat_dim);
211   }
212 
213  private:
214   int concat_dim_attr_;
215 };
216 
217 #define REGISTER_KERNELS(type)                                           \
218   REGISTER_KERNEL_BUILDER(                                               \
219       Name("SparseConcat").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
220       SparseConcatOp<CPUDevice, type>)
221 
222 TF_CALL_ALL_TYPES(REGISTER_KERNELS);
223 #undef REGISTER_KERNELS
224 }  // namespace tensorflow
225