• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #define EIGEN_USE_THREADS
16 
17 #include "tensorflow/core/kernels/reshape_util.h"
18 
19 #include <algorithm>
20 #include <numeric>
21 #include <unordered_map>
22 #include <utility>
23 #include <vector>
24 
25 #include "tensorflow/core/framework/op_kernel.h"
26 #include "tensorflow/core/framework/register_types.h"
27 #include "tensorflow/core/framework/tensor.h"
28 #include "tensorflow/core/framework/tensor_util.h"
29 #include "tensorflow/core/framework/types.h"
30 #include "tensorflow/core/lib/gtl/inlined_vector.h"
31 
32 namespace tensorflow {
33 
Reshape(OpKernelContext * context,const Tensor & input_indices_in,const Tensor & input_shape_in,const Tensor & target_shape_in,int output_indices_idx,int output_shape_idx)34 void Reshape(OpKernelContext *context, const Tensor &input_indices_in,
35              const Tensor &input_shape_in, const Tensor &target_shape_in,
36              int output_indices_idx, int output_shape_idx) {
37   OP_REQUIRES(context, TensorShapeUtils::IsMatrix(input_indices_in.shape()),
38               errors::InvalidArgument(
39                   "Input indices should be a matrix but received shape ",
40                   input_indices_in.shape().DebugString()));
41   OP_REQUIRES(context, TensorShapeUtils::IsVector(input_shape_in.shape()),
42               errors::InvalidArgument(
43                   "Input shape should be a vector but received shape ",
44                   input_shape_in.shape().DebugString()));
45   OP_REQUIRES(context, TensorShapeUtils::IsVector(target_shape_in.shape()),
46               errors::InvalidArgument(
47                   "Target shape should be a vector but received shape ",
48                   target_shape_in.shape().DebugString()));
49 
50   const int64 input_rank = input_shape_in.NumElements();
51   const int64 output_rank = target_shape_in.NumElements();
52   const TensorShape input_shape(input_shape_in.vec<int64>());
53   const int64 dense_size = input_shape.num_elements();
54   const int64 nnz = input_indices_in.shape().dim_size(0);
55 
56   // Compute the output shape. Determine product of specified dimensions, and
57   // find the index of the unspecified one.
58   TensorShape output_shape;
59   int64 product = 1;
60   int unknown_index = -1;
61   auto target_shape = target_shape_in.vec<int64>();
62   for (int d = 0; d < output_rank; ++d) {
63     const int64 size = target_shape(d);
64     if (size == -1) {
65       OP_REQUIRES(
66           context, unknown_index == -1,
67           errors::InvalidArgument("only one output dimension may be -1, "
68                                   "not both ",
69                                   unknown_index, " and ", d));
70       unknown_index = d;
71       output_shape.AddDim(1);
72     } else {
73       OP_REQUIRES(context, size >= 0,
74                   errors::InvalidArgument("size ", d,
75                                           " must be non-negative, not ", size));
76       product *= size;
77       output_shape.AddDim(size);
78     }
79   }
80   if (unknown_index != -1) {
81     OP_REQUIRES(
82         context, product > 0,
83         errors::InvalidArgument("reshape cannot infer the missing "
84                                 "input size for an empty tensor unless all "
85                                 "specified input sizes are non-zero"));
86     const int64 missing = dense_size / product;
87     OP_REQUIRES(
88         context, product * missing == dense_size,
89         errors::InvalidArgument(
90             "Input to reshape is a SparseTensor with ", dense_size,
91             " dense values, but the requested shape requires a multiple of ",
92             product));
93     output_shape.set_dim(unknown_index, missing);
94   }
95 
96   OP_REQUIRES(
97       context, output_shape.num_elements() == dense_size,
98       errors::InvalidArgument("Input to reshape is a tensor with ", dense_size,
99                               " dense values, but the requested shape has ",
100                               output_shape.num_elements()));
101 
102   // Optimize for reshaping to the same shape.
103   if (input_shape == output_shape) {
104     context->set_output(output_indices_idx, input_indices_in);
105     context->set_output(output_shape_idx, input_shape_in);
106     return;
107   }
108 
109   gtl::InlinedVector<int64, 8> input_strides(input_rank);
110   if (input_rank > 0) {
111     input_strides[input_rank - 1] = 1;
112     for (int d = input_rank - 2; d >= 0; --d) {
113       input_strides[d] = input_strides[d + 1] * input_shape.dim_size(d + 1);
114     }
115   }
116 
117   gtl::InlinedVector<int64, 8> output_strides(output_rank);
118   if (output_rank > 0) {
119     output_strides[output_rank - 1] = 1;
120     for (int d = output_rank - 2; d >= 0; --d) {
121       output_strides[d] = output_strides[d + 1] * output_shape.dim_size(d + 1);
122     }
123   }
124 
125   Tensor *result_indices = nullptr;
126   OP_REQUIRES_OK(context,
127                  context->allocate_output(output_indices_idx,
128                                           TensorShape({nnz, output_rank}),
129                                           &result_indices));
130   auto input_ind = input_indices_in.matrix<int64>();
131   auto output_ind = result_indices->matrix<int64>();
132   for (int i = 0; i < nnz; ++i) {
133     int64 id = 0;
134     for (int j = 0; j < input_rank; ++j) {
135       id += input_ind(i, j) * input_strides[j];
136     }
137     for (int j = 0; j < output_rank; ++j) {
138       output_ind(i, j) = id / output_strides[j];
139       id %= output_strides[j];
140     }
141   }
142 
143   Tensor *result_shape = nullptr;
144   OP_REQUIRES_OK(context, context->allocate_output(output_shape_idx,
145                                                    TensorShape({output_rank}),
146                                                    &result_shape));
147   auto output_shape_vec = result_shape->vec<int64>();
148   for (int j = 0; j < output_shape.dims(); ++j) {
149     output_shape_vec(j) = output_shape.dim_size(j);
150   }
151 }
152 
153 }  // namespace tensorflow
154