• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/client/lib/pooling.h"
17 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
18 #include "tensorflow/compiler/xla/client/lib/constants.h"
19 #include "tensorflow/compiler/xla/client/lib/conv_grad_size_util.h"
20 
21 namespace xla {
22 
23 namespace {
24 
25 // Common computation shared between AvgPool and AvgPoolGrad. Divide each
26 // element of an image by the count of elements that contributed to that
27 // element during pooling.
AvgPoolDivideByCountWithGeneralPadding(XlaOp sums,PrimitiveType dtype,absl::Span<const int64> input_shape,absl::Span<const std::pair<int64,int64>> spatial_padding,absl::Span<const int64> ksize,absl::Span<const int64> stride,const TensorFormat & data_format)28 XlaOp AvgPoolDivideByCountWithGeneralPadding(
29     XlaOp sums, PrimitiveType dtype, absl::Span<const int64> input_shape,
30     absl::Span<const std::pair<int64, int64>> spatial_padding,
31     absl::Span<const int64> ksize, absl::Span<const int64> stride,
32     const TensorFormat& data_format) {
33   // The padding shouldn't be included in the counts. We use another
34   // ReduceWindow to find the right counts.
35   const int num_spatial_dims = spatial_padding.size();
36 
37   std::vector<int64> input_dim_sizes(num_spatial_dims);
38   std::vector<int64> window_dims(num_spatial_dims);
39   std::vector<int64> window_ksize(num_spatial_dims);
40   std::vector<int64> window_stride(num_spatial_dims);
41   CHECK_EQ(data_format.num_spatial_dims(), num_spatial_dims)
42       << "Invalid number of spatial dimensions in data format specification";
43   for (int i = 0; i < num_spatial_dims; ++i) {
44     int dim = data_format.spatial_dimension(i);
45     input_dim_sizes[i] = input_shape[dim];
46     window_dims[i] = dim;
47     window_ksize[i] = ksize[dim];
48     window_stride[i] = stride[dim];
49   }
50 
51   XlaBuilder* b = sums.builder();
52   // Build a matrix of all 1s, with the same width/height as the input.
53   auto ones = Broadcast(One(b, dtype), input_dim_sizes);
54   PaddingConfig padding_config;
55   for (int i = 0; i < num_spatial_dims; ++i) {
56     auto dims = padding_config.add_dimensions();
57     dims->set_edge_padding_low(spatial_padding[i].first);
58     dims->set_edge_padding_high(spatial_padding[i].second);
59   }
60   auto zero = Zero(b, dtype);
61   auto padded_ones = Pad(ones, zero, padding_config);
62 
63   // Perform a ReduceWindow with the same window size, strides, and padding
64   // to count the number of contributions to each result element.
65   auto counts =
66       ReduceWindow(padded_ones, zero, CreateScalarAddComputation(dtype, b),
67                    window_ksize, window_stride, Padding::kValid);
68 
69   return Div(sums, counts, window_dims);
70 }
71 
72 // Sums all elements in the window specified by 'kernel_size' and 'stride'.
ComputeSums(XlaOp operand,XlaOp init_value,absl::Span<const int64> kernel_size,absl::Span<const int64> stride,const TensorFormat & data_format)73 XlaOp ComputeSums(XlaOp operand, XlaOp init_value,
74                   absl::Span<const int64> kernel_size,
75                   absl::Span<const int64> stride,
76                   const TensorFormat& data_format) {
77   XlaBuilder* b = operand.builder();
78   return b->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
79     TF_ASSIGN_OR_RETURN(Shape operand_shape, b->GetShape(operand));
80     TF_ASSIGN_OR_RETURN(Shape init_shape, b->GetShape(init_value));
81     PrimitiveType accumulation_type = init_shape.element_type();
82     auto add_computation = CreateScalarAddComputation(accumulation_type, b);
83     return ReduceWindow(operand, init_value, add_computation, kernel_size,
84                         stride, Padding::kValid);
85   });
86 }
87 
88 // Creates a padding configuration out of spatial padding values.
MakeSpatialPaddingConfig(absl::Span<const std::pair<int64,int64>> spatial_padding,int num_spatial_dims,absl::Span<const int64> stride,const TensorFormat & data_format)89 PaddingConfig MakeSpatialPaddingConfig(
90     absl::Span<const std::pair<int64, int64>> spatial_padding,
91     int num_spatial_dims, absl::Span<const int64> stride,
92     const TensorFormat& data_format) {
93   PaddingConfig padding_config;
94   padding_config.mutable_dimensions()->Reserve(2 + num_spatial_dims);
95   for (int i = 0; i < 2 + num_spatial_dims; ++i) {
96     padding_config.add_dimensions();
97   }
98   CHECK_EQ(data_format.num_spatial_dims(), num_spatial_dims)
99       << "Invalid number of spatial dimensions in data format specification";
100   for (int i = 0; i < num_spatial_dims; ++i) {
101     int dim = data_format.spatial_dimension(i);
102     auto padding_dimension = padding_config.mutable_dimensions(dim);
103     padding_dimension->set_edge_padding_low(spatial_padding[i].first);
104     padding_dimension->set_edge_padding_high(spatial_padding[i].second);
105   }
106   return padding_config;
107 }
108 
AvgPoolDivideByCount(XlaOp pooled,absl::Span<const int64> input_size,absl::Span<const int64> window_dimensions,absl::Span<const int64> window_strides,absl::Span<const std::pair<int64,int64>> padding,PrimitiveType dtype,const TensorFormat & data_format,bool counts_include_padding)109 XlaOp AvgPoolDivideByCount(XlaOp pooled, absl::Span<const int64> input_size,
110                            absl::Span<const int64> window_dimensions,
111                            absl::Span<const int64> window_strides,
112                            absl::Span<const std::pair<int64, int64>> padding,
113                            PrimitiveType dtype, const TensorFormat& data_format,
114                            bool counts_include_padding) {
115   if (counts_include_padding) {
116     // If counts include padding, all windows have the same number of elements
117     // contributing to each average. Divide by the window size everywhere to get
118     // the average.
119     int64_t window_size =
120         std::accumulate(window_dimensions.begin(), window_dimensions.end(), 1,
121                         [](int64_t a, int64_t b) { return a * b; });
122     auto divisor = ConstantR0WithType(pooled.builder(), dtype, window_size);
123 
124     return pooled / divisor;
125   } else {
126     return AvgPoolDivideByCountWithGeneralPadding(pooled, dtype, input_size,
127                                                   padding, window_dimensions,
128                                                   window_strides, data_format);
129   }
130 }
131 
132 }  // namespace
133 
MaxPool(XlaOp operand,absl::Span<const int64> kernel_size,absl::Span<const int64> stride,Padding padding,const TensorFormat & data_format)134 XlaOp MaxPool(XlaOp operand, absl::Span<const int64> kernel_size,
135               absl::Span<const int64> stride, Padding padding,
136               const TensorFormat& data_format) {
137   XlaBuilder* b = operand.builder();
138   return b->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
139     TF_ASSIGN_OR_RETURN(Shape operand_shape, b->GetShape(operand));
140     PrimitiveType dtype = operand_shape.element_type();
141     auto max_computation = CreateScalarMaxComputation(dtype, b);
142     auto init_value = MinValue(b, dtype);
143     return ReduceWindow(operand, init_value, max_computation, kernel_size,
144                         stride, padding);
145   });
146 }
147 
AvgPool(XlaOp operand,absl::Span<const int64> kernel_size,absl::Span<const int64> stride,absl::Span<const std::pair<int64,int64>> padding,const TensorFormat & data_format,const bool counts_include_padding)148 XlaOp AvgPool(XlaOp operand, absl::Span<const int64> kernel_size,
149               absl::Span<const int64> stride,
150               absl::Span<const std::pair<int64, int64>> padding,
151               const TensorFormat& data_format,
152               const bool counts_include_padding) {
153   XlaBuilder* b = operand.builder();
154   return b->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
155     TF_ASSIGN_OR_RETURN(Shape operand_shape, b->GetShape(operand));
156     PrimitiveType dtype = operand_shape.element_type();
157     auto init_value = Zero(b, dtype);
158     std::vector<int64> input_size(operand_shape.dimensions().begin(),
159                                   operand_shape.dimensions().end());
160     const int num_dims = kernel_size.size();
161     const int num_spatial_dims = num_dims - 2;
162     auto padding_config = MakeSpatialPaddingConfig(padding, num_spatial_dims,
163                                                    stride, data_format);
164     auto padded_operand = Pad(operand, Zero(b, dtype), padding_config);
165     auto pooled = ComputeSums(padded_operand, init_value, kernel_size, stride,
166                               data_format);
167     return AvgPoolDivideByCount(pooled, input_size, kernel_size, stride,
168                                 padding, dtype, data_format,
169                                 counts_include_padding);
170   });
171 }
172 
MakeSpatialPadding(absl::Span<const int64> input_size,absl::Span<const int64> kernel_size,absl::Span<const int64> stride,Padding padding,const TensorFormat & data_format)173 std::vector<std::pair<int64, int64>> MakeSpatialPadding(
174     absl::Span<const int64> input_size, absl::Span<const int64> kernel_size,
175     absl::Span<const int64> stride, Padding padding,
176     const TensorFormat& data_format) {
177   const int num_spatial_dims = kernel_size.size() - 2;
178   std::vector<int64> input_spatial_dimensions;
179   std::vector<int64> kernel_size_spatial_dimensions;
180   std::vector<int64> stride_spatial_dimensions;
181   CHECK_EQ(data_format.num_spatial_dims(), num_spatial_dims)
182       << "Invalid number of spatial dimensions in data format specification";
183   for (int i = 0; i < num_spatial_dims; ++i) {
184     int dim = data_format.spatial_dimension(i);
185     input_spatial_dimensions.push_back(input_size[dim]);
186     kernel_size_spatial_dimensions.push_back(kernel_size[dim]);
187     stride_spatial_dimensions.push_back(stride[dim]);
188   }
189   return MakePadding(input_spatial_dimensions, kernel_size_spatial_dimensions,
190                      stride_spatial_dimensions, padding);
191 }
192 
AvgPoolGrad(XlaOp out_backprop,absl::Span<const int64> gradients_size,absl::Span<const int64> kernel_size,absl::Span<const int64> stride,absl::Span<const std::pair<int64,int64>> spatial_padding,const TensorFormat & data_format,const bool counts_include_padding)193 XlaOp AvgPoolGrad(XlaOp out_backprop, absl::Span<const int64> gradients_size,
194                   absl::Span<const int64> kernel_size,
195                   absl::Span<const int64> stride,
196                   absl::Span<const std::pair<int64, int64>> spatial_padding,
197                   const TensorFormat& data_format,
198                   const bool counts_include_padding) {
199   XlaBuilder* b = out_backprop.builder();
200   return b->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
201     const int num_dims = kernel_size.size();
202     const int num_gradients = gradients_size.size();
203     if (num_gradients != num_dims) {
204       return tensorflow::errors::InvalidArgument("gradients must be ", num_dims,
205                                                  "-dimensional");
206     }
207 
208     TF_ASSIGN_OR_RETURN(Shape out_backprop_xla_shape,
209                         b->GetShape(out_backprop));
210     const int backprop_xla_num_dims =
211         out_backprop_xla_shape.dimensions().size();
212     if (backprop_xla_num_dims != num_dims) {
213       return tensorflow::errors::InvalidArgument("out_backprop must be ",
214                                                  num_dims, "-dimensional");
215     }
216 
217     // We can think of average-pooling as:
218     // * a convolution with a kernel consisting entirely of 1s, where the
219     //   input feature and output feature are equal, and 0s everywhere else.
220     // * followed by dividing by the counts.
221     //
222     // This then gives us an algorithm to build the gradient:
223     // * divide out_backprop by the counts, followed by
224     // * Conv2DBackpropInput specialized for that kernel, which simplifies to
225     //   a Pad and a ReduceWindow.
226     //
227     // For an explanation of backpropagation for convolution, see the comments
228     // in third_party/tensorflow/core/kernels/conv_grad_ops.h
229 
230     // TF filter shape is [ H, W, ..., inC, outC ]
231 
232     // The input gradients are computed by a convolution of the output gradients
233     // and the filter, with some appropriate padding. See the comment at the top
234     // of conv_grad_ops.h for details.
235     PrimitiveType dtype = out_backprop_xla_shape.element_type();
236     auto out_backprop_div = AvgPoolDivideByCount(
237         out_backprop, gradients_size, kernel_size, stride, spatial_padding,
238         dtype, data_format, counts_include_padding);
239 
240     // Pad the gradients in the spatial dimensions. We use the same padding
241     // as Conv2DBackpropInput.
242     PaddingConfig padding_config = MakeNoPaddingConfig(num_dims);
243     std::vector<int64> padded_gradients_size(gradients_size.begin(),
244                                              gradients_size.end());
245     // First, pad the output gradients the same way as the input. The additional
246     // padding will be removed as a last step before returning the input
247     // gradients.
248     const int num_spatial_dims = num_dims - 2;
249     for (int i = 0; i < num_spatial_dims; ++i) {
250       int dim = data_format.spatial_dimension(i);
251       padded_gradients_size[dim] +=
252           (spatial_padding[i].first + spatial_padding[i].second);
253     }
254     for (int i = 0; i < num_spatial_dims; ++i) {
255       int dim = data_format.spatial_dimension(i);
256       TF_ASSIGN_OR_RETURN(
257           SpatialDimensionOutputSizeAndPadding conv_backprop_spatial_dim,
258           ConvGradExtractAndVerifyDimension(
259               /*input_size=*/padded_gradients_size[dim],
260               /*filter_size=*/kernel_size[dim],
261               /*output_size=*/out_backprop_xla_shape.dimensions(dim),
262               /*dilation=*/1,
263               /*stride=*/stride[dim], /*padding=*/Padding::kValid));
264       auto* padding = padding_config.mutable_dimensions(dim);
265       padding->set_edge_padding_low(conv_backprop_spatial_dim.pad_before);
266       padding->set_edge_padding_high(conv_backprop_spatial_dim.pad_after);
267       padding->set_interior_padding(stride[dim] - 1);
268     }
269 
270     auto zero = Zero(b, dtype);
271     auto padded_gradients = Pad(out_backprop_div, zero, padding_config);
272 
273     // in_backprop = padded_gradients <conv> ones
274     std::vector<int64> ones(num_dims, 1LL);
275     auto in_backprop =
276         ReduceWindow(padded_gradients, Zero(b, dtype),
277                      CreateScalarAddComputation(dtype, b), kernel_size,
278                      /*window_strides=*/ones, Padding::kValid);
279     // The input padding doesn't contribute to the gradient, remove it.
280     std::vector<std::pair<int64, int64>> neg_spatial_padding;
281     neg_spatial_padding.reserve(spatial_padding.size());
282     for (const std::pair<int64, int64>& spatial_padding_dim : spatial_padding) {
283       neg_spatial_padding.emplace_back(-spatial_padding_dim.first,
284                                        -spatial_padding_dim.second);
285     }
286     auto remove_padding_config = MakeSpatialPaddingConfig(
287         neg_spatial_padding, num_spatial_dims, stride, data_format);
288     return Pad(in_backprop, zero, remove_padding_config);
289   });
290 }
291 
292 }  // namespace xla
293