• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
17 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
18 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
19 #include "tensorflow/compiler/xla/client/xla_builder.h"
20 
21 namespace tensorflow {
22 namespace {
23 
BatchToSpace(XlaOpKernelContext * ctx,const xla::XlaOp & input,DataType input_dtype,const TensorShape & input_tensor_shape,absl::Span<const int64> block_shape,const xla::Literal & crops)24 void BatchToSpace(XlaOpKernelContext* ctx, const xla::XlaOp& input,
25                   DataType input_dtype, const TensorShape& input_tensor_shape,
26                   absl::Span<const int64> block_shape,
27                   const xla::Literal& crops) {
28   const int input_rank = input_tensor_shape.dims();
29   const absl::InlinedVector<int64, 4> input_shape =
30       input_tensor_shape.dim_sizes();
31   const int block_rank = block_shape.size();
32 
33   OP_REQUIRES(
34       ctx, input_rank >= 1 + block_rank,
35       errors::InvalidArgument("input rank should be >= ", 1 + block_rank,
36                               " instead of ", input_rank));
37   absl::Span<const int64> remainder_shape(input_shape);
38   remainder_shape.remove_prefix(1 + block_rank);
39 
40   OP_REQUIRES(
41       ctx,
42       crops.shape().rank() == 2 &&
43           block_rank == xla::ShapeUtil::GetDimension(crops.shape(), 0) &&
44           2 == xla::ShapeUtil::GetDimension(crops.shape(), 1),
45       errors::InvalidArgument("crops should have shape [", block_rank,
46                               ", 2] instead of ",
47                               xla::ShapeUtil::HumanString(crops.shape())));
48 
49   const int64 batch_size = input_shape[0];
50 
51   // Compute the product of the block_shape values.
52   int64 block_num_elems = 1;
53   for (int i = 0; i < block_rank; ++i) {
54     block_num_elems *= block_shape[i];
55   }
56   OP_REQUIRES(ctx, block_num_elems > 0,
57               errors::InvalidArgument(
58                   "The product of the block dimensions must be positive"));
59 
60   // 1. Reshape `input` to `reshaped` of shape:
61   //      [block_shape[0], ..., block_shape[M-1],
62   //       batch / prod(block_shape),
63   //       input_shape[1], ..., input_shape[N-1]]
64 
65   OP_REQUIRES(
66       ctx, batch_size % block_num_elems == 0,
67       errors::InvalidArgument("Input batch dimension (", batch_size,
68                               ") is not divisible by product of block sizes (",
69                               block_num_elems, ")"));
70   std::vector<int64> reshaped_shape(input_rank + block_rank);
71   std::copy(block_shape.begin(), block_shape.end(), reshaped_shape.begin());
72   reshaped_shape[block_rank] = batch_size / block_num_elems;
73   std::copy(input_shape.begin() + 1, input_shape.end(),
74             reshaped_shape.begin() + block_rank + 1);
75   xla::XlaOp reshaped = xla::Reshape(input, reshaped_shape);
76 
77   // 2. Permute dimensions of `reshaped` to produce `permuted` of shape
78   //      [batch / prod(block_shape),
79   //
80   //       input_shape[1], block_shape[0],
81   //       ...,
82   //       input_shape[M], block_shape[M-1],
83   //
84   //       input_shape[M+1], ..., input_shape[N-1]]
85   std::vector<int64> permutation(reshaped_shape.size());
86   permutation[0] = block_rank;
87   for (int i = 0; i < block_rank; ++i) {
88     permutation[1 + 2 * i] = block_rank + 1 + i;
89     permutation[1 + 2 * i + 1] = i;
90   }
91   std::iota(permutation.begin() + 1 + block_rank * 2, permutation.end(),
92             1 + block_rank * 2);
93   xla::XlaOp permuted = xla::Transpose(reshaped, permutation);
94 
95   // 3. Reshape `permuted` to produce `reshaped_permuted` of shape
96   //      [batch / prod(block_shape),
97   //
98   //       input_shape[1] * block_shape[0],
99   //       ...,
100   //       input_shape[M] * block_shape[M-1],
101   //
102   //       input_shape[M+1],
103   //       ...,
104   //       input_shape[N-1]]
105   std::vector<int64> reshaped_permuted_shape(input_rank);
106   reshaped_permuted_shape[0] = batch_size / block_num_elems;
107   for (int i = 0; i < block_rank; ++i) {
108     reshaped_permuted_shape[1 + i] = block_shape[i] * input_shape[1 + i];
109   }
110   std::copy(remainder_shape.begin(), remainder_shape.end(),
111             reshaped_permuted_shape.begin() + 1 + block_rank);
112 
113   xla::XlaOp reshaped_permuted =
114       xla::Reshape(permuted, reshaped_permuted_shape);
115 
116   // 4. Crop the start and end of dimensions `[1, ..., M]` of
117   //    `reshaped_permuted` according to `crops` to produce the output of shape:
118   //      [batch / prod(block_shape),
119   //
120   //       input_shape[1] * block_shape[0] - crops[0,0] - crops[0,1],
121   //       ...,
122   //       input_shape[M] * block_shape[M-1] - crops[M-1,0] - crops[M-1,1],
123   //
124   //       input_shape[M+1], ..., input_shape[N-1]]
125   std::vector<int64> start_indices(input_rank, 0);
126   std::vector<int64> end_indices = reshaped_permuted_shape;
127   std::vector<int64> strides(input_rank, 1);
128   for (int i = 0; i < block_rank; ++i) {
129     int64 crop_start = crops.Get<int64>({i, 0});
130     int64 crop_end = crops.Get<int64>({i, 1});
131     OP_REQUIRES(ctx, crop_start >= 0 && crop_end >= 0,
132                 errors::InvalidArgument("Crops must be non-negative"));
133     start_indices[1 + i] = crop_start;
134     end_indices[1 + i] -= crop_end;
135     OP_REQUIRES(
136         ctx, start_indices[1 + i] <= end_indices[1 + i],
137         errors::InvalidArgument(
138             "Cropped size must be non-negative: start: ", crop_start,
139             " end: ", crop_end, " size ", reshaped_permuted_shape[1 + i]));
140   }
141   xla::XlaOp output =
142       xla::Slice(reshaped_permuted, start_indices, end_indices, strides);
143   ctx->SetOutput(0, output);
144 }
145 
146 class BatchToSpaceNDOp : public XlaOpKernel {
147  public:
BatchToSpaceNDOp(OpKernelConstruction * ctx)148   explicit BatchToSpaceNDOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
149 
Compile(XlaOpKernelContext * ctx)150   void Compile(XlaOpKernelContext* ctx) override {
151     std::vector<int64> block_shape;
152     OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, &block_shape));
153 
154     xla::Literal crops;
155     OP_REQUIRES_OK(ctx, ctx->ConstantInputAsInt64Literal(2, &crops));
156 
157     BatchToSpace(ctx, ctx->Input(0), input_type(0), ctx->InputShape(0),
158                  block_shape, crops);
159   }
160 };
161 REGISTER_XLA_OP(Name("BatchToSpaceND")
162                     .CompileTimeConstantInput("block_shape")
163                     .CompileTimeConstantInput("crops"),
164                 BatchToSpaceNDOp);
165 
166 class BatchToSpaceOp : public XlaOpKernel {
167  public:
BatchToSpaceOp(OpKernelConstruction * ctx)168   explicit BatchToSpaceOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
169     OP_REQUIRES_OK(ctx, ctx->GetAttr("block_size", &block_size_));
170     OP_REQUIRES(
171         ctx, block_size_ > 1,
172         errors::InvalidArgument("Block size should be > 1: ", block_size_));
173   }
174 
Compile(XlaOpKernelContext * ctx)175   void Compile(XlaOpKernelContext* ctx) override {
176     xla::Literal crops;
177     OP_REQUIRES_OK(ctx, ctx->ConstantInputAsInt64Literal(1, &crops));
178 
179     BatchToSpace(ctx, ctx->Input(0), input_type(0), ctx->InputShape(0),
180                  {block_size_, block_size_}, crops);
181   }
182 
183  private:
184   int block_size_;
185 };
186 REGISTER_XLA_OP(Name("BatchToSpace").CompileTimeConstantInput("crops"),
187                 BatchToSpaceOp);
188 
189 }  // namespace
190 }  // namespace tensorflow
191