• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // XLA specific pooling ops.
17 
18 #include <string>
19 
20 #include "tensorflow/compiler/tf2xla/mlir_xla_op_kernel.h"
21 #include "tensorflow/compiler/tf2xla/shape_util.h"
22 #include "tensorflow/compiler/tf2xla/type_util.h"
23 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
24 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
25 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
26 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
27 #include "tensorflow/compiler/xla/client/lib/constants.h"
28 #include "tensorflow/compiler/xla/client/lib/pooling.h"
29 #include "tensorflow/compiler/xla/client/xla_builder.h"
30 #include "tensorflow/compiler/xla/client/xla_computation.h"
31 #include "tensorflow/compiler/xla/literal.h"
32 #include "tensorflow/compiler/xla/util.h"
33 #include "tensorflow/core/framework/bounds_check.h"
34 #include "tensorflow/core/framework/op_kernel.h"
35 #include "tensorflow/core/framework/register_types.h"
36 #include "tensorflow/core/framework/tensor.h"
37 #include "tensorflow/core/platform/errors.h"
38 #include "tensorflow/core/util/tensor_format.h"
39 
40 namespace tensorflow {
41 namespace {
42 
43 // Superclass of pooling ops.
44 class PoolingOp : public XlaOpKernel {
45  public:
PoolingOp(OpKernelConstruction * ctx,int num_spatial_dims,const DataType reduction_type)46   PoolingOp(OpKernelConstruction* ctx, int num_spatial_dims,
47             const DataType reduction_type)
48       : XlaOpKernel(ctx),
49         num_spatial_dims_(num_spatial_dims),
50         reduction_type_(reduction_type) {
51     if (ctx->num_inputs() == 1) {
52       std::vector<int32> ksize_int;
53       std::vector<int32> stride_int;
54       OP_REQUIRES_OK(ctx, ctx->GetAttr("ksize", &ksize_int));
55       OP_REQUIRES(ctx, ksize_int.size() == num_dims(),
56                   errors::InvalidArgument("Sliding window ksize field must "
57                                           "specify ",
58                                           num_dims(), " dimensions"));
59       OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &stride_int));
60       OP_REQUIRES(ctx, stride_int.size() == num_dims(),
61                   errors::InvalidArgument("Sliding window stride field must "
62                                           "specify ",
63                                           num_dims(), " dimensions"));
64       for (int i = 0; i < num_dims(); ++i) {
65         ksize_.push_back(ksize_int[i]);
66         stride_.push_back(stride_int[i]);
67       }
68     }
69     Padding padding;
70     OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding));
71     OP_REQUIRES(ctx, padding != EXPLICIT,
72                 errors::Unimplemented(
73                     "XLA does not support pooling ops with explicit padding."));
74     padding_ = (padding == VALID) ? xla::Padding::kValid : xla::Padding::kSame;
75 
76     OP_REQUIRES_OK(
77         ctx, DataTypeToPrimitiveType(reduction_type_, &xla_reduction_type_));
78   }
79 
num_dims() const80   int num_dims() const { return num_spatial_dims_ + 2; }
81 
82  protected:
GetKernelSize(XlaOpKernelContext * ctx)83   StatusOr<std::vector<int64>> GetKernelSize(XlaOpKernelContext* ctx) {
84     if (ctx->num_inputs() == 1) {
85       return ksize_;
86     }
87     const TensorShape ksize_shape = ctx->InputShape(1);
88     // Validate input sizes.
89     if (!TensorShapeUtils::IsVector(ksize_shape)) {
90       return errors::InvalidArgument("ksize must be a vector, not shape ",
91                                      ksize_shape.DebugString());
92     }
93     if (ksize_shape.num_elements() != num_dims()) {
94       return errors::InvalidArgument(
95           "Sliding window ksize field must "
96           "specify ",
97           num_dims(), " dimensions");
98     }
99     std::vector<int64> ksize;
100     auto status = ctx->ConstantInputAsIntVector(1, &ksize);
101     if (!status.ok()) {
102       return status;
103     }
104     return ksize;
105   }
106 
GetStride(XlaOpKernelContext * ctx)107   StatusOr<std::vector<int64>> GetStride(XlaOpKernelContext* ctx) {
108     if (ctx->num_inputs() == 1) {
109       return stride_;
110     }
111     const TensorShape stride_shape = ctx->InputShape(2);
112     // Validate input sizes.
113     if (!TensorShapeUtils::IsVector(stride_shape)) {
114       return errors::InvalidArgument("stride must be a vector, not shape ",
115                                      stride_shape.DebugString());
116     }
117     if (stride_shape.num_elements() != num_dims()) {
118       return errors::InvalidArgument(
119           "Sliding window stride field must "
120           "specify ",
121           num_dims(), " dimensions");
122     }
123     std::vector<int64> stride;
124     auto status = ctx->ConstantInputAsIntVector(2, &stride);
125     if (!status.ok()) {
126       return status;
127     }
128     return stride;
129   }
130 
131  protected:
132   const int num_spatial_dims_;
133   std::vector<int64> ksize_;
134   std::vector<int64> stride_;
135   xla::Padding padding_;
136   TensorFormat data_format_ = FORMAT_NHWC;
137   DataType reduction_type_;
138   xla::PrimitiveType xla_reduction_type_;
139 };
140 
141 // Converts the tensor data format to the one required by the XLA pooling
142 // library.
XlaTensorFormat(tensorflow::TensorFormat data_format,int num_spatial_dims)143 xla::TensorFormat XlaTensorFormat(tensorflow::TensorFormat data_format,
144                                   int num_spatial_dims) {
145   int num_dims = num_spatial_dims + 2;
146   int batch_dimension = GetTensorBatchDimIndex(num_dims, data_format);
147   int feature_dimension = GetTensorFeatureDimIndex(num_dims, data_format);
148   absl::InlinedVector<int64, 4> spatial_dimensions(num_spatial_dims);
149   for (int spatial_dim = 0; spatial_dim < num_spatial_dims; ++spatial_dim) {
150     spatial_dimensions[spatial_dim] =
151         GetTensorSpatialDimIndex(num_dims, data_format, spatial_dim);
152   }
153   return xla::TensorFormat(/*batch_dimension=*/batch_dimension,
154                            /*feature_dimension=*/feature_dimension,
155                            /*spatial_dimensions=*/spatial_dimensions);
156 }
157 
158 class MaxPoolOp : public PoolingOp {
159  public:
MaxPoolOp(OpKernelConstruction * ctx,int num_spatial_dims)160   MaxPoolOp(OpKernelConstruction* ctx, int num_spatial_dims)
161       : PoolingOp(ctx, /*num_spatial_dims=*/num_spatial_dims,
162                   /*reduction_type=*/ctx->input_type(0)) {
163     std::string data_format_str;
164     OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format_str));
165     OP_REQUIRES(ctx, FormatFromString(data_format_str, &data_format_),
166                 errors::InvalidArgument("Invalid data format"));
167     OP_REQUIRES(ctx, data_format_ != FORMAT_NHWC_VECT_W,
168                 errors::Unimplemented(
169                     "XLA does not support the VECT_NHWC_VECT_W data format. "
170                     "Returning unimplemented from MaxPool to keep "
171                     "Tensorflow's intended optimized MaxPool here."));
172   }
173 
Compile(XlaOpKernelContext * ctx)174   void Compile(XlaOpKernelContext* ctx) override {
175     auto ksize_or_error = GetKernelSize(ctx);
176     OP_REQUIRES_OK(ctx, ksize_or_error.status());
177     std::vector<int64> ksize = ksize_or_error.ValueOrDie();
178 
179     auto stride_or_error = GetStride(ctx);
180     OP_REQUIRES_OK(ctx, stride_or_error.status());
181     std::vector<int64> stride = stride_or_error.ValueOrDie();
182 
183     xla::XlaOp input = ctx->Input(0);
184 
185     StatusOr<xla::Shape> input_shape = ctx->builder()->GetShape(input);
186     OP_REQUIRES_OK(ctx, input_shape.status());
187 
188     // For VECT_C max-pool ops, transpose to plain NCHW, do the max-pool, and
189     // transpose back.  This isn't necessarily the most efficient algorithm, but
190     // it's ok for starters.
191     absl::optional<int64> vect_width;
192     if (data_format_ == FORMAT_NCHW_VECT_C) {
193       vect_width = input_shape->dimensions().back();
194       input = xla::Collapse(xla::Transpose(input, {0, 1, 4, 2, 3}), {1, 2});
195 
196       input_shape = ctx->builder()->GetShape(input);
197       OP_REQUIRES_OK(ctx, input_shape.status());
198     }
199 
200     OP_REQUIRES(ctx, input_shape->dimensions_size() == num_dims(),
201                 errors::InvalidArgument("Input to ", type_string(),
202                                         " operator must have ", num_dims(),
203                                         " dimensions"));
204     auto pooling = xla::MaxPool(
205         input, ksize, stride, padding_,
206         XlaTensorFormat(
207             data_format_ == FORMAT_NCHW_VECT_C ? FORMAT_NCHW : data_format_,
208             input_shape->dimensions_size() - 2));
209 
210     if (data_format_ == FORMAT_NCHW_VECT_C) {
211       StatusOr<xla::Shape> result_shape = ctx->builder()->GetShape(pooling);
212       OP_REQUIRES_OK(ctx, result_shape.status());
213 
214       int64 num_channels = result_shape->dimensions(1);
215       OP_REQUIRES(
216           ctx, num_channels % *vect_width == 0,
217           errors::FailedPrecondition("Result of NCHW_VECT_C op must have "
218                                      "channels multiple of ",
219                                      *vect_width, ", but was ", num_channels));
220 
221       absl::InlinedVector<int64, 5> new_dims(result_shape->dimensions().begin(),
222                                              result_shape->dimensions().end());
223       new_dims[1] /= *vect_width;
224       new_dims.insert(new_dims.begin() + 2, *vect_width);
225       pooling =
226           xla::Transpose(xla::Reshape(pooling, new_dims), {0, 1, 3, 4, 2});
227     }
228 
229     ctx->SetOutput(0, pooling);
230   }
231 };
232 
233 class MaxPool2DOp : public MaxPoolOp {
234  public:
MaxPool2DOp(OpKernelConstruction * ctx)235   explicit MaxPool2DOp(OpKernelConstruction* ctx)
236       : MaxPoolOp(ctx, /*num_spatial_dims=*/2) {}
237 };
238 REGISTER_XLA_OP(Name("MaxPool"), MaxPool2DOp);
239 REGISTER_XLA_OP(Name("MaxPoolV2")
240                     .CompileTimeConstantInput("ksize")
241                     .CompileTimeConstantInput("strides"),
242                 MaxPool2DOp);
243 
244 class MaxPool3DOp : public MaxPoolOp {
245  public:
MaxPool3DOp(OpKernelConstruction * ctx)246   explicit MaxPool3DOp(OpKernelConstruction* ctx)
247       : MaxPoolOp(ctx, /*num_spatial_dims=*/3) {}
248 };
249 REGISTER_XLA_OP(Name("MaxPool3D"), MaxPool3DOp);
250 
251 class AvgPoolOp : public PoolingOp {
252  public:
AvgPoolOp(OpKernelConstruction * ctx,int num_spatial_dims)253   AvgPoolOp(OpKernelConstruction* ctx, int num_spatial_dims)
254       : PoolingOp(ctx, /*num_spatial_dims=*/num_spatial_dims,
255                   /*reduction_type=*/
256                   XlaHelpers::SumAccumulationType(ctx->input_type(0))) {
257     string data_format_str;
258     OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format_str));
259     OP_REQUIRES(ctx, FormatFromString(data_format_str, &data_format_),
260                 errors::InvalidArgument("Invalid data format"));
261   }
262 
Compile(XlaOpKernelContext * ctx)263   void Compile(XlaOpKernelContext* ctx) override {
264     auto ksize_or_error = GetKernelSize(ctx);
265     OP_REQUIRES_OK(ctx, ksize_or_error.status());
266     std::vector<int64> ksize = ksize_or_error.ValueOrDie();
267 
268     auto stride_or_error = GetStride(ctx);
269     OP_REQUIRES_OK(ctx, stride_or_error.status());
270     std::vector<int64> stride = stride_or_error.ValueOrDie();
271 
272     const TensorShape input_shape = ctx->InputShape(0);
273     OP_REQUIRES(ctx, input_shape.dims() == num_dims(),
274                 errors::InvalidArgument("Input to ", type_string(),
275                                         " operator must have ", num_dims(),
276                                         " dimensions"));
277 
278     auto xla_data_format =
279         XlaTensorFormat(data_format_, input_shape.dims() - 2);
280     auto spatial_padding = MakeSpatialPadding(
281         input_shape.dim_sizes(), ksize, stride, padding_, xla_data_format);
282 
283     // Convert the input to the reduction type.
284     auto converted_input =
285         ConvertElementType(ctx->Input(0), xla_reduction_type_);
286     auto pooling =
287         xla::AvgPool(converted_input, ksize, stride, spatial_padding,
288                      xla_data_format, padding_ == xla::Padding::kValid);
289     // Convert the pooling result back to the input type before returning it.
290     ctx->SetOutput(0, ConvertElementType(pooling, ctx->input_xla_type(0)));
291   }
292 };
293 
294 class AvgPool2DOp : public AvgPoolOp {
295  public:
AvgPool2DOp(OpKernelConstruction * ctx)296   explicit AvgPool2DOp(OpKernelConstruction* ctx)
297       : AvgPoolOp(ctx, /*num_spatial_dims=*/2) {}
298 };
299 REGISTER_XLA_OP(Name("AvgPool"), AvgPool2DOp);
300 
301 REGISTER_XLA_OP(Name("AvgPool3D"), MlirXlaOpKernel);
302 
303 // The operation to compute MaxPool gradients.
304 // It takes three inputs:
305 //   - The original input tensor
306 //   - The original output tensor
307 //   - Backprop tensor for output
308 // It produces one output: backprop tensor for input.
309 class MaxPoolGradOp : public XlaOpKernel {
310  public:
MaxPoolGradOp(OpKernelConstruction * ctx,int num_spatial_dims)311   MaxPoolGradOp(OpKernelConstruction* ctx, int num_spatial_dims)
312       : XlaOpKernel(ctx), num_spatial_dims_(num_spatial_dims) {
313     if (ctx->num_inputs() == 3) {
314       OP_REQUIRES_OK(ctx, ctx->GetAttr("ksize", &ksize_));
315       OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &stride_));
316     }
317     OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_));
318     OP_REQUIRES(ctx, padding_ != EXPLICIT,
319                 errors::Unimplemented(
320                     "XLA does not support maxpoolgrad with explicit padding."));
321   }
322 
num_dims() const323   int num_dims() const { return num_spatial_dims_ + 2; }
324 
Compile(XlaOpKernelContext * ctx)325   void Compile(XlaOpKernelContext* ctx) override {
326     if (ctx->num_inputs() != 3) {
327       OP_REQUIRES(
328           ctx, ctx->num_inputs() == 5,
329           errors::InvalidArgument("Must supply ksize and stride arguments."));
330       const TensorShape ksize_shape = ctx->InputShape(3);
331       // Validate input sizes.
332       OP_REQUIRES(ctx, TensorShapeUtils::IsVector(ksize_shape),
333                   errors::InvalidArgument("ksize must be a vector, not shape ",
334                                           ksize_shape.DebugString()));
335       OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(3, &ksize_));
336 
337       const TensorShape stride_shape = ctx->InputShape(4);
338       // Validate input sizes.
339       OP_REQUIRES(ctx, TensorShapeUtils::IsVector(stride_shape),
340                   errors::InvalidArgument("stride must be a vector, not shape ",
341                                           stride_shape.DebugString()));
342       OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(4, &stride_));
343     }
344 
345     OP_REQUIRES(ctx, ksize_.size() == num_dims(),
346                 errors::InvalidArgument("Sliding window ksize field must "
347                                         "specify ",
348                                         num_dims(), " dimensions"));
349     OP_REQUIRES(ctx, stride_.size() == num_dims(),
350                 errors::InvalidArgument("Sliding window strides field must "
351                                         "specify ",
352                                         num_dims(), " dimensions"));
353 
354     const TensorShape tensor_in_shape = ctx->InputShape(0);
355     const TensorShape tensor_out_shape = ctx->InputShape(1);
356     const TensorShape out_backprop_shape = ctx->InputShape(2);
357 
358     // For maxpooling, tensor_in should have num_dims() dimensions.
359     OP_REQUIRES(ctx, tensor_in_shape.dims() == num_dims(),
360                 errors::InvalidArgument("tensor_in must be ", num_dims(),
361                                         "-dimensional"));
362     OP_REQUIRES(ctx, tensor_out_shape.dims() == num_dims(),
363                 errors::InvalidArgument("tensor_out must be ", num_dims(),
364                                         "-dimensional"));
365     // For maxpooling, out_backprop should have num_dims() dimensions.
366     OP_REQUIRES(ctx, out_backprop_shape.dims() == num_dims(),
367                 errors::InvalidArgument("out_backprop must be ", num_dims(),
368                                         "-dimensional"));
369 
370     // TODO(phawkins): The XLA version doesn't need tensor_out. Investigate
371     // whether this is a good time/space tradeoff.
372     auto input = ctx->Input(0);
373     auto out_backprop = ctx->Input(2);
374     // We ensured padding_ is not EXPLICIT in the constructor.
375     xla::Padding xla_padding =
376         (padding_ == VALID) ? xla::Padding::kValid : xla::Padding::kSame;
377 
378     // Create a MaxPool operation to check the expected resulting shape, and
379     // then throw away the operation because we don't actually need it here.
380     TensorShape expected_out_shape;
381     auto pooling =
382         xla::MaxPool(ctx->Input(0), ksize_, stride_, xla_padding,
383                      XlaTensorFormat(data_format_, tensor_in_shape.dims() - 2));
384     auto status_or_shape = pooling.builder()->GetShape(pooling);
385     OP_REQUIRES_OK(ctx, status_or_shape.status());
386     OP_REQUIRES_OK(ctx, XLAShapeToTensorShape(status_or_shape.ValueOrDie(),
387                                               &expected_out_shape));
388     OP_REQUIRES(ctx, expected_out_shape == out_backprop_shape,
389                 errors::Unimplemented("The output dimensions do not match the "
390                                       "other input values."));
391 
392     xla::PrimitiveType element_type;
393     OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(input_type(2), &element_type));
394     xla::XlaOp init_value = XlaHelpers::Zero(ctx->builder(), input_type(2));
395     auto select = CreateScalarGeComputation(element_type, ctx->builder());
396     auto scatter = CreateScalarAddComputation(element_type, ctx->builder());
397     xla::XlaOp gradients =
398         xla::SelectAndScatter(input, select, ksize_, stride_, xla_padding,
399                               out_backprop, init_value, scatter);
400 
401     ctx->SetOutput(0, gradients);
402   }
403 
404  protected:
405   const int num_spatial_dims_;
406   std::vector<int64> ksize_;
407   std::vector<int64> stride_;
408   Padding padding_;
409   TensorFormat data_format_ = FORMAT_NHWC;
410 };
411 
412 class MaxPool2DGradOp : public MaxPoolGradOp {
413  public:
MaxPool2DGradOp(OpKernelConstruction * ctx)414   explicit MaxPool2DGradOp(OpKernelConstruction* ctx)
415       : MaxPoolGradOp(ctx, /*num_spatial_dims=*/2) {
416     string data_format;
417     OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format));
418     OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_),
419                 errors::InvalidArgument("Invalid data format"));
420   }
421 };
422 REGISTER_XLA_OP(Name("MaxPoolGrad"), MaxPool2DGradOp);
423 REGISTER_XLA_OP(Name("MaxPoolGradV2")
424                     .CompileTimeConstantInput("ksize")
425                     .CompileTimeConstantInput("strides"),
426                 MaxPool2DGradOp);
427 
428 REGISTER_XLA_OP(Name("MaxPool3DGrad"), MlirXlaOpKernel);
429 
430 // Average-pooling gradient
431 class AvgPoolGradOp : public XlaOpKernel {
432  public:
AvgPoolGradOp(OpKernelConstruction * ctx,int num_spatial_dims)433   AvgPoolGradOp(OpKernelConstruction* ctx, int num_spatial_dims)
434       : XlaOpKernel(ctx), num_spatial_dims_(num_spatial_dims) {
435     OP_REQUIRES_OK(ctx, ctx->GetAttr("ksize", &ksize_));
436     OP_REQUIRES(ctx, ksize_.size() == num_dims(),
437                 errors::InvalidArgument("Sliding window ksize field must "
438                                         "specify ",
439                                         num_dims(), " dimensions"));
440     OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &stride_));
441     OP_REQUIRES(ctx, stride_.size() == num_dims(),
442                 errors::InvalidArgument("Sliding window strides field must "
443                                         "specify ",
444                                         num_dims(), " dimensions"));
445     OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_));
446     OP_REQUIRES(ctx, padding_ != EXPLICIT,
447                 errors::Unimplemented(
448                     "XLA does not support avgpoolgrad with explicit padding."));
449     OP_REQUIRES(ctx, ksize_[0] == 1 && stride_[0] == 1,
450                 errors::Unimplemented(
451                     "Pooling is not yet supported on the batch dimension."));
452 
453     string data_format;
454     OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format));
455     OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_),
456                 errors::InvalidArgument("Invalid data format"));
457   }
458 
num_dims() const459   int num_dims() const { return num_spatial_dims_ + 2; }
460 
Compile(XlaOpKernelContext * ctx)461   void Compile(XlaOpKernelContext* ctx) override {
462     TensorShape gradients_shape;
463     OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &gradients_shape));
464 
465     const TensorShape out_backprop_shape = ctx->InputShape(1);
466 
467     // For avgpooling, tensor_in_shape should have num_dims() dimensions.
468     OP_REQUIRES(ctx, gradients_shape.dims() == num_dims(),
469                 errors::InvalidArgument("orig_input_shape must be ", num_dims(),
470                                         "-dimensional"));
471 
472     // For avgpooling, out_backprop should have num_dims() dimensions.
473     OP_REQUIRES(ctx, out_backprop_shape.dims() == num_dims(),
474                 errors::InvalidArgument("out_backprop must be ", num_dims(),
475                                         "-dimensional"));
476 
477     auto out_backprop = ctx->Input(1);
478     std::vector<int64> stride_int64s(stride_.begin(), stride_.end());
479     xla::Padding xla_padding =
480         (padding_ == VALID) ? xla::Padding::kValid : xla::Padding::kSame;
481     xla::PrimitiveType xla_reduction_type;
482     auto reduction_type = XlaHelpers::SumAccumulationType(ctx->input_type(1));
483     OP_REQUIRES_OK(
484         ctx, DataTypeToPrimitiveType(reduction_type, &xla_reduction_type));
485     auto converted_out_backprop =
486         xla::ConvertElementType(out_backprop, xla_reduction_type);
487     auto xla_data_format =
488         XlaTensorFormat(data_format_, gradients_shape.dims() - 2);
489     auto padding_values =
490         MakeSpatialPadding(gradients_shape.dim_sizes(), ksize_, stride_int64s,
491                            xla_padding, xla_data_format);
492     auto in_backprop =
493         xla::AvgPoolGrad(converted_out_backprop, gradients_shape.dim_sizes(),
494                          ksize_, stride_int64s, padding_values, xla_data_format,
495                          /*counts_include_padding=*/padding_ == VALID);
496     // Convert the pooling result back to the input type before returning it.
497     xla::PrimitiveType xla_out_backprop_type;
498     OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(ctx->input_type(1),
499                                                 &xla_out_backprop_type));
500     ctx->SetOutput(0,
501                    xla::ConvertElementType(in_backprop, xla_out_backprop_type));
502   }
503 
504  protected:
505   const int num_spatial_dims_;
506   std::vector<int64> ksize_;
507   std::vector<int32> stride_;
508   Padding padding_;
509   TensorFormat data_format_ = FORMAT_NHWC;
510 };
511 
512 class AvgPool2DGradOp : public AvgPoolGradOp {
513  public:
AvgPool2DGradOp(OpKernelConstruction * ctx)514   explicit AvgPool2DGradOp(OpKernelConstruction* ctx)
515       : AvgPoolGradOp(ctx, /*num_spatial_dims=*/2) {}
516 };
517 REGISTER_XLA_OP(
518     Name("AvgPoolGrad").CompileTimeConstantInput("orig_input_shape"),
519     AvgPool2DGradOp);
520 
521 class AvgPool3DGradOp : public AvgPoolGradOp {
522  public:
AvgPool3DGradOp(OpKernelConstruction * ctx)523   explicit AvgPool3DGradOp(OpKernelConstruction* ctx)
524       : AvgPoolGradOp(ctx, /*num_spatial_dims=*/3) {}
525 };
526 REGISTER_XLA_OP(
527     Name("AvgPool3DGrad").CompileTimeConstantInput("orig_input_shape"),
528     AvgPool3DGradOp);
529 
530 class MaxPoolGradGradOp : public XlaOpKernel {
531  public:
MaxPoolGradGradOp(OpKernelConstruction * ctx,int num_spatial_dims)532   MaxPoolGradGradOp(OpKernelConstruction* ctx, int num_spatial_dims)
533       : XlaOpKernel(ctx), num_spatial_dims_(num_spatial_dims) {
534     if (ctx->num_inputs() == 3) {
535       OP_REQUIRES_OK(ctx, ctx->GetAttr("ksize", &ksize_));
536       OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &stride_));
537     }
538     OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_));
539     OP_REQUIRES(
540         ctx, padding_ != EXPLICIT,
541         errors::Unimplemented(
542             "XLA does not support maxpoolgradgrad with explicit padding."));
543   }
544 
num_dims() const545   int num_dims() const { return num_spatial_dims_ + 2; }
546 
Compile(XlaOpKernelContext * ctx)547   void Compile(XlaOpKernelContext* ctx) override {
548     if (ctx->num_inputs() != 3) {
549       OP_REQUIRES(
550           ctx, ctx->num_inputs() == 5,
551           errors::InvalidArgument("Must supply ksize and stride arguments."));
552       const TensorShape ksize_shape = ctx->InputShape(3);
553       // Validate input sizes.
554       OP_REQUIRES(ctx, TensorShapeUtils::IsVector(ksize_shape),
555                   errors::InvalidArgument("ksize must be a vector, not shape ",
556                                           ksize_shape.DebugString()));
557       OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(3, &ksize_));
558 
559       const TensorShape stride_shape = ctx->InputShape(4);
560       // Validate input sizes.
561       OP_REQUIRES(ctx, TensorShapeUtils::IsVector(stride_shape),
562                   errors::InvalidArgument("stride must be a vector, not shape ",
563                                           stride_shape.DebugString()));
564       OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(4, &stride_));
565     }
566 
567     OP_REQUIRES(ctx, ksize_.size() == num_dims(),
568                 errors::InvalidArgument("Sliding window ksize field must "
569                                         "specify ",
570                                         num_dims(), " dimensions"));
571     OP_REQUIRES(ctx, stride_.size() == num_dims(),
572                 errors::InvalidArgument("Sliding window strides field must "
573                                         "specify ",
574                                         num_dims(), " dimensions"));
575 
576     const TensorShape tensor_in_shape = ctx->InputShape(0);
577     const TensorShape tensor_out_shape = ctx->InputShape(1);
578     const TensorShape out_backprop_shape = ctx->InputShape(2);
579 
580     // For maxpooling, tensor_in should have num_dims() dimensions.
581     OP_REQUIRES(ctx, tensor_in_shape.dims() == num_dims(),
582                 errors::InvalidArgument("tensor_in must be ", num_dims(),
583                                         "-dimensional"));
584     OP_REQUIRES(ctx, tensor_out_shape.dims() == num_dims(),
585                 errors::InvalidArgument("tensor_out must be ", num_dims(),
586                                         "-dimensional"));
587     // For maxpooling, out_backprop should have num_dims() dimensions.
588     OP_REQUIRES(ctx, out_backprop_shape.dims() == num_dims(),
589                 errors::InvalidArgument("out_backprop must be ", num_dims(),
590                                         "-dimensional"));
591 
592     // What we want to compute:
593     // Given y = MaxPool(x), and xs_grad = MaxPoolGrad(x, y, ys_grad)
594     // MaxPoolGradGrad computes {ys_grad}_grad given x, y, and {xs_grad}_grad.
595     //
596     // In the regular TF op, this amounts to selecting for each window the
597     // incoming backprop value from xs_grad_grad that corresponds to the maximal
598     // value in the corresponding window of x.
599     //
600     // TODO(b/73062247): What we really want is a ReduceWindow with different
601     // arrays for index selection vs return value selection--a select-to-gather.
602     //
603     // Here, we implement a bitwise hack: we use the hi 16 bits of input for
604     // separate max pooling alongside each of the hi and lo 16 bits of
605     // out_backprop packed into 16 lo bits, which we then glue back together at
606     // the end to get a full 32 bits of gradient.
607     //
608     // This could select the wrong backprop value for two x values that are
609     // equally maximal up to the first 16 bits, in which case we are taking the
610     // latter.
611     //
612     // Note that in principle we could use 32 separate maxpools to recover each
613     // of 32 bits of the gradient while preserving 31 bits of input for the max
614     // pooling criteria; here, we just truncate to the first 16 bits of input.
615 
616     auto input = ctx->Input(0);
617     auto out_backprop = ctx->Input(2);
618 
619     auto b = ctx->builder();
620 
621     auto sixteen = xla::ConstantR0<uint32>(b, 16);
622     // in (f32) -> round to 7 mantissa bits (bf16)-> 16-high-bit u32.
623     //
624     // NOTE: Use a ReducePrecision operation instead of a cast to BF16 and back
625     // to F32 since the XLA compiler may ignore narrowing casts to floating
626     // point types if the debug option xla_allow_excess_precision is set.
627     auto in_hi = xla::BitcastConvertType(
628         xla::ReducePrecision(input, /*exponent_bits=*/8, /*mantissa_bits=*/7),
629         xla::U32);
630     auto bp_int = xla::BitcastConvertType(out_backprop, xla::U32);
631     auto bp_hi = xla::ShiftRightLogical(bp_int, sixteen);
632     auto bp_lo =
633         xla::ShiftRightLogical(xla::ShiftLeft(bp_int, sixteen), sixteen);
634     auto in_hi_bp_hi = xla::Add(in_hi, bp_hi);  // Want an unsigned add.
635     auto in_hi_bp_lo = xla::Add(in_hi, bp_lo);  // Want an unsigned add.
636 
637     auto init_value = xla::MinValue(b, xla::F32);
638     // We will reduce by taking the maximal value up to 16 bits (ignoring the lo
639     // 16 bits of packed-in hi/lo backprop value).
640     auto rb = b->CreateSubBuilder("GreaterOrEqOf_ByFirst16Bits");
641     {
642       // F32 parameters to satisfy lowering type restriction for reduce opcode.
643       const xla::Shape scalar = xla::ShapeUtil::MakeShape(xla::F32, {});
644       auto lhs = xla::Parameter(rb.get(), 0, scalar, "lhs");
645       auto rhs = xla::Parameter(rb.get(), 1, scalar, "rhs");
646       auto sixteen = xla::ConstantR0<int32>(rb.get(), 16);
647       auto lhs_criteria =
648           xla::ShiftLeft(xla::ShiftRightLogical(
649                              xla::BitcastConvertType(lhs, xla::S32), sixteen),
650                          sixteen);
651       auto rhs_criteria =
652           xla::ShiftLeft(xla::ShiftRightLogical(
653                              xla::BitcastConvertType(rhs, xla::S32), sixteen),
654                          sixteen);
655       // Must use a F32 comparison, because S32 would not work for negatives.
656       xla::Select(xla::Ge(xla::BitcastConvertType(lhs_criteria, xla::F32),
657                           xla::BitcastConvertType(rhs_criteria, xla::F32)),
658                   lhs, rhs);
659     }
660     auto reduce = rb->BuildAndNoteError();
661     xla::Padding xla_padding =
662         (padding_ == VALID) ? xla::Padding::kValid : xla::Padding::kSame;
663     auto pooled_hi =
664         xla::ReduceWindow(xla::BitcastConvertType(in_hi_bp_hi, xla::F32),
665                           init_value, reduce, ksize_, stride_, xla_padding);
666     auto pooled_lo =
667         xla::ReduceWindow(xla::BitcastConvertType(in_hi_bp_lo, xla::F32),
668                           init_value, reduce, ksize_, stride_, xla_padding);
669     auto grads_hi =
670         xla::ShiftLeft(xla::BitcastConvertType(pooled_hi, xla::U32), sixteen);
671     auto grads_lo = xla::ShiftRightLogical(
672         xla::ShiftLeft(xla::BitcastConvertType(pooled_lo, xla::U32), sixteen),
673         sixteen);
674     auto grads = xla::Add(grads_hi, grads_lo);  // Want an unsigned add.
675 
676     xla::PrimitiveType element_type;
677     OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(input_type(2), &element_type));
678     ctx->SetOutput(0, xla::BitcastConvertType(grads, element_type));
679   }
680 
681  protected:
682   const int num_spatial_dims_;
683   std::vector<int64> ksize_;
684   std::vector<int64> stride_;
685   Padding padding_;
686   TensorFormat data_format_ = FORMAT_NHWC;
687 };
688 
689 class MaxPool2DGradGradOp : public MaxPoolGradGradOp {
690  public:
MaxPool2DGradGradOp(OpKernelConstruction * ctx)691   explicit MaxPool2DGradGradOp(OpKernelConstruction* ctx)
692       : MaxPoolGradGradOp(ctx, /*num_spatial_dims=*/2) {
693     string data_format;
694     OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format));
695     OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_),
696                 errors::InvalidArgument("Invalid data format"));
697   }
698 };
699 REGISTER_XLA_OP(Name("MaxPoolGradGrad").TypeConstraint("T", DT_FLOAT),
700                 MaxPool2DGradGradOp);
701 REGISTER_XLA_OP(Name("MaxPoolGradGradV2")
702                     .TypeConstraint("T", DT_FLOAT)
703                     .CompileTimeConstantInput("ksize")
704                     .CompileTimeConstantInput("strides"),
705                 MaxPool2DGradGradOp);
706 
707 class MaxPool3DGradGradOp : public MaxPoolGradGradOp {
708  public:
MaxPool3DGradGradOp(OpKernelConstruction * ctx)709   explicit MaxPool3DGradGradOp(OpKernelConstruction* ctx)
710       : MaxPoolGradGradOp(ctx, /*num_spatial_dims=*/3) {
711     string data_format;
712     OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format));
713     OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_),
714                 errors::InvalidArgument("Invalid data format"));
715   }
716 };
717 REGISTER_XLA_OP(Name("MaxPool3DGradGrad").TypeConstraint("T", DT_FLOAT),
718                 MaxPool3DGradGradOp);
719 
720 }  // anonymous namespace
721 }  // namespace tensorflow
722