• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // XLA specific pooling ops.
17 
18 #include <string>
19 
20 #include "tensorflow/compiler/tf2xla/mlir_xla_op_kernel.h"
21 #include "tensorflow/compiler/tf2xla/shape_util.h"
22 #include "tensorflow/compiler/tf2xla/type_util.h"
23 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
24 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
25 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
26 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
27 #include "tensorflow/compiler/xla/client/lib/constants.h"
28 #include "tensorflow/compiler/xla/client/lib/pooling.h"
29 #include "tensorflow/compiler/xla/client/value_inference.h"
30 #include "tensorflow/compiler/xla/client/xla_builder.h"
31 #include "tensorflow/compiler/xla/client/xla_computation.h"
32 #include "tensorflow/compiler/xla/literal.h"
33 #include "tensorflow/compiler/xla/util.h"
34 #include "tensorflow/core/framework/bounds_check.h"
35 #include "tensorflow/core/framework/op_kernel.h"
36 #include "tensorflow/core/framework/register_types.h"
37 #include "tensorflow/core/framework/tensor.h"
38 #include "tensorflow/core/platform/errors.h"
39 #include "tensorflow/core/util/determinism.h"
40 #include "tensorflow/core/util/tensor_format.h"
41 
42 namespace tensorflow {
43 namespace {
44 
45 // Superclass of pooling ops.
46 class PoolingOp : public XlaOpKernel {
47  public:
PoolingOp(OpKernelConstruction * ctx,int num_spatial_dims,const DataType reduction_type)48   PoolingOp(OpKernelConstruction* ctx, int num_spatial_dims,
49             const DataType reduction_type)
50       : XlaOpKernel(ctx),
51         num_spatial_dims_(num_spatial_dims),
52         reduction_type_(reduction_type) {
53     if (ctx->num_inputs() == 1) {
54       std::vector<int32> ksize_int;
55       std::vector<int32> stride_int;
56       OP_REQUIRES_OK(ctx, ctx->GetAttr("ksize", &ksize_int));
57       OP_REQUIRES(ctx, ksize_int.size() == num_dims(),
58                   errors::InvalidArgument("Sliding window ksize field must "
59                                           "specify ",
60                                           num_dims(), " dimensions"));
61       OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &stride_int));
62       OP_REQUIRES(ctx, stride_int.size() == num_dims(),
63                   errors::InvalidArgument("Sliding window stride field must "
64                                           "specify ",
65                                           num_dims(), " dimensions"));
66       for (int i = 0; i < num_dims(); ++i) {
67         ksize_.push_back(ksize_int[i]);
68         stride_.push_back(stride_int[i]);
69       }
70     }
71     Padding padding;
72     OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding));
73     OP_REQUIRES(ctx, padding != EXPLICIT,
74                 errors::Unimplemented(
75                     "XLA does not support pooling ops with explicit padding."));
76     padding_ = (padding == VALID) ? xla::Padding::kValid : xla::Padding::kSame;
77 
78     OP_REQUIRES_OK(
79         ctx, DataTypeToPrimitiveType(reduction_type_, &xla_reduction_type_));
80   }
81 
num_dims() const82   int num_dims() const { return num_spatial_dims_ + 2; }
83 
84  protected:
GetKernelSize(XlaOpKernelContext * ctx)85   StatusOr<std::vector<int64_t>> GetKernelSize(XlaOpKernelContext* ctx) {
86     if (ctx->num_inputs() == 1) {
87       return ksize_;
88     }
89     const TensorShape ksize_shape = ctx->InputShape(1);
90     // Validate input sizes.
91     if (!TensorShapeUtils::IsVector(ksize_shape)) {
92       return errors::InvalidArgument("ksize must be a vector, not shape ",
93                                      ksize_shape.DebugString());
94     }
95     if (ksize_shape.num_elements() != num_dims()) {
96       return errors::InvalidArgument(
97           "Sliding window ksize field must "
98           "specify ",
99           num_dims(), " dimensions");
100     }
101     std::vector<int64_t> ksize;
102     auto status = ctx->ConstantInputAsIntVector(1, &ksize);
103     if (!status.ok()) {
104       return status;
105     }
106     return ksize;
107   }
108 
GetStride(XlaOpKernelContext * ctx)109   StatusOr<std::vector<int64_t>> GetStride(XlaOpKernelContext* ctx) {
110     if (ctx->num_inputs() == 1) {
111       return stride_;
112     }
113     const TensorShape stride_shape = ctx->InputShape(2);
114     // Validate input sizes.
115     if (!TensorShapeUtils::IsVector(stride_shape)) {
116       return errors::InvalidArgument("stride must be a vector, not shape ",
117                                      stride_shape.DebugString());
118     }
119     if (stride_shape.num_elements() != num_dims()) {
120       return errors::InvalidArgument(
121           "Sliding window stride field must "
122           "specify ",
123           num_dims(), " dimensions");
124     }
125     std::vector<int64_t> stride;
126     auto status = ctx->ConstantInputAsIntVector(2, &stride);
127     if (!status.ok()) {
128       return status;
129     }
130     return stride;
131   }
132 
133  protected:
134   const int num_spatial_dims_;
135   std::vector<int64_t> ksize_;
136   std::vector<int64_t> stride_;
137   xla::Padding padding_;
138   TensorFormat data_format_ = FORMAT_NHWC;
139   DataType reduction_type_;
140   xla::PrimitiveType xla_reduction_type_;
141 };
142 
143 // Converts the tensor data format to the one required by the XLA pooling
144 // library.
XlaTensorFormat(tensorflow::TensorFormat data_format,int num_spatial_dims)145 xla::TensorFormat XlaTensorFormat(tensorflow::TensorFormat data_format,
146                                   int num_spatial_dims) {
147   int num_dims = num_spatial_dims + 2;
148   int batch_dimension = GetTensorBatchDimIndex(num_dims, data_format);
149   int feature_dimension = GetTensorFeatureDimIndex(num_dims, data_format);
150   absl::InlinedVector<int64_t, 4> spatial_dimensions(num_spatial_dims);
151   for (int spatial_dim = 0; spatial_dim < num_spatial_dims; ++spatial_dim) {
152     spatial_dimensions[spatial_dim] =
153         GetTensorSpatialDimIndex(num_dims, data_format, spatial_dim);
154   }
155   return xla::TensorFormat(/*batch_dimension=*/batch_dimension,
156                            /*feature_dimension=*/feature_dimension,
157                            /*spatial_dimensions=*/spatial_dimensions);
158 }
159 
160 class MaxPoolOp : public PoolingOp {
161  public:
MaxPoolOp(OpKernelConstruction * ctx,int num_spatial_dims)162   MaxPoolOp(OpKernelConstruction* ctx, int num_spatial_dims)
163       : PoolingOp(ctx, /*num_spatial_dims=*/num_spatial_dims,
164                   /*reduction_type=*/ctx->input_type(0)) {
165     std::string data_format_str;
166     OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format_str));
167     OP_REQUIRES(ctx, FormatFromString(data_format_str, &data_format_),
168                 errors::InvalidArgument("Invalid data format"));
169     OP_REQUIRES(ctx, data_format_ != FORMAT_NHWC_VECT_W,
170                 errors::Unimplemented(
171                     "XLA does not support the VECT_NHWC_VECT_W data format. "
172                     "Returning unimplemented from MaxPool to keep "
173                     "Tensorflow's intended optimized MaxPool here."));
174   }
175 
Compile(XlaOpKernelContext * ctx)176   void Compile(XlaOpKernelContext* ctx) override {
177     auto ksize_or_error = GetKernelSize(ctx);
178     OP_REQUIRES_OK(ctx, ksize_or_error.status());
179     std::vector<int64_t> ksize = ksize_or_error.ValueOrDie();
180 
181     auto stride_or_error = GetStride(ctx);
182     OP_REQUIRES_OK(ctx, stride_or_error.status());
183     std::vector<int64_t> stride = stride_or_error.ValueOrDie();
184 
185     xla::XlaOp input = ctx->Input(0);
186 
187     StatusOr<xla::Shape> input_shape = ctx->builder()->GetShape(input);
188     OP_REQUIRES_OK(ctx, input_shape.status());
189 
190     // For VECT_C max-pool ops, transpose to plain NCHW, do the max-pool, and
191     // transpose back.  This isn't necessarily the most efficient algorithm, but
192     // it's ok for starters.
193     std::optional<int64_t> vect_width;
194     if (data_format_ == FORMAT_NCHW_VECT_C) {
195       vect_width = input_shape->dimensions().back();
196       input = xla::Collapse(xla::Transpose(input, {0, 1, 4, 2, 3}), {1, 2});
197 
198       input_shape = ctx->builder()->GetShape(input);
199       OP_REQUIRES_OK(ctx, input_shape.status());
200     }
201 
202     OP_REQUIRES(ctx, input_shape->dimensions_size() == num_dims(),
203                 errors::InvalidArgument("Input to ", type_string(),
204                                         " operator must have ", num_dims(),
205                                         " dimensions"));
206     auto pooling = xla::MaxPool(
207         input, ksize, stride, padding_,
208         XlaTensorFormat(
209             data_format_ == FORMAT_NCHW_VECT_C ? FORMAT_NCHW : data_format_,
210             input_shape->dimensions_size() - 2));
211 
212     if (data_format_ == FORMAT_NCHW_VECT_C) {
213       StatusOr<xla::Shape> result_shape = ctx->builder()->GetShape(pooling);
214       OP_REQUIRES_OK(ctx, result_shape.status());
215 
216       int64 num_channels = result_shape->dimensions(1);
217       OP_REQUIRES(
218           ctx, num_channels % *vect_width == 0,
219           errors::FailedPrecondition("Result of NCHW_VECT_C op must have "
220                                      "channels multiple of ",
221                                      *vect_width, ", but was ", num_channels));
222 
223       absl::InlinedVector<int64, 5> new_dims(result_shape->dimensions().begin(),
224                                              result_shape->dimensions().end());
225       new_dims[1] /= *vect_width;
226       new_dims.insert(new_dims.begin() + 2, *vect_width);
227       pooling =
228           xla::Transpose(xla::Reshape(pooling, new_dims), {0, 1, 3, 4, 2});
229     }
230 
231     ctx->SetOutput(0, pooling);
232   }
233 };
234 
235 class MaxPool2DOp : public MaxPoolOp {
236  public:
MaxPool2DOp(OpKernelConstruction * ctx)237   explicit MaxPool2DOp(OpKernelConstruction* ctx)
238       : MaxPoolOp(ctx, /*num_spatial_dims=*/2) {}
239 };
240 REGISTER_XLA_OP(Name("MaxPool"), MaxPool2DOp);
241 REGISTER_XLA_OP(Name("MaxPoolV2")
242                     .CompileTimeConstantInput("ksize")
243                     .CompileTimeConstantInput("strides"),
244                 MaxPool2DOp);
245 
246 class MaxPool3DOp : public MaxPoolOp {
247  public:
MaxPool3DOp(OpKernelConstruction * ctx)248   explicit MaxPool3DOp(OpKernelConstruction* ctx)
249       : MaxPoolOp(ctx, /*num_spatial_dims=*/3) {}
250 };
251 REGISTER_XLA_OP(Name("MaxPool3D"), MaxPool3DOp);
252 
253 class AvgPoolOp : public PoolingOp {
254  public:
AvgPoolOp(OpKernelConstruction * ctx,int num_spatial_dims)255   AvgPoolOp(OpKernelConstruction* ctx, int num_spatial_dims)
256       : PoolingOp(ctx, /*num_spatial_dims=*/num_spatial_dims,
257                   /*reduction_type=*/
258                   XlaHelpers::SumAccumulationType(ctx->input_type(0))) {
259     string data_format_str;
260     OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format_str));
261     OP_REQUIRES(ctx, FormatFromString(data_format_str, &data_format_),
262                 errors::InvalidArgument("Invalid data format"));
263   }
264 
Compile(XlaOpKernelContext * ctx)265   void Compile(XlaOpKernelContext* ctx) override {
266     auto ksize_or_error = GetKernelSize(ctx);
267     OP_REQUIRES_OK(ctx, ksize_or_error.status());
268     std::vector<int64_t> ksize = ksize_or_error.ValueOrDie();
269 
270     auto stride_or_error = GetStride(ctx);
271     OP_REQUIRES_OK(ctx, stride_or_error.status());
272     std::vector<int64_t> stride = stride_or_error.ValueOrDie();
273 
274     const TensorShape input_shape = ctx->InputShape(0);
275     OP_REQUIRES(ctx, input_shape.dims() == num_dims(),
276                 errors::InvalidArgument("Input to ", type_string(),
277                                         " operator must have ", num_dims(),
278                                         " dimensions"));
279 
280     auto xla_data_format =
281         XlaTensorFormat(data_format_, input_shape.dims() - 2);
282     auto spatial_padding = MakeSpatialPadding(
283         input_shape.dim_sizes(), ksize, stride, padding_, xla_data_format);
284 
285     // Convert the input to the reduction type.
286     auto converted_input =
287         ConvertElementType(ctx->Input(0), xla_reduction_type_);
288     auto pooling =
289         xla::AvgPool(converted_input, ksize, stride, spatial_padding,
290                      xla_data_format, padding_ == xla::Padding::kValid);
291     // Convert the pooling result back to the input type before returning it.
292     ctx->SetOutput(0, ConvertElementType(pooling, ctx->input_xla_type(0)));
293   }
294 };
295 
296 class AvgPool2DOp : public AvgPoolOp {
297  public:
AvgPool2DOp(OpKernelConstruction * ctx)298   explicit AvgPool2DOp(OpKernelConstruction* ctx)
299       : AvgPoolOp(ctx, /*num_spatial_dims=*/2) {}
300 };
301 REGISTER_XLA_OP(Name("AvgPool"), AvgPool2DOp);
302 
303 REGISTER_XLA_OP(Name("AvgPool3D"), MlirXlaOpKernel);
304 
305 // The operation to compute MaxPool gradients.
306 // It takes three inputs:
307 //   - The original input tensor
308 //   - The original output tensor
309 //   - Backprop tensor for output
310 // It produces one output: backprop tensor for input.
311 class MaxPoolGradOp : public XlaOpKernel {
312  public:
MaxPoolGradOp(OpKernelConstruction * ctx,int num_spatial_dims)313   MaxPoolGradOp(OpKernelConstruction* ctx, int num_spatial_dims)
314       : XlaOpKernel(ctx), num_spatial_dims_(num_spatial_dims) {
315     if (ctx->num_inputs() == 3) {
316       OP_REQUIRES_OK(ctx, ctx->GetAttr("ksize", &ksize_));
317       OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &stride_));
318     }
319     OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_));
320     OP_REQUIRES(ctx, padding_ != EXPLICIT,
321                 errors::Unimplemented(
322                     "XLA does not support maxpoolgrad with explicit padding."));
323     // When determinism is enabled, the use of SelectAndScatter causes a generic
324     // error to be raised. We raise a more informative error here before
325     // SelectAndScatter is used.
326     OP_REQUIRES(
327         ctx, !tensorflow::OpDeterminismRequired(),
328         errors::Unimplemented("GPU MaxPool gradient ops do not yet have a "
329                               "deterministic XLA implementation."));
330   }
331 
num_dims() const332   int num_dims() const { return num_spatial_dims_ + 2; }
333 
Compile(XlaOpKernelContext * ctx)334   void Compile(XlaOpKernelContext* ctx) override {
335     if (ctx->num_inputs() != 3) {
336       OP_REQUIRES(
337           ctx, ctx->num_inputs() == 5,
338           errors::InvalidArgument("Must supply ksize and stride arguments."));
339       const TensorShape ksize_shape = ctx->InputShape(3);
340       // Validate input sizes.
341       OP_REQUIRES(ctx, TensorShapeUtils::IsVector(ksize_shape),
342                   errors::InvalidArgument("ksize must be a vector, not shape ",
343                                           ksize_shape.DebugString()));
344       OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(3, &ksize_));
345 
346       const TensorShape stride_shape = ctx->InputShape(4);
347       // Validate input sizes.
348       OP_REQUIRES(ctx, TensorShapeUtils::IsVector(stride_shape),
349                   errors::InvalidArgument("stride must be a vector, not shape ",
350                                           stride_shape.DebugString()));
351       OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(4, &stride_));
352     }
353 
354     OP_REQUIRES(ctx, ksize_.size() == num_dims(),
355                 errors::InvalidArgument("Sliding window ksize field must "
356                                         "specify ",
357                                         num_dims(), " dimensions"));
358     OP_REQUIRES(ctx, stride_.size() == num_dims(),
359                 errors::InvalidArgument("Sliding window strides field must "
360                                         "specify ",
361                                         num_dims(), " dimensions"));
362 
363     const TensorShape tensor_in_shape = ctx->InputShape(0);
364     const TensorShape tensor_out_shape = ctx->InputShape(1);
365     const TensorShape out_backprop_shape = ctx->InputShape(2);
366 
367     // For maxpooling, tensor_in should have num_dims() dimensions.
368     OP_REQUIRES(ctx, tensor_in_shape.dims() == num_dims(),
369                 errors::InvalidArgument("tensor_in must be ", num_dims(),
370                                         "-dimensional"));
371     OP_REQUIRES(ctx, tensor_out_shape.dims() == num_dims(),
372                 errors::InvalidArgument("tensor_out must be ", num_dims(),
373                                         "-dimensional"));
374     // For maxpooling, out_backprop should have num_dims() dimensions.
375     OP_REQUIRES(ctx, out_backprop_shape.dims() == num_dims(),
376                 errors::InvalidArgument("out_backprop must be ", num_dims(),
377                                         "-dimensional"));
378 
379     // TODO(phawkins): The XLA version doesn't need tensor_out. Investigate
380     // whether this is a good time/space tradeoff.
381     auto input = ctx->Input(0);
382     auto out_backprop = ctx->Input(2);
383     // We ensured padding_ is not EXPLICIT in the constructor.
384     xla::Padding xla_padding =
385         (padding_ == VALID) ? xla::Padding::kValid : xla::Padding::kSame;
386 
387     // Create a MaxPool operation to check the expected resulting shape, and
388     // then throw away the operation because we don't actually need it here.
389     TensorShape expected_out_shape;
390     auto pooling =
391         xla::MaxPool(ctx->Input(0), ksize_, stride_, xla_padding,
392                      XlaTensorFormat(data_format_, tensor_in_shape.dims() - 2));
393     auto status_or_shape = pooling.builder()->GetShape(pooling);
394     OP_REQUIRES_OK(ctx, status_or_shape.status());
395     OP_REQUIRES_OK(ctx, XLAShapeToTensorShape(status_or_shape.ValueOrDie(),
396                                               &expected_out_shape));
397     OP_REQUIRES(ctx, expected_out_shape == out_backprop_shape,
398                 errors::Unimplemented("The output dimensions do not match the "
399                                       "other input values."));
400 
401     xla::PrimitiveType element_type;
402     OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(input_type(2), &element_type));
403     xla::XlaOp init_value = XlaHelpers::Zero(ctx->builder(), input_type(2));
404     auto select = CreateScalarGeComputation(element_type, ctx->builder());
405     auto scatter = CreateScalarAddComputation(element_type, ctx->builder());
406     xla::XlaOp gradients =
407         xla::SelectAndScatter(input, select, ksize_, stride_, xla_padding,
408                               out_backprop, init_value, scatter);
409 
410     ctx->SetOutput(0, gradients);
411   }
412 
413  protected:
414   const int num_spatial_dims_;
415   std::vector<int64_t> ksize_;
416   std::vector<int64_t> stride_;
417   Padding padding_;
418   TensorFormat data_format_ = FORMAT_NHWC;
419 };
420 
421 class MaxPool2DGradOp : public MaxPoolGradOp {
422  public:
MaxPool2DGradOp(OpKernelConstruction * ctx)423   explicit MaxPool2DGradOp(OpKernelConstruction* ctx)
424       : MaxPoolGradOp(ctx, /*num_spatial_dims=*/2) {
425     string data_format;
426     OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format));
427     OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_),
428                 errors::InvalidArgument("Invalid data format"));
429   }
430 };
431 REGISTER_XLA_OP(Name("MaxPoolGrad"), MaxPool2DGradOp);
432 REGISTER_XLA_OP(Name("MaxPoolGradV2")
433                     .CompileTimeConstantInput("ksize")
434                     .CompileTimeConstantInput("strides"),
435                 MaxPool2DGradOp);
436 
437 REGISTER_XLA_OP(Name("MaxPool3DGrad"), MlirXlaOpKernel);
438 
439 // Average-pooling gradient
440 class AvgPoolGradOp : public XlaOpKernel {
441  public:
AvgPoolGradOp(OpKernelConstruction * ctx,int num_spatial_dims)442   AvgPoolGradOp(OpKernelConstruction* ctx, int num_spatial_dims)
443       : XlaOpKernel(ctx), num_spatial_dims_(num_spatial_dims) {
444     OP_REQUIRES_OK(ctx, ctx->GetAttr("ksize", &ksize_));
445     OP_REQUIRES(ctx, ksize_.size() == num_dims(),
446                 errors::InvalidArgument("Sliding window ksize field must "
447                                         "specify ",
448                                         num_dims(), " dimensions"));
449     OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &stride_));
450     OP_REQUIRES(ctx, stride_.size() == num_dims(),
451                 errors::InvalidArgument("Sliding window strides field must "
452                                         "specify ",
453                                         num_dims(), " dimensions"));
454     OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_));
455     OP_REQUIRES(ctx, padding_ != EXPLICIT,
456                 errors::Unimplemented(
457                     "XLA does not support avgpoolgrad with explicit padding."));
458     OP_REQUIRES(ctx, ksize_[0] == 1 && stride_[0] == 1,
459                 errors::Unimplemented(
460                     "Pooling is not yet supported on the batch dimension."));
461 
462     string data_format;
463     OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format));
464     OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_),
465                 errors::InvalidArgument("Invalid data format"));
466   }
467 
num_dims() const468   int num_dims() const { return num_spatial_dims_ + 2; }
469 
Compile(XlaOpKernelContext * ctx)470   void Compile(XlaOpKernelContext* ctx) override {
471     TensorShape gradients_shape;
472     OP_REQUIRES_OK(
473         ctx, ctx->ConstantInputAsShape(0, &gradients_shape,
474                                        xla::ValueInferenceMode::kUpperBound));
475 
476     const TensorShape out_backprop_shape = ctx->InputShape(1);
477 
478     // For avgpooling, tensor_in_shape should have num_dims() dimensions.
479     OP_REQUIRES(ctx, gradients_shape.dims() == num_dims(),
480                 errors::InvalidArgument("orig_input_shape must be ", num_dims(),
481                                         "-dimensional"));
482 
483     // For avgpooling, out_backprop should have num_dims() dimensions.
484     OP_REQUIRES(ctx, out_backprop_shape.dims() == num_dims(),
485                 errors::InvalidArgument("out_backprop must be ", num_dims(),
486                                         "-dimensional"));
487 
488     auto out_backprop = ctx->Input(1);
489     std::vector<int64_t> stride_int64s(stride_.begin(), stride_.end());
490     xla::Padding xla_padding =
491         (padding_ == VALID) ? xla::Padding::kValid : xla::Padding::kSame;
492     xla::PrimitiveType xla_reduction_type;
493     auto reduction_type = XlaHelpers::SumAccumulationType(ctx->input_type(1));
494     OP_REQUIRES_OK(
495         ctx, DataTypeToPrimitiveType(reduction_type, &xla_reduction_type));
496     auto converted_out_backprop =
497         xla::ConvertElementType(out_backprop, xla_reduction_type);
498     auto xla_data_format =
499         XlaTensorFormat(data_format_, gradients_shape.dims() - 2);
500     auto padding_values =
501         MakeSpatialPadding(gradients_shape.dim_sizes(), ksize_, stride_int64s,
502                            xla_padding, xla_data_format);
503     auto in_backprop =
504         xla::AvgPoolGrad(converted_out_backprop, gradients_shape.dim_sizes(),
505                          ksize_, stride_int64s, padding_values, xla_data_format,
506                          /*counts_include_padding=*/padding_ == VALID);
507     // Convert the pooling result back to the input type before returning it.
508     xla::PrimitiveType xla_out_backprop_type;
509     OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(ctx->input_type(1),
510                                                 &xla_out_backprop_type));
511     ctx->SetOutput(0,
512                    xla::ConvertElementType(in_backprop, xla_out_backprop_type));
513   }
514 
515  protected:
516   const int num_spatial_dims_;
517   std::vector<int64_t> ksize_;
518   std::vector<int32> stride_;
519   Padding padding_;
520   TensorFormat data_format_ = FORMAT_NHWC;
521 };
522 
523 class AvgPool2DGradOp : public AvgPoolGradOp {
524  public:
AvgPool2DGradOp(OpKernelConstruction * ctx)525   explicit AvgPool2DGradOp(OpKernelConstruction* ctx)
526       : AvgPoolGradOp(ctx, /*num_spatial_dims=*/2) {}
527 };
528 REGISTER_XLA_OP(
529     Name("AvgPoolGrad").CompileTimeConstantInput("orig_input_shape"),
530     AvgPool2DGradOp);
531 
532 class AvgPool3DGradOp : public AvgPoolGradOp {
533  public:
AvgPool3DGradOp(OpKernelConstruction * ctx)534   explicit AvgPool3DGradOp(OpKernelConstruction* ctx)
535       : AvgPoolGradOp(ctx, /*num_spatial_dims=*/3) {}
536 };
537 REGISTER_XLA_OP(
538     Name("AvgPool3DGrad").CompileTimeConstantInput("orig_input_shape"),
539     AvgPool3DGradOp);
540 
541 class MaxPoolGradGradOp : public XlaOpKernel {
542  public:
MaxPoolGradGradOp(OpKernelConstruction * ctx,int num_spatial_dims)543   MaxPoolGradGradOp(OpKernelConstruction* ctx, int num_spatial_dims)
544       : XlaOpKernel(ctx), num_spatial_dims_(num_spatial_dims) {
545     if (ctx->num_inputs() == 3) {
546       OP_REQUIRES_OK(ctx, ctx->GetAttr("ksize", &ksize_));
547       OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &stride_));
548     }
549     OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_));
550     OP_REQUIRES(
551         ctx, padding_ != EXPLICIT,
552         errors::Unimplemented(
553             "XLA does not support maxpoolgradgrad with explicit padding."));
554   }
555 
num_dims() const556   int num_dims() const { return num_spatial_dims_ + 2; }
557 
Compile(XlaOpKernelContext * ctx)558   void Compile(XlaOpKernelContext* ctx) override {
559     if (ctx->num_inputs() != 3) {
560       OP_REQUIRES(
561           ctx, ctx->num_inputs() == 5,
562           errors::InvalidArgument("Must supply ksize and stride arguments."));
563       const TensorShape ksize_shape = ctx->InputShape(3);
564       // Validate input sizes.
565       OP_REQUIRES(ctx, TensorShapeUtils::IsVector(ksize_shape),
566                   errors::InvalidArgument("ksize must be a vector, not shape ",
567                                           ksize_shape.DebugString()));
568       OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(3, &ksize_));
569 
570       const TensorShape stride_shape = ctx->InputShape(4);
571       // Validate input sizes.
572       OP_REQUIRES(ctx, TensorShapeUtils::IsVector(stride_shape),
573                   errors::InvalidArgument("stride must be a vector, not shape ",
574                                           stride_shape.DebugString()));
575       OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(4, &stride_));
576     }
577 
578     OP_REQUIRES(ctx, ksize_.size() == num_dims(),
579                 errors::InvalidArgument("Sliding window ksize field must "
580                                         "specify ",
581                                         num_dims(), " dimensions"));
582     OP_REQUIRES(ctx, stride_.size() == num_dims(),
583                 errors::InvalidArgument("Sliding window strides field must "
584                                         "specify ",
585                                         num_dims(), " dimensions"));
586 
587     const TensorShape tensor_in_shape = ctx->InputShape(0);
588     const TensorShape tensor_out_shape = ctx->InputShape(1);
589     const TensorShape out_backprop_shape = ctx->InputShape(2);
590 
591     // For maxpooling, tensor_in should have num_dims() dimensions.
592     OP_REQUIRES(ctx, tensor_in_shape.dims() == num_dims(),
593                 errors::InvalidArgument("tensor_in must be ", num_dims(),
594                                         "-dimensional"));
595     OP_REQUIRES(ctx, tensor_out_shape.dims() == num_dims(),
596                 errors::InvalidArgument("tensor_out must be ", num_dims(),
597                                         "-dimensional"));
598     // For maxpooling, out_backprop should have num_dims() dimensions.
599     OP_REQUIRES(ctx, out_backprop_shape.dims() == num_dims(),
600                 errors::InvalidArgument("out_backprop must be ", num_dims(),
601                                         "-dimensional"));
602 
603     // What we want to compute:
604     // Given y = MaxPool(x), and xs_grad = MaxPoolGrad(x, y, ys_grad)
605     // MaxPoolGradGrad computes {ys_grad}_grad given x, y, and {xs_grad}_grad.
606     //
607     // In the regular TF op, this amounts to selecting for each window the
608     // incoming backprop value from xs_grad_grad that corresponds to the maximal
609     // value in the corresponding window of x.
610     //
611     // TODO(b/73062247): What we really want is a ReduceWindow with different
612     // arrays for index selection vs return value selection--a select-to-gather.
613     //
614     // Here, we implement a bitwise hack: we use the hi 16 bits of input for
615     // separate max pooling alongside each of the hi and lo 16 bits of
616     // out_backprop packed into 16 lo bits, which we then glue back together at
617     // the end to get a full 32 bits of gradient.
618     //
619     // This could select the wrong backprop value for two x values that are
620     // equally maximal up to the first 16 bits, in which case we are taking the
621     // latter.
622     //
623     // Note that in principle we could use 32 separate maxpools to recover each
624     // of 32 bits of the gradient while preserving 31 bits of input for the max
625     // pooling criteria; here, we just truncate to the first 16 bits of input.
626 
627     auto input = ctx->Input(0);
628     auto out_backprop = ctx->Input(2);
629 
630     auto b = ctx->builder();
631 
632     auto sixteen = xla::ConstantR0<uint32>(b, 16);
633     // in (f32) -> round to 7 mantissa bits (bf16)-> 16-high-bit u32.
634     //
635     // NOTE: Use a ReducePrecision operation instead of a cast to BF16 and back
636     // to F32 since the XLA compiler may ignore narrowing casts to floating
637     // point types if the debug option xla_allow_excess_precision is set.
638     auto in_hi = xla::BitcastConvertType(
639         xla::ReducePrecision(input, /*exponent_bits=*/8, /*mantissa_bits=*/7),
640         xla::U32);
641     auto bp_int = xla::BitcastConvertType(out_backprop, xla::U32);
642     auto bp_hi = xla::ShiftRightLogical(bp_int, sixteen);
643     auto bp_lo =
644         xla::ShiftRightLogical(xla::ShiftLeft(bp_int, sixteen), sixteen);
645     auto in_hi_bp_hi = xla::Add(in_hi, bp_hi);  // Want an unsigned add.
646     auto in_hi_bp_lo = xla::Add(in_hi, bp_lo);  // Want an unsigned add.
647 
648     auto init_value = xla::MinValue(b, xla::F32);
649     // We will reduce by taking the maximal value up to 16 bits (ignoring the lo
650     // 16 bits of packed-in hi/lo backprop value).
651     auto rb = b->CreateSubBuilder("GreaterOrEqOf_ByFirst16Bits");
652     {
653       // F32 parameters to satisfy lowering type restriction for reduce opcode.
654       const xla::Shape scalar = xla::ShapeUtil::MakeShape(xla::F32, {});
655       auto lhs = xla::Parameter(rb.get(), 0, scalar, "lhs");
656       auto rhs = xla::Parameter(rb.get(), 1, scalar, "rhs");
657       auto sixteen = xla::ConstantR0<int32>(rb.get(), 16);
658       auto lhs_criteria =
659           xla::ShiftLeft(xla::ShiftRightLogical(
660                              xla::BitcastConvertType(lhs, xla::S32), sixteen),
661                          sixteen);
662       auto rhs_criteria =
663           xla::ShiftLeft(xla::ShiftRightLogical(
664                              xla::BitcastConvertType(rhs, xla::S32), sixteen),
665                          sixteen);
666       // Must use a F32 comparison, because S32 would not work for negatives.
667       xla::Select(xla::Ge(xla::BitcastConvertType(lhs_criteria, xla::F32),
668                           xla::BitcastConvertType(rhs_criteria, xla::F32)),
669                   lhs, rhs);
670     }
671     auto reduce = rb->BuildAndNoteError();
672     xla::Padding xla_padding =
673         (padding_ == VALID) ? xla::Padding::kValid : xla::Padding::kSame;
674     auto pooled_hi =
675         xla::ReduceWindow(xla::BitcastConvertType(in_hi_bp_hi, xla::F32),
676                           init_value, reduce, ksize_, stride_, xla_padding);
677     auto pooled_lo =
678         xla::ReduceWindow(xla::BitcastConvertType(in_hi_bp_lo, xla::F32),
679                           init_value, reduce, ksize_, stride_, xla_padding);
680     auto grads_hi =
681         xla::ShiftLeft(xla::BitcastConvertType(pooled_hi, xla::U32), sixteen);
682     auto grads_lo = xla::ShiftRightLogical(
683         xla::ShiftLeft(xla::BitcastConvertType(pooled_lo, xla::U32), sixteen),
684         sixteen);
685     auto grads = xla::Add(grads_hi, grads_lo);  // Want an unsigned add.
686 
687     xla::PrimitiveType element_type;
688     OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(input_type(2), &element_type));
689     ctx->SetOutput(0, xla::BitcastConvertType(grads, element_type));
690   }
691 
692  protected:
693   const int num_spatial_dims_;
694   std::vector<int64_t> ksize_;
695   std::vector<int64_t> stride_;
696   Padding padding_;
697   TensorFormat data_format_ = FORMAT_NHWC;
698 };
699 
700 class MaxPool2DGradGradOp : public MaxPoolGradGradOp {
701  public:
MaxPool2DGradGradOp(OpKernelConstruction * ctx)702   explicit MaxPool2DGradGradOp(OpKernelConstruction* ctx)
703       : MaxPoolGradGradOp(ctx, /*num_spatial_dims=*/2) {
704     string data_format;
705     OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format));
706     OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_),
707                 errors::InvalidArgument("Invalid data format"));
708   }
709 };
710 REGISTER_XLA_OP(Name("MaxPoolGradGrad").TypeConstraint("T", DT_FLOAT),
711                 MaxPool2DGradGradOp);
712 REGISTER_XLA_OP(Name("MaxPoolGradGradV2")
713                     .TypeConstraint("T", DT_FLOAT)
714                     .CompileTimeConstantInput("ksize")
715                     .CompileTimeConstantInput("strides"),
716                 MaxPool2DGradGradOp);
717 
718 class MaxPool3DGradGradOp : public MaxPoolGradGradOp {
719  public:
MaxPool3DGradGradOp(OpKernelConstruction * ctx)720   explicit MaxPool3DGradGradOp(OpKernelConstruction* ctx)
721       : MaxPoolGradGradOp(ctx, /*num_spatial_dims=*/3) {
722     string data_format;
723     OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format));
724     OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_),
725                 errors::InvalidArgument("Invalid data format"));
726   }
727 };
728 REGISTER_XLA_OP(Name("MaxPool3DGradGrad").TypeConstraint("T", DT_FLOAT),
729                 MaxPool3DGradGradOp);
730 
731 }  // anonymous namespace
732 }  // namespace tensorflow
733