• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/tf2xla/lib/data_format.h"
17 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
18 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
19 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
20 #include "tensorflow/compiler/xla/client/xla_builder.h"
21 #include "tensorflow/core/util/tensor_format.h"
22 
23 namespace tensorflow {
24 namespace {
25 
26 class SpaceToDepthOp : public XlaOpKernel {
27  public:
SpaceToDepthOp(OpKernelConstruction * ctx)28   explicit SpaceToDepthOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
29     string data_format_str;
30     OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format_str));
31     OP_REQUIRES(ctx, FormatFromString(data_format_str, &data_format_),
32                 errors::InvalidArgument("Invalid data format"));
33 
34     OP_REQUIRES_OK(ctx, ctx->GetAttr("block_size", &block_size_));
35     OP_REQUIRES(
36         ctx, block_size_ > 1,
37         errors::InvalidArgument("Block size should be > 1: ", block_size_));
38   }
39 
Compile(XlaOpKernelContext * ctx)40   void Compile(XlaOpKernelContext* ctx) override {
41     xla::XlaOp input = ctx->Input(0);
42 
43     TensorFormat data_format = data_format_;
44     // If the data is in a vectorized format, reformat it into a non-vectorized
45     // version first. We'll undo the transformation later.
46     if (data_format == FORMAT_NCHW_VECT_C) {
47       data_format = FORMAT_NCHW;
48       auto input_reshaped = NCHW_VECT_CToNCHW(input);
49       OP_REQUIRES_OK(ctx, input_reshaped.status());
50       input = input_reshaped.ValueOrDie();
51     }
52 
53     OP_REQUIRES(ctx, data_format == FORMAT_NCHW || data_format == FORMAT_NHWC,
54                 errors::InvalidArgument("Unsupported data format ",
55                                         ToString(data_format_)));
56 
57     xla::XlaBuilder* builder = input.builder();
58     auto input_xla_shape = builder->GetShape(input);
59     OP_REQUIRES_OK(ctx, input_xla_shape.status());
60     const std::vector<int64>& input_shape =
61         input_xla_shape.ValueOrDie().dimensions();
62     int input_rank = input_shape.size();
63 
64     static const int kRequiredDims = 4;
65     OP_REQUIRES(ctx, kRequiredDims == input_rank,
66                 errors::InvalidArgument("Input rank should be ", kRequiredDims,
67                                         "; got ", input_rank));
68 
69     int feature_dim = GetTensorFeatureDimIndex(input_rank, data_format);
70     int num_spatial_dims = GetTensorSpatialDims(input_rank, data_format);
71 
72     std::vector<int64> reshaped_shape;
73     std::vector<int64> transpose_order;
74     std::vector<int64> output_shape;
75     reshaped_shape.reserve(input_rank);
76     transpose_order.reserve(input_rank);
77     output_shape.reserve(input_rank);
78     if (data_format == FORMAT_NHWC) {
79       int64 block_elems = 1;
80       for (int i = 0; i < num_spatial_dims; ++i) {
81         OP_REQUIRES(ctx, input_shape[1 + i] % block_size_ == 0,
82                     errors::InvalidArgument(
83                         "input shape[", 1 + i, "]=", input_shape[1 + i],
84                         " is not divisible by block_size=", block_size_));
85         block_elems *= block_size_;
86       }
87 
88       reshaped_shape.push_back(input_shape[0]);
89       for (int i = 0; i < num_spatial_dims; ++i) {
90         reshaped_shape.push_back(input_shape[1 + i] / block_size_);
91         reshaped_shape.push_back(block_size_);
92       }
93       reshaped_shape.push_back(input_shape[feature_dim]);
94 
95       transpose_order.push_back(0);
96       for (int i = 0; i < num_spatial_dims; ++i) {
97         transpose_order.push_back(i * 2 + 1);
98       }
99       for (int i = 0; i < num_spatial_dims; ++i) {
100         transpose_order.push_back(i * 2 + 2);
101       }
102       transpose_order.push_back(feature_dim + num_spatial_dims);
103 
104       output_shape.push_back(input_shape[0]);
105       for (int i = 0; i < num_spatial_dims; ++i) {
106         output_shape.push_back(input_shape[1 + i] / block_size_);
107       }
108       output_shape.push_back(input_shape[feature_dim] * block_elems);
109     } else {
110       // FORMAT_NCHW
111       int64 block_elems = 1;
112       for (int i = 0; i < num_spatial_dims; ++i) {
113         OP_REQUIRES(ctx, input_shape[2 + i] % block_size_ == 0,
114                     errors::InvalidArgument(
115                         "input shape[", 2 + i, "]=", input_shape[2 + i],
116                         " is not divisible by block_size=", block_size_));
117         block_elems *= block_size_;
118       }
119 
120       reshaped_shape.push_back(input_shape[0]);
121       reshaped_shape.push_back(input_shape[feature_dim]);
122       for (int i = 0; i < num_spatial_dims; ++i) {
123         reshaped_shape.push_back(input_shape[2 + i] / block_size_);
124         reshaped_shape.push_back(block_size_);
125       }
126 
127       transpose_order.push_back(0);
128       for (int i = 0; i < num_spatial_dims; ++i) {
129         transpose_order.push_back(i * 2 + 3);
130       }
131       transpose_order.push_back(feature_dim);
132       for (int i = 0; i < num_spatial_dims; ++i) {
133         transpose_order.push_back(i * 2 + 2);
134       }
135 
136       output_shape.push_back(input_shape[0]);
137       output_shape.push_back(input_shape[feature_dim] * block_elems);
138       for (int i = 0; i < num_spatial_dims; ++i) {
139         output_shape.push_back(input_shape[2 + i] / block_size_);
140       }
141     }
142 
143     // Note: comments are given in NHWC format; NCHW is similar with a different
144     // dimension order.
145     // 1. Reshape `input` to `reshaped` of shape:
146     //
147     //      [batch,
148     //       input_shape[1] / block_size_, block_size_,
149     //       input_shape[2] / block_size_, block_size_,
150     //       depth]
151     xla::XlaOp reshaped = xla::Reshape(input, reshaped_shape);
152 
153     // 2. Permute dimensions of `reshaped` to produce
154     //    `permuted_reshaped` of shape:
155     //
156     //      [batch,
157     //       input_shape[1] / block_size_,
158     //       input_shape[2] / block_size_,
159     //       block_size_, block_size_,
160     //       depth]
161     xla::XlaOp permuted_reshaped = xla::Transpose(reshaped, transpose_order);
162 
163     // 3. Reshape `permuted_reshaped` to flatten `block_shape` into the
164     //    batch dimension, producing an output tensor of shape:
165     //
166     //      [batch,
167     //       input_shape[1] / block_size_,
168     //       input_shape[2] / block_size_,
169     //       block_size_ * block_size_ * depth]
170     //
171     xla::XlaOp output = xla::Reshape(permuted_reshaped, output_shape);
172 
173     // If this used to be a vectorized format turn it back now.
174     if (data_format != data_format_) {
175       DCHECK(data_format == FORMAT_NCHW && data_format_ == FORMAT_NCHW_VECT_C);
176       auto output_reshaped = NCHWToNCHW_VECT_C(output);
177       OP_REQUIRES_OK(ctx, output_reshaped.status());
178       output = output_reshaped.ValueOrDie();
179     }
180 
181     ctx->SetOutput(0, output);
182   }
183 
184  private:
185   TensorFormat data_format_;
186   int block_size_;
187 };
188 REGISTER_XLA_OP(Name("SpaceToDepth"), SpaceToDepthOp);
189 
190 }  // namespace
191 }  // namespace tensorflow
192