1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // XLA specific pooling ops.
17
18 #include <string>
19
20 #include "tensorflow/compiler/tf2xla/mlir_xla_op_kernel.h"
21 #include "tensorflow/compiler/tf2xla/shape_util.h"
22 #include "tensorflow/compiler/tf2xla/type_util.h"
23 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
24 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
25 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
26 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
27 #include "tensorflow/compiler/xla/client/lib/constants.h"
28 #include "tensorflow/compiler/xla/client/lib/pooling.h"
29 #include "tensorflow/compiler/xla/client/value_inference.h"
30 #include "tensorflow/compiler/xla/client/xla_builder.h"
31 #include "tensorflow/compiler/xla/client/xla_computation.h"
32 #include "tensorflow/compiler/xla/literal.h"
33 #include "tensorflow/compiler/xla/util.h"
34 #include "tensorflow/core/framework/bounds_check.h"
35 #include "tensorflow/core/framework/op_kernel.h"
36 #include "tensorflow/core/framework/register_types.h"
37 #include "tensorflow/core/framework/tensor.h"
38 #include "tensorflow/core/platform/errors.h"
39 #include "tensorflow/core/util/determinism.h"
40 #include "tensorflow/core/util/tensor_format.h"
41
42 namespace tensorflow {
43 namespace {
44
45 // Superclass of pooling ops.
46 class PoolingOp : public XlaOpKernel {
47 public:
PoolingOp(OpKernelConstruction * ctx,int num_spatial_dims,const DataType reduction_type)48 PoolingOp(OpKernelConstruction* ctx, int num_spatial_dims,
49 const DataType reduction_type)
50 : XlaOpKernel(ctx),
51 num_spatial_dims_(num_spatial_dims),
52 reduction_type_(reduction_type) {
53 if (ctx->num_inputs() == 1) {
54 std::vector<int32> ksize_int;
55 std::vector<int32> stride_int;
56 OP_REQUIRES_OK(ctx, ctx->GetAttr("ksize", &ksize_int));
57 OP_REQUIRES(ctx, ksize_int.size() == num_dims(),
58 errors::InvalidArgument("Sliding window ksize field must "
59 "specify ",
60 num_dims(), " dimensions"));
61 OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &stride_int));
62 OP_REQUIRES(ctx, stride_int.size() == num_dims(),
63 errors::InvalidArgument("Sliding window stride field must "
64 "specify ",
65 num_dims(), " dimensions"));
66 for (int i = 0; i < num_dims(); ++i) {
67 ksize_.push_back(ksize_int[i]);
68 stride_.push_back(stride_int[i]);
69 }
70 }
71 Padding padding;
72 OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding));
73 OP_REQUIRES(ctx, padding != EXPLICIT,
74 errors::Unimplemented(
75 "XLA does not support pooling ops with explicit padding."));
76 padding_ = (padding == VALID) ? xla::Padding::kValid : xla::Padding::kSame;
77
78 OP_REQUIRES_OK(
79 ctx, DataTypeToPrimitiveType(reduction_type_, &xla_reduction_type_));
80 }
81
num_dims() const82 int num_dims() const { return num_spatial_dims_ + 2; }
83
84 protected:
GetKernelSize(XlaOpKernelContext * ctx)85 StatusOr<std::vector<int64_t>> GetKernelSize(XlaOpKernelContext* ctx) {
86 if (ctx->num_inputs() == 1) {
87 return ksize_;
88 }
89 const TensorShape ksize_shape = ctx->InputShape(1);
90 // Validate input sizes.
91 if (!TensorShapeUtils::IsVector(ksize_shape)) {
92 return errors::InvalidArgument("ksize must be a vector, not shape ",
93 ksize_shape.DebugString());
94 }
95 if (ksize_shape.num_elements() != num_dims()) {
96 return errors::InvalidArgument(
97 "Sliding window ksize field must "
98 "specify ",
99 num_dims(), " dimensions");
100 }
101 std::vector<int64_t> ksize;
102 auto status = ctx->ConstantInputAsIntVector(1, &ksize);
103 if (!status.ok()) {
104 return status;
105 }
106 return ksize;
107 }
108
GetStride(XlaOpKernelContext * ctx)109 StatusOr<std::vector<int64_t>> GetStride(XlaOpKernelContext* ctx) {
110 if (ctx->num_inputs() == 1) {
111 return stride_;
112 }
113 const TensorShape stride_shape = ctx->InputShape(2);
114 // Validate input sizes.
115 if (!TensorShapeUtils::IsVector(stride_shape)) {
116 return errors::InvalidArgument("stride must be a vector, not shape ",
117 stride_shape.DebugString());
118 }
119 if (stride_shape.num_elements() != num_dims()) {
120 return errors::InvalidArgument(
121 "Sliding window stride field must "
122 "specify ",
123 num_dims(), " dimensions");
124 }
125 std::vector<int64_t> stride;
126 auto status = ctx->ConstantInputAsIntVector(2, &stride);
127 if (!status.ok()) {
128 return status;
129 }
130 return stride;
131 }
132
133 protected:
134 const int num_spatial_dims_;
135 std::vector<int64_t> ksize_;
136 std::vector<int64_t> stride_;
137 xla::Padding padding_;
138 TensorFormat data_format_ = FORMAT_NHWC;
139 DataType reduction_type_;
140 xla::PrimitiveType xla_reduction_type_;
141 };
142
143 // Converts the tensor data format to the one required by the XLA pooling
144 // library.
XlaTensorFormat(tensorflow::TensorFormat data_format,int num_spatial_dims)145 xla::TensorFormat XlaTensorFormat(tensorflow::TensorFormat data_format,
146 int num_spatial_dims) {
147 int num_dims = num_spatial_dims + 2;
148 int batch_dimension = GetTensorBatchDimIndex(num_dims, data_format);
149 int feature_dimension = GetTensorFeatureDimIndex(num_dims, data_format);
150 absl::InlinedVector<int64_t, 4> spatial_dimensions(num_spatial_dims);
151 for (int spatial_dim = 0; spatial_dim < num_spatial_dims; ++spatial_dim) {
152 spatial_dimensions[spatial_dim] =
153 GetTensorSpatialDimIndex(num_dims, data_format, spatial_dim);
154 }
155 return xla::TensorFormat(/*batch_dimension=*/batch_dimension,
156 /*feature_dimension=*/feature_dimension,
157 /*spatial_dimensions=*/spatial_dimensions);
158 }
159
160 class MaxPoolOp : public PoolingOp {
161 public:
MaxPoolOp(OpKernelConstruction * ctx,int num_spatial_dims)162 MaxPoolOp(OpKernelConstruction* ctx, int num_spatial_dims)
163 : PoolingOp(ctx, /*num_spatial_dims=*/num_spatial_dims,
164 /*reduction_type=*/ctx->input_type(0)) {
165 std::string data_format_str;
166 OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format_str));
167 OP_REQUIRES(ctx, FormatFromString(data_format_str, &data_format_),
168 errors::InvalidArgument("Invalid data format"));
169 OP_REQUIRES(ctx, data_format_ != FORMAT_NHWC_VECT_W,
170 errors::Unimplemented(
171 "XLA does not support the VECT_NHWC_VECT_W data format. "
172 "Returning unimplemented from MaxPool to keep "
173 "Tensorflow's intended optimized MaxPool here."));
174 }
175
Compile(XlaOpKernelContext * ctx)176 void Compile(XlaOpKernelContext* ctx) override {
177 auto ksize_or_error = GetKernelSize(ctx);
178 OP_REQUIRES_OK(ctx, ksize_or_error.status());
179 std::vector<int64_t> ksize = ksize_or_error.ValueOrDie();
180
181 auto stride_or_error = GetStride(ctx);
182 OP_REQUIRES_OK(ctx, stride_or_error.status());
183 std::vector<int64_t> stride = stride_or_error.ValueOrDie();
184
185 xla::XlaOp input = ctx->Input(0);
186
187 StatusOr<xla::Shape> input_shape = ctx->builder()->GetShape(input);
188 OP_REQUIRES_OK(ctx, input_shape.status());
189
190 // For VECT_C max-pool ops, transpose to plain NCHW, do the max-pool, and
191 // transpose back. This isn't necessarily the most efficient algorithm, but
192 // it's ok for starters.
193 std::optional<int64_t> vect_width;
194 if (data_format_ == FORMAT_NCHW_VECT_C) {
195 vect_width = input_shape->dimensions().back();
196 input = xla::Collapse(xla::Transpose(input, {0, 1, 4, 2, 3}), {1, 2});
197
198 input_shape = ctx->builder()->GetShape(input);
199 OP_REQUIRES_OK(ctx, input_shape.status());
200 }
201
202 OP_REQUIRES(ctx, input_shape->dimensions_size() == num_dims(),
203 errors::InvalidArgument("Input to ", type_string(),
204 " operator must have ", num_dims(),
205 " dimensions"));
206 auto pooling = xla::MaxPool(
207 input, ksize, stride, padding_,
208 XlaTensorFormat(
209 data_format_ == FORMAT_NCHW_VECT_C ? FORMAT_NCHW : data_format_,
210 input_shape->dimensions_size() - 2));
211
212 if (data_format_ == FORMAT_NCHW_VECT_C) {
213 StatusOr<xla::Shape> result_shape = ctx->builder()->GetShape(pooling);
214 OP_REQUIRES_OK(ctx, result_shape.status());
215
216 int64 num_channels = result_shape->dimensions(1);
217 OP_REQUIRES(
218 ctx, num_channels % *vect_width == 0,
219 errors::FailedPrecondition("Result of NCHW_VECT_C op must have "
220 "channels multiple of ",
221 *vect_width, ", but was ", num_channels));
222
223 absl::InlinedVector<int64, 5> new_dims(result_shape->dimensions().begin(),
224 result_shape->dimensions().end());
225 new_dims[1] /= *vect_width;
226 new_dims.insert(new_dims.begin() + 2, *vect_width);
227 pooling =
228 xla::Transpose(xla::Reshape(pooling, new_dims), {0, 1, 3, 4, 2});
229 }
230
231 ctx->SetOutput(0, pooling);
232 }
233 };
234
235 class MaxPool2DOp : public MaxPoolOp {
236 public:
MaxPool2DOp(OpKernelConstruction * ctx)237 explicit MaxPool2DOp(OpKernelConstruction* ctx)
238 : MaxPoolOp(ctx, /*num_spatial_dims=*/2) {}
239 };
240 REGISTER_XLA_OP(Name("MaxPool"), MaxPool2DOp);
241 REGISTER_XLA_OP(Name("MaxPoolV2")
242 .CompileTimeConstantInput("ksize")
243 .CompileTimeConstantInput("strides"),
244 MaxPool2DOp);
245
246 class MaxPool3DOp : public MaxPoolOp {
247 public:
MaxPool3DOp(OpKernelConstruction * ctx)248 explicit MaxPool3DOp(OpKernelConstruction* ctx)
249 : MaxPoolOp(ctx, /*num_spatial_dims=*/3) {}
250 };
251 REGISTER_XLA_OP(Name("MaxPool3D"), MaxPool3DOp);
252
253 class AvgPoolOp : public PoolingOp {
254 public:
AvgPoolOp(OpKernelConstruction * ctx,int num_spatial_dims)255 AvgPoolOp(OpKernelConstruction* ctx, int num_spatial_dims)
256 : PoolingOp(ctx, /*num_spatial_dims=*/num_spatial_dims,
257 /*reduction_type=*/
258 XlaHelpers::SumAccumulationType(ctx->input_type(0))) {
259 string data_format_str;
260 OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format_str));
261 OP_REQUIRES(ctx, FormatFromString(data_format_str, &data_format_),
262 errors::InvalidArgument("Invalid data format"));
263 }
264
Compile(XlaOpKernelContext * ctx)265 void Compile(XlaOpKernelContext* ctx) override {
266 auto ksize_or_error = GetKernelSize(ctx);
267 OP_REQUIRES_OK(ctx, ksize_or_error.status());
268 std::vector<int64_t> ksize = ksize_or_error.ValueOrDie();
269
270 auto stride_or_error = GetStride(ctx);
271 OP_REQUIRES_OK(ctx, stride_or_error.status());
272 std::vector<int64_t> stride = stride_or_error.ValueOrDie();
273
274 const TensorShape input_shape = ctx->InputShape(0);
275 OP_REQUIRES(ctx, input_shape.dims() == num_dims(),
276 errors::InvalidArgument("Input to ", type_string(),
277 " operator must have ", num_dims(),
278 " dimensions"));
279
280 auto xla_data_format =
281 XlaTensorFormat(data_format_, input_shape.dims() - 2);
282 auto spatial_padding = MakeSpatialPadding(
283 input_shape.dim_sizes(), ksize, stride, padding_, xla_data_format);
284
285 // Convert the input to the reduction type.
286 auto converted_input =
287 ConvertElementType(ctx->Input(0), xla_reduction_type_);
288 auto pooling =
289 xla::AvgPool(converted_input, ksize, stride, spatial_padding,
290 xla_data_format, padding_ == xla::Padding::kValid);
291 // Convert the pooling result back to the input type before returning it.
292 ctx->SetOutput(0, ConvertElementType(pooling, ctx->input_xla_type(0)));
293 }
294 };
295
296 class AvgPool2DOp : public AvgPoolOp {
297 public:
AvgPool2DOp(OpKernelConstruction * ctx)298 explicit AvgPool2DOp(OpKernelConstruction* ctx)
299 : AvgPoolOp(ctx, /*num_spatial_dims=*/2) {}
300 };
301 REGISTER_XLA_OP(Name("AvgPool"), AvgPool2DOp);
302
303 REGISTER_XLA_OP(Name("AvgPool3D"), MlirXlaOpKernel);
304
305 // The operation to compute MaxPool gradients.
306 // It takes three inputs:
307 // - The original input tensor
308 // - The original output tensor
309 // - Backprop tensor for output
310 // It produces one output: backprop tensor for input.
311 class MaxPoolGradOp : public XlaOpKernel {
312 public:
MaxPoolGradOp(OpKernelConstruction * ctx,int num_spatial_dims)313 MaxPoolGradOp(OpKernelConstruction* ctx, int num_spatial_dims)
314 : XlaOpKernel(ctx), num_spatial_dims_(num_spatial_dims) {
315 if (ctx->num_inputs() == 3) {
316 OP_REQUIRES_OK(ctx, ctx->GetAttr("ksize", &ksize_));
317 OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &stride_));
318 }
319 OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_));
320 OP_REQUIRES(ctx, padding_ != EXPLICIT,
321 errors::Unimplemented(
322 "XLA does not support maxpoolgrad with explicit padding."));
323 // When determinism is enabled, the use of SelectAndScatter causes a generic
324 // error to be raised. We raise a more informative error here before
325 // SelectAndScatter is used.
326 OP_REQUIRES(
327 ctx, !tensorflow::OpDeterminismRequired(),
328 errors::Unimplemented("GPU MaxPool gradient ops do not yet have a "
329 "deterministic XLA implementation."));
330 }
331
num_dims() const332 int num_dims() const { return num_spatial_dims_ + 2; }
333
Compile(XlaOpKernelContext * ctx)334 void Compile(XlaOpKernelContext* ctx) override {
335 if (ctx->num_inputs() != 3) {
336 OP_REQUIRES(
337 ctx, ctx->num_inputs() == 5,
338 errors::InvalidArgument("Must supply ksize and stride arguments."));
339 const TensorShape ksize_shape = ctx->InputShape(3);
340 // Validate input sizes.
341 OP_REQUIRES(ctx, TensorShapeUtils::IsVector(ksize_shape),
342 errors::InvalidArgument("ksize must be a vector, not shape ",
343 ksize_shape.DebugString()));
344 OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(3, &ksize_));
345
346 const TensorShape stride_shape = ctx->InputShape(4);
347 // Validate input sizes.
348 OP_REQUIRES(ctx, TensorShapeUtils::IsVector(stride_shape),
349 errors::InvalidArgument("stride must be a vector, not shape ",
350 stride_shape.DebugString()));
351 OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(4, &stride_));
352 }
353
354 OP_REQUIRES(ctx, ksize_.size() == num_dims(),
355 errors::InvalidArgument("Sliding window ksize field must "
356 "specify ",
357 num_dims(), " dimensions"));
358 OP_REQUIRES(ctx, stride_.size() == num_dims(),
359 errors::InvalidArgument("Sliding window strides field must "
360 "specify ",
361 num_dims(), " dimensions"));
362
363 const TensorShape tensor_in_shape = ctx->InputShape(0);
364 const TensorShape tensor_out_shape = ctx->InputShape(1);
365 const TensorShape out_backprop_shape = ctx->InputShape(2);
366
367 // For maxpooling, tensor_in should have num_dims() dimensions.
368 OP_REQUIRES(ctx, tensor_in_shape.dims() == num_dims(),
369 errors::InvalidArgument("tensor_in must be ", num_dims(),
370 "-dimensional"));
371 OP_REQUIRES(ctx, tensor_out_shape.dims() == num_dims(),
372 errors::InvalidArgument("tensor_out must be ", num_dims(),
373 "-dimensional"));
374 // For maxpooling, out_backprop should have num_dims() dimensions.
375 OP_REQUIRES(ctx, out_backprop_shape.dims() == num_dims(),
376 errors::InvalidArgument("out_backprop must be ", num_dims(),
377 "-dimensional"));
378
379 // TODO(phawkins): The XLA version doesn't need tensor_out. Investigate
380 // whether this is a good time/space tradeoff.
381 auto input = ctx->Input(0);
382 auto out_backprop = ctx->Input(2);
383 // We ensured padding_ is not EXPLICIT in the constructor.
384 xla::Padding xla_padding =
385 (padding_ == VALID) ? xla::Padding::kValid : xla::Padding::kSame;
386
387 // Create a MaxPool operation to check the expected resulting shape, and
388 // then throw away the operation because we don't actually need it here.
389 TensorShape expected_out_shape;
390 auto pooling =
391 xla::MaxPool(ctx->Input(0), ksize_, stride_, xla_padding,
392 XlaTensorFormat(data_format_, tensor_in_shape.dims() - 2));
393 auto status_or_shape = pooling.builder()->GetShape(pooling);
394 OP_REQUIRES_OK(ctx, status_or_shape.status());
395 OP_REQUIRES_OK(ctx, XLAShapeToTensorShape(status_or_shape.ValueOrDie(),
396 &expected_out_shape));
397 OP_REQUIRES(ctx, expected_out_shape == out_backprop_shape,
398 errors::Unimplemented("The output dimensions do not match the "
399 "other input values."));
400
401 xla::PrimitiveType element_type;
402 OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(input_type(2), &element_type));
403 xla::XlaOp init_value = XlaHelpers::Zero(ctx->builder(), input_type(2));
404 auto select = CreateScalarGeComputation(element_type, ctx->builder());
405 auto scatter = CreateScalarAddComputation(element_type, ctx->builder());
406 xla::XlaOp gradients =
407 xla::SelectAndScatter(input, select, ksize_, stride_, xla_padding,
408 out_backprop, init_value, scatter);
409
410 ctx->SetOutput(0, gradients);
411 }
412
413 protected:
414 const int num_spatial_dims_;
415 std::vector<int64_t> ksize_;
416 std::vector<int64_t> stride_;
417 Padding padding_;
418 TensorFormat data_format_ = FORMAT_NHWC;
419 };
420
421 class MaxPool2DGradOp : public MaxPoolGradOp {
422 public:
MaxPool2DGradOp(OpKernelConstruction * ctx)423 explicit MaxPool2DGradOp(OpKernelConstruction* ctx)
424 : MaxPoolGradOp(ctx, /*num_spatial_dims=*/2) {
425 string data_format;
426 OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format));
427 OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_),
428 errors::InvalidArgument("Invalid data format"));
429 }
430 };
431 REGISTER_XLA_OP(Name("MaxPoolGrad"), MaxPool2DGradOp);
432 REGISTER_XLA_OP(Name("MaxPoolGradV2")
433 .CompileTimeConstantInput("ksize")
434 .CompileTimeConstantInput("strides"),
435 MaxPool2DGradOp);
436
437 REGISTER_XLA_OP(Name("MaxPool3DGrad"), MlirXlaOpKernel);
438
439 // Average-pooling gradient
440 class AvgPoolGradOp : public XlaOpKernel {
441 public:
AvgPoolGradOp(OpKernelConstruction * ctx,int num_spatial_dims)442 AvgPoolGradOp(OpKernelConstruction* ctx, int num_spatial_dims)
443 : XlaOpKernel(ctx), num_spatial_dims_(num_spatial_dims) {
444 OP_REQUIRES_OK(ctx, ctx->GetAttr("ksize", &ksize_));
445 OP_REQUIRES(ctx, ksize_.size() == num_dims(),
446 errors::InvalidArgument("Sliding window ksize field must "
447 "specify ",
448 num_dims(), " dimensions"));
449 OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &stride_));
450 OP_REQUIRES(ctx, stride_.size() == num_dims(),
451 errors::InvalidArgument("Sliding window strides field must "
452 "specify ",
453 num_dims(), " dimensions"));
454 OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_));
455 OP_REQUIRES(ctx, padding_ != EXPLICIT,
456 errors::Unimplemented(
457 "XLA does not support avgpoolgrad with explicit padding."));
458 OP_REQUIRES(ctx, ksize_[0] == 1 && stride_[0] == 1,
459 errors::Unimplemented(
460 "Pooling is not yet supported on the batch dimension."));
461
462 string data_format;
463 OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format));
464 OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_),
465 errors::InvalidArgument("Invalid data format"));
466 }
467
num_dims() const468 int num_dims() const { return num_spatial_dims_ + 2; }
469
Compile(XlaOpKernelContext * ctx)470 void Compile(XlaOpKernelContext* ctx) override {
471 TensorShape gradients_shape;
472 OP_REQUIRES_OK(
473 ctx, ctx->ConstantInputAsShape(0, &gradients_shape,
474 xla::ValueInferenceMode::kUpperBound));
475
476 const TensorShape out_backprop_shape = ctx->InputShape(1);
477
478 // For avgpooling, tensor_in_shape should have num_dims() dimensions.
479 OP_REQUIRES(ctx, gradients_shape.dims() == num_dims(),
480 errors::InvalidArgument("orig_input_shape must be ", num_dims(),
481 "-dimensional"));
482
483 // For avgpooling, out_backprop should have num_dims() dimensions.
484 OP_REQUIRES(ctx, out_backprop_shape.dims() == num_dims(),
485 errors::InvalidArgument("out_backprop must be ", num_dims(),
486 "-dimensional"));
487
488 auto out_backprop = ctx->Input(1);
489 std::vector<int64_t> stride_int64s(stride_.begin(), stride_.end());
490 xla::Padding xla_padding =
491 (padding_ == VALID) ? xla::Padding::kValid : xla::Padding::kSame;
492 xla::PrimitiveType xla_reduction_type;
493 auto reduction_type = XlaHelpers::SumAccumulationType(ctx->input_type(1));
494 OP_REQUIRES_OK(
495 ctx, DataTypeToPrimitiveType(reduction_type, &xla_reduction_type));
496 auto converted_out_backprop =
497 xla::ConvertElementType(out_backprop, xla_reduction_type);
498 auto xla_data_format =
499 XlaTensorFormat(data_format_, gradients_shape.dims() - 2);
500 auto padding_values =
501 MakeSpatialPadding(gradients_shape.dim_sizes(), ksize_, stride_int64s,
502 xla_padding, xla_data_format);
503 auto in_backprop =
504 xla::AvgPoolGrad(converted_out_backprop, gradients_shape.dim_sizes(),
505 ksize_, stride_int64s, padding_values, xla_data_format,
506 /*counts_include_padding=*/padding_ == VALID);
507 // Convert the pooling result back to the input type before returning it.
508 xla::PrimitiveType xla_out_backprop_type;
509 OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(ctx->input_type(1),
510 &xla_out_backprop_type));
511 ctx->SetOutput(0,
512 xla::ConvertElementType(in_backprop, xla_out_backprop_type));
513 }
514
515 protected:
516 const int num_spatial_dims_;
517 std::vector<int64_t> ksize_;
518 std::vector<int32> stride_;
519 Padding padding_;
520 TensorFormat data_format_ = FORMAT_NHWC;
521 };
522
523 class AvgPool2DGradOp : public AvgPoolGradOp {
524 public:
AvgPool2DGradOp(OpKernelConstruction * ctx)525 explicit AvgPool2DGradOp(OpKernelConstruction* ctx)
526 : AvgPoolGradOp(ctx, /*num_spatial_dims=*/2) {}
527 };
528 REGISTER_XLA_OP(
529 Name("AvgPoolGrad").CompileTimeConstantInput("orig_input_shape"),
530 AvgPool2DGradOp);
531
532 class AvgPool3DGradOp : public AvgPoolGradOp {
533 public:
AvgPool3DGradOp(OpKernelConstruction * ctx)534 explicit AvgPool3DGradOp(OpKernelConstruction* ctx)
535 : AvgPoolGradOp(ctx, /*num_spatial_dims=*/3) {}
536 };
537 REGISTER_XLA_OP(
538 Name("AvgPool3DGrad").CompileTimeConstantInput("orig_input_shape"),
539 AvgPool3DGradOp);
540
541 class MaxPoolGradGradOp : public XlaOpKernel {
542 public:
MaxPoolGradGradOp(OpKernelConstruction * ctx,int num_spatial_dims)543 MaxPoolGradGradOp(OpKernelConstruction* ctx, int num_spatial_dims)
544 : XlaOpKernel(ctx), num_spatial_dims_(num_spatial_dims) {
545 if (ctx->num_inputs() == 3) {
546 OP_REQUIRES_OK(ctx, ctx->GetAttr("ksize", &ksize_));
547 OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &stride_));
548 }
549 OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_));
550 OP_REQUIRES(
551 ctx, padding_ != EXPLICIT,
552 errors::Unimplemented(
553 "XLA does not support maxpoolgradgrad with explicit padding."));
554 }
555
num_dims() const556 int num_dims() const { return num_spatial_dims_ + 2; }
557
Compile(XlaOpKernelContext * ctx)558 void Compile(XlaOpKernelContext* ctx) override {
559 if (ctx->num_inputs() != 3) {
560 OP_REQUIRES(
561 ctx, ctx->num_inputs() == 5,
562 errors::InvalidArgument("Must supply ksize and stride arguments."));
563 const TensorShape ksize_shape = ctx->InputShape(3);
564 // Validate input sizes.
565 OP_REQUIRES(ctx, TensorShapeUtils::IsVector(ksize_shape),
566 errors::InvalidArgument("ksize must be a vector, not shape ",
567 ksize_shape.DebugString()));
568 OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(3, &ksize_));
569
570 const TensorShape stride_shape = ctx->InputShape(4);
571 // Validate input sizes.
572 OP_REQUIRES(ctx, TensorShapeUtils::IsVector(stride_shape),
573 errors::InvalidArgument("stride must be a vector, not shape ",
574 stride_shape.DebugString()));
575 OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(4, &stride_));
576 }
577
578 OP_REQUIRES(ctx, ksize_.size() == num_dims(),
579 errors::InvalidArgument("Sliding window ksize field must "
580 "specify ",
581 num_dims(), " dimensions"));
582 OP_REQUIRES(ctx, stride_.size() == num_dims(),
583 errors::InvalidArgument("Sliding window strides field must "
584 "specify ",
585 num_dims(), " dimensions"));
586
587 const TensorShape tensor_in_shape = ctx->InputShape(0);
588 const TensorShape tensor_out_shape = ctx->InputShape(1);
589 const TensorShape out_backprop_shape = ctx->InputShape(2);
590
591 // For maxpooling, tensor_in should have num_dims() dimensions.
592 OP_REQUIRES(ctx, tensor_in_shape.dims() == num_dims(),
593 errors::InvalidArgument("tensor_in must be ", num_dims(),
594 "-dimensional"));
595 OP_REQUIRES(ctx, tensor_out_shape.dims() == num_dims(),
596 errors::InvalidArgument("tensor_out must be ", num_dims(),
597 "-dimensional"));
598 // For maxpooling, out_backprop should have num_dims() dimensions.
599 OP_REQUIRES(ctx, out_backprop_shape.dims() == num_dims(),
600 errors::InvalidArgument("out_backprop must be ", num_dims(),
601 "-dimensional"));
602
603 // What we want to compute:
604 // Given y = MaxPool(x), and xs_grad = MaxPoolGrad(x, y, ys_grad)
605 // MaxPoolGradGrad computes {ys_grad}_grad given x, y, and {xs_grad}_grad.
606 //
607 // In the regular TF op, this amounts to selecting for each window the
608 // incoming backprop value from xs_grad_grad that corresponds to the maximal
609 // value in the corresponding window of x.
610 //
611 // TODO(b/73062247): What we really want is a ReduceWindow with different
612 // arrays for index selection vs return value selection--a select-to-gather.
613 //
614 // Here, we implement a bitwise hack: we use the hi 16 bits of input for
615 // separate max pooling alongside each of the hi and lo 16 bits of
616 // out_backprop packed into 16 lo bits, which we then glue back together at
617 // the end to get a full 32 bits of gradient.
618 //
619 // This could select the wrong backprop value for two x values that are
620 // equally maximal up to the first 16 bits, in which case we are taking the
621 // latter.
622 //
623 // Note that in principle we could use 32 separate maxpools to recover each
624 // of 32 bits of the gradient while preserving 31 bits of input for the max
625 // pooling criteria; here, we just truncate to the first 16 bits of input.
626
627 auto input = ctx->Input(0);
628 auto out_backprop = ctx->Input(2);
629
630 auto b = ctx->builder();
631
632 auto sixteen = xla::ConstantR0<uint32>(b, 16);
633 // in (f32) -> round to 7 mantissa bits (bf16)-> 16-high-bit u32.
634 //
635 // NOTE: Use a ReducePrecision operation instead of a cast to BF16 and back
636 // to F32 since the XLA compiler may ignore narrowing casts to floating
637 // point types if the debug option xla_allow_excess_precision is set.
638 auto in_hi = xla::BitcastConvertType(
639 xla::ReducePrecision(input, /*exponent_bits=*/8, /*mantissa_bits=*/7),
640 xla::U32);
641 auto bp_int = xla::BitcastConvertType(out_backprop, xla::U32);
642 auto bp_hi = xla::ShiftRightLogical(bp_int, sixteen);
643 auto bp_lo =
644 xla::ShiftRightLogical(xla::ShiftLeft(bp_int, sixteen), sixteen);
645 auto in_hi_bp_hi = xla::Add(in_hi, bp_hi); // Want an unsigned add.
646 auto in_hi_bp_lo = xla::Add(in_hi, bp_lo); // Want an unsigned add.
647
648 auto init_value = xla::MinValue(b, xla::F32);
649 // We will reduce by taking the maximal value up to 16 bits (ignoring the lo
650 // 16 bits of packed-in hi/lo backprop value).
651 auto rb = b->CreateSubBuilder("GreaterOrEqOf_ByFirst16Bits");
652 {
653 // F32 parameters to satisfy lowering type restriction for reduce opcode.
654 const xla::Shape scalar = xla::ShapeUtil::MakeShape(xla::F32, {});
655 auto lhs = xla::Parameter(rb.get(), 0, scalar, "lhs");
656 auto rhs = xla::Parameter(rb.get(), 1, scalar, "rhs");
657 auto sixteen = xla::ConstantR0<int32>(rb.get(), 16);
658 auto lhs_criteria =
659 xla::ShiftLeft(xla::ShiftRightLogical(
660 xla::BitcastConvertType(lhs, xla::S32), sixteen),
661 sixteen);
662 auto rhs_criteria =
663 xla::ShiftLeft(xla::ShiftRightLogical(
664 xla::BitcastConvertType(rhs, xla::S32), sixteen),
665 sixteen);
666 // Must use a F32 comparison, because S32 would not work for negatives.
667 xla::Select(xla::Ge(xla::BitcastConvertType(lhs_criteria, xla::F32),
668 xla::BitcastConvertType(rhs_criteria, xla::F32)),
669 lhs, rhs);
670 }
671 auto reduce = rb->BuildAndNoteError();
672 xla::Padding xla_padding =
673 (padding_ == VALID) ? xla::Padding::kValid : xla::Padding::kSame;
674 auto pooled_hi =
675 xla::ReduceWindow(xla::BitcastConvertType(in_hi_bp_hi, xla::F32),
676 init_value, reduce, ksize_, stride_, xla_padding);
677 auto pooled_lo =
678 xla::ReduceWindow(xla::BitcastConvertType(in_hi_bp_lo, xla::F32),
679 init_value, reduce, ksize_, stride_, xla_padding);
680 auto grads_hi =
681 xla::ShiftLeft(xla::BitcastConvertType(pooled_hi, xla::U32), sixteen);
682 auto grads_lo = xla::ShiftRightLogical(
683 xla::ShiftLeft(xla::BitcastConvertType(pooled_lo, xla::U32), sixteen),
684 sixteen);
685 auto grads = xla::Add(grads_hi, grads_lo); // Want an unsigned add.
686
687 xla::PrimitiveType element_type;
688 OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(input_type(2), &element_type));
689 ctx->SetOutput(0, xla::BitcastConvertType(grads, element_type));
690 }
691
692 protected:
693 const int num_spatial_dims_;
694 std::vector<int64_t> ksize_;
695 std::vector<int64_t> stride_;
696 Padding padding_;
697 TensorFormat data_format_ = FORMAT_NHWC;
698 };
699
700 class MaxPool2DGradGradOp : public MaxPoolGradGradOp {
701 public:
MaxPool2DGradGradOp(OpKernelConstruction * ctx)702 explicit MaxPool2DGradGradOp(OpKernelConstruction* ctx)
703 : MaxPoolGradGradOp(ctx, /*num_spatial_dims=*/2) {
704 string data_format;
705 OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format));
706 OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_),
707 errors::InvalidArgument("Invalid data format"));
708 }
709 };
710 REGISTER_XLA_OP(Name("MaxPoolGradGrad").TypeConstraint("T", DT_FLOAT),
711 MaxPool2DGradGradOp);
712 REGISTER_XLA_OP(Name("MaxPoolGradGradV2")
713 .TypeConstraint("T", DT_FLOAT)
714 .CompileTimeConstantInput("ksize")
715 .CompileTimeConstantInput("strides"),
716 MaxPool2DGradGradOp);
717
718 class MaxPool3DGradGradOp : public MaxPoolGradGradOp {
719 public:
MaxPool3DGradGradOp(OpKernelConstruction * ctx)720 explicit MaxPool3DGradGradOp(OpKernelConstruction* ctx)
721 : MaxPoolGradGradOp(ctx, /*num_spatial_dims=*/3) {
722 string data_format;
723 OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format));
724 OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_),
725 errors::InvalidArgument("Invalid data format"));
726 }
727 };
728 REGISTER_XLA_OP(Name("MaxPool3DGradGrad").TypeConstraint("T", DT_FLOAT),
729 MaxPool3DGradGradOp);
730
731 } // anonymous namespace
732 } // namespace tensorflow
733