1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #define EIGEN_USE_THREADS
16
17 #include "tensorflow/core/kernels/reshape_util.h"
18
19 #include <algorithm>
20 #include <numeric>
21 #include <unordered_map>
22 #include <utility>
23 #include <vector>
24
25 #include "tensorflow/core/framework/op_kernel.h"
26 #include "tensorflow/core/framework/register_types.h"
27 #include "tensorflow/core/framework/tensor.h"
28 #include "tensorflow/core/framework/tensor_util.h"
29 #include "tensorflow/core/framework/types.h"
30 #include "tensorflow/core/lib/gtl/inlined_vector.h"
31
32 namespace tensorflow {
33
Reshape(OpKernelContext * context,const Tensor & input_indices_in,const Tensor & input_shape_in,const Tensor & target_shape_in,int output_indices_idx,int output_shape_idx)34 void Reshape(OpKernelContext *context, const Tensor &input_indices_in,
35 const Tensor &input_shape_in, const Tensor &target_shape_in,
36 int output_indices_idx, int output_shape_idx) {
37 OP_REQUIRES(context, TensorShapeUtils::IsMatrix(input_indices_in.shape()),
38 errors::InvalidArgument(
39 "Input indices should be a matrix but received shape ",
40 input_indices_in.shape().DebugString()));
41 OP_REQUIRES(context, TensorShapeUtils::IsVector(input_shape_in.shape()),
42 errors::InvalidArgument(
43 "Input shape should be a vector but received shape ",
44 input_shape_in.shape().DebugString()));
45 OP_REQUIRES(context, TensorShapeUtils::IsVector(target_shape_in.shape()),
46 errors::InvalidArgument(
47 "Target shape should be a vector but received shape ",
48 target_shape_in.shape().DebugString()));
49
50 const int64 input_rank = input_shape_in.NumElements();
51 const int64 output_rank = target_shape_in.NumElements();
52 const TensorShape input_shape(input_shape_in.vec<int64>());
53 const int64 dense_size = input_shape.num_elements();
54 const int64 nnz = input_indices_in.shape().dim_size(0);
55
56 // Compute the output shape. Determine product of specified dimensions, and
57 // find the index of the unspecified one.
58 TensorShape output_shape;
59 int64 product = 1;
60 int unknown_index = -1;
61 auto target_shape = target_shape_in.vec<int64>();
62 for (int d = 0; d < output_rank; ++d) {
63 const int64 size = target_shape(d);
64 if (size == -1) {
65 OP_REQUIRES(
66 context, unknown_index == -1,
67 errors::InvalidArgument("only one output dimension may be -1, "
68 "not both ",
69 unknown_index, " and ", d));
70 unknown_index = d;
71 output_shape.AddDim(1);
72 } else {
73 OP_REQUIRES(context, size >= 0,
74 errors::InvalidArgument("size ", d,
75 " must be non-negative, not ", size));
76 product *= size;
77 output_shape.AddDim(size);
78 }
79 }
80 if (unknown_index != -1) {
81 OP_REQUIRES(
82 context, product > 0,
83 errors::InvalidArgument("reshape cannot infer the missing "
84 "input size for an empty tensor unless all "
85 "specified input sizes are non-zero"));
86 const int64 missing = dense_size / product;
87 OP_REQUIRES(
88 context, product * missing == dense_size,
89 errors::InvalidArgument(
90 "Input to reshape is a SparseTensor with ", dense_size,
91 " dense values, but the requested shape requires a multiple of ",
92 product));
93 output_shape.set_dim(unknown_index, missing);
94 }
95
96 OP_REQUIRES(
97 context, output_shape.num_elements() == dense_size,
98 errors::InvalidArgument("Input to reshape is a tensor with ", dense_size,
99 " dense values, but the requested shape has ",
100 output_shape.num_elements()));
101
102 // Optimize for reshaping to the same shape.
103 if (input_shape == output_shape) {
104 context->set_output(output_indices_idx, input_indices_in);
105 context->set_output(output_shape_idx, input_shape_in);
106 return;
107 }
108
109 gtl::InlinedVector<int64, 8> input_strides(input_rank);
110 if (input_rank > 0) {
111 input_strides[input_rank - 1] = 1;
112 for (int d = input_rank - 2; d >= 0; --d) {
113 input_strides[d] = input_strides[d + 1] * input_shape.dim_size(d + 1);
114 }
115 }
116
117 gtl::InlinedVector<int64, 8> output_strides(output_rank);
118 if (output_rank > 0) {
119 output_strides[output_rank - 1] = 1;
120 for (int d = output_rank - 2; d >= 0; --d) {
121 output_strides[d] = output_strides[d + 1] * output_shape.dim_size(d + 1);
122 }
123 }
124
125 Tensor *result_indices = nullptr;
126 OP_REQUIRES_OK(context,
127 context->allocate_output(output_indices_idx,
128 TensorShape({nnz, output_rank}),
129 &result_indices));
130 auto input_ind = input_indices_in.matrix<int64>();
131 auto output_ind = result_indices->matrix<int64>();
132 for (int i = 0; i < nnz; ++i) {
133 int64 id = 0;
134 for (int j = 0; j < input_rank; ++j) {
135 id += input_ind(i, j) * input_strides[j];
136 }
137 for (int j = 0; j < output_rank; ++j) {
138 output_ind(i, j) = id / output_strides[j];
139 id %= output_strides[j];
140 }
141 }
142
143 Tensor *result_shape = nullptr;
144 OP_REQUIRES_OK(context, context->allocate_output(output_shape_idx,
145 TensorShape({output_rank}),
146 &result_shape));
147 auto output_shape_vec = result_shape->vec<int64>();
148 for (int j = 0; j < output_shape.dims(); ++j) {
149 output_shape_vec(j) = output_shape.dim_size(j);
150 }
151 }
152
153 } // namespace tensorflow
154