1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #define EIGEN_USE_THREADS 17 18 #include "tensorflow/core/kernels/sparse_concat_op.h" 19 20 #include <algorithm> 21 #include <numeric> 22 #include <unordered_map> 23 #include <utility> 24 #include <vector> 25 26 #include "tensorflow/core/framework/op_kernel.h" 27 #include "tensorflow/core/framework/register_types.h" 28 #include "tensorflow/core/framework/tensor.h" 29 #include "tensorflow/core/framework/tensor_util.h" 30 #include "tensorflow/core/framework/types.h" 31 #include "tensorflow/core/lib/gtl/inlined_vector.h" 32 #include "tensorflow/core/util/overflow.h" 33 #include "tensorflow/core/util/sparse/sparse_tensor.h" 34 35 namespace tensorflow { 36 37 typedef Eigen::ThreadPoolDevice CPUDevice; 38 39 namespace functor { 40 41 template <typename T> 42 struct SparseConcatFunctor<CPUDevice, T> { operator ()tensorflow::functor::SparseConcatFunctor43 void operator()(OpKernelContext* context, const OpInputList& inds, 44 const OpInputList& vals, const OpInputList& shapes, 45 int concat_dim) { 46 const int N = inds.size(); 47 const TensorShape input_shape(shapes[0].vec<int64>()); 48 const int input_rank = input_shape.dims(); 49 50 // The input and output sparse tensors are assumed to be ordered along 51 // increasing dimension number. But in order for concat to work properly, 52 // order[0] must be concat_dim. So we will reorder the inputs to the 53 // concat ordering, concatenate, then reorder back to the standard order. 54 // We make a deep copy of the input tensors to ensure that the in-place 55 // reorder doesn't create race conditions for other ops that may be 56 // concurrently reading the indices and values tensors. 57 58 gtl::InlinedVector<int64, 8> std_order(input_rank); 59 std::iota(std_order.begin(), std_order.end(), 0); 60 61 std::vector<int64> concat_order; 62 concat_order.reserve(input_rank); 63 concat_order.push_back(concat_dim); 64 for (int j = 0; j < input_rank; ++j) { 65 if (j != concat_dim) { 66 concat_order.push_back(j); 67 } 68 } 69 70 std::vector<sparse::SparseTensor> sp_inputs; 71 for (int i = 0; i < N; ++i) { 72 const TensorShape current_shape(shapes[i].vec<int64>()); 73 sparse::SparseTensor tensor; 74 OP_REQUIRES_OK(context, 75 sparse::SparseTensor::Create( 76 tensor::DeepCopy(inds[i]), tensor::DeepCopy(vals[i]), 77 current_shape, std_order, &tensor)); 78 sp_inputs.push_back(std::move(tensor)); 79 sp_inputs[i].Reorder<T>(concat_order); 80 } 81 82 sparse::SparseTensor concat = sparse::SparseTensor::Concat<T>(sp_inputs); 83 concat.Reorder<T>(std_order); 84 85 context->set_output(0, concat.indices()); 86 context->set_output(1, concat.values()); 87 } 88 }; 89 90 } // namespace functor 91 92 template <typename Device, typename T> 93 class SparseConcatOp : public OpKernel { 94 public: SparseConcatOp(OpKernelConstruction * context)95 explicit SparseConcatOp(OpKernelConstruction* context) : OpKernel(context) { 96 OP_REQUIRES_OK(context, context->GetAttr("concat_dim", &concat_dim_attr_)); 97 } 98 Compute(OpKernelContext * context)99 void Compute(OpKernelContext* context) override { 100 OpInputList inds; 101 OP_REQUIRES_OK(context, context->input_list("indices", &inds)); 102 const int N = inds.size(); 103 for (int i = 0; i < N; i++) { 104 OP_REQUIRES(context, TensorShapeUtils::IsMatrix(inds[i].shape()), 105 errors::InvalidArgument( 106 "Input indices should be a matrix but received shape ", 107 inds[i].shape().DebugString(), " at position ", i)); 108 } 109 110 OpInputList vals; 111 OP_REQUIRES_OK(context, context->input_list("values", &vals)); 112 OP_REQUIRES(context, vals.size() == N, 113 errors::InvalidArgument("Expected ", N, " input values, got ", 114 vals.size())); 115 for (int i = 0; i < N; i++) { 116 OP_REQUIRES(context, TensorShapeUtils::IsVector(vals[i].shape()), 117 errors::InvalidArgument( 118 "Input values should be a vector but received shape ", 119 vals[i].shape().DebugString(), " at position ", i)); 120 } 121 122 OpInputList shapes; 123 OP_REQUIRES_OK(context, context->input_list("shapes", &shapes)); 124 OP_REQUIRES(context, shapes.size() == N, 125 errors::InvalidArgument("Expected ", N, " input shapes, got ", 126 shapes.size())); 127 bool overflow_ocurred = false; 128 for (int i = 0; i < N; i++) { 129 int64_t new_num_elements = 1; 130 OP_REQUIRES(context, TensorShapeUtils::IsVector(shapes[i].shape()), 131 errors::InvalidArgument( 132 "Input shapes should be a vector but received shape ", 133 shapes[i].shape().DebugString(), " at position ", i)); 134 auto input_shape_vector = shapes[i].vec<int64>(); 135 for (int j = 0; j < input_shape_vector.size(); j++) { 136 new_num_elements = 137 MultiplyWithoutOverflow(new_num_elements, input_shape_vector(j)); 138 if (new_num_elements < 0) { 139 overflow_ocurred = true; 140 break; 141 } 142 } 143 144 if (overflow_ocurred) { 145 break; 146 } 147 } 148 149 OP_REQUIRES( 150 context, !overflow_ocurred, 151 errors::Internal("Encountered overflow from large input shape.")); 152 153 const TensorShape input_shape(shapes[0].vec<int64>()); 154 const int input_rank = input_shape.dims(); 155 const int concat_dim = (concat_dim_attr_ < 0) 156 ? input_rank + concat_dim_attr_ 157 : concat_dim_attr_; 158 OP_REQUIRES(context, concat_dim >= 0 && concat_dim < input_rank, 159 errors::InvalidArgument("Concat dimension must be in range [", 160 -input_rank, ", ", input_rank, 161 "), got ", concat_dim_attr_)); 162 TensorShape output_shape = input_shape; 163 for (int i = 1; i < N; ++i) { 164 const TensorShape current_shape(shapes[i].vec<int64>()); 165 OP_REQUIRES( 166 context, current_shape.dims() == input_rank, 167 errors::InvalidArgument( 168 "Ranks of all input tensors must match: expected ", input_rank, 169 " but got ", current_shape.dims(), " at position ", i)); 170 for (int j = 0; j < input_rank; ++j) { 171 if (j != concat_dim) { 172 OP_REQUIRES( 173 context, input_shape.dim_size(j) == current_shape.dim_size(j), 174 errors::InvalidArgument( 175 "Input shapes must match: expected ", input_shape.dim_size(j), 176 " for dimension ", j, " but got ", current_shape.dim_size(j), 177 " at position ", i)); 178 } else { 179 output_shape.set_dim( 180 j, output_shape.dim_size(j) + current_shape.dim_size(j)); 181 } 182 } 183 } 184 185 Tensor* output_shape_out = nullptr; 186 OP_REQUIRES_OK( 187 context, context->allocate_output(2, TensorShape({output_shape.dims()}), 188 &output_shape_out)); 189 auto output_shape_t = output_shape_out->vec<int64>(); 190 for (int j = 0; j < output_shape.dims(); ++j) { 191 output_shape_t(j) = output_shape.dim_size(j); 192 } 193 194 int64 output_nnz = 0; 195 for (int i = 0; i < N; ++i) { 196 output_nnz += inds[i].dim_size(0); 197 } 198 if (output_nnz == 0) { 199 Tensor* output_inds = nullptr; 200 OP_REQUIRES_OK(context, 201 context->allocate_output(0, TensorShape({0, input_rank}), 202 &output_inds)); 203 Tensor* output_vals = nullptr; 204 OP_REQUIRES_OK( 205 context, context->allocate_output(1, TensorShape({0}), &output_vals)); 206 return; // No work to do 207 } 208 209 functor::SparseConcatFunctor<Device, T>()(context, inds, vals, shapes, 210 concat_dim); 211 } 212 213 private: 214 int concat_dim_attr_; 215 }; 216 217 #define REGISTER_KERNELS(type) \ 218 REGISTER_KERNEL_BUILDER( \ 219 Name("SparseConcat").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ 220 SparseConcatOp<CPUDevice, type>) 221 222 TF_CALL_ALL_TYPES(REGISTER_KERNELS); 223 #undef REGISTER_KERNELS 224 } // namespace tensorflow 225