1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // XLA specific pooling ops.
17
18 #include "tensorflow/compiler/tf2xla/type_util.h"
19 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
20 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
21 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
22 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
23 #include "tensorflow/compiler/xla/client/lib/constants.h"
24 #include "tensorflow/compiler/xla/client/lib/pooling.h"
25 #include "tensorflow/compiler/xla/client/xla_builder.h"
26 #include "tensorflow/compiler/xla/client/xla_computation.h"
27 #include "tensorflow/compiler/xla/literal.h"
28 #include "tensorflow/compiler/xla/util.h"
29 #include "tensorflow/core/framework/bounds_check.h"
30 #include "tensorflow/core/framework/op_kernel.h"
31 #include "tensorflow/core/framework/register_types.h"
32 #include "tensorflow/core/framework/tensor.h"
33 #include "tensorflow/core/kernels/pooling_ops_common.h"
34
35 namespace tensorflow {
36 namespace {
37
38 // Superclass of pooling ops.
39 class PoolingOp : public XlaOpKernel {
40 public:
PoolingOp(OpKernelConstruction * ctx,int num_spatial_dims,const DataType reduction_type)41 PoolingOp(OpKernelConstruction* ctx, int num_spatial_dims,
42 const DataType reduction_type)
43 : XlaOpKernel(ctx),
44 num_spatial_dims_(num_spatial_dims),
45 reduction_type_(reduction_type) {
46 if (ctx->num_inputs() == 1) {
47 std::vector<int32> ksize_int;
48 std::vector<int32> stride_int;
49 OP_REQUIRES_OK(ctx, ctx->GetAttr("ksize", &ksize_int));
50 OP_REQUIRES(ctx, ksize_int.size() == num_dims(),
51 errors::InvalidArgument("Sliding window ksize field must "
52 "specify ",
53 num_dims(), " dimensions"));
54 OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &stride_int));
55 OP_REQUIRES(ctx, stride_int.size() == num_dims(),
56 errors::InvalidArgument("Sliding window stride field must "
57 "specify ",
58 num_dims(), " dimensions"));
59 for (int i = 0; i < num_dims(); ++i) {
60 ksize_.push_back(ksize_int[i]);
61 stride_.push_back(stride_int[i]);
62 }
63 }
64 Padding padding;
65 OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding));
66 padding_ = (padding == VALID) ? xla::Padding::kValid : xla::Padding::kSame;
67
68 OP_REQUIRES_OK(
69 ctx, DataTypeToPrimitiveType(reduction_type_, &xla_reduction_type_));
70 }
71
num_dims() const72 int num_dims() const { return num_spatial_dims_ + 2; }
73
74 protected:
GetKernelSize(XlaOpKernelContext * ctx)75 xla::StatusOr<std::vector<int64>> GetKernelSize(XlaOpKernelContext* ctx) {
76 if (ctx->num_inputs() == 1) {
77 return ksize_;
78 }
79 const TensorShape ksize_shape = ctx->InputShape(1);
80 // Validate input sizes.
81 if (!TensorShapeUtils::IsVector(ksize_shape)) {
82 return errors::InvalidArgument("ksize must be a vector, not shape ",
83 ksize_shape.DebugString());
84 }
85 if (ksize_shape.num_elements() != num_dims()) {
86 return errors::InvalidArgument(
87 "Sliding window ksize field must "
88 "specify ",
89 num_dims(), " dimensions");
90 }
91 std::vector<int64> ksize;
92 auto status = ctx->ConstantInputAsIntVector(1, &ksize);
93 if (!status.ok()) {
94 return status;
95 }
96 return ksize;
97 }
98
GetStride(XlaOpKernelContext * ctx)99 xla::StatusOr<std::vector<int64>> GetStride(XlaOpKernelContext* ctx) {
100 if (ctx->num_inputs() == 1) {
101 return stride_;
102 }
103 const TensorShape stride_shape = ctx->InputShape(2);
104 // Validate input sizes.
105 if (!TensorShapeUtils::IsVector(stride_shape)) {
106 return errors::InvalidArgument("stride must be a vector, not shape ",
107 stride_shape.DebugString());
108 }
109 if (stride_shape.num_elements() != num_dims()) {
110 return errors::InvalidArgument(
111 "Sliding window stride field must "
112 "specify ",
113 num_dims(), " dimensions");
114 }
115 std::vector<int64> stride;
116 auto status = ctx->ConstantInputAsIntVector(2, &stride);
117 if (!status.ok()) {
118 return status;
119 }
120 return stride;
121 }
122
123 protected:
124 const int num_spatial_dims_;
125 std::vector<int64> ksize_;
126 std::vector<int64> stride_;
127 xla::Padding padding_;
128 TensorFormat data_format_ = FORMAT_NHWC;
129 DataType reduction_type_;
130 xla::PrimitiveType xla_reduction_type_;
131 };
132
133 // Converts the tensor data format to the one required by the XLA pooling
134 // library.
XlaTensorFormat(tensorflow::TensorFormat data_format,int num_spatial_dims)135 xla::TensorFormat XlaTensorFormat(tensorflow::TensorFormat data_format,
136 int num_spatial_dims) {
137 int num_dims = num_spatial_dims + 2;
138 int batch_dimension = GetTensorBatchDimIndex(num_dims, data_format);
139 int feature_dimension = GetTensorFeatureDimIndex(num_dims, data_format);
140 absl::InlinedVector<int64, 4> spatial_dimensions(num_spatial_dims);
141 for (int spatial_dim = 0; spatial_dim < num_spatial_dims; ++spatial_dim) {
142 spatial_dimensions[spatial_dim] =
143 GetTensorSpatialDimIndex(num_dims, data_format, spatial_dim);
144 }
145 return xla::TensorFormat(/*batch_dimension=*/batch_dimension,
146 /*feature_dimension=*/feature_dimension,
147 /*spatial_dimensions=*/spatial_dimensions);
148 }
149
150 class MaxPoolOp : public PoolingOp {
151 public:
MaxPoolOp(OpKernelConstruction * ctx,int num_spatial_dims)152 MaxPoolOp(OpKernelConstruction* ctx, int num_spatial_dims)
153 : PoolingOp(ctx, /*num_spatial_dims=*/num_spatial_dims,
154 /*reduction_type=*/ctx->input_type(0)) {
155 string data_format_str;
156 OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format_str));
157 OP_REQUIRES(ctx, FormatFromString(data_format_str, &data_format_),
158 errors::InvalidArgument("Invalid data format"));
159 }
160
Compile(XlaOpKernelContext * ctx)161 void Compile(XlaOpKernelContext* ctx) override {
162 auto ksize_or_error = GetKernelSize(ctx);
163 OP_REQUIRES_OK(ctx, ksize_or_error.status());
164 std::vector<int64> ksize = ksize_or_error.ValueOrDie();
165
166 auto stride_or_error = GetStride(ctx);
167 OP_REQUIRES_OK(ctx, stride_or_error.status());
168 std::vector<int64> stride = stride_or_error.ValueOrDie();
169
170 const TensorShape input_shape = ctx->InputShape(0);
171 OP_REQUIRES(ctx, input_shape.dims() == num_dims(),
172 errors::InvalidArgument("Input to ", type_string(),
173 " operator must have ", num_dims(),
174 " dimensions"));
175
176 auto pooling =
177 xla::MaxPool(ctx->Input(0), ksize, stride, padding_,
178 XlaTensorFormat(data_format_, input_shape.dims() - 2));
179 ctx->SetOutput(0, pooling);
180 }
181 };
182
183 class MaxPool2DOp : public MaxPoolOp {
184 public:
MaxPool2DOp(OpKernelConstruction * ctx)185 explicit MaxPool2DOp(OpKernelConstruction* ctx)
186 : MaxPoolOp(ctx, /*num_spatial_dims=*/2) {
187 }
188 };
189 REGISTER_XLA_OP(Name("MaxPool"), MaxPool2DOp);
190 REGISTER_XLA_OP(Name("MaxPoolV2")
191 .CompileTimeConstantInput("ksize")
192 .CompileTimeConstantInput("strides"),
193 MaxPool2DOp);
194
195 class MaxPool3DOp : public MaxPoolOp {
196 public:
MaxPool3DOp(OpKernelConstruction * ctx)197 explicit MaxPool3DOp(OpKernelConstruction* ctx)
198 : MaxPoolOp(ctx, /*num_spatial_dims=*/3) {}
199 };
200 REGISTER_XLA_OP(Name("MaxPool3D"), MaxPool3DOp);
201
202 class AvgPoolOp : public PoolingOp {
203 public:
AvgPoolOp(OpKernelConstruction * ctx,int num_spatial_dims)204 AvgPoolOp(OpKernelConstruction* ctx, int num_spatial_dims)
205 : PoolingOp(ctx, /*num_spatial_dims=*/num_spatial_dims,
206 /*reduction_type=*/
207 XlaHelpers::SumAccumulationType(ctx->input_type(0))) {
208 string data_format_str;
209 OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format_str));
210 OP_REQUIRES(ctx, FormatFromString(data_format_str, &data_format_),
211 errors::InvalidArgument("Invalid data format"));
212 }
213
Compile(XlaOpKernelContext * ctx)214 void Compile(XlaOpKernelContext* ctx) override {
215 auto ksize_or_error = GetKernelSize(ctx);
216 OP_REQUIRES_OK(ctx, ksize_or_error.status());
217 std::vector<int64> ksize = ksize_or_error.ValueOrDie();
218
219 auto stride_or_error = GetStride(ctx);
220 OP_REQUIRES_OK(ctx, stride_or_error.status());
221 std::vector<int64> stride = stride_or_error.ValueOrDie();
222
223 const TensorShape input_shape = ctx->InputShape(0);
224 OP_REQUIRES(ctx, input_shape.dims() == num_dims(),
225 errors::InvalidArgument("Input to ", type_string(),
226 " operator must have ", num_dims(),
227 " dimensions"));
228
229 auto xla_data_format =
230 XlaTensorFormat(data_format_, input_shape.dims() - 2);
231 auto spatial_padding = MakeSpatialPadding(
232 input_shape.dim_sizes(), ksize, stride, padding_, xla_data_format);
233
234 // Convert the input to the reduction type.
235 auto converted_input =
236 ConvertElementType(ctx->Input(0), xla_reduction_type_);
237 auto pooling =
238 xla::AvgPool(converted_input, ksize, stride, spatial_padding,
239 xla_data_format, padding_ == xla::Padding::kValid);
240 // Convert the pooling result back to the input type before returning it.
241 ctx->SetOutput(0, ConvertElementType(pooling, ctx->input_xla_type(0)));
242 }
243 };
244
245 class AvgPool2DOp : public AvgPoolOp {
246 public:
AvgPool2DOp(OpKernelConstruction * ctx)247 explicit AvgPool2DOp(OpKernelConstruction* ctx)
248 : AvgPoolOp(ctx, /*num_spatial_dims=*/2) {
249 }
250 };
251 REGISTER_XLA_OP(Name("AvgPool"), AvgPool2DOp);
252
253 class AvgPool3DOp : public AvgPoolOp {
254 public:
AvgPool3DOp(OpKernelConstruction * ctx)255 explicit AvgPool3DOp(OpKernelConstruction* ctx)
256 : AvgPoolOp(ctx, /*num_spatial_dims=*/3) {}
257 };
258 REGISTER_XLA_OP(Name("AvgPool3D"), AvgPool3DOp);
259
260 // The operation to compute MaxPool gradients.
261 // It takes three inputs:
262 // - The original input tensor
263 // - The original output tensor
264 // - Backprop tensor for output
265 // It produces one output: backprop tensor for input.
266 class MaxPoolGradOp : public XlaOpKernel {
267 public:
MaxPoolGradOp(OpKernelConstruction * ctx,int num_spatial_dims)268 MaxPoolGradOp(OpKernelConstruction* ctx, int num_spatial_dims)
269 : XlaOpKernel(ctx), num_spatial_dims_(num_spatial_dims) {
270 if (ctx->num_inputs() == 3) {
271 OP_REQUIRES_OK(ctx, ctx->GetAttr("ksize", &ksize_));
272 OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &stride_));
273 }
274 OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_));
275 }
276
num_dims() const277 int num_dims() const { return num_spatial_dims_ + 2; }
278
Compile(XlaOpKernelContext * ctx)279 void Compile(XlaOpKernelContext* ctx) override {
280 if (ctx->num_inputs() != 3) {
281 OP_REQUIRES(
282 ctx, ctx->num_inputs() == 5,
283 errors::InvalidArgument("Must supply ksize and stride arguments."));
284 const TensorShape ksize_shape = ctx->InputShape(3);
285 // Validate input sizes.
286 OP_REQUIRES(ctx, TensorShapeUtils::IsVector(ksize_shape),
287 errors::InvalidArgument("ksize must be a vector, not shape ",
288 ksize_shape.DebugString()));
289 OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(3, &ksize_));
290
291 const TensorShape stride_shape = ctx->InputShape(4);
292 // Validate input sizes.
293 OP_REQUIRES(ctx, TensorShapeUtils::IsVector(stride_shape),
294 errors::InvalidArgument("stride must be a vector, not shape ",
295 stride_shape.DebugString()));
296 OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(4, &stride_));
297 }
298
299 OP_REQUIRES(ctx, ksize_.size() == num_dims(),
300 errors::InvalidArgument("Sliding window ksize field must "
301 "specify ",
302 num_dims(), " dimensions"));
303 OP_REQUIRES(ctx, stride_.size() == num_dims(),
304 errors::InvalidArgument("Sliding window strides field must "
305 "specify ",
306 num_dims(), " dimensions"));
307
308 const TensorShape tensor_in_shape = ctx->InputShape(0);
309 const TensorShape tensor_out_shape = ctx->InputShape(1);
310 const TensorShape out_backprop_shape = ctx->InputShape(2);
311
312 // For maxpooling, tensor_in should have num_dims() dimensions.
313 OP_REQUIRES(ctx, tensor_in_shape.dims() == num_dims(),
314 errors::InvalidArgument("tensor_in must be ", num_dims(),
315 "-dimensional"));
316 OP_REQUIRES(ctx, tensor_out_shape.dims() == num_dims(),
317 errors::InvalidArgument("tensor_out must be ", num_dims(),
318 "-dimensional"));
319 // For maxpooling, out_backprop should have num_dims() dimensions.
320 OP_REQUIRES(ctx, out_backprop_shape.dims() == num_dims(),
321 errors::InvalidArgument("out_backprop must be ", num_dims(),
322 "-dimensional"));
323
324 // TODO(phawkins): The XLA version doesn't need tensor_out. Investigate
325 // whether this is a good time/space tradeoff.
326 auto input = ctx->Input(0);
327 auto out_backprop = ctx->Input(2);
328
329 xla::Padding xla_padding =
330 (padding_ == VALID) ? xla::Padding::kValid : xla::Padding::kSame;
331
332 xla::PrimitiveType element_type;
333 OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(input_type(2), &element_type));
334 xla::XlaOp init_value = XlaHelpers::Zero(ctx->builder(), input_type(2));
335 auto select = CreateScalarGeComputation(element_type, ctx->builder());
336 auto scatter = CreateScalarAddComputation(element_type, ctx->builder());
337 xla::XlaOp gradients =
338 xla::SelectAndScatter(input, select, ksize_, stride_, xla_padding,
339 out_backprop, init_value, scatter);
340
341 ctx->SetOutput(0, gradients);
342 }
343
344 protected:
345 const int num_spatial_dims_;
346 std::vector<int64> ksize_;
347 std::vector<int64> stride_;
348 Padding padding_;
349 TensorFormat data_format_ = FORMAT_NHWC;
350 };
351
352 class MaxPool2DGradOp : public MaxPoolGradOp {
353 public:
MaxPool2DGradOp(OpKernelConstruction * ctx)354 explicit MaxPool2DGradOp(OpKernelConstruction* ctx)
355 : MaxPoolGradOp(ctx, /*num_spatial_dims=*/2) {
356 string data_format;
357 OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format));
358 OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_),
359 errors::InvalidArgument("Invalid data format"));
360 }
361 };
362 REGISTER_XLA_OP(Name("MaxPoolGrad"), MaxPool2DGradOp);
363 REGISTER_XLA_OP(Name("MaxPoolGradV2")
364 .CompileTimeConstantInput("ksize")
365 .CompileTimeConstantInput("strides"),
366 MaxPool2DGradOp);
367
368 class MaxPool3DGradOp : public MaxPoolGradOp {
369 public:
MaxPool3DGradOp(OpKernelConstruction * ctx)370 explicit MaxPool3DGradOp(OpKernelConstruction* ctx)
371 : MaxPoolGradOp(ctx, /*num_spatial_dims=*/3) {}
372 };
373 REGISTER_XLA_OP(Name("MaxPool3DGrad"), MaxPool3DGradOp);
374
375 // Average-pooling gradient
376 class AvgPoolGradOp : public XlaOpKernel {
377 public:
AvgPoolGradOp(OpKernelConstruction * ctx,int num_spatial_dims)378 AvgPoolGradOp(OpKernelConstruction* ctx, int num_spatial_dims)
379 : XlaOpKernel(ctx), num_spatial_dims_(num_spatial_dims) {
380 OP_REQUIRES_OK(ctx, ctx->GetAttr("ksize", &ksize_));
381 OP_REQUIRES(ctx, ksize_.size() == num_dims(),
382 errors::InvalidArgument("Sliding window ksize field must "
383 "specify ",
384 num_dims(), " dimensions"));
385 OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &stride_));
386 OP_REQUIRES(ctx, stride_.size() == num_dims(),
387 errors::InvalidArgument("Sliding window strides field must "
388 "specify ",
389 num_dims(), " dimensions"));
390 OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_));
391 OP_REQUIRES(ctx, ksize_[0] == 1 && stride_[0] == 1,
392 errors::Unimplemented(
393 "Pooling is not yet supported on the batch dimension."));
394
395 string data_format;
396 OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format));
397 OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_),
398 errors::InvalidArgument("Invalid data format"));
399 }
400
num_dims() const401 int num_dims() const { return num_spatial_dims_ + 2; }
402
Compile(XlaOpKernelContext * ctx)403 void Compile(XlaOpKernelContext* ctx) override {
404 TensorShape gradients_shape;
405 OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &gradients_shape));
406
407 const TensorShape out_backprop_shape = ctx->InputShape(1);
408
409 // For avgpooling, tensor_in_shape should have num_dims() dimensions.
410 OP_REQUIRES(ctx, gradients_shape.dims() == num_dims(),
411 errors::InvalidArgument("orig_input_shape must be ", num_dims(),
412 "-dimensional"));
413
414 // For avgpooling, out_backprop should have num_dims() dimensions.
415 OP_REQUIRES(ctx, out_backprop_shape.dims() == num_dims(),
416 errors::InvalidArgument("out_backprop must be ", num_dims(),
417 "-dimensional"));
418
419 auto out_backprop = ctx->Input(1);
420 std::vector<int64> stride_int64s(stride_.begin(), stride_.end());
421 xla::Padding xla_padding =
422 (padding_ == VALID) ? xla::Padding::kValid : xla::Padding::kSame;
423 xla::PrimitiveType xla_reduction_type;
424 auto reduction_type = XlaHelpers::SumAccumulationType(ctx->input_type(1));
425 OP_REQUIRES_OK(
426 ctx, DataTypeToPrimitiveType(reduction_type, &xla_reduction_type));
427 auto converted_out_backprop =
428 xla::ConvertElementType(out_backprop, xla_reduction_type);
429 auto xla_data_format =
430 XlaTensorFormat(data_format_, gradients_shape.dims() - 2);
431 auto padding_values =
432 MakeSpatialPadding(gradients_shape.dim_sizes(), ksize_, stride_int64s,
433 xla_padding, xla_data_format);
434 auto in_backprop =
435 xla::AvgPoolGrad(converted_out_backprop, gradients_shape.dim_sizes(),
436 ksize_, stride_int64s, padding_values, xla_data_format,
437 /*counts_include_padding=*/padding_ == VALID);
438 // Convert the pooling result back to the input type before returning it.
439 xla::PrimitiveType xla_out_backprop_type;
440 OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(ctx->input_type(1),
441 &xla_out_backprop_type));
442 ctx->SetOutput(0,
443 xla::ConvertElementType(in_backprop, xla_out_backprop_type));
444 }
445
446 protected:
447 const int num_spatial_dims_;
448 std::vector<int64> ksize_;
449 std::vector<int32> stride_;
450 Padding padding_;
451 TensorFormat data_format_ = FORMAT_NHWC;
452 };
453
454 class AvgPool2DGradOp : public AvgPoolGradOp {
455 public:
AvgPool2DGradOp(OpKernelConstruction * ctx)456 explicit AvgPool2DGradOp(OpKernelConstruction* ctx)
457 : AvgPoolGradOp(ctx, /*num_spatial_dims=*/2) {
458 }
459 };
460 REGISTER_XLA_OP(
461 Name("AvgPoolGrad").CompileTimeConstantInput("orig_input_shape"),
462 AvgPool2DGradOp);
463
464 class AvgPool3DGradOp : public AvgPoolGradOp {
465 public:
AvgPool3DGradOp(OpKernelConstruction * ctx)466 explicit AvgPool3DGradOp(OpKernelConstruction* ctx)
467 : AvgPoolGradOp(ctx, /*num_spatial_dims=*/3) {}
468 };
469 REGISTER_XLA_OP(
470 Name("AvgPool3DGrad").CompileTimeConstantInput("orig_input_shape"),
471 AvgPool3DGradOp);
472
473 class MaxPoolGradGradOp : public XlaOpKernel {
474 public:
MaxPoolGradGradOp(OpKernelConstruction * ctx,int num_spatial_dims)475 MaxPoolGradGradOp(OpKernelConstruction* ctx, int num_spatial_dims)
476 : XlaOpKernel(ctx), num_spatial_dims_(num_spatial_dims) {
477 if (ctx->num_inputs() == 3) {
478 OP_REQUIRES_OK(ctx, ctx->GetAttr("ksize", &ksize_));
479 OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &stride_));
480 }
481 OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_));
482 }
483
num_dims() const484 int num_dims() const { return num_spatial_dims_ + 2; }
485
Compile(XlaOpKernelContext * ctx)486 void Compile(XlaOpKernelContext* ctx) override {
487 if (ctx->num_inputs() != 3) {
488 OP_REQUIRES(
489 ctx, ctx->num_inputs() == 5,
490 errors::InvalidArgument("Must supply ksize and stride arguments."));
491 const TensorShape ksize_shape = ctx->InputShape(3);
492 // Validate input sizes.
493 OP_REQUIRES(ctx, TensorShapeUtils::IsVector(ksize_shape),
494 errors::InvalidArgument("ksize must be a vector, not shape ",
495 ksize_shape.DebugString()));
496 OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(3, &ksize_));
497
498 const TensorShape stride_shape = ctx->InputShape(4);
499 // Validate input sizes.
500 OP_REQUIRES(ctx, TensorShapeUtils::IsVector(stride_shape),
501 errors::InvalidArgument("stride must be a vector, not shape ",
502 stride_shape.DebugString()));
503 OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(4, &stride_));
504 }
505
506 OP_REQUIRES(ctx, ksize_.size() == num_dims(),
507 errors::InvalidArgument("Sliding window ksize field must "
508 "specify ",
509 num_dims(), " dimensions"));
510 OP_REQUIRES(ctx, stride_.size() == num_dims(),
511 errors::InvalidArgument("Sliding window strides field must "
512 "specify ",
513 num_dims(), " dimensions"));
514
515 const TensorShape tensor_in_shape = ctx->InputShape(0);
516 const TensorShape tensor_out_shape = ctx->InputShape(1);
517 const TensorShape out_backprop_shape = ctx->InputShape(2);
518
519 // For maxpooling, tensor_in should have num_dims() dimensions.
520 OP_REQUIRES(ctx, tensor_in_shape.dims() == num_dims(),
521 errors::InvalidArgument("tensor_in must be ", num_dims(),
522 "-dimensional"));
523 OP_REQUIRES(ctx, tensor_out_shape.dims() == num_dims(),
524 errors::InvalidArgument("tensor_out must be ", num_dims(),
525 "-dimensional"));
526 // For maxpooling, out_backprop should have num_dims() dimensions.
527 OP_REQUIRES(ctx, out_backprop_shape.dims() == num_dims(),
528 errors::InvalidArgument("out_backprop must be ", num_dims(),
529 "-dimensional"));
530
531 // What we want to compute:
532 // Given y = MaxPool(x), and xs_grad = MaxPoolGrad(x, y, ys_grad)
533 // MaxPoolGradGrad computes {ys_grad}_grad given x, y, and {xs_grad}_grad.
534 //
535 // In the regular TF op, this amounts to selecting for each window the
536 // incoming backprop value from xs_grad_grad that corresponds to the maximal
537 // value in the corresponding window of x.
538 //
539 // TODO(b/73062247): What we really want is a ReduceWindow with different
540 // arrays for index selection vs return value selection--a select-to-gather.
541 //
542 // Here, we implement a bitwise hack: we use the hi 16 bits of input for
543 // separate max pooling alongside each of the hi and lo 16 bits of
544 // out_backprop packed into 16 lo bits, which we then glue back together at
545 // the end to get a full 32 bits of gradient.
546 //
547 // This could select the wrong backprop value for two x values that are
548 // equally maximal up to the first 16 bits, in which case we are taking the
549 // latter.
550 //
551 // Note that in principle we could use 32 separate maxpools to recover each
552 // of 32 bits of the gradient while preserving 31 bits of input for the max
553 // pooling criteria; here, we just truncate to the first 16 bits of input.
554
555 auto input = ctx->Input(0);
556 auto out_backprop = ctx->Input(2);
557
558 auto b = ctx->builder();
559
560 auto sixteen = xla::ConstantR0<uint32>(b, 16);
561 // in (f32) -> round to bf16 -> f32 for correct bitwidth -> 16-high-bit u32
562 auto in_hi = xla::BitcastConvertType(
563 xla::ConvertElementType(xla::ConvertElementType(input, xla::BF16),
564 xla::F32),
565 xla::U32);
566 auto bp_int = xla::BitcastConvertType(out_backprop, xla::U32);
567 auto bp_hi = xla::ShiftRightLogical(bp_int, sixteen);
568 auto bp_lo =
569 xla::ShiftRightLogical(xla::ShiftLeft(bp_int, sixteen), sixteen);
570 auto in_hi_bp_hi = xla::Add(in_hi, bp_hi); // Want an unsigned add.
571 auto in_hi_bp_lo = xla::Add(in_hi, bp_lo); // Want an unsigned add.
572
573 auto init_value = xla::MinValue(b, xla::F32);
574 // We will reduce by taking the maximal value up to 16 bits (ignoring the lo
575 // 16 bits of packed-in hi/lo backprop value).
576 auto rb = b->CreateSubBuilder("GreaterOrEqOf_ByFirst16Bits");
577 {
578 // F32 parameters to satisfy lowering type restriction for reduce opcode.
579 const xla::Shape scalar = xla::ShapeUtil::MakeShape(xla::F32, {});
580 auto lhs = xla::Parameter(rb.get(), 0, scalar, "lhs");
581 auto rhs = xla::Parameter(rb.get(), 1, scalar, "rhs");
582 auto sixteen = xla::ConstantR0<int32>(rb.get(), 16);
583 auto lhs_criteria =
584 xla::ShiftLeft(xla::ShiftRightLogical(
585 xla::BitcastConvertType(lhs, xla::S32), sixteen),
586 sixteen);
587 auto rhs_criteria =
588 xla::ShiftLeft(xla::ShiftRightLogical(
589 xla::BitcastConvertType(rhs, xla::S32), sixteen),
590 sixteen);
591 // Must use a F32 comparison, because S32 would not work for negatives.
592 xla::Select(xla::Ge(xla::BitcastConvertType(lhs_criteria, xla::F32),
593 xla::BitcastConvertType(rhs_criteria, xla::F32)),
594 lhs, rhs);
595 }
596 auto reduce = rb->BuildAndNoteError();
597 xla::Padding xla_padding =
598 (padding_ == VALID) ? xla::Padding::kValid : xla::Padding::kSame;
599 auto pooled_hi =
600 xla::ReduceWindow(xla::BitcastConvertType(in_hi_bp_hi, xla::F32),
601 init_value, reduce, ksize_, stride_, xla_padding);
602 auto pooled_lo =
603 xla::ReduceWindow(xla::BitcastConvertType(in_hi_bp_lo, xla::F32),
604 init_value, reduce, ksize_, stride_, xla_padding);
605 auto grads_hi =
606 xla::ShiftLeft(xla::BitcastConvertType(pooled_hi, xla::U32), sixteen);
607 auto grads_lo = xla::ShiftRightLogical(
608 xla::ShiftLeft(xla::BitcastConvertType(pooled_lo, xla::U32), sixteen),
609 sixteen);
610 auto grads = xla::Add(grads_hi, grads_lo); // Want an unsigned add.
611
612 xla::PrimitiveType element_type;
613 OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(input_type(2), &element_type));
614 ctx->SetOutput(0, xla::BitcastConvertType(grads, element_type));
615 }
616
617 protected:
618 const int num_spatial_dims_;
619 std::vector<int64> ksize_;
620 std::vector<int64> stride_;
621 Padding padding_;
622 TensorFormat data_format_ = FORMAT_NHWC;
623 };
624
625 class MaxPool2DGradGradOp : public MaxPoolGradGradOp {
626 public:
MaxPool2DGradGradOp(OpKernelConstruction * ctx)627 explicit MaxPool2DGradGradOp(OpKernelConstruction* ctx)
628 : MaxPoolGradGradOp(ctx, /*num_spatial_dims=*/2) {
629 string data_format;
630 OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format));
631 OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_),
632 errors::InvalidArgument("Invalid data format"));
633 }
634 };
635 REGISTER_XLA_OP(Name("MaxPoolGradGrad").TypeConstraint("T", DT_FLOAT),
636 MaxPool2DGradGradOp);
637 REGISTER_XLA_OP(Name("MaxPoolGradGradV2")
638 .TypeConstraint("T", DT_FLOAT)
639 .CompileTimeConstantInput("ksize")
640 .CompileTimeConstantInput("strides"),
641 MaxPool2DGradGradOp);
642
643 class MaxPool3DGradGradOp : public MaxPoolGradGradOp {
644 public:
MaxPool3DGradGradOp(OpKernelConstruction * ctx)645 explicit MaxPool3DGradGradOp(OpKernelConstruction* ctx)
646 : MaxPoolGradGradOp(ctx, /*num_spatial_dims=*/3) {
647 string data_format;
648 OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format));
649 OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_),
650 errors::InvalidArgument("Invalid data format"));
651 }
652 };
653 REGISTER_XLA_OP(Name("MaxPool3DGradGrad").TypeConstraint("T", DT_FLOAT),
654 MaxPool3DGradGradOp);
655
656 } // anonymous namespace
657 } // namespace tensorflow
658