• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // XLA specific pooling ops.
17 
18 #include "tensorflow/compiler/tf2xla/shape_util.h"
19 #include "tensorflow/compiler/tf2xla/type_util.h"
20 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
21 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
22 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
23 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
24 #include "tensorflow/compiler/xla/client/lib/constants.h"
25 #include "tensorflow/compiler/xla/client/lib/pooling.h"
26 #include "tensorflow/compiler/xla/client/xla_builder.h"
27 #include "tensorflow/compiler/xla/client/xla_computation.h"
28 #include "tensorflow/compiler/xla/literal.h"
29 #include "tensorflow/compiler/xla/util.h"
30 #include "tensorflow/core/framework/bounds_check.h"
31 #include "tensorflow/core/framework/op_kernel.h"
32 #include "tensorflow/core/framework/register_types.h"
33 #include "tensorflow/core/framework/tensor.h"
34 #include "tensorflow/core/platform/errors.h"
35 #include "tensorflow/core/util/tensor_format.h"
36 
37 namespace tensorflow {
38 namespace {
39 
40 // Superclass of pooling ops.
41 class PoolingOp : public XlaOpKernel {
42  public:
PoolingOp(OpKernelConstruction * ctx,int num_spatial_dims,const DataType reduction_type)43   PoolingOp(OpKernelConstruction* ctx, int num_spatial_dims,
44             const DataType reduction_type)
45       : XlaOpKernel(ctx),
46         num_spatial_dims_(num_spatial_dims),
47         reduction_type_(reduction_type) {
48     if (ctx->num_inputs() == 1) {
49       std::vector<int32> ksize_int;
50       std::vector<int32> stride_int;
51       OP_REQUIRES_OK(ctx, ctx->GetAttr("ksize", &ksize_int));
52       OP_REQUIRES(ctx, ksize_int.size() == num_dims(),
53                   errors::InvalidArgument("Sliding window ksize field must "
54                                           "specify ",
55                                           num_dims(), " dimensions"));
56       OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &stride_int));
57       OP_REQUIRES(ctx, stride_int.size() == num_dims(),
58                   errors::InvalidArgument("Sliding window stride field must "
59                                           "specify ",
60                                           num_dims(), " dimensions"));
61       for (int i = 0; i < num_dims(); ++i) {
62         ksize_.push_back(ksize_int[i]);
63         stride_.push_back(stride_int[i]);
64       }
65     }
66     Padding padding;
67     OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding));
68     padding_ = (padding == VALID) ? xla::Padding::kValid : xla::Padding::kSame;
69 
70     OP_REQUIRES_OK(
71         ctx, DataTypeToPrimitiveType(reduction_type_, &xla_reduction_type_));
72   }
73 
num_dims() const74   int num_dims() const { return num_spatial_dims_ + 2; }
75 
76  protected:
GetKernelSize(XlaOpKernelContext * ctx)77   xla::StatusOr<std::vector<int64>> GetKernelSize(XlaOpKernelContext* ctx) {
78     if (ctx->num_inputs() == 1) {
79       return ksize_;
80     }
81     const TensorShape ksize_shape = ctx->InputShape(1);
82     // Validate input sizes.
83     if (!TensorShapeUtils::IsVector(ksize_shape)) {
84       return errors::InvalidArgument("ksize must be a vector, not shape ",
85                                      ksize_shape.DebugString());
86     }
87     if (ksize_shape.num_elements() != num_dims()) {
88       return errors::InvalidArgument(
89           "Sliding window ksize field must "
90           "specify ",
91           num_dims(), " dimensions");
92     }
93     std::vector<int64> ksize;
94     auto status = ctx->ConstantInputAsIntVector(1, &ksize);
95     if (!status.ok()) {
96       return status;
97     }
98     return ksize;
99   }
100 
GetStride(XlaOpKernelContext * ctx)101   xla::StatusOr<std::vector<int64>> GetStride(XlaOpKernelContext* ctx) {
102     if (ctx->num_inputs() == 1) {
103       return stride_;
104     }
105     const TensorShape stride_shape = ctx->InputShape(2);
106     // Validate input sizes.
107     if (!TensorShapeUtils::IsVector(stride_shape)) {
108       return errors::InvalidArgument("stride must be a vector, not shape ",
109                                      stride_shape.DebugString());
110     }
111     if (stride_shape.num_elements() != num_dims()) {
112       return errors::InvalidArgument(
113           "Sliding window stride field must "
114           "specify ",
115           num_dims(), " dimensions");
116     }
117     std::vector<int64> stride;
118     auto status = ctx->ConstantInputAsIntVector(2, &stride);
119     if (!status.ok()) {
120       return status;
121     }
122     return stride;
123   }
124 
125  protected:
126   const int num_spatial_dims_;
127   std::vector<int64> ksize_;
128   std::vector<int64> stride_;
129   xla::Padding padding_;
130   TensorFormat data_format_ = FORMAT_NHWC;
131   DataType reduction_type_;
132   xla::PrimitiveType xla_reduction_type_;
133 };
134 
135 // Converts the tensor data format to the one required by the XLA pooling
136 // library.
XlaTensorFormat(tensorflow::TensorFormat data_format,int num_spatial_dims)137 xla::TensorFormat XlaTensorFormat(tensorflow::TensorFormat data_format,
138                                   int num_spatial_dims) {
139   int num_dims = num_spatial_dims + 2;
140   int batch_dimension = GetTensorBatchDimIndex(num_dims, data_format);
141   int feature_dimension = GetTensorFeatureDimIndex(num_dims, data_format);
142   absl::InlinedVector<int64, 4> spatial_dimensions(num_spatial_dims);
143   for (int spatial_dim = 0; spatial_dim < num_spatial_dims; ++spatial_dim) {
144     spatial_dimensions[spatial_dim] =
145         GetTensorSpatialDimIndex(num_dims, data_format, spatial_dim);
146   }
147   return xla::TensorFormat(/*batch_dimension=*/batch_dimension,
148                            /*feature_dimension=*/feature_dimension,
149                            /*spatial_dimensions=*/spatial_dimensions);
150 }
151 
152 class MaxPoolOp : public PoolingOp {
153  public:
MaxPoolOp(OpKernelConstruction * ctx,int num_spatial_dims)154   MaxPoolOp(OpKernelConstruction* ctx, int num_spatial_dims)
155       : PoolingOp(ctx, /*num_spatial_dims=*/num_spatial_dims,
156                   /*reduction_type=*/ctx->input_type(0)) {
157     string data_format_str;
158     OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format_str));
159     OP_REQUIRES(ctx, FormatFromString(data_format_str, &data_format_),
160                 errors::InvalidArgument("Invalid data format"));
161     OP_REQUIRES(
162         ctx,
163         data_format_ != FORMAT_NCHW_VECT_C &&
164             data_format_ != FORMAT_NHWC_VECT_W,
165         errors::Unimplemented("XLA does not support the VECT_* data formats. "
166                               "Returning unimplemented from MaxPool to keep "
167                               "Tensorflow's intended optimized MaxPool here."));
168   }
169 
Compile(XlaOpKernelContext * ctx)170   void Compile(XlaOpKernelContext* ctx) override {
171     auto ksize_or_error = GetKernelSize(ctx);
172     OP_REQUIRES_OK(ctx, ksize_or_error.status());
173     std::vector<int64> ksize = ksize_or_error.ValueOrDie();
174 
175     auto stride_or_error = GetStride(ctx);
176     OP_REQUIRES_OK(ctx, stride_or_error.status());
177     std::vector<int64> stride = stride_or_error.ValueOrDie();
178 
179     const TensorShape input_shape = ctx->InputShape(0);
180     OP_REQUIRES(ctx, input_shape.dims() == num_dims(),
181                 errors::InvalidArgument("Input to ", type_string(),
182                                         " operator must have ", num_dims(),
183                                         " dimensions"));
184 
185     auto pooling =
186         xla::MaxPool(ctx->Input(0), ksize, stride, padding_,
187                      XlaTensorFormat(data_format_, input_shape.dims() - 2));
188     ctx->SetOutput(0, pooling);
189   }
190 };
191 
192 class MaxPool2DOp : public MaxPoolOp {
193  public:
MaxPool2DOp(OpKernelConstruction * ctx)194   explicit MaxPool2DOp(OpKernelConstruction* ctx)
195       : MaxPoolOp(ctx, /*num_spatial_dims=*/2) {}
196 };
197 REGISTER_XLA_OP(Name("MaxPool"), MaxPool2DOp);
198 REGISTER_XLA_OP(Name("MaxPoolV2")
199                     .CompileTimeConstantInput("ksize")
200                     .CompileTimeConstantInput("strides"),
201                 MaxPool2DOp);
202 
203 class MaxPool3DOp : public MaxPoolOp {
204  public:
MaxPool3DOp(OpKernelConstruction * ctx)205   explicit MaxPool3DOp(OpKernelConstruction* ctx)
206       : MaxPoolOp(ctx, /*num_spatial_dims=*/3) {}
207 };
208 REGISTER_XLA_OP(Name("MaxPool3D"), MaxPool3DOp);
209 
210 class AvgPoolOp : public PoolingOp {
211  public:
AvgPoolOp(OpKernelConstruction * ctx,int num_spatial_dims)212   AvgPoolOp(OpKernelConstruction* ctx, int num_spatial_dims)
213       : PoolingOp(ctx, /*num_spatial_dims=*/num_spatial_dims,
214                   /*reduction_type=*/
215                   XlaHelpers::SumAccumulationType(ctx->input_type(0))) {
216     string data_format_str;
217     OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format_str));
218     OP_REQUIRES(ctx, FormatFromString(data_format_str, &data_format_),
219                 errors::InvalidArgument("Invalid data format"));
220   }
221 
Compile(XlaOpKernelContext * ctx)222   void Compile(XlaOpKernelContext* ctx) override {
223     auto ksize_or_error = GetKernelSize(ctx);
224     OP_REQUIRES_OK(ctx, ksize_or_error.status());
225     std::vector<int64> ksize = ksize_or_error.ValueOrDie();
226 
227     auto stride_or_error = GetStride(ctx);
228     OP_REQUIRES_OK(ctx, stride_or_error.status());
229     std::vector<int64> stride = stride_or_error.ValueOrDie();
230 
231     const TensorShape input_shape = ctx->InputShape(0);
232     OP_REQUIRES(ctx, input_shape.dims() == num_dims(),
233                 errors::InvalidArgument("Input to ", type_string(),
234                                         " operator must have ", num_dims(),
235                                         " dimensions"));
236 
237     auto xla_data_format =
238         XlaTensorFormat(data_format_, input_shape.dims() - 2);
239     auto spatial_padding = MakeSpatialPadding(
240         input_shape.dim_sizes(), ksize, stride, padding_, xla_data_format);
241 
242     // Convert the input to the reduction type.
243     auto converted_input =
244         ConvertElementType(ctx->Input(0), xla_reduction_type_);
245     auto pooling =
246         xla::AvgPool(converted_input, ksize, stride, spatial_padding,
247                      xla_data_format, padding_ == xla::Padding::kValid);
248     // Convert the pooling result back to the input type before returning it.
249     ctx->SetOutput(0, ConvertElementType(pooling, ctx->input_xla_type(0)));
250   }
251 };
252 
253 class AvgPool2DOp : public AvgPoolOp {
254  public:
AvgPool2DOp(OpKernelConstruction * ctx)255   explicit AvgPool2DOp(OpKernelConstruction* ctx)
256       : AvgPoolOp(ctx, /*num_spatial_dims=*/2) {}
257 };
258 REGISTER_XLA_OP(Name("AvgPool"), AvgPool2DOp);
259 
260 class AvgPool3DOp : public AvgPoolOp {
261  public:
AvgPool3DOp(OpKernelConstruction * ctx)262   explicit AvgPool3DOp(OpKernelConstruction* ctx)
263       : AvgPoolOp(ctx, /*num_spatial_dims=*/3) {}
264 };
265 REGISTER_XLA_OP(Name("AvgPool3D"), AvgPool3DOp);
266 
267 // The operation to compute MaxPool gradients.
268 // It takes three inputs:
269 //   - The original input tensor
270 //   - The original output tensor
271 //   - Backprop tensor for output
272 // It produces one output: backprop tensor for input.
273 class MaxPoolGradOp : public XlaOpKernel {
274  public:
MaxPoolGradOp(OpKernelConstruction * ctx,int num_spatial_dims)275   MaxPoolGradOp(OpKernelConstruction* ctx, int num_spatial_dims)
276       : XlaOpKernel(ctx), num_spatial_dims_(num_spatial_dims) {
277     if (ctx->num_inputs() == 3) {
278       OP_REQUIRES_OK(ctx, ctx->GetAttr("ksize", &ksize_));
279       OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &stride_));
280     }
281     OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_));
282   }
283 
num_dims() const284   int num_dims() const { return num_spatial_dims_ + 2; }
285 
Compile(XlaOpKernelContext * ctx)286   void Compile(XlaOpKernelContext* ctx) override {
287     if (ctx->num_inputs() != 3) {
288       OP_REQUIRES(
289           ctx, ctx->num_inputs() == 5,
290           errors::InvalidArgument("Must supply ksize and stride arguments."));
291       const TensorShape ksize_shape = ctx->InputShape(3);
292       // Validate input sizes.
293       OP_REQUIRES(ctx, TensorShapeUtils::IsVector(ksize_shape),
294                   errors::InvalidArgument("ksize must be a vector, not shape ",
295                                           ksize_shape.DebugString()));
296       OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(3, &ksize_));
297 
298       const TensorShape stride_shape = ctx->InputShape(4);
299       // Validate input sizes.
300       OP_REQUIRES(ctx, TensorShapeUtils::IsVector(stride_shape),
301                   errors::InvalidArgument("stride must be a vector, not shape ",
302                                           stride_shape.DebugString()));
303       OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(4, &stride_));
304     }
305 
306     OP_REQUIRES(ctx, ksize_.size() == num_dims(),
307                 errors::InvalidArgument("Sliding window ksize field must "
308                                         "specify ",
309                                         num_dims(), " dimensions"));
310     OP_REQUIRES(ctx, stride_.size() == num_dims(),
311                 errors::InvalidArgument("Sliding window strides field must "
312                                         "specify ",
313                                         num_dims(), " dimensions"));
314 
315     const TensorShape tensor_in_shape = ctx->InputShape(0);
316     const TensorShape tensor_out_shape = ctx->InputShape(1);
317     const TensorShape out_backprop_shape = ctx->InputShape(2);
318 
319     // For maxpooling, tensor_in should have num_dims() dimensions.
320     OP_REQUIRES(ctx, tensor_in_shape.dims() == num_dims(),
321                 errors::InvalidArgument("tensor_in must be ", num_dims(),
322                                         "-dimensional"));
323     OP_REQUIRES(ctx, tensor_out_shape.dims() == num_dims(),
324                 errors::InvalidArgument("tensor_out must be ", num_dims(),
325                                         "-dimensional"));
326     // For maxpooling, out_backprop should have num_dims() dimensions.
327     OP_REQUIRES(ctx, out_backprop_shape.dims() == num_dims(),
328                 errors::InvalidArgument("out_backprop must be ", num_dims(),
329                                         "-dimensional"));
330 
331     // TODO(phawkins): The XLA version doesn't need tensor_out. Investigate
332     // whether this is a good time/space tradeoff.
333     auto input = ctx->Input(0);
334     auto out_backprop = ctx->Input(2);
335 
336     xla::Padding xla_padding =
337         (padding_ == VALID) ? xla::Padding::kValid : xla::Padding::kSame;
338 
339     // Create a MaxPool operation to check the expected resulting shape, and
340     // then throw away the operation because we don't actually need it here.
341     TensorShape expected_out_shape;
342     auto pooling =
343         xla::MaxPool(ctx->Input(0), ksize_, stride_, xla_padding,
344                      XlaTensorFormat(data_format_, tensor_in_shape.dims() - 2));
345     auto status_or_shape = pooling.builder()->GetShape(pooling);
346     OP_REQUIRES_OK(ctx, status_or_shape.status());
347     OP_REQUIRES_OK(ctx, XLAShapeToTensorShape(status_or_shape.ValueOrDie(),
348                                               &expected_out_shape));
349     OP_REQUIRES(ctx, expected_out_shape == out_backprop_shape,
350                 errors::Unimplemented("The output dimensions do not match the "
351                                       "other input values."));
352 
353     xla::PrimitiveType element_type;
354     OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(input_type(2), &element_type));
355     xla::XlaOp init_value = XlaHelpers::Zero(ctx->builder(), input_type(2));
356     auto select = CreateScalarGeComputation(element_type, ctx->builder());
357     auto scatter = CreateScalarAddComputation(element_type, ctx->builder());
358     xla::XlaOp gradients =
359         xla::SelectAndScatter(input, select, ksize_, stride_, xla_padding,
360                               out_backprop, init_value, scatter);
361 
362     ctx->SetOutput(0, gradients);
363   }
364 
365  protected:
366   const int num_spatial_dims_;
367   std::vector<int64> ksize_;
368   std::vector<int64> stride_;
369   Padding padding_;
370   TensorFormat data_format_ = FORMAT_NHWC;
371 };
372 
373 class MaxPool2DGradOp : public MaxPoolGradOp {
374  public:
MaxPool2DGradOp(OpKernelConstruction * ctx)375   explicit MaxPool2DGradOp(OpKernelConstruction* ctx)
376       : MaxPoolGradOp(ctx, /*num_spatial_dims=*/2) {
377     string data_format;
378     OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format));
379     OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_),
380                 errors::InvalidArgument("Invalid data format"));
381   }
382 };
383 REGISTER_XLA_OP(Name("MaxPoolGrad"), MaxPool2DGradOp);
384 REGISTER_XLA_OP(Name("MaxPoolGradV2")
385                     .CompileTimeConstantInput("ksize")
386                     .CompileTimeConstantInput("strides"),
387                 MaxPool2DGradOp);
388 
389 class MaxPool3DGradOp : public MaxPoolGradOp {
390  public:
MaxPool3DGradOp(OpKernelConstruction * ctx)391   explicit MaxPool3DGradOp(OpKernelConstruction* ctx)
392       : MaxPoolGradOp(ctx, /*num_spatial_dims=*/3) {}
393 };
394 REGISTER_XLA_OP(Name("MaxPool3DGrad"), MaxPool3DGradOp);
395 
396 // Average-pooling gradient
397 class AvgPoolGradOp : public XlaOpKernel {
398  public:
AvgPoolGradOp(OpKernelConstruction * ctx,int num_spatial_dims)399   AvgPoolGradOp(OpKernelConstruction* ctx, int num_spatial_dims)
400       : XlaOpKernel(ctx), num_spatial_dims_(num_spatial_dims) {
401     OP_REQUIRES_OK(ctx, ctx->GetAttr("ksize", &ksize_));
402     OP_REQUIRES(ctx, ksize_.size() == num_dims(),
403                 errors::InvalidArgument("Sliding window ksize field must "
404                                         "specify ",
405                                         num_dims(), " dimensions"));
406     OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &stride_));
407     OP_REQUIRES(ctx, stride_.size() == num_dims(),
408                 errors::InvalidArgument("Sliding window strides field must "
409                                         "specify ",
410                                         num_dims(), " dimensions"));
411     OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_));
412     OP_REQUIRES(ctx, ksize_[0] == 1 && stride_[0] == 1,
413                 errors::Unimplemented(
414                     "Pooling is not yet supported on the batch dimension."));
415 
416     string data_format;
417     OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format));
418     OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_),
419                 errors::InvalidArgument("Invalid data format"));
420   }
421 
num_dims() const422   int num_dims() const { return num_spatial_dims_ + 2; }
423 
Compile(XlaOpKernelContext * ctx)424   void Compile(XlaOpKernelContext* ctx) override {
425     TensorShape gradients_shape;
426     OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &gradients_shape));
427 
428     const TensorShape out_backprop_shape = ctx->InputShape(1);
429 
430     // For avgpooling, tensor_in_shape should have num_dims() dimensions.
431     OP_REQUIRES(ctx, gradients_shape.dims() == num_dims(),
432                 errors::InvalidArgument("orig_input_shape must be ", num_dims(),
433                                         "-dimensional"));
434 
435     // For avgpooling, out_backprop should have num_dims() dimensions.
436     OP_REQUIRES(ctx, out_backprop_shape.dims() == num_dims(),
437                 errors::InvalidArgument("out_backprop must be ", num_dims(),
438                                         "-dimensional"));
439 
440     auto out_backprop = ctx->Input(1);
441     std::vector<int64> stride_int64s(stride_.begin(), stride_.end());
442     xla::Padding xla_padding =
443         (padding_ == VALID) ? xla::Padding::kValid : xla::Padding::kSame;
444     xla::PrimitiveType xla_reduction_type;
445     auto reduction_type = XlaHelpers::SumAccumulationType(ctx->input_type(1));
446     OP_REQUIRES_OK(
447         ctx, DataTypeToPrimitiveType(reduction_type, &xla_reduction_type));
448     auto converted_out_backprop =
449         xla::ConvertElementType(out_backprop, xla_reduction_type);
450     auto xla_data_format =
451         XlaTensorFormat(data_format_, gradients_shape.dims() - 2);
452     auto padding_values =
453         MakeSpatialPadding(gradients_shape.dim_sizes(), ksize_, stride_int64s,
454                            xla_padding, xla_data_format);
455     auto in_backprop =
456         xla::AvgPoolGrad(converted_out_backprop, gradients_shape.dim_sizes(),
457                          ksize_, stride_int64s, padding_values, xla_data_format,
458                          /*counts_include_padding=*/padding_ == VALID);
459     // Convert the pooling result back to the input type before returning it.
460     xla::PrimitiveType xla_out_backprop_type;
461     OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(ctx->input_type(1),
462                                                 &xla_out_backprop_type));
463     ctx->SetOutput(0,
464                    xla::ConvertElementType(in_backprop, xla_out_backprop_type));
465   }
466 
467  protected:
468   const int num_spatial_dims_;
469   std::vector<int64> ksize_;
470   std::vector<int32> stride_;
471   Padding padding_;
472   TensorFormat data_format_ = FORMAT_NHWC;
473 };
474 
475 class AvgPool2DGradOp : public AvgPoolGradOp {
476  public:
AvgPool2DGradOp(OpKernelConstruction * ctx)477   explicit AvgPool2DGradOp(OpKernelConstruction* ctx)
478       : AvgPoolGradOp(ctx, /*num_spatial_dims=*/2) {}
479 };
480 REGISTER_XLA_OP(
481     Name("AvgPoolGrad").CompileTimeConstantInput("orig_input_shape"),
482     AvgPool2DGradOp);
483 
484 class AvgPool3DGradOp : public AvgPoolGradOp {
485  public:
AvgPool3DGradOp(OpKernelConstruction * ctx)486   explicit AvgPool3DGradOp(OpKernelConstruction* ctx)
487       : AvgPoolGradOp(ctx, /*num_spatial_dims=*/3) {}
488 };
489 REGISTER_XLA_OP(
490     Name("AvgPool3DGrad").CompileTimeConstantInput("orig_input_shape"),
491     AvgPool3DGradOp);
492 
493 class MaxPoolGradGradOp : public XlaOpKernel {
494  public:
MaxPoolGradGradOp(OpKernelConstruction * ctx,int num_spatial_dims)495   MaxPoolGradGradOp(OpKernelConstruction* ctx, int num_spatial_dims)
496       : XlaOpKernel(ctx), num_spatial_dims_(num_spatial_dims) {
497     if (ctx->num_inputs() == 3) {
498       OP_REQUIRES_OK(ctx, ctx->GetAttr("ksize", &ksize_));
499       OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &stride_));
500     }
501     OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_));
502   }
503 
num_dims() const504   int num_dims() const { return num_spatial_dims_ + 2; }
505 
Compile(XlaOpKernelContext * ctx)506   void Compile(XlaOpKernelContext* ctx) override {
507     if (ctx->num_inputs() != 3) {
508       OP_REQUIRES(
509           ctx, ctx->num_inputs() == 5,
510           errors::InvalidArgument("Must supply ksize and stride arguments."));
511       const TensorShape ksize_shape = ctx->InputShape(3);
512       // Validate input sizes.
513       OP_REQUIRES(ctx, TensorShapeUtils::IsVector(ksize_shape),
514                   errors::InvalidArgument("ksize must be a vector, not shape ",
515                                           ksize_shape.DebugString()));
516       OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(3, &ksize_));
517 
518       const TensorShape stride_shape = ctx->InputShape(4);
519       // Validate input sizes.
520       OP_REQUIRES(ctx, TensorShapeUtils::IsVector(stride_shape),
521                   errors::InvalidArgument("stride must be a vector, not shape ",
522                                           stride_shape.DebugString()));
523       OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(4, &stride_));
524     }
525 
526     OP_REQUIRES(ctx, ksize_.size() == num_dims(),
527                 errors::InvalidArgument("Sliding window ksize field must "
528                                         "specify ",
529                                         num_dims(), " dimensions"));
530     OP_REQUIRES(ctx, stride_.size() == num_dims(),
531                 errors::InvalidArgument("Sliding window strides field must "
532                                         "specify ",
533                                         num_dims(), " dimensions"));
534 
535     const TensorShape tensor_in_shape = ctx->InputShape(0);
536     const TensorShape tensor_out_shape = ctx->InputShape(1);
537     const TensorShape out_backprop_shape = ctx->InputShape(2);
538 
539     // For maxpooling, tensor_in should have num_dims() dimensions.
540     OP_REQUIRES(ctx, tensor_in_shape.dims() == num_dims(),
541                 errors::InvalidArgument("tensor_in must be ", num_dims(),
542                                         "-dimensional"));
543     OP_REQUIRES(ctx, tensor_out_shape.dims() == num_dims(),
544                 errors::InvalidArgument("tensor_out must be ", num_dims(),
545                                         "-dimensional"));
546     // For maxpooling, out_backprop should have num_dims() dimensions.
547     OP_REQUIRES(ctx, out_backprop_shape.dims() == num_dims(),
548                 errors::InvalidArgument("out_backprop must be ", num_dims(),
549                                         "-dimensional"));
550 
551     // What we want to compute:
552     // Given y = MaxPool(x), and xs_grad = MaxPoolGrad(x, y, ys_grad)
553     // MaxPoolGradGrad computes {ys_grad}_grad given x, y, and {xs_grad}_grad.
554     //
555     // In the regular TF op, this amounts to selecting for each window the
556     // incoming backprop value from xs_grad_grad that corresponds to the maximal
557     // value in the corresponding window of x.
558     //
559     // TODO(b/73062247): What we really want is a ReduceWindow with different
560     // arrays for index selection vs return value selection--a select-to-gather.
561     //
562     // Here, we implement a bitwise hack: we use the hi 16 bits of input for
563     // separate max pooling alongside each of the hi and lo 16 bits of
564     // out_backprop packed into 16 lo bits, which we then glue back together at
565     // the end to get a full 32 bits of gradient.
566     //
567     // This could select the wrong backprop value for two x values that are
568     // equally maximal up to the first 16 bits, in which case we are taking the
569     // latter.
570     //
571     // Note that in principle we could use 32 separate maxpools to recover each
572     // of 32 bits of the gradient while preserving 31 bits of input for the max
573     // pooling criteria; here, we just truncate to the first 16 bits of input.
574 
575     auto input = ctx->Input(0);
576     auto out_backprop = ctx->Input(2);
577 
578     auto b = ctx->builder();
579 
580     auto sixteen = xla::ConstantR0<uint32>(b, 16);
581     // in (f32) -> round to 7 mantissa bits (bf16)-> 16-high-bit u32.
582     //
583     // NOTE: Use a ReducePrecision operation instead of a cast to BF16 and back
584     // to F32 since the XLA compiler may ignore narrowing casts to floating
585     // point types if the debug option xla_allow_excess_precision is set.
586     auto in_hi = xla::BitcastConvertType(
587         xla::ReducePrecision(input, /*exponent_bits=*/8, /*mantissa_bits=*/7),
588         xla::U32);
589     auto bp_int = xla::BitcastConvertType(out_backprop, xla::U32);
590     auto bp_hi = xla::ShiftRightLogical(bp_int, sixteen);
591     auto bp_lo =
592         xla::ShiftRightLogical(xla::ShiftLeft(bp_int, sixteen), sixteen);
593     auto in_hi_bp_hi = xla::Add(in_hi, bp_hi);  // Want an unsigned add.
594     auto in_hi_bp_lo = xla::Add(in_hi, bp_lo);  // Want an unsigned add.
595 
596     auto init_value = xla::MinValue(b, xla::F32);
597     // We will reduce by taking the maximal value up to 16 bits (ignoring the lo
598     // 16 bits of packed-in hi/lo backprop value).
599     auto rb = b->CreateSubBuilder("GreaterOrEqOf_ByFirst16Bits");
600     {
601       // F32 parameters to satisfy lowering type restriction for reduce opcode.
602       const xla::Shape scalar = xla::ShapeUtil::MakeShape(xla::F32, {});
603       auto lhs = xla::Parameter(rb.get(), 0, scalar, "lhs");
604       auto rhs = xla::Parameter(rb.get(), 1, scalar, "rhs");
605       auto sixteen = xla::ConstantR0<int32>(rb.get(), 16);
606       auto lhs_criteria =
607           xla::ShiftLeft(xla::ShiftRightLogical(
608                              xla::BitcastConvertType(lhs, xla::S32), sixteen),
609                          sixteen);
610       auto rhs_criteria =
611           xla::ShiftLeft(xla::ShiftRightLogical(
612                              xla::BitcastConvertType(rhs, xla::S32), sixteen),
613                          sixteen);
614       // Must use a F32 comparison, because S32 would not work for negatives.
615       xla::Select(xla::Ge(xla::BitcastConvertType(lhs_criteria, xla::F32),
616                           xla::BitcastConvertType(rhs_criteria, xla::F32)),
617                   lhs, rhs);
618     }
619     auto reduce = rb->BuildAndNoteError();
620     xla::Padding xla_padding =
621         (padding_ == VALID) ? xla::Padding::kValid : xla::Padding::kSame;
622     auto pooled_hi =
623         xla::ReduceWindow(xla::BitcastConvertType(in_hi_bp_hi, xla::F32),
624                           init_value, reduce, ksize_, stride_, xla_padding);
625     auto pooled_lo =
626         xla::ReduceWindow(xla::BitcastConvertType(in_hi_bp_lo, xla::F32),
627                           init_value, reduce, ksize_, stride_, xla_padding);
628     auto grads_hi =
629         xla::ShiftLeft(xla::BitcastConvertType(pooled_hi, xla::U32), sixteen);
630     auto grads_lo = xla::ShiftRightLogical(
631         xla::ShiftLeft(xla::BitcastConvertType(pooled_lo, xla::U32), sixteen),
632         sixteen);
633     auto grads = xla::Add(grads_hi, grads_lo);  // Want an unsigned add.
634 
635     xla::PrimitiveType element_type;
636     OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(input_type(2), &element_type));
637     ctx->SetOutput(0, xla::BitcastConvertType(grads, element_type));
638   }
639 
640  protected:
641   const int num_spatial_dims_;
642   std::vector<int64> ksize_;
643   std::vector<int64> stride_;
644   Padding padding_;
645   TensorFormat data_format_ = FORMAT_NHWC;
646 };
647 
648 class MaxPool2DGradGradOp : public MaxPoolGradGradOp {
649  public:
MaxPool2DGradGradOp(OpKernelConstruction * ctx)650   explicit MaxPool2DGradGradOp(OpKernelConstruction* ctx)
651       : MaxPoolGradGradOp(ctx, /*num_spatial_dims=*/2) {
652     string data_format;
653     OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format));
654     OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_),
655                 errors::InvalidArgument("Invalid data format"));
656   }
657 };
658 REGISTER_XLA_OP(Name("MaxPoolGradGrad").TypeConstraint("T", DT_FLOAT),
659                 MaxPool2DGradGradOp);
660 REGISTER_XLA_OP(Name("MaxPoolGradGradV2")
661                     .TypeConstraint("T", DT_FLOAT)
662                     .CompileTimeConstantInput("ksize")
663                     .CompileTimeConstantInput("strides"),
664                 MaxPool2DGradGradOp);
665 
666 class MaxPool3DGradGradOp : public MaxPoolGradGradOp {
667  public:
MaxPool3DGradGradOp(OpKernelConstruction * ctx)668   explicit MaxPool3DGradGradOp(OpKernelConstruction* ctx)
669       : MaxPoolGradGradOp(ctx, /*num_spatial_dims=*/3) {
670     string data_format;
671     OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format));
672     OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_),
673                 errors::InvalidArgument("Invalid data format"));
674   }
675 };
676 REGISTER_XLA_OP(Name("MaxPool3DGradGrad").TypeConstraint("T", DT_FLOAT),
677                 MaxPool3DGradGradOp);
678 
679 }  // anonymous namespace
680 }  // namespace tensorflow
681