1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // XLA-specific reshape Op. 17 18 #include "tensorflow/compiler/tf2xla/type_util.h" 19 #include "tensorflow/compiler/tf2xla/xla_helpers.h" 20 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" 21 #include "tensorflow/compiler/tf2xla/xla_op_registry.h" 22 #include "tensorflow/compiler/xla/client/lib/constants.h" 23 #include "tensorflow/compiler/xla/client/xla_builder.h" 24 #include "tensorflow/compiler/xla/literal.h" 25 #include "tensorflow/compiler/xla/util.h" 26 #include "tensorflow/core/framework/op_kernel.h" 27 #include "tensorflow/core/framework/register_types.h" 28 #include "tensorflow/core/framework/tensor.h" 29 #include "tensorflow/core/framework/tensor_shape.h" 30 31 namespace tensorflow { 32 namespace { 33 34 class ReshapeOp : public XlaOpKernel { 35 public: ReshapeOp(OpKernelConstruction * ctx)36 explicit ReshapeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} 37 Compile(XlaOpKernelContext * ctx)38 void Compile(XlaOpKernelContext* ctx) override { 39 const TensorShape input_shape = ctx->InputShape(0); 40 const TensorShape sizes_shape = ctx->InputShape(1); 41 // Preliminary validation of sizes. 42 OP_REQUIRES(ctx, TensorShapeUtils::IsVector(sizes_shape), 43 errors::InvalidArgument("sizes input must be 1-D, not shape ", 44 sizes_shape.DebugString())); 45 const int64 num_dims = sizes_shape.num_elements(); 46 47 std::vector<int64> shape_input; 48 OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, &shape_input)); 49 50 // Compute the output shape. Determine product of specified 51 // dimensions, and find the index of the unspecified one if there 52 // is one. 53 TensorShape shape; 54 int64 product = 1; 55 int unknown_index = -1; 56 bool shape_has_zero_dim = false; 57 for (int d = 0; d < num_dims; ++d) { 58 const int32 size = shape_input[d]; 59 if (size == -1) { 60 OP_REQUIRES( 61 ctx, unknown_index == -1, 62 errors::InvalidArgument("only one input size may be -1, not both ", 63 unknown_index, " and ", d)); 64 unknown_index = d; 65 shape.AddDim(1); 66 } else if (size == 0) { 67 // We don't include zero-sized dimension in product, so that we can 68 // still calculate number of elements for non-zero-sized dimensions and 69 // therefore infer their shapes. 70 shape.AddDim(size); 71 shape_has_zero_dim = true; 72 } else { 73 OP_REQUIRES(ctx, size >= 0, 74 errors::InvalidArgument( 75 "size ", d, " must be non-negative, not ", size)); 76 shape.AddDim(size); 77 product *= size; 78 } 79 } 80 if (unknown_index != -1) { 81 int64 input_num_elements = 1; 82 bool input_has_zero_dim = false; 83 for (int dim = 0; dim < input_shape.dims(); dim++) { 84 // For zero dimension, we don't count it into `input_num_elements` 85 // unless `sizes` has no zero dimension, so we are still able to 86 // infer shapes for other dimensions. 87 if (input_shape.dim_size(dim) > 0 || !shape_has_zero_dim) { 88 input_num_elements *= input_shape.dim_size(dim); 89 } else { 90 input_has_zero_dim = true; 91 } 92 } 93 94 const int64 missing = input_num_elements / product; 95 if (!input_has_zero_dim) { 96 OP_REQUIRES( 97 ctx, product * missing == input_num_elements, 98 errors::InvalidArgument( 99 "Input to reshape is a tensor with ", input_num_elements, 100 " values, but the requested shape requires a multiple of ", 101 product)); 102 } 103 shape.set_dim(unknown_index, missing); 104 } 105 OP_REQUIRES(ctx, shape.num_elements() == input_shape.num_elements(), 106 errors::InvalidArgument("Input to reshape is a tensor with ", 107 input_shape.num_elements(), 108 " values, but the requested shape has ", 109 shape.num_elements())); 110 111 VLOG(2) << "Reshape from " << input_shape.DebugString() << " to " 112 << shape.DebugString() << ", unknown_index=" << unknown_index; 113 auto input_xla_shape = ctx->InputXlaShape(0); 114 if (input_xla_shape->is_static()) { 115 ctx->SetOutput(0, xla::Reshape(ctx->Input(0), shape.dim_sizes())); 116 return; 117 } 118 // Handing dynamic reshapes if input contains a dynamic dimension. 119 std::vector<xla::XlaOp> output_dim_sizes; 120 std::vector<bool> dims_are_dynamic; 121 for (int64 i = 0; i < shape.dims(); ++i) { 122 output_dim_sizes.push_back( 123 xla::Reshape(xla::Slice(ctx->Input(1), {i}, {i + 1}, {1}), {})); 124 } 125 OP_REQUIRES_OK( 126 ctx, ctx->ResolveInputDynamismIntoPredVector(1, &dims_are_dynamic)); 127 if (unknown_index == -1) { 128 // No unknown index. 129 ctx->SetOutput(0, 130 xla::DynamicReshape(ctx->Input(0), output_dim_sizes, 131 shape.dim_sizes(), dims_are_dynamic)); 132 return; 133 } 134 auto common_factors = 135 xla::CommonFactors(input_shape.dim_sizes(), shape.dim_sizes()); 136 137 // Find common_factors that the input belongs to. 138 for (int64 i = 0; i < common_factors.size() - 1; ++i) { 139 auto start = common_factors[i]; 140 auto end = common_factors[i + 1]; 141 bool input_is_dynamic = false; 142 // product of all input dims in this group. E.g., in 143 // reshape(Tensor([2, 3, 3]), [3, -1, 3]) product of the group 144 // containing -1 will be 6. 145 xla::XlaOp product = xla::One(ctx->builder(), xla::S32); 146 for (int64 dim = start.first; dim < end.first; ++dim) { 147 if (input_xla_shape->is_dynamic_dimension(dim)) { 148 input_is_dynamic = true; 149 } 150 product = xla::Mul(product, xla::GetDimensionSize(ctx->Input(0), dim)); 151 } 152 bool unknown_dim_in_group = false; 153 // The real size for the -1 dimension in a reshape. E.g., in 154 // reshape(Tensor([2, 3, 3]), [3, -1, 3]) this will be 2. 155 xla::XlaOp unknown_dim_size = product; 156 for (int64 dim = start.second; dim < end.second; ++dim) { 157 if (dim == unknown_index) { 158 unknown_dim_in_group = true; 159 } else { 160 unknown_dim_size = xla::Div(unknown_dim_size, output_dim_sizes[dim]); 161 } 162 } 163 164 if (unknown_dim_in_group) { 165 // If input dim is dynamic, output dim at the -1 position must be 166 // dynamic. Similarly, if input dim is static, output dim has to be 167 // static at the -1 dimension. 168 dims_are_dynamic[unknown_index] = input_is_dynamic; 169 output_dim_sizes[unknown_index] = unknown_dim_size; 170 171 ctx->SetOutput( 172 0, xla::DynamicReshape(ctx->Input(0), output_dim_sizes, 173 shape.dim_sizes(), dims_are_dynamic)); 174 VLOG(2) << "Reshape from " << ctx->InputXlaShape(0)->ToString() 175 << " to " << xla::VectorString(shape.dim_sizes()) 176 << ", dynamic_dims=" << xla::VectorString(dims_are_dynamic); 177 return; 178 } 179 } 180 } 181 }; 182 183 REGISTER_XLA_OP(Name("Reshape").CompileTimeConstantInput("shape"), ReshapeOp); 184 185 } // namespace 186 } // namespace tensorflow 187