• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_POOLING_H_
17 #define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_POOLING_H_
18 
19 #include "absl/container/inlined_vector.h"
20 #include "tensorflow/compiler/xla/client/xla_builder.h"
21 
22 namespace xla {
23 
24 // Tensor format for reduce window operations.
25 class TensorFormat {
26  public:
TensorFormat(int batch_dimension,int feature_dimension,absl::Span<const int64> spatial_dimensions)27   TensorFormat(int batch_dimension, int feature_dimension,
28                absl::Span<const int64> spatial_dimensions)
29       : batch_dimension_(batch_dimension),
30         feature_dimension_(feature_dimension),
31         spatial_dimensions_(spatial_dimensions.begin(),
32                             spatial_dimensions.end()) {}
33 
batch_dimension()34   int batch_dimension() const { return batch_dimension_; }
35 
feature_dimension()36   int feature_dimension() const { return feature_dimension_; }
37 
spatial_dimension(int dim)38   int spatial_dimension(int dim) const { return spatial_dimensions_[dim]; }
39 
num_spatial_dims()40   int num_spatial_dims() const { return spatial_dimensions_.size(); }
41 
42  private:
43   // The number of the dimension that represents the batch.
44   int batch_dimension_;
45   // The number of the dimension that represents the features.
46   int feature_dimension_;
47   // The dimension numbers for the spatial dimensions.
48   absl::InlinedVector<int, 4> spatial_dimensions_;
49 };
50 
51 // Computes the max pool of 'operand'.
52 XlaOp MaxPool(XlaOp operand, absl::Span<const int64> kernel_size,
53               absl::Span<const int64> stride, Padding padding,
54               const TensorFormat& data_format);
55 
56 // Computes the average pool of 'operand'.
57 XlaOp AvgPool(XlaOp operand, absl::Span<const int64> kernel_size,
58               absl::Span<const int64> stride,
59               absl::Span<const std::pair<int64, int64>> padding,
60               const TensorFormat& data_format,
61               const bool counts_include_padding);
62 
63 // Returns the list of low and high padding elements in each spatial dimension
64 // for the given 'padding' specification.
65 std::vector<std::pair<int64, int64>> MakeSpatialPadding(
66     absl::Span<const int64> input_size, absl::Span<const int64> kernel_size,
67     absl::Span<const int64> stride, Padding padding,
68     const TensorFormat& data_format);
69 
70 // Computes the average pool gradient.
71 XlaOp AvgPoolGrad(XlaOp out_backprop, absl::Span<const int64> gradients_size,
72                   absl::Span<const int64> kernel_size,
73                   absl::Span<const int64> stride,
74                   absl::Span<const std::pair<int64, int64>> spatial_padding,
75                   const TensorFormat& data_format,
76                   const bool counts_include_padding);
77 
78 }  // namespace xla
79 
80 #endif  // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_POOLING_H_
81