1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // XLA specific pooling ops.
17
18 #include <string>
19
20 #include "tensorflow/compiler/tf2xla/mlir_xla_op_kernel.h"
21 #include "tensorflow/compiler/tf2xla/shape_util.h"
22 #include "tensorflow/compiler/tf2xla/type_util.h"
23 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
24 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
25 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
26 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
27 #include "tensorflow/compiler/xla/client/lib/constants.h"
28 #include "tensorflow/compiler/xla/client/lib/pooling.h"
29 #include "tensorflow/compiler/xla/client/xla_builder.h"
30 #include "tensorflow/compiler/xla/client/xla_computation.h"
31 #include "tensorflow/compiler/xla/literal.h"
32 #include "tensorflow/compiler/xla/util.h"
33 #include "tensorflow/core/framework/bounds_check.h"
34 #include "tensorflow/core/framework/op_kernel.h"
35 #include "tensorflow/core/framework/register_types.h"
36 #include "tensorflow/core/framework/tensor.h"
37 #include "tensorflow/core/platform/errors.h"
38 #include "tensorflow/core/util/tensor_format.h"
39
40 namespace tensorflow {
41 namespace {
42
43 // Superclass of pooling ops.
44 class PoolingOp : public XlaOpKernel {
45 public:
PoolingOp(OpKernelConstruction * ctx,int num_spatial_dims,const DataType reduction_type)46 PoolingOp(OpKernelConstruction* ctx, int num_spatial_dims,
47 const DataType reduction_type)
48 : XlaOpKernel(ctx),
49 num_spatial_dims_(num_spatial_dims),
50 reduction_type_(reduction_type) {
51 if (ctx->num_inputs() == 1) {
52 std::vector<int32> ksize_int;
53 std::vector<int32> stride_int;
54 OP_REQUIRES_OK(ctx, ctx->GetAttr("ksize", &ksize_int));
55 OP_REQUIRES(ctx, ksize_int.size() == num_dims(),
56 errors::InvalidArgument("Sliding window ksize field must "
57 "specify ",
58 num_dims(), " dimensions"));
59 OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &stride_int));
60 OP_REQUIRES(ctx, stride_int.size() == num_dims(),
61 errors::InvalidArgument("Sliding window stride field must "
62 "specify ",
63 num_dims(), " dimensions"));
64 for (int i = 0; i < num_dims(); ++i) {
65 ksize_.push_back(ksize_int[i]);
66 stride_.push_back(stride_int[i]);
67 }
68 }
69 Padding padding;
70 OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding));
71 OP_REQUIRES(ctx, padding != EXPLICIT,
72 errors::Unimplemented(
73 "XLA does not support pooling ops with explicit padding."));
74 padding_ = (padding == VALID) ? xla::Padding::kValid : xla::Padding::kSame;
75
76 OP_REQUIRES_OK(
77 ctx, DataTypeToPrimitiveType(reduction_type_, &xla_reduction_type_));
78 }
79
num_dims() const80 int num_dims() const { return num_spatial_dims_ + 2; }
81
82 protected:
GetKernelSize(XlaOpKernelContext * ctx)83 StatusOr<std::vector<int64>> GetKernelSize(XlaOpKernelContext* ctx) {
84 if (ctx->num_inputs() == 1) {
85 return ksize_;
86 }
87 const TensorShape ksize_shape = ctx->InputShape(1);
88 // Validate input sizes.
89 if (!TensorShapeUtils::IsVector(ksize_shape)) {
90 return errors::InvalidArgument("ksize must be a vector, not shape ",
91 ksize_shape.DebugString());
92 }
93 if (ksize_shape.num_elements() != num_dims()) {
94 return errors::InvalidArgument(
95 "Sliding window ksize field must "
96 "specify ",
97 num_dims(), " dimensions");
98 }
99 std::vector<int64> ksize;
100 auto status = ctx->ConstantInputAsIntVector(1, &ksize);
101 if (!status.ok()) {
102 return status;
103 }
104 return ksize;
105 }
106
GetStride(XlaOpKernelContext * ctx)107 StatusOr<std::vector<int64>> GetStride(XlaOpKernelContext* ctx) {
108 if (ctx->num_inputs() == 1) {
109 return stride_;
110 }
111 const TensorShape stride_shape = ctx->InputShape(2);
112 // Validate input sizes.
113 if (!TensorShapeUtils::IsVector(stride_shape)) {
114 return errors::InvalidArgument("stride must be a vector, not shape ",
115 stride_shape.DebugString());
116 }
117 if (stride_shape.num_elements() != num_dims()) {
118 return errors::InvalidArgument(
119 "Sliding window stride field must "
120 "specify ",
121 num_dims(), " dimensions");
122 }
123 std::vector<int64> stride;
124 auto status = ctx->ConstantInputAsIntVector(2, &stride);
125 if (!status.ok()) {
126 return status;
127 }
128 return stride;
129 }
130
131 protected:
132 const int num_spatial_dims_;
133 std::vector<int64> ksize_;
134 std::vector<int64> stride_;
135 xla::Padding padding_;
136 TensorFormat data_format_ = FORMAT_NHWC;
137 DataType reduction_type_;
138 xla::PrimitiveType xla_reduction_type_;
139 };
140
141 // Converts the tensor data format to the one required by the XLA pooling
142 // library.
XlaTensorFormat(tensorflow::TensorFormat data_format,int num_spatial_dims)143 xla::TensorFormat XlaTensorFormat(tensorflow::TensorFormat data_format,
144 int num_spatial_dims) {
145 int num_dims = num_spatial_dims + 2;
146 int batch_dimension = GetTensorBatchDimIndex(num_dims, data_format);
147 int feature_dimension = GetTensorFeatureDimIndex(num_dims, data_format);
148 absl::InlinedVector<int64, 4> spatial_dimensions(num_spatial_dims);
149 for (int spatial_dim = 0; spatial_dim < num_spatial_dims; ++spatial_dim) {
150 spatial_dimensions[spatial_dim] =
151 GetTensorSpatialDimIndex(num_dims, data_format, spatial_dim);
152 }
153 return xla::TensorFormat(/*batch_dimension=*/batch_dimension,
154 /*feature_dimension=*/feature_dimension,
155 /*spatial_dimensions=*/spatial_dimensions);
156 }
157
158 class MaxPoolOp : public PoolingOp {
159 public:
MaxPoolOp(OpKernelConstruction * ctx,int num_spatial_dims)160 MaxPoolOp(OpKernelConstruction* ctx, int num_spatial_dims)
161 : PoolingOp(ctx, /*num_spatial_dims=*/num_spatial_dims,
162 /*reduction_type=*/ctx->input_type(0)) {
163 std::string data_format_str;
164 OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format_str));
165 OP_REQUIRES(ctx, FormatFromString(data_format_str, &data_format_),
166 errors::InvalidArgument("Invalid data format"));
167 OP_REQUIRES(ctx, data_format_ != FORMAT_NHWC_VECT_W,
168 errors::Unimplemented(
169 "XLA does not support the VECT_NHWC_VECT_W data format. "
170 "Returning unimplemented from MaxPool to keep "
171 "Tensorflow's intended optimized MaxPool here."));
172 }
173
Compile(XlaOpKernelContext * ctx)174 void Compile(XlaOpKernelContext* ctx) override {
175 auto ksize_or_error = GetKernelSize(ctx);
176 OP_REQUIRES_OK(ctx, ksize_or_error.status());
177 std::vector<int64> ksize = ksize_or_error.ValueOrDie();
178
179 auto stride_or_error = GetStride(ctx);
180 OP_REQUIRES_OK(ctx, stride_or_error.status());
181 std::vector<int64> stride = stride_or_error.ValueOrDie();
182
183 xla::XlaOp input = ctx->Input(0);
184
185 StatusOr<xla::Shape> input_shape = ctx->builder()->GetShape(input);
186 OP_REQUIRES_OK(ctx, input_shape.status());
187
188 // For VECT_C max-pool ops, transpose to plain NCHW, do the max-pool, and
189 // transpose back. This isn't necessarily the most efficient algorithm, but
190 // it's ok for starters.
191 absl::optional<int64> vect_width;
192 if (data_format_ == FORMAT_NCHW_VECT_C) {
193 vect_width = input_shape->dimensions().back();
194 input = xla::Collapse(xla::Transpose(input, {0, 1, 4, 2, 3}), {1, 2});
195
196 input_shape = ctx->builder()->GetShape(input);
197 OP_REQUIRES_OK(ctx, input_shape.status());
198 }
199
200 OP_REQUIRES(ctx, input_shape->dimensions_size() == num_dims(),
201 errors::InvalidArgument("Input to ", type_string(),
202 " operator must have ", num_dims(),
203 " dimensions"));
204 auto pooling = xla::MaxPool(
205 input, ksize, stride, padding_,
206 XlaTensorFormat(
207 data_format_ == FORMAT_NCHW_VECT_C ? FORMAT_NCHW : data_format_,
208 input_shape->dimensions_size() - 2));
209
210 if (data_format_ == FORMAT_NCHW_VECT_C) {
211 StatusOr<xla::Shape> result_shape = ctx->builder()->GetShape(pooling);
212 OP_REQUIRES_OK(ctx, result_shape.status());
213
214 int64 num_channels = result_shape->dimensions(1);
215 OP_REQUIRES(
216 ctx, num_channels % *vect_width == 0,
217 errors::FailedPrecondition("Result of NCHW_VECT_C op must have "
218 "channels multiple of ",
219 *vect_width, ", but was ", num_channels));
220
221 absl::InlinedVector<int64, 5> new_dims(result_shape->dimensions().begin(),
222 result_shape->dimensions().end());
223 new_dims[1] /= *vect_width;
224 new_dims.insert(new_dims.begin() + 2, *vect_width);
225 pooling =
226 xla::Transpose(xla::Reshape(pooling, new_dims), {0, 1, 3, 4, 2});
227 }
228
229 ctx->SetOutput(0, pooling);
230 }
231 };
232
233 class MaxPool2DOp : public MaxPoolOp {
234 public:
MaxPool2DOp(OpKernelConstruction * ctx)235 explicit MaxPool2DOp(OpKernelConstruction* ctx)
236 : MaxPoolOp(ctx, /*num_spatial_dims=*/2) {}
237 };
238 REGISTER_XLA_OP(Name("MaxPool"), MaxPool2DOp);
239 REGISTER_XLA_OP(Name("MaxPoolV2")
240 .CompileTimeConstantInput("ksize")
241 .CompileTimeConstantInput("strides"),
242 MaxPool2DOp);
243
244 class MaxPool3DOp : public MaxPoolOp {
245 public:
MaxPool3DOp(OpKernelConstruction * ctx)246 explicit MaxPool3DOp(OpKernelConstruction* ctx)
247 : MaxPoolOp(ctx, /*num_spatial_dims=*/3) {}
248 };
249 REGISTER_XLA_OP(Name("MaxPool3D"), MaxPool3DOp);
250
251 class AvgPoolOp : public PoolingOp {
252 public:
AvgPoolOp(OpKernelConstruction * ctx,int num_spatial_dims)253 AvgPoolOp(OpKernelConstruction* ctx, int num_spatial_dims)
254 : PoolingOp(ctx, /*num_spatial_dims=*/num_spatial_dims,
255 /*reduction_type=*/
256 XlaHelpers::SumAccumulationType(ctx->input_type(0))) {
257 string data_format_str;
258 OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format_str));
259 OP_REQUIRES(ctx, FormatFromString(data_format_str, &data_format_),
260 errors::InvalidArgument("Invalid data format"));
261 }
262
Compile(XlaOpKernelContext * ctx)263 void Compile(XlaOpKernelContext* ctx) override {
264 auto ksize_or_error = GetKernelSize(ctx);
265 OP_REQUIRES_OK(ctx, ksize_or_error.status());
266 std::vector<int64> ksize = ksize_or_error.ValueOrDie();
267
268 auto stride_or_error = GetStride(ctx);
269 OP_REQUIRES_OK(ctx, stride_or_error.status());
270 std::vector<int64> stride = stride_or_error.ValueOrDie();
271
272 const TensorShape input_shape = ctx->InputShape(0);
273 OP_REQUIRES(ctx, input_shape.dims() == num_dims(),
274 errors::InvalidArgument("Input to ", type_string(),
275 " operator must have ", num_dims(),
276 " dimensions"));
277
278 auto xla_data_format =
279 XlaTensorFormat(data_format_, input_shape.dims() - 2);
280 auto spatial_padding = MakeSpatialPadding(
281 input_shape.dim_sizes(), ksize, stride, padding_, xla_data_format);
282
283 // Convert the input to the reduction type.
284 auto converted_input =
285 ConvertElementType(ctx->Input(0), xla_reduction_type_);
286 auto pooling =
287 xla::AvgPool(converted_input, ksize, stride, spatial_padding,
288 xla_data_format, padding_ == xla::Padding::kValid);
289 // Convert the pooling result back to the input type before returning it.
290 ctx->SetOutput(0, ConvertElementType(pooling, ctx->input_xla_type(0)));
291 }
292 };
293
294 class AvgPool2DOp : public AvgPoolOp {
295 public:
AvgPool2DOp(OpKernelConstruction * ctx)296 explicit AvgPool2DOp(OpKernelConstruction* ctx)
297 : AvgPoolOp(ctx, /*num_spatial_dims=*/2) {}
298 };
299 REGISTER_XLA_OP(Name("AvgPool"), AvgPool2DOp);
300
301 REGISTER_XLA_OP(Name("AvgPool3D"), MlirXlaOpKernel);
302
303 // The operation to compute MaxPool gradients.
304 // It takes three inputs:
305 // - The original input tensor
306 // - The original output tensor
307 // - Backprop tensor for output
308 // It produces one output: backprop tensor for input.
309 class MaxPoolGradOp : public XlaOpKernel {
310 public:
MaxPoolGradOp(OpKernelConstruction * ctx,int num_spatial_dims)311 MaxPoolGradOp(OpKernelConstruction* ctx, int num_spatial_dims)
312 : XlaOpKernel(ctx), num_spatial_dims_(num_spatial_dims) {
313 if (ctx->num_inputs() == 3) {
314 OP_REQUIRES_OK(ctx, ctx->GetAttr("ksize", &ksize_));
315 OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &stride_));
316 }
317 OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_));
318 OP_REQUIRES(ctx, padding_ != EXPLICIT,
319 errors::Unimplemented(
320 "XLA does not support maxpoolgrad with explicit padding."));
321 }
322
num_dims() const323 int num_dims() const { return num_spatial_dims_ + 2; }
324
Compile(XlaOpKernelContext * ctx)325 void Compile(XlaOpKernelContext* ctx) override {
326 if (ctx->num_inputs() != 3) {
327 OP_REQUIRES(
328 ctx, ctx->num_inputs() == 5,
329 errors::InvalidArgument("Must supply ksize and stride arguments."));
330 const TensorShape ksize_shape = ctx->InputShape(3);
331 // Validate input sizes.
332 OP_REQUIRES(ctx, TensorShapeUtils::IsVector(ksize_shape),
333 errors::InvalidArgument("ksize must be a vector, not shape ",
334 ksize_shape.DebugString()));
335 OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(3, &ksize_));
336
337 const TensorShape stride_shape = ctx->InputShape(4);
338 // Validate input sizes.
339 OP_REQUIRES(ctx, TensorShapeUtils::IsVector(stride_shape),
340 errors::InvalidArgument("stride must be a vector, not shape ",
341 stride_shape.DebugString()));
342 OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(4, &stride_));
343 }
344
345 OP_REQUIRES(ctx, ksize_.size() == num_dims(),
346 errors::InvalidArgument("Sliding window ksize field must "
347 "specify ",
348 num_dims(), " dimensions"));
349 OP_REQUIRES(ctx, stride_.size() == num_dims(),
350 errors::InvalidArgument("Sliding window strides field must "
351 "specify ",
352 num_dims(), " dimensions"));
353
354 const TensorShape tensor_in_shape = ctx->InputShape(0);
355 const TensorShape tensor_out_shape = ctx->InputShape(1);
356 const TensorShape out_backprop_shape = ctx->InputShape(2);
357
358 // For maxpooling, tensor_in should have num_dims() dimensions.
359 OP_REQUIRES(ctx, tensor_in_shape.dims() == num_dims(),
360 errors::InvalidArgument("tensor_in must be ", num_dims(),
361 "-dimensional"));
362 OP_REQUIRES(ctx, tensor_out_shape.dims() == num_dims(),
363 errors::InvalidArgument("tensor_out must be ", num_dims(),
364 "-dimensional"));
365 // For maxpooling, out_backprop should have num_dims() dimensions.
366 OP_REQUIRES(ctx, out_backprop_shape.dims() == num_dims(),
367 errors::InvalidArgument("out_backprop must be ", num_dims(),
368 "-dimensional"));
369
370 // TODO(phawkins): The XLA version doesn't need tensor_out. Investigate
371 // whether this is a good time/space tradeoff.
372 auto input = ctx->Input(0);
373 auto out_backprop = ctx->Input(2);
374 // We ensured padding_ is not EXPLICIT in the constructor.
375 xla::Padding xla_padding =
376 (padding_ == VALID) ? xla::Padding::kValid : xla::Padding::kSame;
377
378 // Create a MaxPool operation to check the expected resulting shape, and
379 // then throw away the operation because we don't actually need it here.
380 TensorShape expected_out_shape;
381 auto pooling =
382 xla::MaxPool(ctx->Input(0), ksize_, stride_, xla_padding,
383 XlaTensorFormat(data_format_, tensor_in_shape.dims() - 2));
384 auto status_or_shape = pooling.builder()->GetShape(pooling);
385 OP_REQUIRES_OK(ctx, status_or_shape.status());
386 OP_REQUIRES_OK(ctx, XLAShapeToTensorShape(status_or_shape.ValueOrDie(),
387 &expected_out_shape));
388 OP_REQUIRES(ctx, expected_out_shape == out_backprop_shape,
389 errors::Unimplemented("The output dimensions do not match the "
390 "other input values."));
391
392 xla::PrimitiveType element_type;
393 OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(input_type(2), &element_type));
394 xla::XlaOp init_value = XlaHelpers::Zero(ctx->builder(), input_type(2));
395 auto select = CreateScalarGeComputation(element_type, ctx->builder());
396 auto scatter = CreateScalarAddComputation(element_type, ctx->builder());
397 xla::XlaOp gradients =
398 xla::SelectAndScatter(input, select, ksize_, stride_, xla_padding,
399 out_backprop, init_value, scatter);
400
401 ctx->SetOutput(0, gradients);
402 }
403
404 protected:
405 const int num_spatial_dims_;
406 std::vector<int64> ksize_;
407 std::vector<int64> stride_;
408 Padding padding_;
409 TensorFormat data_format_ = FORMAT_NHWC;
410 };
411
412 class MaxPool2DGradOp : public MaxPoolGradOp {
413 public:
MaxPool2DGradOp(OpKernelConstruction * ctx)414 explicit MaxPool2DGradOp(OpKernelConstruction* ctx)
415 : MaxPoolGradOp(ctx, /*num_spatial_dims=*/2) {
416 string data_format;
417 OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format));
418 OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_),
419 errors::InvalidArgument("Invalid data format"));
420 }
421 };
422 REGISTER_XLA_OP(Name("MaxPoolGrad"), MaxPool2DGradOp);
423 REGISTER_XLA_OP(Name("MaxPoolGradV2")
424 .CompileTimeConstantInput("ksize")
425 .CompileTimeConstantInput("strides"),
426 MaxPool2DGradOp);
427
428 REGISTER_XLA_OP(Name("MaxPool3DGrad"), MlirXlaOpKernel);
429
430 // Average-pooling gradient
431 class AvgPoolGradOp : public XlaOpKernel {
432 public:
AvgPoolGradOp(OpKernelConstruction * ctx,int num_spatial_dims)433 AvgPoolGradOp(OpKernelConstruction* ctx, int num_spatial_dims)
434 : XlaOpKernel(ctx), num_spatial_dims_(num_spatial_dims) {
435 OP_REQUIRES_OK(ctx, ctx->GetAttr("ksize", &ksize_));
436 OP_REQUIRES(ctx, ksize_.size() == num_dims(),
437 errors::InvalidArgument("Sliding window ksize field must "
438 "specify ",
439 num_dims(), " dimensions"));
440 OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &stride_));
441 OP_REQUIRES(ctx, stride_.size() == num_dims(),
442 errors::InvalidArgument("Sliding window strides field must "
443 "specify ",
444 num_dims(), " dimensions"));
445 OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_));
446 OP_REQUIRES(ctx, padding_ != EXPLICIT,
447 errors::Unimplemented(
448 "XLA does not support avgpoolgrad with explicit padding."));
449 OP_REQUIRES(ctx, ksize_[0] == 1 && stride_[0] == 1,
450 errors::Unimplemented(
451 "Pooling is not yet supported on the batch dimension."));
452
453 string data_format;
454 OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format));
455 OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_),
456 errors::InvalidArgument("Invalid data format"));
457 }
458
num_dims() const459 int num_dims() const { return num_spatial_dims_ + 2; }
460
Compile(XlaOpKernelContext * ctx)461 void Compile(XlaOpKernelContext* ctx) override {
462 TensorShape gradients_shape;
463 OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &gradients_shape));
464
465 const TensorShape out_backprop_shape = ctx->InputShape(1);
466
467 // For avgpooling, tensor_in_shape should have num_dims() dimensions.
468 OP_REQUIRES(ctx, gradients_shape.dims() == num_dims(),
469 errors::InvalidArgument("orig_input_shape must be ", num_dims(),
470 "-dimensional"));
471
472 // For avgpooling, out_backprop should have num_dims() dimensions.
473 OP_REQUIRES(ctx, out_backprop_shape.dims() == num_dims(),
474 errors::InvalidArgument("out_backprop must be ", num_dims(),
475 "-dimensional"));
476
477 auto out_backprop = ctx->Input(1);
478 std::vector<int64> stride_int64s(stride_.begin(), stride_.end());
479 xla::Padding xla_padding =
480 (padding_ == VALID) ? xla::Padding::kValid : xla::Padding::kSame;
481 xla::PrimitiveType xla_reduction_type;
482 auto reduction_type = XlaHelpers::SumAccumulationType(ctx->input_type(1));
483 OP_REQUIRES_OK(
484 ctx, DataTypeToPrimitiveType(reduction_type, &xla_reduction_type));
485 auto converted_out_backprop =
486 xla::ConvertElementType(out_backprop, xla_reduction_type);
487 auto xla_data_format =
488 XlaTensorFormat(data_format_, gradients_shape.dims() - 2);
489 auto padding_values =
490 MakeSpatialPadding(gradients_shape.dim_sizes(), ksize_, stride_int64s,
491 xla_padding, xla_data_format);
492 auto in_backprop =
493 xla::AvgPoolGrad(converted_out_backprop, gradients_shape.dim_sizes(),
494 ksize_, stride_int64s, padding_values, xla_data_format,
495 /*counts_include_padding=*/padding_ == VALID);
496 // Convert the pooling result back to the input type before returning it.
497 xla::PrimitiveType xla_out_backprop_type;
498 OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(ctx->input_type(1),
499 &xla_out_backprop_type));
500 ctx->SetOutput(0,
501 xla::ConvertElementType(in_backprop, xla_out_backprop_type));
502 }
503
504 protected:
505 const int num_spatial_dims_;
506 std::vector<int64> ksize_;
507 std::vector<int32> stride_;
508 Padding padding_;
509 TensorFormat data_format_ = FORMAT_NHWC;
510 };
511
512 class AvgPool2DGradOp : public AvgPoolGradOp {
513 public:
AvgPool2DGradOp(OpKernelConstruction * ctx)514 explicit AvgPool2DGradOp(OpKernelConstruction* ctx)
515 : AvgPoolGradOp(ctx, /*num_spatial_dims=*/2) {}
516 };
517 REGISTER_XLA_OP(
518 Name("AvgPoolGrad").CompileTimeConstantInput("orig_input_shape"),
519 AvgPool2DGradOp);
520
521 class AvgPool3DGradOp : public AvgPoolGradOp {
522 public:
AvgPool3DGradOp(OpKernelConstruction * ctx)523 explicit AvgPool3DGradOp(OpKernelConstruction* ctx)
524 : AvgPoolGradOp(ctx, /*num_spatial_dims=*/3) {}
525 };
526 REGISTER_XLA_OP(
527 Name("AvgPool3DGrad").CompileTimeConstantInput("orig_input_shape"),
528 AvgPool3DGradOp);
529
530 class MaxPoolGradGradOp : public XlaOpKernel {
531 public:
MaxPoolGradGradOp(OpKernelConstruction * ctx,int num_spatial_dims)532 MaxPoolGradGradOp(OpKernelConstruction* ctx, int num_spatial_dims)
533 : XlaOpKernel(ctx), num_spatial_dims_(num_spatial_dims) {
534 if (ctx->num_inputs() == 3) {
535 OP_REQUIRES_OK(ctx, ctx->GetAttr("ksize", &ksize_));
536 OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &stride_));
537 }
538 OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_));
539 OP_REQUIRES(
540 ctx, padding_ != EXPLICIT,
541 errors::Unimplemented(
542 "XLA does not support maxpoolgradgrad with explicit padding."));
543 }
544
num_dims() const545 int num_dims() const { return num_spatial_dims_ + 2; }
546
Compile(XlaOpKernelContext * ctx)547 void Compile(XlaOpKernelContext* ctx) override {
548 if (ctx->num_inputs() != 3) {
549 OP_REQUIRES(
550 ctx, ctx->num_inputs() == 5,
551 errors::InvalidArgument("Must supply ksize and stride arguments."));
552 const TensorShape ksize_shape = ctx->InputShape(3);
553 // Validate input sizes.
554 OP_REQUIRES(ctx, TensorShapeUtils::IsVector(ksize_shape),
555 errors::InvalidArgument("ksize must be a vector, not shape ",
556 ksize_shape.DebugString()));
557 OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(3, &ksize_));
558
559 const TensorShape stride_shape = ctx->InputShape(4);
560 // Validate input sizes.
561 OP_REQUIRES(ctx, TensorShapeUtils::IsVector(stride_shape),
562 errors::InvalidArgument("stride must be a vector, not shape ",
563 stride_shape.DebugString()));
564 OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(4, &stride_));
565 }
566
567 OP_REQUIRES(ctx, ksize_.size() == num_dims(),
568 errors::InvalidArgument("Sliding window ksize field must "
569 "specify ",
570 num_dims(), " dimensions"));
571 OP_REQUIRES(ctx, stride_.size() == num_dims(),
572 errors::InvalidArgument("Sliding window strides field must "
573 "specify ",
574 num_dims(), " dimensions"));
575
576 const TensorShape tensor_in_shape = ctx->InputShape(0);
577 const TensorShape tensor_out_shape = ctx->InputShape(1);
578 const TensorShape out_backprop_shape = ctx->InputShape(2);
579
580 // For maxpooling, tensor_in should have num_dims() dimensions.
581 OP_REQUIRES(ctx, tensor_in_shape.dims() == num_dims(),
582 errors::InvalidArgument("tensor_in must be ", num_dims(),
583 "-dimensional"));
584 OP_REQUIRES(ctx, tensor_out_shape.dims() == num_dims(),
585 errors::InvalidArgument("tensor_out must be ", num_dims(),
586 "-dimensional"));
587 // For maxpooling, out_backprop should have num_dims() dimensions.
588 OP_REQUIRES(ctx, out_backprop_shape.dims() == num_dims(),
589 errors::InvalidArgument("out_backprop must be ", num_dims(),
590 "-dimensional"));
591
592 // What we want to compute:
593 // Given y = MaxPool(x), and xs_grad = MaxPoolGrad(x, y, ys_grad)
594 // MaxPoolGradGrad computes {ys_grad}_grad given x, y, and {xs_grad}_grad.
595 //
596 // In the regular TF op, this amounts to selecting for each window the
597 // incoming backprop value from xs_grad_grad that corresponds to the maximal
598 // value in the corresponding window of x.
599 //
600 // TODO(b/73062247): What we really want is a ReduceWindow with different
601 // arrays for index selection vs return value selection--a select-to-gather.
602 //
603 // Here, we implement a bitwise hack: we use the hi 16 bits of input for
604 // separate max pooling alongside each of the hi and lo 16 bits of
605 // out_backprop packed into 16 lo bits, which we then glue back together at
606 // the end to get a full 32 bits of gradient.
607 //
608 // This could select the wrong backprop value for two x values that are
609 // equally maximal up to the first 16 bits, in which case we are taking the
610 // latter.
611 //
612 // Note that in principle we could use 32 separate maxpools to recover each
613 // of 32 bits of the gradient while preserving 31 bits of input for the max
614 // pooling criteria; here, we just truncate to the first 16 bits of input.
615
616 auto input = ctx->Input(0);
617 auto out_backprop = ctx->Input(2);
618
619 auto b = ctx->builder();
620
621 auto sixteen = xla::ConstantR0<uint32>(b, 16);
622 // in (f32) -> round to 7 mantissa bits (bf16)-> 16-high-bit u32.
623 //
624 // NOTE: Use a ReducePrecision operation instead of a cast to BF16 and back
625 // to F32 since the XLA compiler may ignore narrowing casts to floating
626 // point types if the debug option xla_allow_excess_precision is set.
627 auto in_hi = xla::BitcastConvertType(
628 xla::ReducePrecision(input, /*exponent_bits=*/8, /*mantissa_bits=*/7),
629 xla::U32);
630 auto bp_int = xla::BitcastConvertType(out_backprop, xla::U32);
631 auto bp_hi = xla::ShiftRightLogical(bp_int, sixteen);
632 auto bp_lo =
633 xla::ShiftRightLogical(xla::ShiftLeft(bp_int, sixteen), sixteen);
634 auto in_hi_bp_hi = xla::Add(in_hi, bp_hi); // Want an unsigned add.
635 auto in_hi_bp_lo = xla::Add(in_hi, bp_lo); // Want an unsigned add.
636
637 auto init_value = xla::MinValue(b, xla::F32);
638 // We will reduce by taking the maximal value up to 16 bits (ignoring the lo
639 // 16 bits of packed-in hi/lo backprop value).
640 auto rb = b->CreateSubBuilder("GreaterOrEqOf_ByFirst16Bits");
641 {
642 // F32 parameters to satisfy lowering type restriction for reduce opcode.
643 const xla::Shape scalar = xla::ShapeUtil::MakeShape(xla::F32, {});
644 auto lhs = xla::Parameter(rb.get(), 0, scalar, "lhs");
645 auto rhs = xla::Parameter(rb.get(), 1, scalar, "rhs");
646 auto sixteen = xla::ConstantR0<int32>(rb.get(), 16);
647 auto lhs_criteria =
648 xla::ShiftLeft(xla::ShiftRightLogical(
649 xla::BitcastConvertType(lhs, xla::S32), sixteen),
650 sixteen);
651 auto rhs_criteria =
652 xla::ShiftLeft(xla::ShiftRightLogical(
653 xla::BitcastConvertType(rhs, xla::S32), sixteen),
654 sixteen);
655 // Must use a F32 comparison, because S32 would not work for negatives.
656 xla::Select(xla::Ge(xla::BitcastConvertType(lhs_criteria, xla::F32),
657 xla::BitcastConvertType(rhs_criteria, xla::F32)),
658 lhs, rhs);
659 }
660 auto reduce = rb->BuildAndNoteError();
661 xla::Padding xla_padding =
662 (padding_ == VALID) ? xla::Padding::kValid : xla::Padding::kSame;
663 auto pooled_hi =
664 xla::ReduceWindow(xla::BitcastConvertType(in_hi_bp_hi, xla::F32),
665 init_value, reduce, ksize_, stride_, xla_padding);
666 auto pooled_lo =
667 xla::ReduceWindow(xla::BitcastConvertType(in_hi_bp_lo, xla::F32),
668 init_value, reduce, ksize_, stride_, xla_padding);
669 auto grads_hi =
670 xla::ShiftLeft(xla::BitcastConvertType(pooled_hi, xla::U32), sixteen);
671 auto grads_lo = xla::ShiftRightLogical(
672 xla::ShiftLeft(xla::BitcastConvertType(pooled_lo, xla::U32), sixteen),
673 sixteen);
674 auto grads = xla::Add(grads_hi, grads_lo); // Want an unsigned add.
675
676 xla::PrimitiveType element_type;
677 OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(input_type(2), &element_type));
678 ctx->SetOutput(0, xla::BitcastConvertType(grads, element_type));
679 }
680
681 protected:
682 const int num_spatial_dims_;
683 std::vector<int64> ksize_;
684 std::vector<int64> stride_;
685 Padding padding_;
686 TensorFormat data_format_ = FORMAT_NHWC;
687 };
688
689 class MaxPool2DGradGradOp : public MaxPoolGradGradOp {
690 public:
MaxPool2DGradGradOp(OpKernelConstruction * ctx)691 explicit MaxPool2DGradGradOp(OpKernelConstruction* ctx)
692 : MaxPoolGradGradOp(ctx, /*num_spatial_dims=*/2) {
693 string data_format;
694 OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format));
695 OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_),
696 errors::InvalidArgument("Invalid data format"));
697 }
698 };
699 REGISTER_XLA_OP(Name("MaxPoolGradGrad").TypeConstraint("T", DT_FLOAT),
700 MaxPool2DGradGradOp);
701 REGISTER_XLA_OP(Name("MaxPoolGradGradV2")
702 .TypeConstraint("T", DT_FLOAT)
703 .CompileTimeConstantInput("ksize")
704 .CompileTimeConstantInput("strides"),
705 MaxPool2DGradGradOp);
706
707 class MaxPool3DGradGradOp : public MaxPoolGradGradOp {
708 public:
MaxPool3DGradGradOp(OpKernelConstruction * ctx)709 explicit MaxPool3DGradGradOp(OpKernelConstruction* ctx)
710 : MaxPoolGradGradOp(ctx, /*num_spatial_dims=*/3) {
711 string data_format;
712 OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format));
713 OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_),
714 errors::InvalidArgument("Invalid data format"));
715 }
716 };
717 REGISTER_XLA_OP(Name("MaxPool3DGradGrad").TypeConstraint("T", DT_FLOAT),
718 MaxPool3DGradGradOp);
719
720 } // anonymous namespace
721 } // namespace tensorflow
722