1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/client/lib/pooling.h"
17 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
18 #include "tensorflow/compiler/xla/client/lib/constants.h"
19 #include "tensorflow/compiler/xla/client/lib/conv_grad_size_util.h"
20
21 namespace xla {
22
23 namespace {
24
25 // Common computation shared between AvgPool and AvgPoolGrad. Divide each
26 // element of an image by the count of elements that contributed to that
27 // element during pooling.
AvgPoolDivideByCountWithGeneralPadding(XlaOp sums,PrimitiveType dtype,absl::Span<const int64> input_shape,absl::Span<const std::pair<int64,int64>> spatial_padding,absl::Span<const int64> ksize,absl::Span<const int64> stride,const TensorFormat & data_format)28 XlaOp AvgPoolDivideByCountWithGeneralPadding(
29 XlaOp sums, PrimitiveType dtype, absl::Span<const int64> input_shape,
30 absl::Span<const std::pair<int64, int64>> spatial_padding,
31 absl::Span<const int64> ksize, absl::Span<const int64> stride,
32 const TensorFormat& data_format) {
33 // The padding shouldn't be included in the counts. We use another
34 // ReduceWindow to find the right counts.
35 const int num_spatial_dims = spatial_padding.size();
36
37 std::vector<int64> input_dim_sizes(num_spatial_dims);
38 std::vector<int64> window_dims(num_spatial_dims);
39 std::vector<int64> window_ksize(num_spatial_dims);
40 std::vector<int64> window_stride(num_spatial_dims);
41 CHECK_EQ(data_format.num_spatial_dims(), num_spatial_dims)
42 << "Invalid number of spatial dimensions in data format specification";
43 for (int i = 0; i < num_spatial_dims; ++i) {
44 int dim = data_format.spatial_dimension(i);
45 input_dim_sizes[i] = input_shape[dim];
46 window_dims[i] = dim;
47 window_ksize[i] = ksize[dim];
48 window_stride[i] = stride[dim];
49 }
50
51 XlaBuilder* b = sums.builder();
52 // Build a matrix of all 1s, with the same width/height as the input.
53 auto ones = Broadcast(One(b, dtype), input_dim_sizes);
54 PaddingConfig padding_config;
55 for (int i = 0; i < num_spatial_dims; ++i) {
56 auto dims = padding_config.add_dimensions();
57 dims->set_edge_padding_low(spatial_padding[i].first);
58 dims->set_edge_padding_high(spatial_padding[i].second);
59 }
60 auto zero = Zero(b, dtype);
61 auto padded_ones = Pad(ones, zero, padding_config);
62
63 // Perform a ReduceWindow with the same window size, strides, and padding
64 // to count the number of contributions to each result element.
65 auto counts =
66 ReduceWindow(padded_ones, zero, CreateScalarAddComputation(dtype, b),
67 window_ksize, window_stride, Padding::kValid);
68
69 return Div(sums, counts, window_dims);
70 }
71
72 // Sums all elements in the window specified by 'kernel_size' and 'stride'.
ComputeSums(XlaOp operand,XlaOp init_value,absl::Span<const int64> kernel_size,absl::Span<const int64> stride,const TensorFormat & data_format)73 XlaOp ComputeSums(XlaOp operand, XlaOp init_value,
74 absl::Span<const int64> kernel_size,
75 absl::Span<const int64> stride,
76 const TensorFormat& data_format) {
77 XlaBuilder* b = operand.builder();
78 return b->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
79 TF_ASSIGN_OR_RETURN(Shape operand_shape, b->GetShape(operand));
80 TF_ASSIGN_OR_RETURN(Shape init_shape, b->GetShape(init_value));
81 PrimitiveType accumulation_type = init_shape.element_type();
82 auto add_computation = CreateScalarAddComputation(accumulation_type, b);
83 return ReduceWindow(operand, init_value, add_computation, kernel_size,
84 stride, Padding::kValid);
85 });
86 }
87
88 // Creates a padding configuration out of spatial padding values.
MakeSpatialPaddingConfig(absl::Span<const std::pair<int64,int64>> spatial_padding,int num_spatial_dims,absl::Span<const int64> stride,const TensorFormat & data_format)89 PaddingConfig MakeSpatialPaddingConfig(
90 absl::Span<const std::pair<int64, int64>> spatial_padding,
91 int num_spatial_dims, absl::Span<const int64> stride,
92 const TensorFormat& data_format) {
93 PaddingConfig padding_config;
94 padding_config.mutable_dimensions()->Reserve(2 + num_spatial_dims);
95 for (int i = 0; i < 2 + num_spatial_dims; ++i) {
96 padding_config.add_dimensions();
97 }
98 CHECK_EQ(data_format.num_spatial_dims(), num_spatial_dims)
99 << "Invalid number of spatial dimensions in data format specification";
100 for (int i = 0; i < num_spatial_dims; ++i) {
101 int dim = data_format.spatial_dimension(i);
102 auto padding_dimension = padding_config.mutable_dimensions(dim);
103 padding_dimension->set_edge_padding_low(spatial_padding[i].first);
104 padding_dimension->set_edge_padding_high(spatial_padding[i].second);
105 }
106 return padding_config;
107 }
108
AvgPoolDivideByCount(XlaOp pooled,absl::Span<const int64> input_size,absl::Span<const int64> window_dimensions,absl::Span<const int64> window_strides,absl::Span<const std::pair<int64,int64>> padding,PrimitiveType dtype,const TensorFormat & data_format,bool counts_include_padding)109 XlaOp AvgPoolDivideByCount(XlaOp pooled, absl::Span<const int64> input_size,
110 absl::Span<const int64> window_dimensions,
111 absl::Span<const int64> window_strides,
112 absl::Span<const std::pair<int64, int64>> padding,
113 PrimitiveType dtype, const TensorFormat& data_format,
114 bool counts_include_padding) {
115 if (counts_include_padding) {
116 // If counts include padding, all windows have the same number of elements
117 // contributing to each average. Divide by the window size everywhere to get
118 // the average.
119 int64_t window_size =
120 std::accumulate(window_dimensions.begin(), window_dimensions.end(), 1,
121 [](int64_t a, int64_t b) { return a * b; });
122 auto divisor = ConstantR0WithType(pooled.builder(), dtype, window_size);
123
124 return pooled / divisor;
125 } else {
126 return AvgPoolDivideByCountWithGeneralPadding(pooled, dtype, input_size,
127 padding, window_dimensions,
128 window_strides, data_format);
129 }
130 }
131
132 } // namespace
133
MaxPool(XlaOp operand,absl::Span<const int64> kernel_size,absl::Span<const int64> stride,Padding padding,const TensorFormat & data_format)134 XlaOp MaxPool(XlaOp operand, absl::Span<const int64> kernel_size,
135 absl::Span<const int64> stride, Padding padding,
136 const TensorFormat& data_format) {
137 XlaBuilder* b = operand.builder();
138 return b->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
139 TF_ASSIGN_OR_RETURN(Shape operand_shape, b->GetShape(operand));
140 PrimitiveType dtype = operand_shape.element_type();
141 auto max_computation = CreateScalarMaxComputation(dtype, b);
142 auto init_value = MinValue(b, dtype);
143 return ReduceWindow(operand, init_value, max_computation, kernel_size,
144 stride, padding);
145 });
146 }
147
AvgPool(XlaOp operand,absl::Span<const int64> kernel_size,absl::Span<const int64> stride,absl::Span<const std::pair<int64,int64>> padding,const TensorFormat & data_format,const bool counts_include_padding)148 XlaOp AvgPool(XlaOp operand, absl::Span<const int64> kernel_size,
149 absl::Span<const int64> stride,
150 absl::Span<const std::pair<int64, int64>> padding,
151 const TensorFormat& data_format,
152 const bool counts_include_padding) {
153 XlaBuilder* b = operand.builder();
154 return b->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
155 TF_ASSIGN_OR_RETURN(Shape operand_shape, b->GetShape(operand));
156 PrimitiveType dtype = operand_shape.element_type();
157 auto init_value = Zero(b, dtype);
158 std::vector<int64> input_size(operand_shape.dimensions().begin(),
159 operand_shape.dimensions().end());
160 const int num_dims = kernel_size.size();
161 const int num_spatial_dims = num_dims - 2;
162 auto padding_config = MakeSpatialPaddingConfig(padding, num_spatial_dims,
163 stride, data_format);
164 auto padded_operand = Pad(operand, Zero(b, dtype), padding_config);
165 auto pooled = ComputeSums(padded_operand, init_value, kernel_size, stride,
166 data_format);
167 return AvgPoolDivideByCount(pooled, input_size, kernel_size, stride,
168 padding, dtype, data_format,
169 counts_include_padding);
170 });
171 }
172
MakeSpatialPadding(absl::Span<const int64> input_size,absl::Span<const int64> kernel_size,absl::Span<const int64> stride,Padding padding,const TensorFormat & data_format)173 std::vector<std::pair<int64, int64>> MakeSpatialPadding(
174 absl::Span<const int64> input_size, absl::Span<const int64> kernel_size,
175 absl::Span<const int64> stride, Padding padding,
176 const TensorFormat& data_format) {
177 const int num_spatial_dims = kernel_size.size() - 2;
178 std::vector<int64> input_spatial_dimensions;
179 std::vector<int64> kernel_size_spatial_dimensions;
180 std::vector<int64> stride_spatial_dimensions;
181 CHECK_EQ(data_format.num_spatial_dims(), num_spatial_dims)
182 << "Invalid number of spatial dimensions in data format specification";
183 for (int i = 0; i < num_spatial_dims; ++i) {
184 int dim = data_format.spatial_dimension(i);
185 input_spatial_dimensions.push_back(input_size[dim]);
186 kernel_size_spatial_dimensions.push_back(kernel_size[dim]);
187 stride_spatial_dimensions.push_back(stride[dim]);
188 }
189 return MakePadding(input_spatial_dimensions, kernel_size_spatial_dimensions,
190 stride_spatial_dimensions, padding);
191 }
192
AvgPoolGrad(XlaOp out_backprop,absl::Span<const int64> gradients_size,absl::Span<const int64> kernel_size,absl::Span<const int64> stride,absl::Span<const std::pair<int64,int64>> spatial_padding,const TensorFormat & data_format,const bool counts_include_padding)193 XlaOp AvgPoolGrad(XlaOp out_backprop, absl::Span<const int64> gradients_size,
194 absl::Span<const int64> kernel_size,
195 absl::Span<const int64> stride,
196 absl::Span<const std::pair<int64, int64>> spatial_padding,
197 const TensorFormat& data_format,
198 const bool counts_include_padding) {
199 XlaBuilder* b = out_backprop.builder();
200 return b->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
201 const int num_dims = kernel_size.size();
202 const int num_gradients = gradients_size.size();
203 if (num_gradients != num_dims) {
204 return tensorflow::errors::InvalidArgument("gradients must be ", num_dims,
205 "-dimensional");
206 }
207
208 TF_ASSIGN_OR_RETURN(Shape out_backprop_xla_shape,
209 b->GetShape(out_backprop));
210 const int backprop_xla_num_dims =
211 out_backprop_xla_shape.dimensions().size();
212 if (backprop_xla_num_dims != num_dims) {
213 return tensorflow::errors::InvalidArgument("out_backprop must be ",
214 num_dims, "-dimensional");
215 }
216
217 // We can think of average-pooling as:
218 // * a convolution with a kernel consisting entirely of 1s, where the
219 // input feature and output feature are equal, and 0s everywhere else.
220 // * followed by dividing by the counts.
221 //
222 // This then gives us an algorithm to build the gradient:
223 // * divide out_backprop by the counts, followed by
224 // * Conv2DBackpropInput specialized for that kernel, which simplifies to
225 // a Pad and a ReduceWindow.
226 //
227 // For an explanation of backpropagation for convolution, see the comments
228 // in third_party/tensorflow/core/kernels/conv_grad_ops.h
229
230 // TF filter shape is [ H, W, ..., inC, outC ]
231
232 // The input gradients are computed by a convolution of the output gradients
233 // and the filter, with some appropriate padding. See the comment at the top
234 // of conv_grad_ops.h for details.
235 PrimitiveType dtype = out_backprop_xla_shape.element_type();
236 auto out_backprop_div = AvgPoolDivideByCount(
237 out_backprop, gradients_size, kernel_size, stride, spatial_padding,
238 dtype, data_format, counts_include_padding);
239
240 // Pad the gradients in the spatial dimensions. We use the same padding
241 // as Conv2DBackpropInput.
242 PaddingConfig padding_config = MakeNoPaddingConfig(num_dims);
243 std::vector<int64> padded_gradients_size(gradients_size.begin(),
244 gradients_size.end());
245 // First, pad the output gradients the same way as the input. The additional
246 // padding will be removed as a last step before returning the input
247 // gradients.
248 const int num_spatial_dims = num_dims - 2;
249 for (int i = 0; i < num_spatial_dims; ++i) {
250 int dim = data_format.spatial_dimension(i);
251 padded_gradients_size[dim] +=
252 (spatial_padding[i].first + spatial_padding[i].second);
253 }
254 for (int i = 0; i < num_spatial_dims; ++i) {
255 int dim = data_format.spatial_dimension(i);
256 TF_ASSIGN_OR_RETURN(
257 SpatialDimensionOutputSizeAndPadding conv_backprop_spatial_dim,
258 ConvGradExtractAndVerifyDimension(
259 /*input_size=*/padded_gradients_size[dim],
260 /*filter_size=*/kernel_size[dim],
261 /*output_size=*/out_backprop_xla_shape.dimensions(dim),
262 /*dilation=*/1,
263 /*stride=*/stride[dim], /*padding=*/Padding::kValid));
264 auto* padding = padding_config.mutable_dimensions(dim);
265 padding->set_edge_padding_low(conv_backprop_spatial_dim.pad_before);
266 padding->set_edge_padding_high(conv_backprop_spatial_dim.pad_after);
267 padding->set_interior_padding(stride[dim] - 1);
268 }
269
270 auto zero = Zero(b, dtype);
271 auto padded_gradients = Pad(out_backprop_div, zero, padding_config);
272
273 // in_backprop = padded_gradients <conv> ones
274 std::vector<int64> ones(num_dims, 1LL);
275 auto in_backprop =
276 ReduceWindow(padded_gradients, Zero(b, dtype),
277 CreateScalarAddComputation(dtype, b), kernel_size,
278 /*window_strides=*/ones, Padding::kValid);
279 // The input padding doesn't contribute to the gradient, remove it.
280 std::vector<std::pair<int64, int64>> neg_spatial_padding;
281 neg_spatial_padding.reserve(spatial_padding.size());
282 for (const std::pair<int64, int64>& spatial_padding_dim : spatial_padding) {
283 neg_spatial_padding.emplace_back(-spatial_padding_dim.first,
284 -spatial_padding_dim.second);
285 }
286 auto remove_padding_config = MakeSpatialPaddingConfig(
287 neg_spatial_padding, num_spatial_dims, stride, data_format);
288 return Pad(in_backprop, zero, remove_padding_config);
289 });
290 }
291
292 } // namespace xla
293