1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/client/lib/pooling.h"
17 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
18 #include "tensorflow/compiler/xla/client/lib/constants.h"
19 #include "tensorflow/compiler/xla/client/lib/conv_grad_size_util.h"
20
21 namespace xla {
22
23 namespace {
24
25 // Common computation shared between AvgPool and AvgPoolGrad. Divide each
26 // element of an image by the count of elements that contributed to that
27 // element during pooling.
AvgPoolDivideByCountWithGeneralPadding(XlaOp sums,PrimitiveType dtype,absl::Span<const int64> input_shape,absl::Span<const std::pair<int64,int64>> spatial_padding,absl::Span<const int64> ksize,absl::Span<const int64> stride,const TensorFormat & data_format)28 XlaOp AvgPoolDivideByCountWithGeneralPadding(
29 XlaOp sums, PrimitiveType dtype, absl::Span<const int64> input_shape,
30 absl::Span<const std::pair<int64, int64>> spatial_padding,
31 absl::Span<const int64> ksize, absl::Span<const int64> stride,
32 const TensorFormat& data_format) {
33 // The padding shouldn't be included in the counts. We use another
34 // ReduceWindow to find the right counts.
35 const int num_spatial_dims = spatial_padding.size();
36
37 std::vector<int64> input_dim_sizes(num_spatial_dims);
38 std::vector<int64> window_dims(num_spatial_dims);
39 std::vector<int64> window_ksize(num_spatial_dims);
40 std::vector<int64> window_stride(num_spatial_dims);
41 CHECK_EQ(data_format.num_spatial_dims(), num_spatial_dims)
42 << "Invalid number of spatial dimentions in data format specification";
43 for (int i = 0; i < num_spatial_dims; ++i) {
44 int dim = data_format.spatial_dimension(i);
45 input_dim_sizes[i] = input_shape[dim];
46 window_dims[i] = dim;
47 window_ksize[i] = ksize[dim];
48 window_stride[i] = stride[dim];
49 }
50
51 XlaBuilder* b = sums.builder();
52 // Build a matrix of all 1s, with the same width/height as the input.
53 auto ones = Broadcast(One(b, dtype), input_dim_sizes);
54 PaddingConfig padding_config;
55 for (int i = 0; i < num_spatial_dims; ++i) {
56 auto dims = padding_config.add_dimensions();
57 dims->set_edge_padding_low(spatial_padding[i].first);
58 dims->set_edge_padding_high(spatial_padding[i].second);
59 }
60 auto zero = Zero(b, dtype);
61 auto padded_ones = Pad(ones, zero, padding_config);
62
63 // Perform a ReduceWindow with the same window size, strides, and padding
64 // to count the number of contributions to each result element.
65 auto counts =
66 ReduceWindow(padded_ones, zero, CreateScalarAddComputation(dtype, b),
67 window_ksize, window_stride, Padding::kValid);
68
69 return Div(sums, counts, window_dims);
70 }
71
72 // Sums all elements in the window specified by 'kernel_size' and 'stride'.
ComputeSums(XlaOp operand,XlaOp init_value,absl::Span<const int64> kernel_size,absl::Span<const int64> stride,const TensorFormat & data_format)73 XlaOp ComputeSums(XlaOp operand, XlaOp init_value,
74 absl::Span<const int64> kernel_size,
75 absl::Span<const int64> stride,
76 const TensorFormat& data_format) {
77 XlaBuilder* b = operand.builder();
78 return b->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
79 TF_ASSIGN_OR_RETURN(Shape operand_shape, b->GetShape(operand));
80 TF_ASSIGN_OR_RETURN(Shape init_shape, b->GetShape(init_value));
81 PrimitiveType accumulation_type = init_shape.element_type();
82 auto add_computation = CreateScalarAddComputation(accumulation_type, b);
83 return ReduceWindow(operand, init_value, add_computation, kernel_size,
84 stride, Padding::kValid);
85 });
86 }
87
88 // Creates a padding configuration out of spatial padding values.
MakeSpatialPaddingConfig(absl::Span<const std::pair<int64,int64>> spatial_padding,int num_spatial_dims,absl::Span<const int64> stride,const TensorFormat & data_format)89 PaddingConfig MakeSpatialPaddingConfig(
90 absl::Span<const std::pair<int64, int64>> spatial_padding,
91 int num_spatial_dims, absl::Span<const int64> stride,
92 const TensorFormat& data_format) {
93 PaddingConfig padding_config;
94 for (int i = 0; i < 2 + num_spatial_dims; ++i) {
95 padding_config.add_dimensions();
96 }
97 CHECK_EQ(data_format.num_spatial_dims(), num_spatial_dims)
98 << "Invalid number of spatial dimentions in data format specification";
99 for (int i = 0; i < num_spatial_dims; ++i) {
100 int dim = data_format.spatial_dimension(i);
101 auto padding_dimension = padding_config.mutable_dimensions(dim);
102 padding_dimension->set_edge_padding_low(spatial_padding[i].first);
103 padding_dimension->set_edge_padding_high(spatial_padding[i].second);
104 }
105 return padding_config;
106 }
107
AvgPoolDivideByCount(XlaOp pooled,absl::Span<const int64> input_size,absl::Span<const int64> window_dimensions,absl::Span<const int64> window_strides,absl::Span<const std::pair<int64,int64>> padding,PrimitiveType dtype,const TensorFormat & data_format,bool counts_include_padding)108 XlaOp AvgPoolDivideByCount(XlaOp pooled, absl::Span<const int64> input_size,
109 absl::Span<const int64> window_dimensions,
110 absl::Span<const int64> window_strides,
111 absl::Span<const std::pair<int64, int64>> padding,
112 PrimitiveType dtype, const TensorFormat& data_format,
113 bool counts_include_padding) {
114 if (counts_include_padding) {
115 // If counts include padding, all windows have the same number of elements
116 // contributing to each average. Divide by the window size everywhere to get
117 // the average.
118 int64 window_size =
119 std::accumulate(window_dimensions.begin(), window_dimensions.end(), 1,
120 [](int64 a, int64 b) { return a * b; });
121 auto divisor = ConstantR0WithType(pooled.builder(), dtype, window_size);
122
123 return pooled / divisor;
124 } else {
125 return AvgPoolDivideByCountWithGeneralPadding(pooled, dtype, input_size,
126 padding, window_dimensions,
127 window_strides, data_format);
128 }
129 }
130
131 } // namespace
132
MaxPool(XlaOp operand,absl::Span<const int64> kernel_size,absl::Span<const int64> stride,Padding padding,const TensorFormat & data_format)133 XlaOp MaxPool(XlaOp operand, absl::Span<const int64> kernel_size,
134 absl::Span<const int64> stride, Padding padding,
135 const TensorFormat& data_format) {
136 XlaBuilder* b = operand.builder();
137 return b->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
138 TF_ASSIGN_OR_RETURN(Shape operand_shape, b->GetShape(operand));
139 PrimitiveType dtype = operand_shape.element_type();
140 auto max_computation = CreateScalarMaxComputation(dtype, b);
141 auto init_value = MinValue(b, dtype);
142 return ReduceWindow(operand, init_value, max_computation, kernel_size,
143 stride, padding);
144 });
145 }
146
AvgPool(XlaOp operand,absl::Span<const int64> kernel_size,absl::Span<const int64> stride,absl::Span<const std::pair<int64,int64>> padding,const TensorFormat & data_format,const bool counts_include_padding)147 XlaOp AvgPool(XlaOp operand, absl::Span<const int64> kernel_size,
148 absl::Span<const int64> stride,
149 absl::Span<const std::pair<int64, int64>> padding,
150 const TensorFormat& data_format,
151 const bool counts_include_padding) {
152 XlaBuilder* b = operand.builder();
153 return b->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
154 TF_ASSIGN_OR_RETURN(Shape operand_shape, b->GetShape(operand));
155 PrimitiveType dtype = operand_shape.element_type();
156 auto init_value = Zero(b, dtype);
157 std::vector<int64> input_size(operand_shape.dimensions().begin(),
158 operand_shape.dimensions().end());
159 const int num_dims = kernel_size.size();
160 const int num_spatial_dims = num_dims - 2;
161 auto padding_config = MakeSpatialPaddingConfig(padding, num_spatial_dims,
162 stride, data_format);
163 auto padded_operand = Pad(operand, Zero(b, dtype), padding_config);
164 auto pooled = ComputeSums(padded_operand, init_value, kernel_size, stride,
165 data_format);
166 return AvgPoolDivideByCount(pooled, input_size, kernel_size, stride,
167 padding, dtype, data_format,
168 counts_include_padding);
169 });
170 }
171
MakeSpatialPadding(absl::Span<const int64> input_size,absl::Span<const int64> kernel_size,absl::Span<const int64> stride,Padding padding,const TensorFormat & data_format)172 std::vector<std::pair<int64, int64>> MakeSpatialPadding(
173 absl::Span<const int64> input_size, absl::Span<const int64> kernel_size,
174 absl::Span<const int64> stride, Padding padding,
175 const TensorFormat& data_format) {
176 const int num_spatial_dims = kernel_size.size() - 2;
177 std::vector<int64> input_spatial_dimensions;
178 std::vector<int64> kernel_size_spatial_dimensions;
179 std::vector<int64> stride_spatial_dimensions;
180 CHECK_EQ(data_format.num_spatial_dims(), num_spatial_dims)
181 << "Invalid number of spatial dimentions in data format specification";
182 for (int i = 0; i < num_spatial_dims; ++i) {
183 int dim = data_format.spatial_dimension(i);
184 input_spatial_dimensions.push_back(input_size[dim]);
185 kernel_size_spatial_dimensions.push_back(kernel_size[dim]);
186 stride_spatial_dimensions.push_back(stride[dim]);
187 }
188 return MakePadding(input_spatial_dimensions, kernel_size_spatial_dimensions,
189 stride_spatial_dimensions, padding);
190 }
191
AvgPoolGrad(XlaOp out_backprop,absl::Span<const int64> gradients_size,absl::Span<const int64> kernel_size,absl::Span<const int64> stride,absl::Span<const std::pair<int64,int64>> spatial_padding,const TensorFormat & data_format,const bool counts_include_padding)192 XlaOp AvgPoolGrad(XlaOp out_backprop, absl::Span<const int64> gradients_size,
193 absl::Span<const int64> kernel_size,
194 absl::Span<const int64> stride,
195 absl::Span<const std::pair<int64, int64>> spatial_padding,
196 const TensorFormat& data_format,
197 const bool counts_include_padding) {
198 XlaBuilder* b = out_backprop.builder();
199 return b->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
200 const int num_dims = kernel_size.size();
201
202 if (gradients_size.size() != num_dims) {
203 return tensorflow::errors::InvalidArgument("gradients must be ", num_dims,
204 "-dimensional");
205 }
206
207 TF_ASSIGN_OR_RETURN(Shape out_backprop_xla_shape,
208 b->GetShape(out_backprop));
209 if (out_backprop_xla_shape.dimensions().size() != num_dims) {
210 return tensorflow::errors::InvalidArgument("out_backprop must be ",
211 num_dims, "-dimensional");
212 }
213
214 // We can think of average-pooling as:
215 // * a convolution with a kernel consisting entirely of 1s, where the
216 // input feature and output feature are equal, and 0s everywhere else.
217 // * followed by dividing by the counts.
218 //
219 // This then gives us an algorithm to build the gradient:
220 // * divide out_backprop by the counts, followed by
221 // * Conv2DBackpropInput specialized for that kernel, which simplifies to
222 // a Pad and a ReduceWindow.
223 //
224 // For an explanation of backpropagation for convolution, see the comments
225 // in third_party/tensorflow/core/kernels/conv_grad_ops.h
226
227 // TF filter shape is [ H, W, ..., inC, outC ]
228
229 // The input gradients are computed by a convolution of the output gradients
230 // and the filter, with some appropriate padding. See the comment at the top
231 // of conv_grad_ops.h for details.
232 PrimitiveType dtype = out_backprop_xla_shape.element_type();
233 auto out_backprop_div = AvgPoolDivideByCount(
234 out_backprop, gradients_size, kernel_size, stride, spatial_padding,
235 dtype, data_format, counts_include_padding);
236
237 // Pad the gradients in the spatial dimensions. We use the same padding
238 // as Conv2DBackpropInput.
239 PaddingConfig padding_config = MakeNoPaddingConfig(num_dims);
240 std::vector<int64> padded_gradients_size(gradients_size.begin(),
241 gradients_size.end());
242 // First, pad the output gradients the same way as the input. The additional
243 // padding will be removed as a last step before returning the input
244 // gradients.
245 const int num_spatial_dims = num_dims - 2;
246 for (int i = 0; i < num_spatial_dims; ++i) {
247 int dim = data_format.spatial_dimension(i);
248 padded_gradients_size[dim] +=
249 (spatial_padding[i].first + spatial_padding[i].second);
250 }
251 for (int i = 0; i < num_spatial_dims; ++i) {
252 int dim = data_format.spatial_dimension(i);
253 TF_ASSIGN_OR_RETURN(
254 SpatialDimensionOutputSizeAndPadding conv_backprop_spatial_dim,
255 ConvGradExtractAndVerifyDimension(
256 /*input_size=*/padded_gradients_size[dim],
257 /*filter_size=*/kernel_size[dim],
258 /*output_size=*/out_backprop_xla_shape.dimensions(dim),
259 /*dilation=*/1,
260 /*stride=*/stride[dim], /*padding=*/Padding::kValid));
261 auto* padding = padding_config.mutable_dimensions(dim);
262 padding->set_edge_padding_low(conv_backprop_spatial_dim.pad_before);
263 padding->set_edge_padding_high(conv_backprop_spatial_dim.pad_after);
264 padding->set_interior_padding(stride[dim] - 1);
265 }
266
267 auto zero = Zero(b, dtype);
268 auto padded_gradients = Pad(out_backprop_div, zero, padding_config);
269
270 // in_backprop = padded_gradients <conv> ones
271 std::vector<int64> ones(num_dims, 1LL);
272 auto in_backprop =
273 ReduceWindow(padded_gradients, Zero(b, dtype),
274 CreateScalarAddComputation(dtype, b), kernel_size,
275 /*window_strides=*/ones, Padding::kValid);
276 // The input padding doesn't contribute to the gradient, remove it.
277 std::vector<std::pair<int64, int64>> neg_spatial_padding;
278 neg_spatial_padding.reserve(spatial_padding.size());
279 for (const std::pair<int64, int64>& spatial_padding_dim : spatial_padding) {
280 neg_spatial_padding.emplace_back(-spatial_padding_dim.first,
281 -spatial_padding_dim.second);
282 }
283 auto remove_padding_config = MakeSpatialPaddingConfig(
284 neg_spatial_padding, num_spatial_dims, stride, data_format);
285 return Pad(in_backprop, zero, remove_padding_config);
286 });
287 }
288
289 } // namespace xla
290