1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #define EIGEN_USE_THREADS
16
17 #include "tensorflow/core/kernels/reshape_util.h"
18
19 #include <algorithm>
20 #include <numeric>
21 #include <unordered_map>
22 #include <utility>
23 #include <vector>
24
25 #include "tensorflow/core/framework/op_kernel.h"
26 #include "tensorflow/core/framework/register_types.h"
27 #include "tensorflow/core/framework/tensor.h"
28 #include "tensorflow/core/framework/tensor_util.h"
29 #include "tensorflow/core/framework/types.h"
30 #include "tensorflow/core/lib/gtl/inlined_vector.h"
31
32 namespace tensorflow {
33
34 using CPUDevice = Eigen::ThreadPoolDevice;
35 using GPUDevice = Eigen::GpuDevice;
36
37 namespace functor {
38
39 template <>
40 struct ReshapeSparseTensorFunctor<CPUDevice> {
operator ()tensorflow::functor::ReshapeSparseTensorFunctor41 Status operator()(OpKernelContext *context, const TensorShape &input_shape,
42 const TensorShape &output_shape,
43 typename TTypes<int64>::ConstMatrix input_indices,
44 typename TTypes<int64>::Matrix output_indices) const {
45 (void)context; // Unused (only used in GPU implementation)
46 const int64_t input_rank = input_shape.dims();
47 const int64_t output_rank = output_shape.dims();
48 const int64_t nnz = input_indices.dimension(0);
49 gtl::InlinedVector<int64, 8> input_strides(input_rank);
50 if (input_rank > 0) {
51 input_strides[input_rank - 1] = 1;
52 for (int d = input_rank - 2; d >= 0; --d) {
53 input_strides[d] = input_strides[d + 1] * input_shape.dim_size(d + 1);
54 }
55 }
56
57 gtl::InlinedVector<int64, 8> output_strides(output_rank);
58 if (output_rank > 0) {
59 output_strides[output_rank - 1] = 1;
60 for (int d = output_rank - 2; d >= 0; --d) {
61 output_strides[d] =
62 output_strides[d + 1] * output_shape.dim_size(d + 1);
63 }
64 }
65
66 for (int i = 0; i < nnz; ++i) {
67 int64_t id = 0;
68 for (int j = 0; j < input_rank; ++j) {
69 id += input_indices(i, j) * input_strides[j];
70 }
71 for (int j = 0; j < output_rank; ++j) {
72 output_indices(i, j) = id / output_strides[j];
73 id %= output_strides[j];
74 }
75 }
76 return Status::OK();
77 }
78 };
79
80 } // namespace functor
81
82 template <typename Device>
ReshapeSparseTensor(OpKernelContext * context,const Tensor & input_indices_in,const Tensor & input_shape_in,const Tensor & target_shape_in,int output_indices_idx,int output_shape_idx)83 void ReshapeSparseTensor(OpKernelContext *context,
84 const Tensor &input_indices_in,
85 const Tensor &input_shape_in,
86 const Tensor &target_shape_in, int output_indices_idx,
87 int output_shape_idx) {
88 OP_REQUIRES(context, TensorShapeUtils::IsMatrix(input_indices_in.shape()),
89 errors::InvalidArgument(
90 "Input indices should be a matrix but received shape ",
91 input_indices_in.shape().DebugString()));
92 OP_REQUIRES(context, TensorShapeUtils::IsVector(input_shape_in.shape()),
93 errors::InvalidArgument(
94 "Input shape should be a vector but received shape ",
95 input_shape_in.shape().DebugString()));
96 OP_REQUIRES(context, TensorShapeUtils::IsVector(target_shape_in.shape()),
97 errors::InvalidArgument(
98 "Target shape should be a vector but received shape ",
99 target_shape_in.shape().DebugString()));
100
101 const int64_t output_rank = target_shape_in.NumElements();
102 const TensorShape input_shape(input_shape_in.vec<int64>());
103 const int64_t dense_size = input_shape.num_elements();
104 const int64_t nnz = input_indices_in.shape().dim_size(0);
105
106 // Compute the output shape. Determine product of specified dimensions, and
107 // find the index of the unspecified one.
108 TensorShape output_shape;
109 int64_t product = 1;
110 int unknown_index = -1;
111 auto target_shape = target_shape_in.vec<int64>();
112 for (int d = 0; d < output_rank; ++d) {
113 const int64_t size = target_shape(d);
114 if (size == -1) {
115 OP_REQUIRES(
116 context, unknown_index == -1,
117 errors::InvalidArgument("only one output dimension may be -1, "
118 "not both ",
119 unknown_index, " and ", d));
120 unknown_index = d;
121 output_shape.AddDim(1);
122 } else {
123 OP_REQUIRES(context, size >= 0,
124 errors::InvalidArgument("size ", d,
125 " must be non-negative, not ", size));
126 product *= size;
127 output_shape.AddDim(size);
128 }
129 }
130 if (unknown_index != -1) {
131 OP_REQUIRES(
132 context, product > 0,
133 errors::InvalidArgument("reshape cannot infer the missing "
134 "input size for an empty tensor unless all "
135 "specified input sizes are non-zero"));
136 const int64_t missing = dense_size / product;
137 OP_REQUIRES(
138 context, product * missing == dense_size,
139 errors::InvalidArgument(
140 "Input to reshape is a SparseTensor with ", dense_size,
141 " dense values, but the requested shape requires a multiple of ",
142 product, ". input_shape=", input_shape.DebugString(),
143 " output_shape=", output_shape.DebugString()));
144 output_shape.set_dim(unknown_index, missing);
145 }
146
147 OP_REQUIRES(
148 context, output_shape.num_elements() == dense_size,
149 errors::InvalidArgument("Input to reshape is a tensor with ", dense_size,
150 " dense values, but the requested shape has ",
151 output_shape.num_elements(),
152 ". input_shape=", input_shape.DebugString(),
153 " output_shape=", output_shape.DebugString()));
154
155 // Optimize for reshaping to the same shape.
156 if (input_shape == output_shape) {
157 context->set_output(output_indices_idx, input_indices_in);
158 context->set_output(output_shape_idx, input_shape_in);
159 return;
160 }
161
162 Tensor *result_shape = nullptr;
163 OP_REQUIRES_OK(context, context->allocate_output(output_shape_idx,
164 TensorShape({output_rank}),
165 &result_shape));
166 auto output_shape_vec = result_shape->vec<int64>();
167 for (int j = 0; j < output_shape.dims(); ++j) {
168 output_shape_vec(j) = output_shape.dim_size(j);
169 }
170
171 Tensor *result_indices = nullptr;
172 OP_REQUIRES_OK(context,
173 context->allocate_output(output_indices_idx,
174 TensorShape({nnz, output_rank}),
175 &result_indices));
176 if (nnz > 0) {
177 OP_REQUIRES(
178 context, dense_size > 0 && product > 0,
179 errors::InvalidArgument(
180 "Input tensor has ", nnz, " non zero elements but input shape (",
181 input_shape.DebugString(), ") or output shape (",
182 output_shape.DebugString(), ") is empty"));
183 OP_REQUIRES_OK(context, functor::ReshapeSparseTensorFunctor<Device>()(
184 context, input_shape, output_shape,
185 input_indices_in.matrix<int64>(),
186 result_indices->matrix<int64>()));
187 }
188 }
189
190 #define EXPLICITLY_INSTANTIATE_FUNCTION(Device) \
191 template void ReshapeSparseTensor<Device>( \
192 OpKernelContext * context, const Tensor &input_indices_in, \
193 const Tensor &input_shape_in, const Tensor &target_shape_in, \
194 int output_indices_idx, int output_shape_idx)
195 EXPLICITLY_INSTANTIATE_FUNCTION(CPUDevice);
196
197 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
198 EXPLICITLY_INSTANTIATE_FUNCTION(GPUDevice);
199 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
200 #undef EXPLICITLY_INSTANTIATE_FUNCTION
201
202 } // namespace tensorflow
203