1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/tf2xla/lib/data_format.h" 17 #include "tensorflow/compiler/tf2xla/xla_helpers.h" 18 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" 19 #include "tensorflow/compiler/tf2xla/xla_op_registry.h" 20 #include "tensorflow/compiler/xla/client/xla_builder.h" 21 #include "tensorflow/core/util/tensor_format.h" 22 23 namespace tensorflow { 24 namespace { 25 26 class SpaceToDepthOp : public XlaOpKernel { 27 public: SpaceToDepthOp(OpKernelConstruction * ctx)28 explicit SpaceToDepthOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { 29 string data_format_str; 30 OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format_str)); 31 OP_REQUIRES(ctx, FormatFromString(data_format_str, &data_format_), 32 errors::InvalidArgument("Invalid data format")); 33 34 OP_REQUIRES_OK(ctx, ctx->GetAttr("block_size", &block_size_)); 35 OP_REQUIRES( 36 ctx, block_size_ > 1, 37 errors::InvalidArgument("Block size should be > 1: ", block_size_)); 38 } 39 Compile(XlaOpKernelContext * ctx)40 void Compile(XlaOpKernelContext* ctx) override { 41 xla::XlaOp input = ctx->Input(0); 42 43 TensorFormat data_format = data_format_; 44 // If the data is in a vectorized format, reformat it into a non-vectorized 45 // version first. We'll undo the transformation later. 46 if (data_format == FORMAT_NCHW_VECT_C) { 47 data_format = FORMAT_NCHW; 48 auto input_reshaped = NCHW_VECT_CToNCHW(input); 49 OP_REQUIRES_OK(ctx, input_reshaped.status()); 50 input = input_reshaped.ValueOrDie(); 51 } 52 53 OP_REQUIRES(ctx, data_format == FORMAT_NCHW || data_format == FORMAT_NHWC, 54 errors::InvalidArgument("Unsupported data format ", 55 ToString(data_format_))); 56 57 xla::XlaBuilder* builder = input.builder(); 58 auto input_xla_shape = builder->GetShape(input); 59 OP_REQUIRES_OK(ctx, input_xla_shape.status()); 60 const std::vector<int64>& input_shape = 61 input_xla_shape.ValueOrDie().dimensions(); 62 int input_rank = input_shape.size(); 63 64 static const int kRequiredDims = 4; 65 OP_REQUIRES(ctx, kRequiredDims == input_rank, 66 errors::InvalidArgument("Input rank should be ", kRequiredDims, 67 "; got ", input_rank)); 68 69 int feature_dim = GetTensorFeatureDimIndex(input_rank, data_format); 70 int num_spatial_dims = GetTensorSpatialDims(input_rank, data_format); 71 72 std::vector<int64> reshaped_shape; 73 std::vector<int64> transpose_order; 74 std::vector<int64> output_shape; 75 reshaped_shape.reserve(input_rank); 76 transpose_order.reserve(input_rank); 77 output_shape.reserve(input_rank); 78 if (data_format == FORMAT_NHWC) { 79 int64 block_elems = 1; 80 for (int i = 0; i < num_spatial_dims; ++i) { 81 OP_REQUIRES(ctx, input_shape[1 + i] % block_size_ == 0, 82 errors::InvalidArgument( 83 "input shape[", 1 + i, "]=", input_shape[1 + i], 84 " is not divisible by block_size=", block_size_)); 85 block_elems *= block_size_; 86 } 87 88 reshaped_shape.push_back(input_shape[0]); 89 for (int i = 0; i < num_spatial_dims; ++i) { 90 reshaped_shape.push_back(input_shape[1 + i] / block_size_); 91 reshaped_shape.push_back(block_size_); 92 } 93 reshaped_shape.push_back(input_shape[feature_dim]); 94 95 transpose_order.push_back(0); 96 for (int i = 0; i < num_spatial_dims; ++i) { 97 transpose_order.push_back(i * 2 + 1); 98 } 99 for (int i = 0; i < num_spatial_dims; ++i) { 100 transpose_order.push_back(i * 2 + 2); 101 } 102 transpose_order.push_back(feature_dim + num_spatial_dims); 103 104 output_shape.push_back(input_shape[0]); 105 for (int i = 0; i < num_spatial_dims; ++i) { 106 output_shape.push_back(input_shape[1 + i] / block_size_); 107 } 108 output_shape.push_back(input_shape[feature_dim] * block_elems); 109 } else { 110 // FORMAT_NCHW 111 int64 block_elems = 1; 112 for (int i = 0; i < num_spatial_dims; ++i) { 113 OP_REQUIRES(ctx, input_shape[2 + i] % block_size_ == 0, 114 errors::InvalidArgument( 115 "input shape[", 2 + i, "]=", input_shape[2 + i], 116 " is not divisible by block_size=", block_size_)); 117 block_elems *= block_size_; 118 } 119 120 reshaped_shape.push_back(input_shape[0]); 121 reshaped_shape.push_back(input_shape[feature_dim]); 122 for (int i = 0; i < num_spatial_dims; ++i) { 123 reshaped_shape.push_back(input_shape[2 + i] / block_size_); 124 reshaped_shape.push_back(block_size_); 125 } 126 127 transpose_order.push_back(0); 128 for (int i = 0; i < num_spatial_dims; ++i) { 129 transpose_order.push_back(i * 2 + 3); 130 } 131 transpose_order.push_back(feature_dim); 132 for (int i = 0; i < num_spatial_dims; ++i) { 133 transpose_order.push_back(i * 2 + 2); 134 } 135 136 output_shape.push_back(input_shape[0]); 137 output_shape.push_back(input_shape[feature_dim] * block_elems); 138 for (int i = 0; i < num_spatial_dims; ++i) { 139 output_shape.push_back(input_shape[2 + i] / block_size_); 140 } 141 } 142 143 // Note: comments are given in NHWC format; NCHW is similar with a different 144 // dimension order. 145 // 1. Reshape `input` to `reshaped` of shape: 146 // 147 // [batch, 148 // input_shape[1] / block_size_, block_size_, 149 // input_shape[2] / block_size_, block_size_, 150 // depth] 151 xla::XlaOp reshaped = xla::Reshape(input, reshaped_shape); 152 153 // 2. Permute dimensions of `reshaped` to produce 154 // `permuted_reshaped` of shape: 155 // 156 // [batch, 157 // input_shape[1] / block_size_, 158 // input_shape[2] / block_size_, 159 // block_size_, block_size_, 160 // depth] 161 xla::XlaOp permuted_reshaped = xla::Transpose(reshaped, transpose_order); 162 163 // 3. Reshape `permuted_reshaped` to flatten `block_shape` into the 164 // batch dimension, producing an output tensor of shape: 165 // 166 // [batch, 167 // input_shape[1] / block_size_, 168 // input_shape[2] / block_size_, 169 // block_size_ * block_size_ * depth] 170 // 171 xla::XlaOp output = xla::Reshape(permuted_reshaped, output_shape); 172 173 // If this used to be a vectorized format turn it back now. 174 if (data_format != data_format_) { 175 DCHECK(data_format == FORMAT_NCHW && data_format_ == FORMAT_NCHW_VECT_C); 176 auto output_reshaped = NCHWToNCHW_VECT_C(output); 177 OP_REQUIRES_OK(ctx, output_reshaped.status()); 178 output = output_reshaped.ValueOrDie(); 179 } 180 181 ctx->SetOutput(0, output); 182 } 183 184 private: 185 TensorFormat data_format_; 186 int block_size_; 187 }; 188 REGISTER_XLA_OP(Name("SpaceToDepth"), SpaceToDepthOp); 189 190 } // namespace 191 } // namespace tensorflow 192