• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // XLA-specific reshape Op.
17 
18 #include "tensorflow/compiler/tf2xla/type_util.h"
19 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
20 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
21 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
22 #include "tensorflow/compiler/xla/client/lib/constants.h"
23 #include "tensorflow/compiler/xla/client/xla_builder.h"
24 #include "tensorflow/compiler/xla/literal.h"
25 #include "tensorflow/compiler/xla/util.h"
26 #include "tensorflow/core/framework/op_kernel.h"
27 #include "tensorflow/core/framework/register_types.h"
28 #include "tensorflow/core/framework/tensor.h"
29 #include "tensorflow/core/framework/tensor_shape.h"
30 
31 namespace tensorflow {
32 namespace {
33 
34 class ReshapeOp : public XlaOpKernel {
35  public:
ReshapeOp(OpKernelConstruction * ctx)36   explicit ReshapeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
37 
Compile(XlaOpKernelContext * ctx)38   void Compile(XlaOpKernelContext* ctx) override {
39     const TensorShape input_shape = ctx->InputShape(0);
40     const TensorShape sizes_shape = ctx->InputShape(1);
41     // Preliminary validation of sizes.
42     OP_REQUIRES(ctx, TensorShapeUtils::IsVector(sizes_shape),
43                 errors::InvalidArgument("sizes input must be 1-D, not shape ",
44                                         sizes_shape.DebugString()));
45     const int64 num_dims = sizes_shape.num_elements();
46 
47     std::vector<int64> shape_input;
48     OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, &shape_input));
49 
50     // Compute the output shape.  Determine product of specified
51     // dimensions, and find the index of the unspecified one if there
52     // is one.
53     TensorShape shape;
54     int64 product = 1;
55     int unknown_index = -1;
56     bool shape_has_zero_dim = false;
57     for (int d = 0; d < num_dims; ++d) {
58       const int32 size = shape_input[d];
59       if (size == -1) {
60         OP_REQUIRES(
61             ctx, unknown_index == -1,
62             errors::InvalidArgument("only one input size may be -1, not both ",
63                                     unknown_index, " and ", d));
64         unknown_index = d;
65         shape.AddDim(1);
66       } else if (size == 0) {
67         // We don't include zero-sized dimension in product, so that we can
68         // still calculate number of elements for non-zero-sized dimensions and
69         // therefore infer their shapes.
70         shape.AddDim(size);
71         shape_has_zero_dim = true;
72       } else {
73         OP_REQUIRES(ctx, size >= 0,
74                     errors::InvalidArgument(
75                         "size ", d, " must be non-negative, not ", size));
76         shape.AddDim(size);
77         product *= size;
78       }
79     }
80     if (unknown_index != -1) {
81       int64 input_num_elements = 1;
82       bool input_has_zero_dim = false;
83       for (int dim = 0; dim < input_shape.dims(); dim++) {
84         // For zero dimension, we don't count it into `input_num_elements`
85         // unless `sizes` has no zero dimension, so we are still able to
86         // infer shapes for other dimensions.
87         if (input_shape.dim_size(dim) > 0 || !shape_has_zero_dim) {
88           input_num_elements *= input_shape.dim_size(dim);
89         } else {
90           input_has_zero_dim = true;
91         }
92       }
93 
94       const int64 missing = input_num_elements / product;
95       if (!input_has_zero_dim) {
96         OP_REQUIRES(
97             ctx, product * missing == input_num_elements,
98             errors::InvalidArgument(
99                 "Input to reshape is a tensor with ", input_num_elements,
100                 " values, but the requested shape requires a multiple of ",
101                 product));
102       }
103       shape.set_dim(unknown_index, missing);
104     }
105     OP_REQUIRES(ctx, shape.num_elements() == input_shape.num_elements(),
106                 errors::InvalidArgument("Input to reshape is a tensor with ",
107                                         input_shape.num_elements(),
108                                         " values, but the requested shape has ",
109                                         shape.num_elements()));
110 
111     VLOG(2) << "Reshape from " << input_shape.DebugString() << " to "
112             << shape.DebugString() << ", unknown_index=" << unknown_index;
113     auto input_xla_shape = ctx->InputXlaShape(0);
114     if (input_xla_shape->is_static()) {
115       ctx->SetOutput(0, xla::Reshape(ctx->Input(0), shape.dim_sizes()));
116       return;
117     }
118     // Handing dynamic reshapes if input contains a dynamic dimension.
119     std::vector<xla::XlaOp> output_dim_sizes;
120     std::vector<bool> dims_are_dynamic;
121     for (int64 i = 0; i < shape.dims(); ++i) {
122       output_dim_sizes.push_back(
123           xla::Reshape(xla::Slice(ctx->Input(1), {i}, {i + 1}, {1}), {}));
124     }
125     OP_REQUIRES_OK(
126         ctx, ctx->ResolveInputDynamismIntoPredVector(1, &dims_are_dynamic));
127     if (unknown_index == -1) {
128       // No unknown index.
129       ctx->SetOutput(0,
130                      xla::DynamicReshape(ctx->Input(0), output_dim_sizes,
131                                          shape.dim_sizes(), dims_are_dynamic));
132       return;
133     }
134     auto common_factors =
135         xla::CommonFactors(input_shape.dim_sizes(), shape.dim_sizes());
136 
137     // Find common_factors that the input belongs to.
138     for (int64 i = 0; i < common_factors.size() - 1; ++i) {
139       auto start = common_factors[i];
140       auto end = common_factors[i + 1];
141       bool input_is_dynamic = false;
142       // product of all input dims in this group. E.g., in
143       // reshape(Tensor([2, 3, 3]), [3, -1, 3]) product of the group
144       // containing -1 will be 6.
145       xla::XlaOp product = xla::One(ctx->builder(), xla::S32);
146       for (int64 dim = start.first; dim < end.first; ++dim) {
147         if (input_xla_shape->is_dynamic_dimension(dim)) {
148           input_is_dynamic = true;
149         }
150         product = xla::Mul(product, xla::GetDimensionSize(ctx->Input(0), dim));
151       }
152       bool unknown_dim_in_group = false;
153       // The real size for the -1 dimension in a reshape. E.g., in
154       // reshape(Tensor([2, 3, 3]), [3, -1, 3]) this will be 2.
155       xla::XlaOp unknown_dim_size = product;
156       for (int64 dim = start.second; dim < end.second; ++dim) {
157         if (dim == unknown_index) {
158           unknown_dim_in_group = true;
159         } else {
160           unknown_dim_size = xla::Div(unknown_dim_size, output_dim_sizes[dim]);
161         }
162       }
163 
164       if (unknown_dim_in_group) {
165         // If input dim is dynamic, output dim at the -1 position must be
166         // dynamic. Similarly, if input dim is static, output dim has to be
167         // static at the -1 dimension.
168         dims_are_dynamic[unknown_index] = input_is_dynamic;
169         output_dim_sizes[unknown_index] = unknown_dim_size;
170 
171         ctx->SetOutput(
172             0, xla::DynamicReshape(ctx->Input(0), output_dim_sizes,
173                                    shape.dim_sizes(), dims_are_dynamic));
174         VLOG(2) << "Reshape from " << ctx->InputXlaShape(0)->ToString()
175                 << " to " << xla::VectorString(shape.dim_sizes())
176                 << ", dynamic_dims=" << xla::VectorString(dims_are_dynamic);
177         return;
178       }
179     }
180   }
181 };
182 
183 REGISTER_XLA_OP(Name("Reshape").CompileTimeConstantInput("shape"), ReshapeOp);
184 
185 }  // namespace
186 }  // namespace tensorflow
187