1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // XLA specific pooling ops.
17
18 #include "tensorflow/compiler/tf2xla/shape_util.h"
19 #include "tensorflow/compiler/tf2xla/type_util.h"
20 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
21 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
22 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
23 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
24 #include "tensorflow/compiler/xla/client/lib/constants.h"
25 #include "tensorflow/compiler/xla/client/lib/pooling.h"
26 #include "tensorflow/compiler/xla/client/xla_builder.h"
27 #include "tensorflow/compiler/xla/client/xla_computation.h"
28 #include "tensorflow/compiler/xla/literal.h"
29 #include "tensorflow/compiler/xla/util.h"
30 #include "tensorflow/core/framework/bounds_check.h"
31 #include "tensorflow/core/framework/op_kernel.h"
32 #include "tensorflow/core/framework/register_types.h"
33 #include "tensorflow/core/framework/tensor.h"
34 #include "tensorflow/core/platform/errors.h"
35 #include "tensorflow/core/util/tensor_format.h"
36
37 namespace tensorflow {
38 namespace {
39
40 // Superclass of pooling ops.
41 class PoolingOp : public XlaOpKernel {
42 public:
PoolingOp(OpKernelConstruction * ctx,int num_spatial_dims,const DataType reduction_type)43 PoolingOp(OpKernelConstruction* ctx, int num_spatial_dims,
44 const DataType reduction_type)
45 : XlaOpKernel(ctx),
46 num_spatial_dims_(num_spatial_dims),
47 reduction_type_(reduction_type) {
48 if (ctx->num_inputs() == 1) {
49 std::vector<int32> ksize_int;
50 std::vector<int32> stride_int;
51 OP_REQUIRES_OK(ctx, ctx->GetAttr("ksize", &ksize_int));
52 OP_REQUIRES(ctx, ksize_int.size() == num_dims(),
53 errors::InvalidArgument("Sliding window ksize field must "
54 "specify ",
55 num_dims(), " dimensions"));
56 OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &stride_int));
57 OP_REQUIRES(ctx, stride_int.size() == num_dims(),
58 errors::InvalidArgument("Sliding window stride field must "
59 "specify ",
60 num_dims(), " dimensions"));
61 for (int i = 0; i < num_dims(); ++i) {
62 ksize_.push_back(ksize_int[i]);
63 stride_.push_back(stride_int[i]);
64 }
65 }
66 Padding padding;
67 OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding));
68 padding_ = (padding == VALID) ? xla::Padding::kValid : xla::Padding::kSame;
69
70 OP_REQUIRES_OK(
71 ctx, DataTypeToPrimitiveType(reduction_type_, &xla_reduction_type_));
72 }
73
num_dims() const74 int num_dims() const { return num_spatial_dims_ + 2; }
75
76 protected:
GetKernelSize(XlaOpKernelContext * ctx)77 xla::StatusOr<std::vector<int64>> GetKernelSize(XlaOpKernelContext* ctx) {
78 if (ctx->num_inputs() == 1) {
79 return ksize_;
80 }
81 const TensorShape ksize_shape = ctx->InputShape(1);
82 // Validate input sizes.
83 if (!TensorShapeUtils::IsVector(ksize_shape)) {
84 return errors::InvalidArgument("ksize must be a vector, not shape ",
85 ksize_shape.DebugString());
86 }
87 if (ksize_shape.num_elements() != num_dims()) {
88 return errors::InvalidArgument(
89 "Sliding window ksize field must "
90 "specify ",
91 num_dims(), " dimensions");
92 }
93 std::vector<int64> ksize;
94 auto status = ctx->ConstantInputAsIntVector(1, &ksize);
95 if (!status.ok()) {
96 return status;
97 }
98 return ksize;
99 }
100
GetStride(XlaOpKernelContext * ctx)101 xla::StatusOr<std::vector<int64>> GetStride(XlaOpKernelContext* ctx) {
102 if (ctx->num_inputs() == 1) {
103 return stride_;
104 }
105 const TensorShape stride_shape = ctx->InputShape(2);
106 // Validate input sizes.
107 if (!TensorShapeUtils::IsVector(stride_shape)) {
108 return errors::InvalidArgument("stride must be a vector, not shape ",
109 stride_shape.DebugString());
110 }
111 if (stride_shape.num_elements() != num_dims()) {
112 return errors::InvalidArgument(
113 "Sliding window stride field must "
114 "specify ",
115 num_dims(), " dimensions");
116 }
117 std::vector<int64> stride;
118 auto status = ctx->ConstantInputAsIntVector(2, &stride);
119 if (!status.ok()) {
120 return status;
121 }
122 return stride;
123 }
124
125 protected:
126 const int num_spatial_dims_;
127 std::vector<int64> ksize_;
128 std::vector<int64> stride_;
129 xla::Padding padding_;
130 TensorFormat data_format_ = FORMAT_NHWC;
131 DataType reduction_type_;
132 xla::PrimitiveType xla_reduction_type_;
133 };
134
135 // Converts the tensor data format to the one required by the XLA pooling
136 // library.
XlaTensorFormat(tensorflow::TensorFormat data_format,int num_spatial_dims)137 xla::TensorFormat XlaTensorFormat(tensorflow::TensorFormat data_format,
138 int num_spatial_dims) {
139 int num_dims = num_spatial_dims + 2;
140 int batch_dimension = GetTensorBatchDimIndex(num_dims, data_format);
141 int feature_dimension = GetTensorFeatureDimIndex(num_dims, data_format);
142 absl::InlinedVector<int64, 4> spatial_dimensions(num_spatial_dims);
143 for (int spatial_dim = 0; spatial_dim < num_spatial_dims; ++spatial_dim) {
144 spatial_dimensions[spatial_dim] =
145 GetTensorSpatialDimIndex(num_dims, data_format, spatial_dim);
146 }
147 return xla::TensorFormat(/*batch_dimension=*/batch_dimension,
148 /*feature_dimension=*/feature_dimension,
149 /*spatial_dimensions=*/spatial_dimensions);
150 }
151
152 class MaxPoolOp : public PoolingOp {
153 public:
MaxPoolOp(OpKernelConstruction * ctx,int num_spatial_dims)154 MaxPoolOp(OpKernelConstruction* ctx, int num_spatial_dims)
155 : PoolingOp(ctx, /*num_spatial_dims=*/num_spatial_dims,
156 /*reduction_type=*/ctx->input_type(0)) {
157 string data_format_str;
158 OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format_str));
159 OP_REQUIRES(ctx, FormatFromString(data_format_str, &data_format_),
160 errors::InvalidArgument("Invalid data format"));
161 OP_REQUIRES(
162 ctx,
163 data_format_ != FORMAT_NCHW_VECT_C &&
164 data_format_ != FORMAT_NHWC_VECT_W,
165 errors::Unimplemented("XLA does not support the VECT_* data formats. "
166 "Returning unimplemented from MaxPool to keep "
167 "Tensorflow's intended optimized MaxPool here."));
168 }
169
Compile(XlaOpKernelContext * ctx)170 void Compile(XlaOpKernelContext* ctx) override {
171 auto ksize_or_error = GetKernelSize(ctx);
172 OP_REQUIRES_OK(ctx, ksize_or_error.status());
173 std::vector<int64> ksize = ksize_or_error.ValueOrDie();
174
175 auto stride_or_error = GetStride(ctx);
176 OP_REQUIRES_OK(ctx, stride_or_error.status());
177 std::vector<int64> stride = stride_or_error.ValueOrDie();
178
179 const TensorShape input_shape = ctx->InputShape(0);
180 OP_REQUIRES(ctx, input_shape.dims() == num_dims(),
181 errors::InvalidArgument("Input to ", type_string(),
182 " operator must have ", num_dims(),
183 " dimensions"));
184
185 auto pooling =
186 xla::MaxPool(ctx->Input(0), ksize, stride, padding_,
187 XlaTensorFormat(data_format_, input_shape.dims() - 2));
188 ctx->SetOutput(0, pooling);
189 }
190 };
191
192 class MaxPool2DOp : public MaxPoolOp {
193 public:
MaxPool2DOp(OpKernelConstruction * ctx)194 explicit MaxPool2DOp(OpKernelConstruction* ctx)
195 : MaxPoolOp(ctx, /*num_spatial_dims=*/2) {}
196 };
197 REGISTER_XLA_OP(Name("MaxPool"), MaxPool2DOp);
198 REGISTER_XLA_OP(Name("MaxPoolV2")
199 .CompileTimeConstantInput("ksize")
200 .CompileTimeConstantInput("strides"),
201 MaxPool2DOp);
202
203 class MaxPool3DOp : public MaxPoolOp {
204 public:
MaxPool3DOp(OpKernelConstruction * ctx)205 explicit MaxPool3DOp(OpKernelConstruction* ctx)
206 : MaxPoolOp(ctx, /*num_spatial_dims=*/3) {}
207 };
208 REGISTER_XLA_OP(Name("MaxPool3D"), MaxPool3DOp);
209
210 class AvgPoolOp : public PoolingOp {
211 public:
AvgPoolOp(OpKernelConstruction * ctx,int num_spatial_dims)212 AvgPoolOp(OpKernelConstruction* ctx, int num_spatial_dims)
213 : PoolingOp(ctx, /*num_spatial_dims=*/num_spatial_dims,
214 /*reduction_type=*/
215 XlaHelpers::SumAccumulationType(ctx->input_type(0))) {
216 string data_format_str;
217 OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format_str));
218 OP_REQUIRES(ctx, FormatFromString(data_format_str, &data_format_),
219 errors::InvalidArgument("Invalid data format"));
220 }
221
Compile(XlaOpKernelContext * ctx)222 void Compile(XlaOpKernelContext* ctx) override {
223 auto ksize_or_error = GetKernelSize(ctx);
224 OP_REQUIRES_OK(ctx, ksize_or_error.status());
225 std::vector<int64> ksize = ksize_or_error.ValueOrDie();
226
227 auto stride_or_error = GetStride(ctx);
228 OP_REQUIRES_OK(ctx, stride_or_error.status());
229 std::vector<int64> stride = stride_or_error.ValueOrDie();
230
231 const TensorShape input_shape = ctx->InputShape(0);
232 OP_REQUIRES(ctx, input_shape.dims() == num_dims(),
233 errors::InvalidArgument("Input to ", type_string(),
234 " operator must have ", num_dims(),
235 " dimensions"));
236
237 auto xla_data_format =
238 XlaTensorFormat(data_format_, input_shape.dims() - 2);
239 auto spatial_padding = MakeSpatialPadding(
240 input_shape.dim_sizes(), ksize, stride, padding_, xla_data_format);
241
242 // Convert the input to the reduction type.
243 auto converted_input =
244 ConvertElementType(ctx->Input(0), xla_reduction_type_);
245 auto pooling =
246 xla::AvgPool(converted_input, ksize, stride, spatial_padding,
247 xla_data_format, padding_ == xla::Padding::kValid);
248 // Convert the pooling result back to the input type before returning it.
249 ctx->SetOutput(0, ConvertElementType(pooling, ctx->input_xla_type(0)));
250 }
251 };
252
253 class AvgPool2DOp : public AvgPoolOp {
254 public:
AvgPool2DOp(OpKernelConstruction * ctx)255 explicit AvgPool2DOp(OpKernelConstruction* ctx)
256 : AvgPoolOp(ctx, /*num_spatial_dims=*/2) {}
257 };
258 REGISTER_XLA_OP(Name("AvgPool"), AvgPool2DOp);
259
260 class AvgPool3DOp : public AvgPoolOp {
261 public:
AvgPool3DOp(OpKernelConstruction * ctx)262 explicit AvgPool3DOp(OpKernelConstruction* ctx)
263 : AvgPoolOp(ctx, /*num_spatial_dims=*/3) {}
264 };
265 REGISTER_XLA_OP(Name("AvgPool3D"), AvgPool3DOp);
266
267 // The operation to compute MaxPool gradients.
268 // It takes three inputs:
269 // - The original input tensor
270 // - The original output tensor
271 // - Backprop tensor for output
272 // It produces one output: backprop tensor for input.
273 class MaxPoolGradOp : public XlaOpKernel {
274 public:
MaxPoolGradOp(OpKernelConstruction * ctx,int num_spatial_dims)275 MaxPoolGradOp(OpKernelConstruction* ctx, int num_spatial_dims)
276 : XlaOpKernel(ctx), num_spatial_dims_(num_spatial_dims) {
277 if (ctx->num_inputs() == 3) {
278 OP_REQUIRES_OK(ctx, ctx->GetAttr("ksize", &ksize_));
279 OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &stride_));
280 }
281 OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_));
282 }
283
num_dims() const284 int num_dims() const { return num_spatial_dims_ + 2; }
285
Compile(XlaOpKernelContext * ctx)286 void Compile(XlaOpKernelContext* ctx) override {
287 if (ctx->num_inputs() != 3) {
288 OP_REQUIRES(
289 ctx, ctx->num_inputs() == 5,
290 errors::InvalidArgument("Must supply ksize and stride arguments."));
291 const TensorShape ksize_shape = ctx->InputShape(3);
292 // Validate input sizes.
293 OP_REQUIRES(ctx, TensorShapeUtils::IsVector(ksize_shape),
294 errors::InvalidArgument("ksize must be a vector, not shape ",
295 ksize_shape.DebugString()));
296 OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(3, &ksize_));
297
298 const TensorShape stride_shape = ctx->InputShape(4);
299 // Validate input sizes.
300 OP_REQUIRES(ctx, TensorShapeUtils::IsVector(stride_shape),
301 errors::InvalidArgument("stride must be a vector, not shape ",
302 stride_shape.DebugString()));
303 OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(4, &stride_));
304 }
305
306 OP_REQUIRES(ctx, ksize_.size() == num_dims(),
307 errors::InvalidArgument("Sliding window ksize field must "
308 "specify ",
309 num_dims(), " dimensions"));
310 OP_REQUIRES(ctx, stride_.size() == num_dims(),
311 errors::InvalidArgument("Sliding window strides field must "
312 "specify ",
313 num_dims(), " dimensions"));
314
315 const TensorShape tensor_in_shape = ctx->InputShape(0);
316 const TensorShape tensor_out_shape = ctx->InputShape(1);
317 const TensorShape out_backprop_shape = ctx->InputShape(2);
318
319 // For maxpooling, tensor_in should have num_dims() dimensions.
320 OP_REQUIRES(ctx, tensor_in_shape.dims() == num_dims(),
321 errors::InvalidArgument("tensor_in must be ", num_dims(),
322 "-dimensional"));
323 OP_REQUIRES(ctx, tensor_out_shape.dims() == num_dims(),
324 errors::InvalidArgument("tensor_out must be ", num_dims(),
325 "-dimensional"));
326 // For maxpooling, out_backprop should have num_dims() dimensions.
327 OP_REQUIRES(ctx, out_backprop_shape.dims() == num_dims(),
328 errors::InvalidArgument("out_backprop must be ", num_dims(),
329 "-dimensional"));
330
331 // TODO(phawkins): The XLA version doesn't need tensor_out. Investigate
332 // whether this is a good time/space tradeoff.
333 auto input = ctx->Input(0);
334 auto out_backprop = ctx->Input(2);
335
336 xla::Padding xla_padding =
337 (padding_ == VALID) ? xla::Padding::kValid : xla::Padding::kSame;
338
339 // Create a MaxPool operation to check the expected resulting shape, and
340 // then throw away the operation because we don't actually need it here.
341 TensorShape expected_out_shape;
342 auto pooling =
343 xla::MaxPool(ctx->Input(0), ksize_, stride_, xla_padding,
344 XlaTensorFormat(data_format_, tensor_in_shape.dims() - 2));
345 auto status_or_shape = pooling.builder()->GetShape(pooling);
346 OP_REQUIRES_OK(ctx, status_or_shape.status());
347 OP_REQUIRES_OK(ctx, XLAShapeToTensorShape(status_or_shape.ValueOrDie(),
348 &expected_out_shape));
349 OP_REQUIRES(ctx, expected_out_shape == out_backprop_shape,
350 errors::Unimplemented("The output dimensions do not match the "
351 "other input values."));
352
353 xla::PrimitiveType element_type;
354 OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(input_type(2), &element_type));
355 xla::XlaOp init_value = XlaHelpers::Zero(ctx->builder(), input_type(2));
356 auto select = CreateScalarGeComputation(element_type, ctx->builder());
357 auto scatter = CreateScalarAddComputation(element_type, ctx->builder());
358 xla::XlaOp gradients =
359 xla::SelectAndScatter(input, select, ksize_, stride_, xla_padding,
360 out_backprop, init_value, scatter);
361
362 ctx->SetOutput(0, gradients);
363 }
364
365 protected:
366 const int num_spatial_dims_;
367 std::vector<int64> ksize_;
368 std::vector<int64> stride_;
369 Padding padding_;
370 TensorFormat data_format_ = FORMAT_NHWC;
371 };
372
373 class MaxPool2DGradOp : public MaxPoolGradOp {
374 public:
MaxPool2DGradOp(OpKernelConstruction * ctx)375 explicit MaxPool2DGradOp(OpKernelConstruction* ctx)
376 : MaxPoolGradOp(ctx, /*num_spatial_dims=*/2) {
377 string data_format;
378 OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format));
379 OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_),
380 errors::InvalidArgument("Invalid data format"));
381 }
382 };
383 REGISTER_XLA_OP(Name("MaxPoolGrad"), MaxPool2DGradOp);
384 REGISTER_XLA_OP(Name("MaxPoolGradV2")
385 .CompileTimeConstantInput("ksize")
386 .CompileTimeConstantInput("strides"),
387 MaxPool2DGradOp);
388
389 class MaxPool3DGradOp : public MaxPoolGradOp {
390 public:
MaxPool3DGradOp(OpKernelConstruction * ctx)391 explicit MaxPool3DGradOp(OpKernelConstruction* ctx)
392 : MaxPoolGradOp(ctx, /*num_spatial_dims=*/3) {}
393 };
394 REGISTER_XLA_OP(Name("MaxPool3DGrad"), MaxPool3DGradOp);
395
396 // Average-pooling gradient
397 class AvgPoolGradOp : public XlaOpKernel {
398 public:
AvgPoolGradOp(OpKernelConstruction * ctx,int num_spatial_dims)399 AvgPoolGradOp(OpKernelConstruction* ctx, int num_spatial_dims)
400 : XlaOpKernel(ctx), num_spatial_dims_(num_spatial_dims) {
401 OP_REQUIRES_OK(ctx, ctx->GetAttr("ksize", &ksize_));
402 OP_REQUIRES(ctx, ksize_.size() == num_dims(),
403 errors::InvalidArgument("Sliding window ksize field must "
404 "specify ",
405 num_dims(), " dimensions"));
406 OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &stride_));
407 OP_REQUIRES(ctx, stride_.size() == num_dims(),
408 errors::InvalidArgument("Sliding window strides field must "
409 "specify ",
410 num_dims(), " dimensions"));
411 OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_));
412 OP_REQUIRES(ctx, ksize_[0] == 1 && stride_[0] == 1,
413 errors::Unimplemented(
414 "Pooling is not yet supported on the batch dimension."));
415
416 string data_format;
417 OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format));
418 OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_),
419 errors::InvalidArgument("Invalid data format"));
420 }
421
num_dims() const422 int num_dims() const { return num_spatial_dims_ + 2; }
423
Compile(XlaOpKernelContext * ctx)424 void Compile(XlaOpKernelContext* ctx) override {
425 TensorShape gradients_shape;
426 OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &gradients_shape));
427
428 const TensorShape out_backprop_shape = ctx->InputShape(1);
429
430 // For avgpooling, tensor_in_shape should have num_dims() dimensions.
431 OP_REQUIRES(ctx, gradients_shape.dims() == num_dims(),
432 errors::InvalidArgument("orig_input_shape must be ", num_dims(),
433 "-dimensional"));
434
435 // For avgpooling, out_backprop should have num_dims() dimensions.
436 OP_REQUIRES(ctx, out_backprop_shape.dims() == num_dims(),
437 errors::InvalidArgument("out_backprop must be ", num_dims(),
438 "-dimensional"));
439
440 auto out_backprop = ctx->Input(1);
441 std::vector<int64> stride_int64s(stride_.begin(), stride_.end());
442 xla::Padding xla_padding =
443 (padding_ == VALID) ? xla::Padding::kValid : xla::Padding::kSame;
444 xla::PrimitiveType xla_reduction_type;
445 auto reduction_type = XlaHelpers::SumAccumulationType(ctx->input_type(1));
446 OP_REQUIRES_OK(
447 ctx, DataTypeToPrimitiveType(reduction_type, &xla_reduction_type));
448 auto converted_out_backprop =
449 xla::ConvertElementType(out_backprop, xla_reduction_type);
450 auto xla_data_format =
451 XlaTensorFormat(data_format_, gradients_shape.dims() - 2);
452 auto padding_values =
453 MakeSpatialPadding(gradients_shape.dim_sizes(), ksize_, stride_int64s,
454 xla_padding, xla_data_format);
455 auto in_backprop =
456 xla::AvgPoolGrad(converted_out_backprop, gradients_shape.dim_sizes(),
457 ksize_, stride_int64s, padding_values, xla_data_format,
458 /*counts_include_padding=*/padding_ == VALID);
459 // Convert the pooling result back to the input type before returning it.
460 xla::PrimitiveType xla_out_backprop_type;
461 OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(ctx->input_type(1),
462 &xla_out_backprop_type));
463 ctx->SetOutput(0,
464 xla::ConvertElementType(in_backprop, xla_out_backprop_type));
465 }
466
467 protected:
468 const int num_spatial_dims_;
469 std::vector<int64> ksize_;
470 std::vector<int32> stride_;
471 Padding padding_;
472 TensorFormat data_format_ = FORMAT_NHWC;
473 };
474
475 class AvgPool2DGradOp : public AvgPoolGradOp {
476 public:
AvgPool2DGradOp(OpKernelConstruction * ctx)477 explicit AvgPool2DGradOp(OpKernelConstruction* ctx)
478 : AvgPoolGradOp(ctx, /*num_spatial_dims=*/2) {}
479 };
480 REGISTER_XLA_OP(
481 Name("AvgPoolGrad").CompileTimeConstantInput("orig_input_shape"),
482 AvgPool2DGradOp);
483
484 class AvgPool3DGradOp : public AvgPoolGradOp {
485 public:
AvgPool3DGradOp(OpKernelConstruction * ctx)486 explicit AvgPool3DGradOp(OpKernelConstruction* ctx)
487 : AvgPoolGradOp(ctx, /*num_spatial_dims=*/3) {}
488 };
489 REGISTER_XLA_OP(
490 Name("AvgPool3DGrad").CompileTimeConstantInput("orig_input_shape"),
491 AvgPool3DGradOp);
492
493 class MaxPoolGradGradOp : public XlaOpKernel {
494 public:
MaxPoolGradGradOp(OpKernelConstruction * ctx,int num_spatial_dims)495 MaxPoolGradGradOp(OpKernelConstruction* ctx, int num_spatial_dims)
496 : XlaOpKernel(ctx), num_spatial_dims_(num_spatial_dims) {
497 if (ctx->num_inputs() == 3) {
498 OP_REQUIRES_OK(ctx, ctx->GetAttr("ksize", &ksize_));
499 OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &stride_));
500 }
501 OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_));
502 }
503
num_dims() const504 int num_dims() const { return num_spatial_dims_ + 2; }
505
Compile(XlaOpKernelContext * ctx)506 void Compile(XlaOpKernelContext* ctx) override {
507 if (ctx->num_inputs() != 3) {
508 OP_REQUIRES(
509 ctx, ctx->num_inputs() == 5,
510 errors::InvalidArgument("Must supply ksize and stride arguments."));
511 const TensorShape ksize_shape = ctx->InputShape(3);
512 // Validate input sizes.
513 OP_REQUIRES(ctx, TensorShapeUtils::IsVector(ksize_shape),
514 errors::InvalidArgument("ksize must be a vector, not shape ",
515 ksize_shape.DebugString()));
516 OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(3, &ksize_));
517
518 const TensorShape stride_shape = ctx->InputShape(4);
519 // Validate input sizes.
520 OP_REQUIRES(ctx, TensorShapeUtils::IsVector(stride_shape),
521 errors::InvalidArgument("stride must be a vector, not shape ",
522 stride_shape.DebugString()));
523 OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(4, &stride_));
524 }
525
526 OP_REQUIRES(ctx, ksize_.size() == num_dims(),
527 errors::InvalidArgument("Sliding window ksize field must "
528 "specify ",
529 num_dims(), " dimensions"));
530 OP_REQUIRES(ctx, stride_.size() == num_dims(),
531 errors::InvalidArgument("Sliding window strides field must "
532 "specify ",
533 num_dims(), " dimensions"));
534
535 const TensorShape tensor_in_shape = ctx->InputShape(0);
536 const TensorShape tensor_out_shape = ctx->InputShape(1);
537 const TensorShape out_backprop_shape = ctx->InputShape(2);
538
539 // For maxpooling, tensor_in should have num_dims() dimensions.
540 OP_REQUIRES(ctx, tensor_in_shape.dims() == num_dims(),
541 errors::InvalidArgument("tensor_in must be ", num_dims(),
542 "-dimensional"));
543 OP_REQUIRES(ctx, tensor_out_shape.dims() == num_dims(),
544 errors::InvalidArgument("tensor_out must be ", num_dims(),
545 "-dimensional"));
546 // For maxpooling, out_backprop should have num_dims() dimensions.
547 OP_REQUIRES(ctx, out_backprop_shape.dims() == num_dims(),
548 errors::InvalidArgument("out_backprop must be ", num_dims(),
549 "-dimensional"));
550
551 // What we want to compute:
552 // Given y = MaxPool(x), and xs_grad = MaxPoolGrad(x, y, ys_grad)
553 // MaxPoolGradGrad computes {ys_grad}_grad given x, y, and {xs_grad}_grad.
554 //
555 // In the regular TF op, this amounts to selecting for each window the
556 // incoming backprop value from xs_grad_grad that corresponds to the maximal
557 // value in the corresponding window of x.
558 //
559 // TODO(b/73062247): What we really want is a ReduceWindow with different
560 // arrays for index selection vs return value selection--a select-to-gather.
561 //
562 // Here, we implement a bitwise hack: we use the hi 16 bits of input for
563 // separate max pooling alongside each of the hi and lo 16 bits of
564 // out_backprop packed into 16 lo bits, which we then glue back together at
565 // the end to get a full 32 bits of gradient.
566 //
567 // This could select the wrong backprop value for two x values that are
568 // equally maximal up to the first 16 bits, in which case we are taking the
569 // latter.
570 //
571 // Note that in principle we could use 32 separate maxpools to recover each
572 // of 32 bits of the gradient while preserving 31 bits of input for the max
573 // pooling criteria; here, we just truncate to the first 16 bits of input.
574
575 auto input = ctx->Input(0);
576 auto out_backprop = ctx->Input(2);
577
578 auto b = ctx->builder();
579
580 auto sixteen = xla::ConstantR0<uint32>(b, 16);
581 // in (f32) -> round to 7 mantissa bits (bf16)-> 16-high-bit u32.
582 //
583 // NOTE: Use a ReducePrecision operation instead of a cast to BF16 and back
584 // to F32 since the XLA compiler may ignore narrowing casts to floating
585 // point types if the debug option xla_allow_excess_precision is set.
586 auto in_hi = xla::BitcastConvertType(
587 xla::ReducePrecision(input, /*exponent_bits=*/8, /*mantissa_bits=*/7),
588 xla::U32);
589 auto bp_int = xla::BitcastConvertType(out_backprop, xla::U32);
590 auto bp_hi = xla::ShiftRightLogical(bp_int, sixteen);
591 auto bp_lo =
592 xla::ShiftRightLogical(xla::ShiftLeft(bp_int, sixteen), sixteen);
593 auto in_hi_bp_hi = xla::Add(in_hi, bp_hi); // Want an unsigned add.
594 auto in_hi_bp_lo = xla::Add(in_hi, bp_lo); // Want an unsigned add.
595
596 auto init_value = xla::MinValue(b, xla::F32);
597 // We will reduce by taking the maximal value up to 16 bits (ignoring the lo
598 // 16 bits of packed-in hi/lo backprop value).
599 auto rb = b->CreateSubBuilder("GreaterOrEqOf_ByFirst16Bits");
600 {
601 // F32 parameters to satisfy lowering type restriction for reduce opcode.
602 const xla::Shape scalar = xla::ShapeUtil::MakeShape(xla::F32, {});
603 auto lhs = xla::Parameter(rb.get(), 0, scalar, "lhs");
604 auto rhs = xla::Parameter(rb.get(), 1, scalar, "rhs");
605 auto sixteen = xla::ConstantR0<int32>(rb.get(), 16);
606 auto lhs_criteria =
607 xla::ShiftLeft(xla::ShiftRightLogical(
608 xla::BitcastConvertType(lhs, xla::S32), sixteen),
609 sixteen);
610 auto rhs_criteria =
611 xla::ShiftLeft(xla::ShiftRightLogical(
612 xla::BitcastConvertType(rhs, xla::S32), sixteen),
613 sixteen);
614 // Must use a F32 comparison, because S32 would not work for negatives.
615 xla::Select(xla::Ge(xla::BitcastConvertType(lhs_criteria, xla::F32),
616 xla::BitcastConvertType(rhs_criteria, xla::F32)),
617 lhs, rhs);
618 }
619 auto reduce = rb->BuildAndNoteError();
620 xla::Padding xla_padding =
621 (padding_ == VALID) ? xla::Padding::kValid : xla::Padding::kSame;
622 auto pooled_hi =
623 xla::ReduceWindow(xla::BitcastConvertType(in_hi_bp_hi, xla::F32),
624 init_value, reduce, ksize_, stride_, xla_padding);
625 auto pooled_lo =
626 xla::ReduceWindow(xla::BitcastConvertType(in_hi_bp_lo, xla::F32),
627 init_value, reduce, ksize_, stride_, xla_padding);
628 auto grads_hi =
629 xla::ShiftLeft(xla::BitcastConvertType(pooled_hi, xla::U32), sixteen);
630 auto grads_lo = xla::ShiftRightLogical(
631 xla::ShiftLeft(xla::BitcastConvertType(pooled_lo, xla::U32), sixteen),
632 sixteen);
633 auto grads = xla::Add(grads_hi, grads_lo); // Want an unsigned add.
634
635 xla::PrimitiveType element_type;
636 OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(input_type(2), &element_type));
637 ctx->SetOutput(0, xla::BitcastConvertType(grads, element_type));
638 }
639
640 protected:
641 const int num_spatial_dims_;
642 std::vector<int64> ksize_;
643 std::vector<int64> stride_;
644 Padding padding_;
645 TensorFormat data_format_ = FORMAT_NHWC;
646 };
647
648 class MaxPool2DGradGradOp : public MaxPoolGradGradOp {
649 public:
MaxPool2DGradGradOp(OpKernelConstruction * ctx)650 explicit MaxPool2DGradGradOp(OpKernelConstruction* ctx)
651 : MaxPoolGradGradOp(ctx, /*num_spatial_dims=*/2) {
652 string data_format;
653 OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format));
654 OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_),
655 errors::InvalidArgument("Invalid data format"));
656 }
657 };
658 REGISTER_XLA_OP(Name("MaxPoolGradGrad").TypeConstraint("T", DT_FLOAT),
659 MaxPool2DGradGradOp);
660 REGISTER_XLA_OP(Name("MaxPoolGradGradV2")
661 .TypeConstraint("T", DT_FLOAT)
662 .CompileTimeConstantInput("ksize")
663 .CompileTimeConstantInput("strides"),
664 MaxPool2DGradGradOp);
665
666 class MaxPool3DGradGradOp : public MaxPoolGradGradOp {
667 public:
MaxPool3DGradGradOp(OpKernelConstruction * ctx)668 explicit MaxPool3DGradGradOp(OpKernelConstruction* ctx)
669 : MaxPoolGradGradOp(ctx, /*num_spatial_dims=*/3) {
670 string data_format;
671 OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format));
672 OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_),
673 errors::InvalidArgument("Invalid data format"));
674 }
675 };
676 REGISTER_XLA_OP(Name("MaxPool3DGradGrad").TypeConstraint("T", DT_FLOAT),
677 MaxPool3DGradGradOp);
678
679 } // anonymous namespace
680 } // namespace tensorflow
681