• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // XLA specific pooling ops.
17 
18 #include "tensorflow/compiler/tf2xla/type_util.h"
19 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
20 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
21 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
22 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
23 #include "tensorflow/compiler/xla/client/lib/constants.h"
24 #include "tensorflow/compiler/xla/client/lib/pooling.h"
25 #include "tensorflow/compiler/xla/client/xla_builder.h"
26 #include "tensorflow/compiler/xla/client/xla_computation.h"
27 #include "tensorflow/compiler/xla/literal.h"
28 #include "tensorflow/compiler/xla/util.h"
29 #include "tensorflow/core/framework/bounds_check.h"
30 #include "tensorflow/core/framework/op_kernel.h"
31 #include "tensorflow/core/framework/register_types.h"
32 #include "tensorflow/core/framework/tensor.h"
33 #include "tensorflow/core/kernels/pooling_ops_common.h"
34 
35 namespace tensorflow {
36 namespace {
37 
38 // Superclass of pooling ops.
39 class PoolingOp : public XlaOpKernel {
40  public:
PoolingOp(OpKernelConstruction * ctx,int num_spatial_dims,const DataType reduction_type)41   PoolingOp(OpKernelConstruction* ctx, int num_spatial_dims,
42             const DataType reduction_type)
43       : XlaOpKernel(ctx),
44         num_spatial_dims_(num_spatial_dims),
45         reduction_type_(reduction_type) {
46     if (ctx->num_inputs() == 1) {
47       std::vector<int32> ksize_int;
48       std::vector<int32> stride_int;
49       OP_REQUIRES_OK(ctx, ctx->GetAttr("ksize", &ksize_int));
50       OP_REQUIRES(ctx, ksize_int.size() == num_dims(),
51                   errors::InvalidArgument("Sliding window ksize field must "
52                                           "specify ",
53                                           num_dims(), " dimensions"));
54       OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &stride_int));
55       OP_REQUIRES(ctx, stride_int.size() == num_dims(),
56                   errors::InvalidArgument("Sliding window stride field must "
57                                           "specify ",
58                                           num_dims(), " dimensions"));
59       for (int i = 0; i < num_dims(); ++i) {
60         ksize_.push_back(ksize_int[i]);
61         stride_.push_back(stride_int[i]);
62       }
63     }
64     Padding padding;
65     OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding));
66     padding_ = (padding == VALID) ? xla::Padding::kValid : xla::Padding::kSame;
67 
68     OP_REQUIRES_OK(
69         ctx, DataTypeToPrimitiveType(reduction_type_, &xla_reduction_type_));
70   }
71 
num_dims() const72   int num_dims() const { return num_spatial_dims_ + 2; }
73 
74  protected:
GetKernelSize(XlaOpKernelContext * ctx)75   xla::StatusOr<std::vector<int64>> GetKernelSize(XlaOpKernelContext* ctx) {
76     if (ctx->num_inputs() == 1) {
77       return ksize_;
78     }
79     const TensorShape ksize_shape = ctx->InputShape(1);
80     // Validate input sizes.
81     if (!TensorShapeUtils::IsVector(ksize_shape)) {
82       return errors::InvalidArgument("ksize must be a vector, not shape ",
83                                      ksize_shape.DebugString());
84     }
85     if (ksize_shape.num_elements() != num_dims()) {
86       return errors::InvalidArgument(
87           "Sliding window ksize field must "
88           "specify ",
89           num_dims(), " dimensions");
90     }
91     std::vector<int64> ksize;
92     auto status = ctx->ConstantInputAsIntVector(1, &ksize);
93     if (!status.ok()) {
94       return status;
95     }
96     return ksize;
97   }
98 
GetStride(XlaOpKernelContext * ctx)99   xla::StatusOr<std::vector<int64>> GetStride(XlaOpKernelContext* ctx) {
100     if (ctx->num_inputs() == 1) {
101       return stride_;
102     }
103     const TensorShape stride_shape = ctx->InputShape(2);
104     // Validate input sizes.
105     if (!TensorShapeUtils::IsVector(stride_shape)) {
106       return errors::InvalidArgument("stride must be a vector, not shape ",
107                                      stride_shape.DebugString());
108     }
109     if (stride_shape.num_elements() != num_dims()) {
110       return errors::InvalidArgument(
111           "Sliding window stride field must "
112           "specify ",
113           num_dims(), " dimensions");
114     }
115     std::vector<int64> stride;
116     auto status = ctx->ConstantInputAsIntVector(2, &stride);
117     if (!status.ok()) {
118       return status;
119     }
120     return stride;
121   }
122 
123  protected:
124   const int num_spatial_dims_;
125   std::vector<int64> ksize_;
126   std::vector<int64> stride_;
127   xla::Padding padding_;
128   TensorFormat data_format_ = FORMAT_NHWC;
129   DataType reduction_type_;
130   xla::PrimitiveType xla_reduction_type_;
131 };
132 
133 // Converts the tensor data format to the one required by the XLA pooling
134 // library.
XlaTensorFormat(tensorflow::TensorFormat data_format,int num_spatial_dims)135 xla::TensorFormat XlaTensorFormat(tensorflow::TensorFormat data_format,
136                                   int num_spatial_dims) {
137   int num_dims = num_spatial_dims + 2;
138   int batch_dimension = GetTensorBatchDimIndex(num_dims, data_format);
139   int feature_dimension = GetTensorFeatureDimIndex(num_dims, data_format);
140   absl::InlinedVector<int64, 4> spatial_dimensions(num_spatial_dims);
141   for (int spatial_dim = 0; spatial_dim < num_spatial_dims; ++spatial_dim) {
142     spatial_dimensions[spatial_dim] =
143         GetTensorSpatialDimIndex(num_dims, data_format, spatial_dim);
144   }
145   return xla::TensorFormat(/*batch_dimension=*/batch_dimension,
146                            /*feature_dimension=*/feature_dimension,
147                            /*spatial_dimensions=*/spatial_dimensions);
148 }
149 
150 class MaxPoolOp : public PoolingOp {
151  public:
MaxPoolOp(OpKernelConstruction * ctx,int num_spatial_dims)152   MaxPoolOp(OpKernelConstruction* ctx, int num_spatial_dims)
153       : PoolingOp(ctx, /*num_spatial_dims=*/num_spatial_dims,
154                   /*reduction_type=*/ctx->input_type(0)) {
155     string data_format_str;
156     OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format_str));
157     OP_REQUIRES(ctx, FormatFromString(data_format_str, &data_format_),
158                 errors::InvalidArgument("Invalid data format"));
159   }
160 
Compile(XlaOpKernelContext * ctx)161   void Compile(XlaOpKernelContext* ctx) override {
162     auto ksize_or_error = GetKernelSize(ctx);
163     OP_REQUIRES_OK(ctx, ksize_or_error.status());
164     std::vector<int64> ksize = ksize_or_error.ValueOrDie();
165 
166     auto stride_or_error = GetStride(ctx);
167     OP_REQUIRES_OK(ctx, stride_or_error.status());
168     std::vector<int64> stride = stride_or_error.ValueOrDie();
169 
170     const TensorShape input_shape = ctx->InputShape(0);
171     OP_REQUIRES(ctx, input_shape.dims() == num_dims(),
172                 errors::InvalidArgument("Input to ", type_string(),
173                                         " operator must have ", num_dims(),
174                                         " dimensions"));
175 
176     auto pooling =
177         xla::MaxPool(ctx->Input(0), ksize, stride, padding_,
178                      XlaTensorFormat(data_format_, input_shape.dims() - 2));
179     ctx->SetOutput(0, pooling);
180   }
181 };
182 
183 class MaxPool2DOp : public MaxPoolOp {
184  public:
MaxPool2DOp(OpKernelConstruction * ctx)185   explicit MaxPool2DOp(OpKernelConstruction* ctx)
186       : MaxPoolOp(ctx, /*num_spatial_dims=*/2) {
187   }
188 };
189 REGISTER_XLA_OP(Name("MaxPool"), MaxPool2DOp);
190 REGISTER_XLA_OP(Name("MaxPoolV2")
191                     .CompileTimeConstantInput("ksize")
192                     .CompileTimeConstantInput("strides"),
193                 MaxPool2DOp);
194 
195 class MaxPool3DOp : public MaxPoolOp {
196  public:
MaxPool3DOp(OpKernelConstruction * ctx)197   explicit MaxPool3DOp(OpKernelConstruction* ctx)
198       : MaxPoolOp(ctx, /*num_spatial_dims=*/3) {}
199 };
200 REGISTER_XLA_OP(Name("MaxPool3D"), MaxPool3DOp);
201 
202 class AvgPoolOp : public PoolingOp {
203  public:
AvgPoolOp(OpKernelConstruction * ctx,int num_spatial_dims)204   AvgPoolOp(OpKernelConstruction* ctx, int num_spatial_dims)
205       : PoolingOp(ctx, /*num_spatial_dims=*/num_spatial_dims,
206                   /*reduction_type=*/
207                   XlaHelpers::SumAccumulationType(ctx->input_type(0))) {
208     string data_format_str;
209     OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format_str));
210     OP_REQUIRES(ctx, FormatFromString(data_format_str, &data_format_),
211                 errors::InvalidArgument("Invalid data format"));
212   }
213 
Compile(XlaOpKernelContext * ctx)214   void Compile(XlaOpKernelContext* ctx) override {
215     auto ksize_or_error = GetKernelSize(ctx);
216     OP_REQUIRES_OK(ctx, ksize_or_error.status());
217     std::vector<int64> ksize = ksize_or_error.ValueOrDie();
218 
219     auto stride_or_error = GetStride(ctx);
220     OP_REQUIRES_OK(ctx, stride_or_error.status());
221     std::vector<int64> stride = stride_or_error.ValueOrDie();
222 
223     const TensorShape input_shape = ctx->InputShape(0);
224     OP_REQUIRES(ctx, input_shape.dims() == num_dims(),
225                 errors::InvalidArgument("Input to ", type_string(),
226                                         " operator must have ", num_dims(),
227                                         " dimensions"));
228 
229     auto xla_data_format =
230         XlaTensorFormat(data_format_, input_shape.dims() - 2);
231     auto spatial_padding = MakeSpatialPadding(
232         input_shape.dim_sizes(), ksize, stride, padding_, xla_data_format);
233 
234     // Convert the input to the reduction type.
235     auto converted_input =
236         ConvertElementType(ctx->Input(0), xla_reduction_type_);
237     auto pooling =
238         xla::AvgPool(converted_input, ksize, stride, spatial_padding,
239                      xla_data_format, padding_ == xla::Padding::kValid);
240     // Convert the pooling result back to the input type before returning it.
241     ctx->SetOutput(0, ConvertElementType(pooling, ctx->input_xla_type(0)));
242   }
243 };
244 
245 class AvgPool2DOp : public AvgPoolOp {
246  public:
AvgPool2DOp(OpKernelConstruction * ctx)247   explicit AvgPool2DOp(OpKernelConstruction* ctx)
248       : AvgPoolOp(ctx, /*num_spatial_dims=*/2) {
249   }
250 };
251 REGISTER_XLA_OP(Name("AvgPool"), AvgPool2DOp);
252 
253 class AvgPool3DOp : public AvgPoolOp {
254  public:
AvgPool3DOp(OpKernelConstruction * ctx)255   explicit AvgPool3DOp(OpKernelConstruction* ctx)
256       : AvgPoolOp(ctx, /*num_spatial_dims=*/3) {}
257 };
258 REGISTER_XLA_OP(Name("AvgPool3D"), AvgPool3DOp);
259 
260 // The operation to compute MaxPool gradients.
261 // It takes three inputs:
262 //   - The original input tensor
263 //   - The original output tensor
264 //   - Backprop tensor for output
265 // It produces one output: backprop tensor for input.
266 class MaxPoolGradOp : public XlaOpKernel {
267  public:
MaxPoolGradOp(OpKernelConstruction * ctx,int num_spatial_dims)268   MaxPoolGradOp(OpKernelConstruction* ctx, int num_spatial_dims)
269       : XlaOpKernel(ctx), num_spatial_dims_(num_spatial_dims) {
270     if (ctx->num_inputs() == 3) {
271       OP_REQUIRES_OK(ctx, ctx->GetAttr("ksize", &ksize_));
272       OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &stride_));
273     }
274     OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_));
275   }
276 
num_dims() const277   int num_dims() const { return num_spatial_dims_ + 2; }
278 
Compile(XlaOpKernelContext * ctx)279   void Compile(XlaOpKernelContext* ctx) override {
280     if (ctx->num_inputs() != 3) {
281       OP_REQUIRES(
282           ctx, ctx->num_inputs() == 5,
283           errors::InvalidArgument("Must supply ksize and stride arguments."));
284       const TensorShape ksize_shape = ctx->InputShape(3);
285       // Validate input sizes.
286       OP_REQUIRES(ctx, TensorShapeUtils::IsVector(ksize_shape),
287                   errors::InvalidArgument("ksize must be a vector, not shape ",
288                                           ksize_shape.DebugString()));
289       OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(3, &ksize_));
290 
291       const TensorShape stride_shape = ctx->InputShape(4);
292       // Validate input sizes.
293       OP_REQUIRES(ctx, TensorShapeUtils::IsVector(stride_shape),
294                   errors::InvalidArgument("stride must be a vector, not shape ",
295                                           stride_shape.DebugString()));
296       OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(4, &stride_));
297     }
298 
299     OP_REQUIRES(ctx, ksize_.size() == num_dims(),
300                 errors::InvalidArgument("Sliding window ksize field must "
301                                         "specify ",
302                                         num_dims(), " dimensions"));
303     OP_REQUIRES(ctx, stride_.size() == num_dims(),
304                 errors::InvalidArgument("Sliding window strides field must "
305                                         "specify ",
306                                         num_dims(), " dimensions"));
307 
308     const TensorShape tensor_in_shape = ctx->InputShape(0);
309     const TensorShape tensor_out_shape = ctx->InputShape(1);
310     const TensorShape out_backprop_shape = ctx->InputShape(2);
311 
312     // For maxpooling, tensor_in should have num_dims() dimensions.
313     OP_REQUIRES(ctx, tensor_in_shape.dims() == num_dims(),
314                 errors::InvalidArgument("tensor_in must be ", num_dims(),
315                                         "-dimensional"));
316     OP_REQUIRES(ctx, tensor_out_shape.dims() == num_dims(),
317                 errors::InvalidArgument("tensor_out must be ", num_dims(),
318                                         "-dimensional"));
319     // For maxpooling, out_backprop should have num_dims() dimensions.
320     OP_REQUIRES(ctx, out_backprop_shape.dims() == num_dims(),
321                 errors::InvalidArgument("out_backprop must be ", num_dims(),
322                                         "-dimensional"));
323 
324     // TODO(phawkins): The XLA version doesn't need tensor_out. Investigate
325     // whether this is a good time/space tradeoff.
326     auto input = ctx->Input(0);
327     auto out_backprop = ctx->Input(2);
328 
329     xla::Padding xla_padding =
330         (padding_ == VALID) ? xla::Padding::kValid : xla::Padding::kSame;
331 
332     xla::PrimitiveType element_type;
333     OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(input_type(2), &element_type));
334     xla::XlaOp init_value = XlaHelpers::Zero(ctx->builder(), input_type(2));
335     auto select = CreateScalarGeComputation(element_type, ctx->builder());
336     auto scatter = CreateScalarAddComputation(element_type, ctx->builder());
337     xla::XlaOp gradients =
338         xla::SelectAndScatter(input, select, ksize_, stride_, xla_padding,
339                               out_backprop, init_value, scatter);
340 
341     ctx->SetOutput(0, gradients);
342   }
343 
344  protected:
345   const int num_spatial_dims_;
346   std::vector<int64> ksize_;
347   std::vector<int64> stride_;
348   Padding padding_;
349   TensorFormat data_format_ = FORMAT_NHWC;
350 };
351 
352 class MaxPool2DGradOp : public MaxPoolGradOp {
353  public:
MaxPool2DGradOp(OpKernelConstruction * ctx)354   explicit MaxPool2DGradOp(OpKernelConstruction* ctx)
355       : MaxPoolGradOp(ctx, /*num_spatial_dims=*/2) {
356     string data_format;
357     OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format));
358     OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_),
359                 errors::InvalidArgument("Invalid data format"));
360   }
361 };
362 REGISTER_XLA_OP(Name("MaxPoolGrad"), MaxPool2DGradOp);
363 REGISTER_XLA_OP(Name("MaxPoolGradV2")
364                     .CompileTimeConstantInput("ksize")
365                     .CompileTimeConstantInput("strides"),
366                 MaxPool2DGradOp);
367 
368 class MaxPool3DGradOp : public MaxPoolGradOp {
369  public:
MaxPool3DGradOp(OpKernelConstruction * ctx)370   explicit MaxPool3DGradOp(OpKernelConstruction* ctx)
371       : MaxPoolGradOp(ctx, /*num_spatial_dims=*/3) {}
372 };
373 REGISTER_XLA_OP(Name("MaxPool3DGrad"), MaxPool3DGradOp);
374 
375 // Average-pooling gradient
376 class AvgPoolGradOp : public XlaOpKernel {
377  public:
AvgPoolGradOp(OpKernelConstruction * ctx,int num_spatial_dims)378   AvgPoolGradOp(OpKernelConstruction* ctx, int num_spatial_dims)
379       : XlaOpKernel(ctx), num_spatial_dims_(num_spatial_dims) {
380     OP_REQUIRES_OK(ctx, ctx->GetAttr("ksize", &ksize_));
381     OP_REQUIRES(ctx, ksize_.size() == num_dims(),
382                 errors::InvalidArgument("Sliding window ksize field must "
383                                         "specify ",
384                                         num_dims(), " dimensions"));
385     OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &stride_));
386     OP_REQUIRES(ctx, stride_.size() == num_dims(),
387                 errors::InvalidArgument("Sliding window strides field must "
388                                         "specify ",
389                                         num_dims(), " dimensions"));
390     OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_));
391     OP_REQUIRES(ctx, ksize_[0] == 1 && stride_[0] == 1,
392                 errors::Unimplemented(
393                     "Pooling is not yet supported on the batch dimension."));
394 
395     string data_format;
396     OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format));
397     OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_),
398                 errors::InvalidArgument("Invalid data format"));
399   }
400 
num_dims() const401   int num_dims() const { return num_spatial_dims_ + 2; }
402 
Compile(XlaOpKernelContext * ctx)403   void Compile(XlaOpKernelContext* ctx) override {
404     TensorShape gradients_shape;
405     OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &gradients_shape));
406 
407     const TensorShape out_backprop_shape = ctx->InputShape(1);
408 
409     // For avgpooling, tensor_in_shape should have num_dims() dimensions.
410     OP_REQUIRES(ctx, gradients_shape.dims() == num_dims(),
411                 errors::InvalidArgument("orig_input_shape must be ", num_dims(),
412                                         "-dimensional"));
413 
414     // For avgpooling, out_backprop should have num_dims() dimensions.
415     OP_REQUIRES(ctx, out_backprop_shape.dims() == num_dims(),
416                 errors::InvalidArgument("out_backprop must be ", num_dims(),
417                                         "-dimensional"));
418 
419     auto out_backprop = ctx->Input(1);
420     std::vector<int64> stride_int64s(stride_.begin(), stride_.end());
421     xla::Padding xla_padding =
422         (padding_ == VALID) ? xla::Padding::kValid : xla::Padding::kSame;
423     xla::PrimitiveType xla_reduction_type;
424     auto reduction_type = XlaHelpers::SumAccumulationType(ctx->input_type(1));
425     OP_REQUIRES_OK(
426         ctx, DataTypeToPrimitiveType(reduction_type, &xla_reduction_type));
427     auto converted_out_backprop =
428         xla::ConvertElementType(out_backprop, xla_reduction_type);
429     auto xla_data_format =
430         XlaTensorFormat(data_format_, gradients_shape.dims() - 2);
431     auto padding_values =
432         MakeSpatialPadding(gradients_shape.dim_sizes(), ksize_, stride_int64s,
433                            xla_padding, xla_data_format);
434     auto in_backprop =
435         xla::AvgPoolGrad(converted_out_backprop, gradients_shape.dim_sizes(),
436                          ksize_, stride_int64s, padding_values, xla_data_format,
437                          /*counts_include_padding=*/padding_ == VALID);
438     // Convert the pooling result back to the input type before returning it.
439     xla::PrimitiveType xla_out_backprop_type;
440     OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(ctx->input_type(1),
441                                                 &xla_out_backprop_type));
442     ctx->SetOutput(0,
443                    xla::ConvertElementType(in_backprop, xla_out_backprop_type));
444   }
445 
446  protected:
447   const int num_spatial_dims_;
448   std::vector<int64> ksize_;
449   std::vector<int32> stride_;
450   Padding padding_;
451   TensorFormat data_format_ = FORMAT_NHWC;
452 };
453 
454 class AvgPool2DGradOp : public AvgPoolGradOp {
455  public:
AvgPool2DGradOp(OpKernelConstruction * ctx)456   explicit AvgPool2DGradOp(OpKernelConstruction* ctx)
457       : AvgPoolGradOp(ctx, /*num_spatial_dims=*/2) {
458   }
459 };
460 REGISTER_XLA_OP(
461     Name("AvgPoolGrad").CompileTimeConstantInput("orig_input_shape"),
462     AvgPool2DGradOp);
463 
464 class AvgPool3DGradOp : public AvgPoolGradOp {
465  public:
AvgPool3DGradOp(OpKernelConstruction * ctx)466   explicit AvgPool3DGradOp(OpKernelConstruction* ctx)
467       : AvgPoolGradOp(ctx, /*num_spatial_dims=*/3) {}
468 };
469 REGISTER_XLA_OP(
470     Name("AvgPool3DGrad").CompileTimeConstantInput("orig_input_shape"),
471     AvgPool3DGradOp);
472 
473 class MaxPoolGradGradOp : public XlaOpKernel {
474  public:
MaxPoolGradGradOp(OpKernelConstruction * ctx,int num_spatial_dims)475   MaxPoolGradGradOp(OpKernelConstruction* ctx, int num_spatial_dims)
476       : XlaOpKernel(ctx), num_spatial_dims_(num_spatial_dims) {
477     if (ctx->num_inputs() == 3) {
478       OP_REQUIRES_OK(ctx, ctx->GetAttr("ksize", &ksize_));
479       OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &stride_));
480     }
481     OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_));
482   }
483 
num_dims() const484   int num_dims() const { return num_spatial_dims_ + 2; }
485 
Compile(XlaOpKernelContext * ctx)486   void Compile(XlaOpKernelContext* ctx) override {
487     if (ctx->num_inputs() != 3) {
488       OP_REQUIRES(
489           ctx, ctx->num_inputs() == 5,
490           errors::InvalidArgument("Must supply ksize and stride arguments."));
491       const TensorShape ksize_shape = ctx->InputShape(3);
492       // Validate input sizes.
493       OP_REQUIRES(ctx, TensorShapeUtils::IsVector(ksize_shape),
494                   errors::InvalidArgument("ksize must be a vector, not shape ",
495                                           ksize_shape.DebugString()));
496       OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(3, &ksize_));
497 
498       const TensorShape stride_shape = ctx->InputShape(4);
499       // Validate input sizes.
500       OP_REQUIRES(ctx, TensorShapeUtils::IsVector(stride_shape),
501                   errors::InvalidArgument("stride must be a vector, not shape ",
502                                           stride_shape.DebugString()));
503       OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(4, &stride_));
504     }
505 
506     OP_REQUIRES(ctx, ksize_.size() == num_dims(),
507                 errors::InvalidArgument("Sliding window ksize field must "
508                                         "specify ",
509                                         num_dims(), " dimensions"));
510     OP_REQUIRES(ctx, stride_.size() == num_dims(),
511                 errors::InvalidArgument("Sliding window strides field must "
512                                         "specify ",
513                                         num_dims(), " dimensions"));
514 
515     const TensorShape tensor_in_shape = ctx->InputShape(0);
516     const TensorShape tensor_out_shape = ctx->InputShape(1);
517     const TensorShape out_backprop_shape = ctx->InputShape(2);
518 
519     // For maxpooling, tensor_in should have num_dims() dimensions.
520     OP_REQUIRES(ctx, tensor_in_shape.dims() == num_dims(),
521                 errors::InvalidArgument("tensor_in must be ", num_dims(),
522                                         "-dimensional"));
523     OP_REQUIRES(ctx, tensor_out_shape.dims() == num_dims(),
524                 errors::InvalidArgument("tensor_out must be ", num_dims(),
525                                         "-dimensional"));
526     // For maxpooling, out_backprop should have num_dims() dimensions.
527     OP_REQUIRES(ctx, out_backprop_shape.dims() == num_dims(),
528                 errors::InvalidArgument("out_backprop must be ", num_dims(),
529                                         "-dimensional"));
530 
531     // What we want to compute:
532     // Given y = MaxPool(x), and xs_grad = MaxPoolGrad(x, y, ys_grad)
533     // MaxPoolGradGrad computes {ys_grad}_grad given x, y, and {xs_grad}_grad.
534     //
535     // In the regular TF op, this amounts to selecting for each window the
536     // incoming backprop value from xs_grad_grad that corresponds to the maximal
537     // value in the corresponding window of x.
538     //
539     // TODO(b/73062247): What we really want is a ReduceWindow with different
540     // arrays for index selection vs return value selection--a select-to-gather.
541     //
542     // Here, we implement a bitwise hack: we use the hi 16 bits of input for
543     // separate max pooling alongside each of the hi and lo 16 bits of
544     // out_backprop packed into 16 lo bits, which we then glue back together at
545     // the end to get a full 32 bits of gradient.
546     //
547     // This could select the wrong backprop value for two x values that are
548     // equally maximal up to the first 16 bits, in which case we are taking the
549     // latter.
550     //
551     // Note that in principle we could use 32 separate maxpools to recover each
552     // of 32 bits of the gradient while preserving 31 bits of input for the max
553     // pooling criteria; here, we just truncate to the first 16 bits of input.
554 
555     auto input = ctx->Input(0);
556     auto out_backprop = ctx->Input(2);
557 
558     auto b = ctx->builder();
559 
560     auto sixteen = xla::ConstantR0<uint32>(b, 16);
561     // in (f32) -> round to bf16 -> f32 for correct bitwidth -> 16-high-bit u32
562     auto in_hi = xla::BitcastConvertType(
563         xla::ConvertElementType(xla::ConvertElementType(input, xla::BF16),
564                                 xla::F32),
565         xla::U32);
566     auto bp_int = xla::BitcastConvertType(out_backprop, xla::U32);
567     auto bp_hi = xla::ShiftRightLogical(bp_int, sixteen);
568     auto bp_lo =
569         xla::ShiftRightLogical(xla::ShiftLeft(bp_int, sixteen), sixteen);
570     auto in_hi_bp_hi = xla::Add(in_hi, bp_hi);  // Want an unsigned add.
571     auto in_hi_bp_lo = xla::Add(in_hi, bp_lo);  // Want an unsigned add.
572 
573     auto init_value = xla::MinValue(b, xla::F32);
574     // We will reduce by taking the maximal value up to 16 bits (ignoring the lo
575     // 16 bits of packed-in hi/lo backprop value).
576     auto rb = b->CreateSubBuilder("GreaterOrEqOf_ByFirst16Bits");
577     {
578       // F32 parameters to satisfy lowering type restriction for reduce opcode.
579       const xla::Shape scalar = xla::ShapeUtil::MakeShape(xla::F32, {});
580       auto lhs = xla::Parameter(rb.get(), 0, scalar, "lhs");
581       auto rhs = xla::Parameter(rb.get(), 1, scalar, "rhs");
582       auto sixteen = xla::ConstantR0<int32>(rb.get(), 16);
583       auto lhs_criteria =
584           xla::ShiftLeft(xla::ShiftRightLogical(
585                              xla::BitcastConvertType(lhs, xla::S32), sixteen),
586                          sixteen);
587       auto rhs_criteria =
588           xla::ShiftLeft(xla::ShiftRightLogical(
589                              xla::BitcastConvertType(rhs, xla::S32), sixteen),
590                          sixteen);
591       // Must use a F32 comparison, because S32 would not work for negatives.
592       xla::Select(xla::Ge(xla::BitcastConvertType(lhs_criteria, xla::F32),
593                           xla::BitcastConvertType(rhs_criteria, xla::F32)),
594                   lhs, rhs);
595     }
596     auto reduce = rb->BuildAndNoteError();
597     xla::Padding xla_padding =
598         (padding_ == VALID) ? xla::Padding::kValid : xla::Padding::kSame;
599     auto pooled_hi =
600         xla::ReduceWindow(xla::BitcastConvertType(in_hi_bp_hi, xla::F32),
601                           init_value, reduce, ksize_, stride_, xla_padding);
602     auto pooled_lo =
603         xla::ReduceWindow(xla::BitcastConvertType(in_hi_bp_lo, xla::F32),
604                           init_value, reduce, ksize_, stride_, xla_padding);
605     auto grads_hi =
606         xla::ShiftLeft(xla::BitcastConvertType(pooled_hi, xla::U32), sixteen);
607     auto grads_lo = xla::ShiftRightLogical(
608         xla::ShiftLeft(xla::BitcastConvertType(pooled_lo, xla::U32), sixteen),
609         sixteen);
610     auto grads = xla::Add(grads_hi, grads_lo);  // Want an unsigned add.
611 
612     xla::PrimitiveType element_type;
613     OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(input_type(2), &element_type));
614     ctx->SetOutput(0, xla::BitcastConvertType(grads, element_type));
615   }
616 
617  protected:
618   const int num_spatial_dims_;
619   std::vector<int64> ksize_;
620   std::vector<int64> stride_;
621   Padding padding_;
622   TensorFormat data_format_ = FORMAT_NHWC;
623 };
624 
625 class MaxPool2DGradGradOp : public MaxPoolGradGradOp {
626  public:
MaxPool2DGradGradOp(OpKernelConstruction * ctx)627   explicit MaxPool2DGradGradOp(OpKernelConstruction* ctx)
628       : MaxPoolGradGradOp(ctx, /*num_spatial_dims=*/2) {
629     string data_format;
630     OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format));
631     OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_),
632                 errors::InvalidArgument("Invalid data format"));
633   }
634 };
635 REGISTER_XLA_OP(Name("MaxPoolGradGrad").TypeConstraint("T", DT_FLOAT),
636                 MaxPool2DGradGradOp);
637 REGISTER_XLA_OP(Name("MaxPoolGradGradV2")
638                     .TypeConstraint("T", DT_FLOAT)
639                     .CompileTimeConstantInput("ksize")
640                     .CompileTimeConstantInput("strides"),
641                 MaxPool2DGradGradOp);
642 
643 class MaxPool3DGradGradOp : public MaxPoolGradGradOp {
644  public:
MaxPool3DGradGradOp(OpKernelConstruction * ctx)645   explicit MaxPool3DGradGradOp(OpKernelConstruction* ctx)
646       : MaxPoolGradGradOp(ctx, /*num_spatial_dims=*/3) {
647     string data_format;
648     OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format));
649     OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_),
650                 errors::InvalidArgument("Invalid data format"));
651   }
652 };
653 REGISTER_XLA_OP(Name("MaxPool3DGradGrad").TypeConstraint("T", DT_FLOAT),
654                 MaxPool3DGradGradOp);
655 
656 }  // anonymous namespace
657 }  // namespace tensorflow
658