• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "absl/strings/match.h"
17 #include "absl/strings/str_cat.h"
18 #include "absl/strings/string_view.h"
19 #include "tensorflow/compiler/tf2xla/shape_util.h"
20 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
21 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
22 #include "tensorflow/compiler/xla/client/lib/constants.h"
23 #include "tensorflow/compiler/xla/client/xla_builder.h"
24 #include "tensorflow/core/framework/kernel_def_builder.h"
25 #include "tensorflow/core/framework/op_kernel.h"
26 #include "tensorflow/core/framework/register_types.h"
27 #include "tensorflow/core/tpu/tpu_defs.h"
28 
29 namespace tensorflow {
30 
31 class TpuCustomResizeOp : public XlaOpKernel {
32  public:
TpuCustomResizeOp(OpKernelConstruction * ctx)33   explicit TpuCustomResizeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
34     OP_REQUIRES_OK(ctx, ctx->GetAttr("align_corners", &align_corners_));
35     OP_REQUIRES_OK(ctx,
36                    ctx->GetAttr("half_pixel_centers", &half_pixel_centers_));
37   }
38 
GetOutputShape(XlaOpKernelContext * ctx) const39   xla::Shape GetOutputShape(XlaOpKernelContext* ctx) const {
40     std::vector<int64_t> out_size;
41     auto status = ctx->ConstantInputAsIntVector(1, &out_size);
42     CHECK_EQ(out_size.size(), 2) << status.ToString();
43     xla::Shape output_shape =
44         TensorShapeToXLAShape(ctx->output_xla_type(0), ctx->InputShape(0));
45     output_shape.mutable_dimensions()[1] = out_size[0];
46     output_shape.mutable_dimensions()[2] = out_size[1];
47     return output_shape;
48   }
49 
OpaqueField() const50   string OpaqueField() const {
51     return absl::StrCat("\"", align_corners_, half_pixel_centers_, "\"");
52   }
53 
CompileGrad(XlaOpKernelContext * ctx,const char * target,const xla::Shape & output_shape)54   void CompileGrad(XlaOpKernelContext* ctx, const char* target,
55                    const xla::Shape& output_shape) {
56     auto input_shape =
57         TensorShapeToXLAShape(ctx->output_xla_type(0), ctx->InputShape(0));
58     if (ctx->InputShape(1).dim_sizes() == ctx->InputShape(0).dim_sizes()) {
59       ctx->SetOutput(
60           0, xla::ConvertElementType(ctx->Input(0), ctx->output_xla_type(0)));
61       return;
62     }
63     // The gradient should be done in two phases for large resizes.
64     auto input = ctx->Input(0);
65     if (input_shape.dimensions(1) / output_shape.dimensions(1) > 3 &&
66         input_shape.dimensions(2) / output_shape.dimensions(2) > 3) {
67       auto intermediate_shape = output_shape;
68       intermediate_shape.mutable_dimensions()[1] = input_shape.dimensions(1);
69       input = xla::CustomCall(ctx->builder(), target, {ctx->Input(0)},
70                               intermediate_shape, OpaqueField());
71     }
72     ctx->SetOutput(0, xla::CustomCall(ctx->builder(), target, {input},
73                                       output_shape, OpaqueField()));
74   }
75 
CompileForward(XlaOpKernelContext * ctx,const char * target)76   void CompileForward(XlaOpKernelContext* ctx, const char* target) {
77     auto output_shape = GetOutputShape(ctx);
78     if (ctx->InputShape(0).dim_size(1) == output_shape.dimensions(1) &&
79         ctx->InputShape(0).dim_size(2) == output_shape.dimensions(2)) {
80       ctx->SetOutput(
81           0, xla::ConvertElementType(ctx->Input(0), ctx->output_xla_type(0)));
82       return;
83     }
84     if (ctx->InputShape(0).dim_size(1) == 1 &&
85         ctx->InputShape(0).dim_size(2) == 1) {
86       ctx->SetOutput(0,
87                      ctx->Input(0) + xla::Zeros(ctx->builder(), output_shape));
88       return;
89     }
90     ctx->SetOutput(0, xla::CustomCall(ctx->builder(), target, {ctx->Input(0)},
91                                       output_shape, OpaqueField()));
92   }
93 
94  private:
95   bool align_corners_;
96   bool half_pixel_centers_;
97 };
98 
99 class TpuResizeNearestNeighborOp : public TpuCustomResizeOp {
100  public:
TpuResizeNearestNeighborOp(OpKernelConstruction * ctx)101   explicit TpuResizeNearestNeighborOp(OpKernelConstruction* ctx)
102       : TpuCustomResizeOp(ctx) {}
Compile(XlaOpKernelContext * ctx)103   void Compile(XlaOpKernelContext* ctx) override {
104     CompileForward(ctx, "ResizeNearest");
105   }
106 };
107 
108 class TpuResizeBilinearOp : public TpuCustomResizeOp {
109  public:
TpuResizeBilinearOp(OpKernelConstruction * ctx)110   explicit TpuResizeBilinearOp(OpKernelConstruction* ctx)
111       : TpuCustomResizeOp(ctx) {}
Compile(XlaOpKernelContext * ctx)112   void Compile(XlaOpKernelContext* ctx) override {
113     CompileForward(ctx, "ResizeBilinear");
114   }
115 };
116 
117 class TpuResizeNearestNeighborGradOp : public TpuCustomResizeOp {
118  public:
TpuResizeNearestNeighborGradOp(OpKernelConstruction * ctx)119   explicit TpuResizeNearestNeighborGradOp(OpKernelConstruction* ctx)
120       : TpuCustomResizeOp(ctx) {}
Compile(XlaOpKernelContext * ctx)121   void Compile(XlaOpKernelContext* ctx) override {
122     CompileGrad(ctx, "ResizeNearestGrad", GetOutputShape(ctx));
123   }
124 };
125 
126 class TpuResizeBilinearGradOp : public TpuCustomResizeOp {
127  public:
TpuResizeBilinearGradOp(OpKernelConstruction * ctx)128   explicit TpuResizeBilinearGradOp(OpKernelConstruction* ctx)
129       : TpuCustomResizeOp(ctx) {}
Compile(XlaOpKernelContext * ctx)130   void Compile(XlaOpKernelContext* ctx) override {
131     auto output_shape =
132         TensorShapeToXLAShape(ctx->output_xla_type(0), ctx->InputShape(1));
133     CompileGrad(ctx, "ResizeBilinearGrad", output_shape);
134   }
135 };
136 
137 REGISTER_XLA_OP(Name("ResizeNearestNeighbor")
138                     .CompileTimeConstantInput("size")
139                     .Device(DEVICE_TPU_XLA_JIT),
140                 TpuResizeNearestNeighborOp);
141 
142 REGISTER_XLA_OP(Name("ResizeNearestNeighborGrad")
143                     .CompileTimeConstantInput("size")
144                     .Device(DEVICE_TPU_XLA_JIT),
145                 TpuResizeNearestNeighborGradOp);
146 
147 REGISTER_XLA_OP(Name("ResizeBilinear")
148                     .CompileTimeConstantInput("size")
149                     .Device(DEVICE_TPU_XLA_JIT),
150                 TpuResizeBilinearOp);
151 
152 REGISTER_XLA_OP(Name("ResizeBilinearGrad").Device(DEVICE_TPU_XLA_JIT),
153                 TpuResizeBilinearGradOp);
154 
155 }  // namespace tensorflow
156