• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #define EIGEN_USE_THREADS
16 
17 #include "tensorflow/core/kernels/reshape_util.h"
18 
19 #include <algorithm>
20 #include <numeric>
21 #include <unordered_map>
22 #include <utility>
23 #include <vector>
24 
25 #include "tensorflow/core/framework/op_kernel.h"
26 #include "tensorflow/core/framework/register_types.h"
27 #include "tensorflow/core/framework/tensor.h"
28 #include "tensorflow/core/framework/tensor_util.h"
29 #include "tensorflow/core/framework/types.h"
30 #include "tensorflow/core/lib/gtl/inlined_vector.h"
31 
32 namespace tensorflow {
33 
34 using CPUDevice = Eigen::ThreadPoolDevice;
35 using GPUDevice = Eigen::GpuDevice;
36 
37 namespace functor {
38 
39 template <>
40 struct ReshapeSparseTensorFunctor<CPUDevice> {
operator ()tensorflow::functor::ReshapeSparseTensorFunctor41   Status operator()(OpKernelContext *context, const TensorShape &input_shape,
42                     const TensorShape &output_shape,
43                     typename TTypes<int64>::ConstMatrix input_indices,
44                     typename TTypes<int64>::Matrix output_indices) const {
45     (void)context;  // Unused (only used in GPU implementation)
46     const int64_t input_rank = input_shape.dims();
47     const int64_t output_rank = output_shape.dims();
48     const int64_t nnz = input_indices.dimension(0);
49     gtl::InlinedVector<int64, 8> input_strides(input_rank);
50     if (input_rank > 0) {
51       input_strides[input_rank - 1] = 1;
52       for (int d = input_rank - 2; d >= 0; --d) {
53         input_strides[d] = input_strides[d + 1] * input_shape.dim_size(d + 1);
54       }
55     }
56 
57     gtl::InlinedVector<int64, 8> output_strides(output_rank);
58     if (output_rank > 0) {
59       output_strides[output_rank - 1] = 1;
60       for (int d = output_rank - 2; d >= 0; --d) {
61         output_strides[d] =
62             output_strides[d + 1] * output_shape.dim_size(d + 1);
63       }
64     }
65 
66     for (int i = 0; i < nnz; ++i) {
67       int64_t id = 0;
68       for (int j = 0; j < input_rank; ++j) {
69         id += input_indices(i, j) * input_strides[j];
70       }
71       for (int j = 0; j < output_rank; ++j) {
72         output_indices(i, j) = id / output_strides[j];
73         id %= output_strides[j];
74       }
75     }
76     return Status::OK();
77   }
78 };
79 
80 }  // namespace functor
81 
82 template <typename Device>
ReshapeSparseTensor(OpKernelContext * context,const Tensor & input_indices_in,const Tensor & input_shape_in,const Tensor & target_shape_in,int output_indices_idx,int output_shape_idx)83 void ReshapeSparseTensor(OpKernelContext *context,
84                          const Tensor &input_indices_in,
85                          const Tensor &input_shape_in,
86                          const Tensor &target_shape_in, int output_indices_idx,
87                          int output_shape_idx) {
88   OP_REQUIRES(context, TensorShapeUtils::IsMatrix(input_indices_in.shape()),
89               errors::InvalidArgument(
90                   "Input indices should be a matrix but received shape ",
91                   input_indices_in.shape().DebugString()));
92   OP_REQUIRES(context, TensorShapeUtils::IsVector(input_shape_in.shape()),
93               errors::InvalidArgument(
94                   "Input shape should be a vector but received shape ",
95                   input_shape_in.shape().DebugString()));
96   OP_REQUIRES(context, TensorShapeUtils::IsVector(target_shape_in.shape()),
97               errors::InvalidArgument(
98                   "Target shape should be a vector but received shape ",
99                   target_shape_in.shape().DebugString()));
100 
101   const int64_t output_rank = target_shape_in.NumElements();
102   const TensorShape input_shape(input_shape_in.vec<int64>());
103   const int64_t dense_size = input_shape.num_elements();
104   const int64_t nnz = input_indices_in.shape().dim_size(0);
105 
106   // Compute the output shape. Determine product of specified dimensions, and
107   // find the index of the unspecified one.
108   TensorShape output_shape;
109   int64_t product = 1;
110   int unknown_index = -1;
111   auto target_shape = target_shape_in.vec<int64>();
112   for (int d = 0; d < output_rank; ++d) {
113     const int64_t size = target_shape(d);
114     if (size == -1) {
115       OP_REQUIRES(
116           context, unknown_index == -1,
117           errors::InvalidArgument("only one output dimension may be -1, "
118                                   "not both ",
119                                   unknown_index, " and ", d));
120       unknown_index = d;
121       output_shape.AddDim(1);
122     } else {
123       OP_REQUIRES(context, size >= 0,
124                   errors::InvalidArgument("size ", d,
125                                           " must be non-negative, not ", size));
126       product *= size;
127       output_shape.AddDim(size);
128     }
129   }
130   if (unknown_index != -1) {
131     OP_REQUIRES(
132         context, product > 0,
133         errors::InvalidArgument("reshape cannot infer the missing "
134                                 "input size for an empty tensor unless all "
135                                 "specified input sizes are non-zero"));
136     const int64_t missing = dense_size / product;
137     OP_REQUIRES(
138         context, product * missing == dense_size,
139         errors::InvalidArgument(
140             "Input to reshape is a SparseTensor with ", dense_size,
141             " dense values, but the requested shape requires a multiple of ",
142             product, ". input_shape=", input_shape.DebugString(),
143             " output_shape=", output_shape.DebugString()));
144     output_shape.set_dim(unknown_index, missing);
145   }
146 
147   OP_REQUIRES(
148       context, output_shape.num_elements() == dense_size,
149       errors::InvalidArgument("Input to reshape is a tensor with ", dense_size,
150                               " dense values, but the requested shape has ",
151                               output_shape.num_elements(),
152                               ". input_shape=", input_shape.DebugString(),
153                               " output_shape=", output_shape.DebugString()));
154 
155   // Optimize for reshaping to the same shape.
156   if (input_shape == output_shape) {
157     context->set_output(output_indices_idx, input_indices_in);
158     context->set_output(output_shape_idx, input_shape_in);
159     return;
160   }
161 
162   Tensor *result_shape = nullptr;
163   OP_REQUIRES_OK(context, context->allocate_output(output_shape_idx,
164                                                    TensorShape({output_rank}),
165                                                    &result_shape));
166   auto output_shape_vec = result_shape->vec<int64>();
167   for (int j = 0; j < output_shape.dims(); ++j) {
168     output_shape_vec(j) = output_shape.dim_size(j);
169   }
170 
171   Tensor *result_indices = nullptr;
172   OP_REQUIRES_OK(context,
173                  context->allocate_output(output_indices_idx,
174                                           TensorShape({nnz, output_rank}),
175                                           &result_indices));
176   if (nnz > 0) {
177     OP_REQUIRES(
178         context, dense_size > 0 && product > 0,
179         errors::InvalidArgument(
180             "Input tensor has ", nnz, " non zero elements but input shape (",
181             input_shape.DebugString(), ") or output shape (",
182             output_shape.DebugString(), ") is empty"));
183     OP_REQUIRES_OK(context, functor::ReshapeSparseTensorFunctor<Device>()(
184                                 context, input_shape, output_shape,
185                                 input_indices_in.matrix<int64>(),
186                                 result_indices->matrix<int64>()));
187   }
188 }
189 
190 #define EXPLICITLY_INSTANTIATE_FUNCTION(Device)                    \
191   template void ReshapeSparseTensor<Device>(                       \
192       OpKernelContext * context, const Tensor &input_indices_in,   \
193       const Tensor &input_shape_in, const Tensor &target_shape_in, \
194       int output_indices_idx, int output_shape_idx)
195 EXPLICITLY_INSTANTIATE_FUNCTION(CPUDevice);
196 
197 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
198 EXPLICITLY_INSTANTIATE_FUNCTION(GPUDevice);
199 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
200 #undef EXPLICITLY_INSTANTIATE_FUNCTION
201 
202 }  // namespace tensorflow
203