• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // XLA-specific Ops for 2D convolution.
17 
18 #include "tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h"
19 #include "tensorflow/compiler/tf2xla/shape_util.h"
20 #include "tensorflow/compiler/tf2xla/type_util.h"
21 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
22 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
23 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
24 #include "tensorflow/compiler/xla/client/lib/constants.h"
25 #include "tensorflow/compiler/xla/client/lib/matrix.h"
26 #include "tensorflow/compiler/xla/client/xla_builder.h"
27 #include "tensorflow/compiler/xla/literal_util.h"
28 #include "tensorflow/core/framework/bounds_check.h"
29 #include "tensorflow/core/framework/node_def_util.h"
30 #include "tensorflow/core/framework/numeric_op.h"
31 #include "tensorflow/core/framework/op_kernel.h"
32 #include "tensorflow/core/framework/ops_util.h"
33 #include "tensorflow/core/framework/tensor.h"
34 #include "tensorflow/core/framework/tensor_shape.h"
35 #include "tensorflow/core/framework/tensor_slice.h"
36 #include "tensorflow/core/framework/types.pb.h"
37 #include "tensorflow/core/util/padding.h"
38 #include "tensorflow/core/util/tensor_format.h"
39 
40 namespace tensorflow {
41 namespace {
42 
43 class ConvOp : public XlaOpKernel {
44  public:
ConvOp(OpKernelConstruction * ctx,int num_spatial_dims,bool depthwise)45   explicit ConvOp(OpKernelConstruction* ctx, int num_spatial_dims,
46                   bool depthwise)
47       : XlaOpKernel(ctx) {
48     StatusOr<ConvOpAttrs> attrs =
49         ConvOpAttrs::Create(num_spatial_dims, depthwise, ctx);
50     OP_REQUIRES_OK(ctx, attrs.status());
51     attrs_ = attrs.ValueOrDie();
52   }
53 
Compile(XlaOpKernelContext * ctx)54   void Compile(XlaOpKernelContext* ctx) override {
55     StatusOr<xla::XlaOp> conv = MakeXlaForwardConvOp(
56         ctx->op_kernel().type_string(), ctx->Input(0), ctx->Input(1), attrs_);
57     OP_REQUIRES_OK(ctx, conv.status());
58     ctx->SetOutput(0, conv.ValueOrDie());
59   }
60 
61  protected:
62   ConvOpAttrs attrs_;
63 
64  private:
65   TF_DISALLOW_COPY_AND_ASSIGN(ConvOp);
66 };
67 
68 class Conv2DOp : public ConvOp {
69  public:
Conv2DOp(OpKernelConstruction * ctx)70   explicit Conv2DOp(OpKernelConstruction* ctx)
71       : ConvOp(ctx, /*num_spatial_dims=*/2, /*depthwise=*/false) {}
72 };
73 REGISTER_XLA_OP(Name("Conv2D").TypeConstraint("T", GetXlaConvTypes()),
74                 Conv2DOp);
75 
76 class Conv3DOp : public ConvOp {
77  public:
Conv3DOp(OpKernelConstruction * ctx)78   explicit Conv3DOp(OpKernelConstruction* ctx)
79       : ConvOp(ctx, /*num_spatial_dims=*/3, /*depthwise=*/false) {}
80 };
81 REGISTER_XLA_OP(Name("Conv3D").TypeConstraint("T", GetXlaConvTypes()),
82                 Conv3DOp);
83 
84 class DepthwiseConv2DOp : public ConvOp {
85  public:
DepthwiseConv2DOp(OpKernelConstruction * ctx)86   explicit DepthwiseConv2DOp(OpKernelConstruction* ctx)
87       : ConvOp(ctx, /*num_spatial_dims=*/2, /*depthwise=*/true) {}
88 };
89 REGISTER_XLA_OP(
90     Name("DepthwiseConv2dNative").TypeConstraint("T", GetXlaConvTypes()),
91     DepthwiseConv2DOp);
92 
93 // Backprop for input.
94 class ConvBackpropInputOp : public XlaOpKernel {
95  public:
ConvBackpropInputOp(OpKernelConstruction * ctx,int num_spatial_dims,bool depthwise)96   explicit ConvBackpropInputOp(OpKernelConstruction* ctx, int num_spatial_dims,
97                                bool depthwise)
98       : XlaOpKernel(ctx) {
99     StatusOr<ConvOpAttrs> attrs =
100         ConvOpAttrs::Create(num_spatial_dims, depthwise, ctx);
101     OP_REQUIRES_OK(ctx, attrs.status());
102     attrs_ = attrs.ValueOrDie();
103   }
104 
Compile(XlaOpKernelContext * ctx)105   void Compile(XlaOpKernelContext* ctx) override {
106     TensorShape input_tensor_shape;
107     OP_REQUIRES_OK(
108         ctx, ctx->ConstantInputAsShape(0, &input_tensor_shape,
109                                        xla::ValueInferenceMode::kUpperBound));
110     xla::Shape input_shape =
111         TensorShapeToXLAShape(ctx->input_xla_type(1), input_tensor_shape);
112     OP_REQUIRES(ctx, input_shape.rank() == attrs_.num_spatial_dims + 2,
113                 errors::InvalidArgument(
114                     "The rank of the specified input shape must be "
115                     "num_spatial_dims + 2. Expected ",
116                     attrs_.num_spatial_dims + 2, " got ", input_shape.rank()));
117     xla::XlaOp input_sizes = ctx->Input(0);
118     StatusOr<xla::XlaOp> in_backprop = MakeXlaBackpropInputConvOp(
119         ctx->op_kernel().type_string(), input_shape, ctx->Input(1),
120         ctx->Input(2), attrs_, nullptr, &input_sizes);
121     OP_REQUIRES_OK(ctx, in_backprop.status());
122     ctx->SetOutput(0, in_backprop.ValueOrDie());
123   }
124 
125  protected:
126   ConvOpAttrs attrs_;
127 
128  private:
129   TF_DISALLOW_COPY_AND_ASSIGN(ConvBackpropInputOp);
130 };
131 
132 class Conv2DBackpropInputOp : public ConvBackpropInputOp {
133  public:
Conv2DBackpropInputOp(OpKernelConstruction * ctx)134   explicit Conv2DBackpropInputOp(OpKernelConstruction* ctx)
135       : ConvBackpropInputOp(ctx, /*num_spatial_dims=*/2, /*depthwise=*/false) {}
136 };
137 REGISTER_XLA_OP(Name("Conv2DBackpropInput")
138                     .CompileTimeConstantInput("input_sizes")
139                     .TypeConstraint("T", GetXlaConvTypes()),
140                 Conv2DBackpropInputOp);
141 
142 class Conv3DBackpropInputOp : public ConvBackpropInputOp {
143  public:
Conv3DBackpropInputOp(OpKernelConstruction * ctx)144   explicit Conv3DBackpropInputOp(OpKernelConstruction* ctx)
145       : ConvBackpropInputOp(ctx, /*num_spatial_dims=*/3, /*depthwise=*/false) {}
146 };
147 REGISTER_XLA_OP(Name("Conv3DBackpropInputV2")
148                     .CompileTimeConstantInput("input_sizes")
149                     .TypeConstraint("T", GetXlaConvTypes()),
150                 Conv3DBackpropInputOp);
151 
152 class DepthwiseConv2DBackpropInputOp : public ConvBackpropInputOp {
153  public:
DepthwiseConv2DBackpropInputOp(OpKernelConstruction * ctx)154   explicit DepthwiseConv2DBackpropInputOp(OpKernelConstruction* ctx)
155       : ConvBackpropInputOp(ctx, /*num_spatial_dims=*/2, /*depthwise=*/true) {}
156 };
157 REGISTER_XLA_OP(Name("DepthwiseConv2dNativeBackpropInput")
158                     .CompileTimeConstantInput("input_sizes")
159                     .TypeConstraint("T", GetXlaConvTypes()),
160                 DepthwiseConv2DBackpropInputOp);
161 
162 class ConvBackpropFilterOp : public XlaOpKernel {
163  public:
ConvBackpropFilterOp(OpKernelConstruction * ctx,int num_spatial_dims,bool depthwise)164   explicit ConvBackpropFilterOp(OpKernelConstruction* ctx, int num_spatial_dims,
165                                 bool depthwise)
166       : XlaOpKernel(ctx) {
167     StatusOr<ConvOpAttrs> attrs =
168         ConvOpAttrs::Create(num_spatial_dims, depthwise, ctx);
169     OP_REQUIRES_OK(ctx, attrs.status());
170     attrs_ = attrs.ValueOrDie();
171   }
172 
Compile(XlaOpKernelContext * ctx)173   void Compile(XlaOpKernelContext* ctx) override {
174     TensorShape filter_tensor_shape;
175     OP_REQUIRES_OK(
176         ctx, ctx->ConstantInputAsShape(1, &filter_tensor_shape,
177                                        xla::ValueInferenceMode::kUpperBound));
178     xla::Shape filter_shape =
179         TensorShapeToXLAShape(ctx->input_xla_type(0), filter_tensor_shape);
180 
181     StatusOr<xla::XlaOp> filter_backprop = MakeXlaBackpropFilterConvOp(
182         ctx->op_kernel().type_string(), ctx->Input(0), filter_shape,
183         ctx->Input(2), attrs_);
184     OP_REQUIRES_OK(ctx, filter_backprop.status());
185     ctx->SetOutput(0, filter_backprop.ValueOrDie());
186   }
187 
188  protected:
189   ConvOpAttrs attrs_;
190 
191  private:
192   TF_DISALLOW_COPY_AND_ASSIGN(ConvBackpropFilterOp);
193 };
194 
195 class Conv2DBackpropFilterOp : public ConvBackpropFilterOp {
196  public:
Conv2DBackpropFilterOp(OpKernelConstruction * ctx)197   explicit Conv2DBackpropFilterOp(OpKernelConstruction* ctx)
198       : ConvBackpropFilterOp(ctx, /*num_spatial_dims=*/2, /*depthwise=*/false) {
199   }
200 };
201 REGISTER_XLA_OP(Name("Conv2DBackpropFilter")
202                     .CompileTimeConstantInput("filter_sizes")
203                     .TypeConstraint("T", GetXlaConvTypes()),
204                 Conv2DBackpropFilterOp);
205 
206 class Conv3DBackpropFilterOp : public ConvBackpropFilterOp {
207  public:
Conv3DBackpropFilterOp(OpKernelConstruction * ctx)208   explicit Conv3DBackpropFilterOp(OpKernelConstruction* ctx)
209       : ConvBackpropFilterOp(ctx, /*num_spatial_dims=*/3, /*depthwise=*/false) {
210   }
211 };
212 REGISTER_XLA_OP(Name("Conv3DBackpropFilterV2")
213                     .CompileTimeConstantInput("filter_sizes")
214                     .TypeConstraint("T", GetXlaConvTypes()),
215                 Conv3DBackpropFilterOp);
216 
217 class DepthwiseConv2DBackpropFilterOp : public ConvBackpropFilterOp {
218  public:
DepthwiseConv2DBackpropFilterOp(OpKernelConstruction * ctx)219   explicit DepthwiseConv2DBackpropFilterOp(OpKernelConstruction* ctx)
220       : ConvBackpropFilterOp(ctx, /*num_spatial_dims=*/2, /*depthwise=*/true) {}
221 };
222 REGISTER_XLA_OP(Name("DepthwiseConv2dNativeBackpropFilter")
223                     .CompileTimeConstantInput("filter_sizes")
224                     .TypeConstraint("T", GetXlaConvTypes()),
225                 DepthwiseConv2DBackpropFilterOp);
226 
227 }  // namespace
228 }  // namespace tensorflow
229