1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // XLA-specific Ops for 2D convolution. 17 18 #include "tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h" 19 #include "tensorflow/compiler/tf2xla/shape_util.h" 20 #include "tensorflow/compiler/tf2xla/type_util.h" 21 #include "tensorflow/compiler/tf2xla/xla_helpers.h" 22 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" 23 #include "tensorflow/compiler/tf2xla/xla_op_registry.h" 24 #include "tensorflow/compiler/xla/client/lib/constants.h" 25 #include "tensorflow/compiler/xla/client/lib/matrix.h" 26 #include "tensorflow/compiler/xla/client/xla_builder.h" 27 #include "tensorflow/compiler/xla/literal_util.h" 28 #include "tensorflow/core/framework/bounds_check.h" 29 #include "tensorflow/core/framework/node_def_util.h" 30 #include "tensorflow/core/framework/numeric_op.h" 31 #include "tensorflow/core/framework/op_kernel.h" 32 #include "tensorflow/core/framework/ops_util.h" 33 #include "tensorflow/core/framework/tensor.h" 34 #include "tensorflow/core/framework/tensor_shape.h" 35 #include "tensorflow/core/framework/tensor_slice.h" 36 #include "tensorflow/core/framework/types.pb.h" 37 #include "tensorflow/core/util/padding.h" 38 #include "tensorflow/core/util/tensor_format.h" 39 40 namespace tensorflow { 41 namespace { 42 43 class ConvOp : public XlaOpKernel { 44 public: ConvOp(OpKernelConstruction * ctx,int num_spatial_dims,bool depthwise)45 explicit ConvOp(OpKernelConstruction* ctx, int num_spatial_dims, 46 bool depthwise) 47 : XlaOpKernel(ctx) { 48 StatusOr<ConvOpAttrs> attrs = 49 ConvOpAttrs::Create(num_spatial_dims, depthwise, ctx); 50 OP_REQUIRES_OK(ctx, attrs.status()); 51 attrs_ = attrs.ValueOrDie(); 52 } 53 Compile(XlaOpKernelContext * ctx)54 void Compile(XlaOpKernelContext* ctx) override { 55 StatusOr<xla::XlaOp> conv = MakeXlaForwardConvOp( 56 ctx->op_kernel().type_string(), ctx->Input(0), ctx->Input(1), attrs_); 57 OP_REQUIRES_OK(ctx, conv.status()); 58 ctx->SetOutput(0, conv.ValueOrDie()); 59 } 60 61 protected: 62 ConvOpAttrs attrs_; 63 64 private: 65 TF_DISALLOW_COPY_AND_ASSIGN(ConvOp); 66 }; 67 68 class Conv2DOp : public ConvOp { 69 public: Conv2DOp(OpKernelConstruction * ctx)70 explicit Conv2DOp(OpKernelConstruction* ctx) 71 : ConvOp(ctx, /*num_spatial_dims=*/2, /*depthwise=*/false) {} 72 }; 73 REGISTER_XLA_OP(Name("Conv2D").TypeConstraint("T", GetXlaConvTypes()), 74 Conv2DOp); 75 76 class Conv3DOp : public ConvOp { 77 public: Conv3DOp(OpKernelConstruction * ctx)78 explicit Conv3DOp(OpKernelConstruction* ctx) 79 : ConvOp(ctx, /*num_spatial_dims=*/3, /*depthwise=*/false) {} 80 }; 81 REGISTER_XLA_OP(Name("Conv3D").TypeConstraint("T", GetXlaConvTypes()), 82 Conv3DOp); 83 84 class DepthwiseConv2DOp : public ConvOp { 85 public: DepthwiseConv2DOp(OpKernelConstruction * ctx)86 explicit DepthwiseConv2DOp(OpKernelConstruction* ctx) 87 : ConvOp(ctx, /*num_spatial_dims=*/2, /*depthwise=*/true) {} 88 }; 89 REGISTER_XLA_OP( 90 Name("DepthwiseConv2dNative").TypeConstraint("T", GetXlaConvTypes()), 91 DepthwiseConv2DOp); 92 93 // Backprop for input. 94 class ConvBackpropInputOp : public XlaOpKernel { 95 public: ConvBackpropInputOp(OpKernelConstruction * ctx,int num_spatial_dims,bool depthwise)96 explicit ConvBackpropInputOp(OpKernelConstruction* ctx, int num_spatial_dims, 97 bool depthwise) 98 : XlaOpKernel(ctx) { 99 StatusOr<ConvOpAttrs> attrs = 100 ConvOpAttrs::Create(num_spatial_dims, depthwise, ctx); 101 OP_REQUIRES_OK(ctx, attrs.status()); 102 attrs_ = attrs.ValueOrDie(); 103 } 104 Compile(XlaOpKernelContext * ctx)105 void Compile(XlaOpKernelContext* ctx) override { 106 TensorShape input_tensor_shape; 107 OP_REQUIRES_OK( 108 ctx, ctx->ConstantInputAsShape(0, &input_tensor_shape, 109 xla::ValueInferenceMode::kUpperBound)); 110 xla::Shape input_shape = 111 TensorShapeToXLAShape(ctx->input_xla_type(1), input_tensor_shape); 112 OP_REQUIRES(ctx, input_shape.rank() == attrs_.num_spatial_dims + 2, 113 errors::InvalidArgument( 114 "The rank of the specified input shape must be " 115 "num_spatial_dims + 2. Expected ", 116 attrs_.num_spatial_dims + 2, " got ", input_shape.rank())); 117 xla::XlaOp input_sizes = ctx->Input(0); 118 StatusOr<xla::XlaOp> in_backprop = MakeXlaBackpropInputConvOp( 119 ctx->op_kernel().type_string(), input_shape, ctx->Input(1), 120 ctx->Input(2), attrs_, nullptr, &input_sizes); 121 OP_REQUIRES_OK(ctx, in_backprop.status()); 122 ctx->SetOutput(0, in_backprop.ValueOrDie()); 123 } 124 125 protected: 126 ConvOpAttrs attrs_; 127 128 private: 129 TF_DISALLOW_COPY_AND_ASSIGN(ConvBackpropInputOp); 130 }; 131 132 class Conv2DBackpropInputOp : public ConvBackpropInputOp { 133 public: Conv2DBackpropInputOp(OpKernelConstruction * ctx)134 explicit Conv2DBackpropInputOp(OpKernelConstruction* ctx) 135 : ConvBackpropInputOp(ctx, /*num_spatial_dims=*/2, /*depthwise=*/false) {} 136 }; 137 REGISTER_XLA_OP(Name("Conv2DBackpropInput") 138 .CompileTimeConstantInput("input_sizes") 139 .TypeConstraint("T", GetXlaConvTypes()), 140 Conv2DBackpropInputOp); 141 142 class Conv3DBackpropInputOp : public ConvBackpropInputOp { 143 public: Conv3DBackpropInputOp(OpKernelConstruction * ctx)144 explicit Conv3DBackpropInputOp(OpKernelConstruction* ctx) 145 : ConvBackpropInputOp(ctx, /*num_spatial_dims=*/3, /*depthwise=*/false) {} 146 }; 147 REGISTER_XLA_OP(Name("Conv3DBackpropInputV2") 148 .CompileTimeConstantInput("input_sizes") 149 .TypeConstraint("T", GetXlaConvTypes()), 150 Conv3DBackpropInputOp); 151 152 class DepthwiseConv2DBackpropInputOp : public ConvBackpropInputOp { 153 public: DepthwiseConv2DBackpropInputOp(OpKernelConstruction * ctx)154 explicit DepthwiseConv2DBackpropInputOp(OpKernelConstruction* ctx) 155 : ConvBackpropInputOp(ctx, /*num_spatial_dims=*/2, /*depthwise=*/true) {} 156 }; 157 REGISTER_XLA_OP(Name("DepthwiseConv2dNativeBackpropInput") 158 .CompileTimeConstantInput("input_sizes") 159 .TypeConstraint("T", GetXlaConvTypes()), 160 DepthwiseConv2DBackpropInputOp); 161 162 class ConvBackpropFilterOp : public XlaOpKernel { 163 public: ConvBackpropFilterOp(OpKernelConstruction * ctx,int num_spatial_dims,bool depthwise)164 explicit ConvBackpropFilterOp(OpKernelConstruction* ctx, int num_spatial_dims, 165 bool depthwise) 166 : XlaOpKernel(ctx) { 167 StatusOr<ConvOpAttrs> attrs = 168 ConvOpAttrs::Create(num_spatial_dims, depthwise, ctx); 169 OP_REQUIRES_OK(ctx, attrs.status()); 170 attrs_ = attrs.ValueOrDie(); 171 } 172 Compile(XlaOpKernelContext * ctx)173 void Compile(XlaOpKernelContext* ctx) override { 174 TensorShape filter_tensor_shape; 175 OP_REQUIRES_OK( 176 ctx, ctx->ConstantInputAsShape(1, &filter_tensor_shape, 177 xla::ValueInferenceMode::kUpperBound)); 178 xla::Shape filter_shape = 179 TensorShapeToXLAShape(ctx->input_xla_type(0), filter_tensor_shape); 180 181 StatusOr<xla::XlaOp> filter_backprop = MakeXlaBackpropFilterConvOp( 182 ctx->op_kernel().type_string(), ctx->Input(0), filter_shape, 183 ctx->Input(2), attrs_); 184 OP_REQUIRES_OK(ctx, filter_backprop.status()); 185 ctx->SetOutput(0, filter_backprop.ValueOrDie()); 186 } 187 188 protected: 189 ConvOpAttrs attrs_; 190 191 private: 192 TF_DISALLOW_COPY_AND_ASSIGN(ConvBackpropFilterOp); 193 }; 194 195 class Conv2DBackpropFilterOp : public ConvBackpropFilterOp { 196 public: Conv2DBackpropFilterOp(OpKernelConstruction * ctx)197 explicit Conv2DBackpropFilterOp(OpKernelConstruction* ctx) 198 : ConvBackpropFilterOp(ctx, /*num_spatial_dims=*/2, /*depthwise=*/false) { 199 } 200 }; 201 REGISTER_XLA_OP(Name("Conv2DBackpropFilter") 202 .CompileTimeConstantInput("filter_sizes") 203 .TypeConstraint("T", GetXlaConvTypes()), 204 Conv2DBackpropFilterOp); 205 206 class Conv3DBackpropFilterOp : public ConvBackpropFilterOp { 207 public: Conv3DBackpropFilterOp(OpKernelConstruction * ctx)208 explicit Conv3DBackpropFilterOp(OpKernelConstruction* ctx) 209 : ConvBackpropFilterOp(ctx, /*num_spatial_dims=*/3, /*depthwise=*/false) { 210 } 211 }; 212 REGISTER_XLA_OP(Name("Conv3DBackpropFilterV2") 213 .CompileTimeConstantInput("filter_sizes") 214 .TypeConstraint("T", GetXlaConvTypes()), 215 Conv3DBackpropFilterOp); 216 217 class DepthwiseConv2DBackpropFilterOp : public ConvBackpropFilterOp { 218 public: DepthwiseConv2DBackpropFilterOp(OpKernelConstruction * ctx)219 explicit DepthwiseConv2DBackpropFilterOp(OpKernelConstruction* ctx) 220 : ConvBackpropFilterOp(ctx, /*num_spatial_dims=*/2, /*depthwise=*/true) {} 221 }; 222 REGISTER_XLA_OP(Name("DepthwiseConv2dNativeBackpropFilter") 223 .CompileTimeConstantInput("filter_sizes") 224 .TypeConstraint("T", GetXlaConvTypes()), 225 DepthwiseConv2DBackpropFilterOp); 226 227 } // namespace 228 } // namespace tensorflow 229