1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "absl/strings/match.h" 17 #include "absl/strings/str_cat.h" 18 #include "absl/strings/string_view.h" 19 #include "tensorflow/compiler/tf2xla/shape_util.h" 20 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" 21 #include "tensorflow/compiler/tf2xla/xla_op_registry.h" 22 #include "tensorflow/compiler/xla/client/lib/constants.h" 23 #include "tensorflow/compiler/xla/client/xla_builder.h" 24 #include "tensorflow/core/framework/kernel_def_builder.h" 25 #include "tensorflow/core/framework/op_kernel.h" 26 #include "tensorflow/core/framework/register_types.h" 27 #include "tensorflow/core/tpu/tpu_defs.h" 28 29 namespace tensorflow { 30 31 class TpuCustomResizeOp : public XlaOpKernel { 32 public: TpuCustomResizeOp(OpKernelConstruction * ctx)33 explicit TpuCustomResizeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { 34 OP_REQUIRES_OK(ctx, ctx->GetAttr("align_corners", &align_corners_)); 35 OP_REQUIRES_OK(ctx, 36 ctx->GetAttr("half_pixel_centers", &half_pixel_centers_)); 37 } 38 GetOutputShape(XlaOpKernelContext * ctx) const39 xla::Shape GetOutputShape(XlaOpKernelContext* ctx) const { 40 std::vector<int64_t> out_size; 41 auto status = ctx->ConstantInputAsIntVector(1, &out_size); 42 CHECK_EQ(out_size.size(), 2) << status.ToString(); 43 xla::Shape output_shape = 44 TensorShapeToXLAShape(ctx->output_xla_type(0), ctx->InputShape(0)); 45 output_shape.mutable_dimensions()[1] = out_size[0]; 46 output_shape.mutable_dimensions()[2] = out_size[1]; 47 return output_shape; 48 } 49 OpaqueField() const50 string OpaqueField() const { 51 return absl::StrCat("\"", align_corners_, half_pixel_centers_, "\""); 52 } 53 CompileGrad(XlaOpKernelContext * ctx,const char * target,const xla::Shape & output_shape)54 void CompileGrad(XlaOpKernelContext* ctx, const char* target, 55 const xla::Shape& output_shape) { 56 auto input_shape = 57 TensorShapeToXLAShape(ctx->output_xla_type(0), ctx->InputShape(0)); 58 if (ctx->InputShape(1).dim_sizes() == ctx->InputShape(0).dim_sizes()) { 59 ctx->SetOutput( 60 0, xla::ConvertElementType(ctx->Input(0), ctx->output_xla_type(0))); 61 return; 62 } 63 // The gradient should be done in two phases for large resizes. 64 auto input = ctx->Input(0); 65 if (input_shape.dimensions(1) / output_shape.dimensions(1) > 3 && 66 input_shape.dimensions(2) / output_shape.dimensions(2) > 3) { 67 auto intermediate_shape = output_shape; 68 intermediate_shape.mutable_dimensions()[1] = input_shape.dimensions(1); 69 input = xla::CustomCall(ctx->builder(), target, {ctx->Input(0)}, 70 intermediate_shape, OpaqueField()); 71 } 72 ctx->SetOutput(0, xla::CustomCall(ctx->builder(), target, {input}, 73 output_shape, OpaqueField())); 74 } 75 CompileForward(XlaOpKernelContext * ctx,const char * target)76 void CompileForward(XlaOpKernelContext* ctx, const char* target) { 77 auto output_shape = GetOutputShape(ctx); 78 if (ctx->InputShape(0).dim_size(1) == output_shape.dimensions(1) && 79 ctx->InputShape(0).dim_size(2) == output_shape.dimensions(2)) { 80 ctx->SetOutput( 81 0, xla::ConvertElementType(ctx->Input(0), ctx->output_xla_type(0))); 82 return; 83 } 84 if (ctx->InputShape(0).dim_size(1) == 1 && 85 ctx->InputShape(0).dim_size(2) == 1) { 86 ctx->SetOutput(0, 87 ctx->Input(0) + xla::Zeros(ctx->builder(), output_shape)); 88 return; 89 } 90 ctx->SetOutput(0, xla::CustomCall(ctx->builder(), target, {ctx->Input(0)}, 91 output_shape, OpaqueField())); 92 } 93 94 private: 95 bool align_corners_; 96 bool half_pixel_centers_; 97 }; 98 99 class TpuResizeNearestNeighborOp : public TpuCustomResizeOp { 100 public: TpuResizeNearestNeighborOp(OpKernelConstruction * ctx)101 explicit TpuResizeNearestNeighborOp(OpKernelConstruction* ctx) 102 : TpuCustomResizeOp(ctx) {} Compile(XlaOpKernelContext * ctx)103 void Compile(XlaOpKernelContext* ctx) override { 104 CompileForward(ctx, "ResizeNearest"); 105 } 106 }; 107 108 class TpuResizeBilinearOp : public TpuCustomResizeOp { 109 public: TpuResizeBilinearOp(OpKernelConstruction * ctx)110 explicit TpuResizeBilinearOp(OpKernelConstruction* ctx) 111 : TpuCustomResizeOp(ctx) {} Compile(XlaOpKernelContext * ctx)112 void Compile(XlaOpKernelContext* ctx) override { 113 CompileForward(ctx, "ResizeBilinear"); 114 } 115 }; 116 117 class TpuResizeNearestNeighborGradOp : public TpuCustomResizeOp { 118 public: TpuResizeNearestNeighborGradOp(OpKernelConstruction * ctx)119 explicit TpuResizeNearestNeighborGradOp(OpKernelConstruction* ctx) 120 : TpuCustomResizeOp(ctx) {} Compile(XlaOpKernelContext * ctx)121 void Compile(XlaOpKernelContext* ctx) override { 122 CompileGrad(ctx, "ResizeNearestGrad", GetOutputShape(ctx)); 123 } 124 }; 125 126 class TpuResizeBilinearGradOp : public TpuCustomResizeOp { 127 public: TpuResizeBilinearGradOp(OpKernelConstruction * ctx)128 explicit TpuResizeBilinearGradOp(OpKernelConstruction* ctx) 129 : TpuCustomResizeOp(ctx) {} Compile(XlaOpKernelContext * ctx)130 void Compile(XlaOpKernelContext* ctx) override { 131 auto output_shape = 132 TensorShapeToXLAShape(ctx->output_xla_type(0), ctx->InputShape(1)); 133 CompileGrad(ctx, "ResizeBilinearGrad", output_shape); 134 } 135 }; 136 137 REGISTER_XLA_OP(Name("ResizeNearestNeighbor") 138 .CompileTimeConstantInput("size") 139 .Device(DEVICE_TPU_XLA_JIT), 140 TpuResizeNearestNeighborOp); 141 142 REGISTER_XLA_OP(Name("ResizeNearestNeighborGrad") 143 .CompileTimeConstantInput("size") 144 .Device(DEVICE_TPU_XLA_JIT), 145 TpuResizeNearestNeighborGradOp); 146 147 REGISTER_XLA_OP(Name("ResizeBilinear") 148 .CompileTimeConstantInput("size") 149 .Device(DEVICE_TPU_XLA_JIT), 150 TpuResizeBilinearOp); 151 152 REGISTER_XLA_OP(Name("ResizeBilinearGrad").Device(DEVICE_TPU_XLA_JIT), 153 TpuResizeBilinearGradOp); 154 155 } // namespace tensorflow 156