1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_FRAMEWORK_COMMON_SHAPE_FNS_H_
16 #define TENSORFLOW_CORE_FRAMEWORK_COMMON_SHAPE_FNS_H_
17
18 #include <array>
19
20 #include "tensorflow/core/framework/shape_inference.h"
21 #include "tensorflow/core/util/padding.h"
22 #include "tensorflow/core/util/tensor_format.h"
23
24 namespace tensorflow {
25
26 // GetWindowedOutputSize(): Given an input tensor, kernel, stride and padding
27 // type, the function computes the output and padding dimensions.
28 //
29 // For example, ignoring batches or multiple features, a 1D convolution
30 // takes as input a 1D tensor of shape (H), and convolves it with a filter of
31 // shape (K).
32 //
33 // It also takes in a few additional parameters:
34 //
35 // Stride (S): the stride with which we apply the filters. This is the offset
36 // between locations where we apply the filters. A larger stride
37 // means that the output will be spatially smaller.
38 //
39 // Padding (P): the padding we apply to the input tensor along each
40 // dimension. This is usually used to make sure that the spatial dimensions
41 // do not shrink when we progress with convolutions. This function supports two
42 // types of padding.
43 // SAME: the pad value is computed so that the output will have size H/S.
44 // VALID: no padding is carried out.
45 // If you want to use EXPLICIT padding, GetWindowedOutputSizeVerbose must be
46 // called instead. Note the padded area is zero-filled.
47 //
48 // The output dimensions for convolution and many other operations, when given
49 // all the parameters above, are as follows:
50 // - When Padding = SAME: the output size is (H'), where
51 // H' = ceil(float(H) / float(S))
52 // where ceil is the ceiling function. The number of padded cells
53 // is computed as:
54 // Pc = ((H' - 1) * S + K - H) / 2
55 // When the stride is 1, the expression simplifies to
56 // H' = H, Pc = (K-1)/2.
57 // This is where SAME comes from - the output has the same size as the input
58 // has.
59 //
60 // - When Padding = VALID: the output size is computed as
61 // H' = ceil(float(H - K + 1) / float(S))
62 // and the number of padded cells is always zero.
63 // When the stride is 1, the expression simplifies to
64 // H' = H-K+1.
65 //
66 // For convolution, mathematically, the output value at location (r')
67 // is the inner product of two vectors: the chunk of input at
68 // ((r'*S-Pr) : (r'*S-Pr+K)),
69 // and the filter.
70 //
71 // For 2D and 3D convolutions, the spatial dimensions are orthogonal, so the
72 // size and padding of each spatial dimension can be computed by calling
73 // GetWindowedOutputSize separately for each dimension.
74 //
75 Status GetWindowedOutputSize(int64 input_size, int64 filter_size, int64 stride,
76 Padding padding_type, int64* output_size,
77 int64* padding_size);
78
79 // The V2 version computes the same outputs with arbitrary dilation_rate.
80 // The output dimensions are computed as follows:
81 // - When adding dilation_rate (D), we compute an effective filter size (K'):
82 // K' = (K - 1) * D + 1
83 // - When Padding = SAME: the output size is (H'), where
84 // H' = ceil(float(H) / float(S))
85 // where ceil is the ceiling function. The number of padded cells
86 // is computed as:
87 // Pc = ((H' - 1) * S + K' - H) / 2
88 // When the stride is 1, the expression simplifies to
89 // H' = H, Pc = (K'-1)/2.
90 // This is where SAME comes from - the output has the same size as the input
91 // has.
92 //
93 // - When Padding = VALID: the output size is computed as
94 // H' = ceil(float(H - K' + 1) / float(S))
95 // and the number of padded cells is always zero.
96 // When the stride is 1, the expression simplifies to
97 // H' = H-K'+1.
98 //
99 // If you want to use EXPLICIT padding, GetWindowedOutputSizeVerboseV2 must be
100 // called instead
101 //
102 // TODO(b/67112639): Merge V2 versions and the original versions eventually.
103 Status GetWindowedOutputSizeV2(int64 input_size, int64 filter_size,
104 int64 dilation_rate, int64 stride,
105 Padding padding_type, int64* output_size,
106 int64* padding_size);
107
108 // Returns the same output dimensions as in GetWindowedOutputSize, but returns
109 // verbose padding dimensions (before/after), and EXPLICIT padding is supported.
110 // When padding_type is EXPLICIT, *padding_before and *padding_after must
111 // already point to initialized integers with the padding amounts. Otherwise,
112 // *padding_before and *padding_after are set by this function, and any
113 // excess padding (caused by an odd padding size value) is added to the
114 // 'padding_after' dimension.
115 Status GetWindowedOutputSizeVerbose(int64 input_size, int64 filter_size,
116 int64 stride, Padding padding_type,
117 int64* output_size, int64* padding_before,
118 int64* padding_after);
119
120 // The V2 version computes the same outputs with arbitrary dilation_rate. For
121 // detailed equations, refer to the comments for GetWindowedOutputSizeV2().
122 Status GetWindowedOutputSizeVerboseV2(int64 input_size, int64 filter_size,
123 int64 dilation_rate, int64 stride,
124 Padding padding_type, int64* output_size,
125 int64* padding_before,
126 int64* padding_after);
127
128 // Given an input tensor, kernel, stride and padding type, populates the 3D size
129 // of the output tensor and padding to be applied to the input tensor at the
130 // lower end of every dimension. Use for 3D convolutions, where the input data
131 // is padded with zeros, as well as for 3D avg/max pooling, where the input data
132 // is padded with invalid values that are not considered for pooling. EXPLICIT
133 // padding is not supported.
134 Status Get3dOutputSize(const std::array<int64, 3>& input,
135 const std::array<int64, 3>& window,
136 const std::array<int64, 3>& strides,
137 Padding padding_type, std::array<int64, 3>* output_ptr,
138 std::array<int64, 3>* padding_ptr);
139
140 // The V2 version computes the same outputs with arbitrary dilation_rate. For
141 // detailed equations, refer to the comments for GetWindowedOutputSizeV2().
142 Status Get3dOutputSizeV2(const std::array<int64, 3>& input,
143 const std::array<int64, 3>& window,
144 const std::array<int64, 3>& dilations,
145 const std::array<int64, 3>& strides,
146 Padding padding_type, std::array<int64, 3>* output_ptr,
147 std::array<int64, 3>* padding_ptr);
148
149 namespace shape_inference {
150
151 // Like GetWindowedOutputSize, but deals with DimensionHandles. Does not support
152 // EXPLICIT padding.
153 Status GetWindowedOutputSizeFromDims(InferenceContext* c,
154 DimensionHandle input_size,
155 DimensionOrConstant filter_size,
156 int64 stride, Padding padding_type,
157 DimensionHandle* output_size);
158
159 // The V2 version computes the same outputs with arbitrary dilation_rate, and
160 // supports EXPLICIT padding. For detailed equations, refer to the comments
161 // for GetWindowedOutputSizeV2(). The 'padding_before' and 'padding_after'
162 // parameters are only used if padding_type == EXPLICIT.
163 Status GetWindowedOutputSizeFromDimsV2(
164 InferenceContext* c, DimensionHandle input_size,
165 DimensionOrConstant filter_size, int64 dilation_rate, int64 stride,
166 Padding padding_type, int64 padding_before, int64 padding_after,
167 DimensionHandle* output_size);
168
169 // Transfers shape of input(0) to output(0).
170 Status UnchangedShape(shape_inference::InferenceContext* c);
171
172 // Transfers shape of input(0) to output(0), after asserting its rank is <rank>.
UnchangedShapeWithRank(shape_inference::InferenceContext * c,int32 rank)173 inline Status UnchangedShapeWithRank(shape_inference::InferenceContext* c,
174 int32 rank) {
175 ShapeHandle out;
176 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &out));
177 c->set_output(0, out);
178 return Status::OK();
179 }
180
181 // Transfers shape of input(0) to output(0), after asserting its rank >= <rank>.
UnchangedShapeWithRankAtLeast(shape_inference::InferenceContext * c,int32 rank)182 inline Status UnchangedShapeWithRankAtLeast(
183 shape_inference::InferenceContext* c, int32 rank) {
184 ShapeHandle out;
185 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), rank, &out));
186 c->set_output(0, out);
187 return Status::OK();
188 }
189
190 // Transfers shape of input(0) to output(0), after asserting its rank <= <rank>.
UnchangedShapeWithRankAtMost(shape_inference::InferenceContext * c,int32 rank)191 inline Status UnchangedShapeWithRankAtMost(shape_inference::InferenceContext* c,
192 int32 rank) {
193 ShapeHandle out;
194 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), rank, &out));
195 c->set_output(0, out);
196 return Status::OK();
197 }
198
199 // Shape function for use with ops no outputs.
NoOutputs(shape_inference::InferenceContext * c)200 inline Status NoOutputs(shape_inference::InferenceContext* c) {
201 return Status::OK();
202 }
203
204 // Shape function for ops that output a single scalar value.
ScalarShape(shape_inference::InferenceContext * c)205 inline Status ScalarShape(shape_inference::InferenceContext* c) {
206 c->set_output(0, c->Scalar());
207 return Status::OK();
208 }
209
210 // Shape function for binary ops where both inputs and the output match.
MergeBothInputsShapeFn(InferenceContext * c)211 inline Status MergeBothInputsShapeFn(InferenceContext* c) {
212 ShapeHandle out;
213 TF_RETURN_IF_ERROR(c->Merge(c->input(0), c->input(1), &out));
214 c->set_output(0, out);
215 return Status::OK();
216 }
217
218 // Returns a new shape with the specified dims arranged in the specified
219 // format. The returned value is owned by this context.
220 // Note: if format = "FORMAT_NCHW_VECT_C" then C represents the outer_depth.
221 Status MakeShapeFromFormat(TensorFormat format, DimensionOrConstant N,
222 const std::vector<DimensionOrConstant>& spatial,
223 DimensionOrConstant C, ShapeHandle* out,
224 shape_inference::InferenceContext* context);
225
226 // Shape function for MatMul-like operations.
227 Status MatMulShape(shape_inference::InferenceContext* c);
228
229 // Shape function for BiasAdd-like operations.
230 Status BiasAddShape(shape_inference::InferenceContext* c);
231
232 // Shape function for BiasAddGrad-like operations.
233 Status BiasAddGradShape(shape_inference::InferenceContext* c);
234
235 // Shape function for Conv2D-like operations that support explicit padding.
236 Status Conv2DShapeWithExplicitPadding(shape_inference::InferenceContext* c);
237
238 // Shape function for Conv2D-like operations that do not support explicit
239 // padding.
240 Status Conv2DShape(shape_inference::InferenceContext* c);
241
242 // Shape function for Conv3D-like operations.
243 Status Conv3DShape(shape_inference::InferenceContext* c);
244
245 // Shape function for DepthwiseConv2D-like operations.
246 Status DepthwiseConv2DNativeShape(shape_inference::InferenceContext* c);
247
248 // Shape function for AvgPool-like operations.
249 Status AvgPoolShape(shape_inference::InferenceContext* c);
250
251 // Shape function for FusedBatchNorm and FusedBatchNormV2 operations.
252 Status FusedBatchNormShape(shape_inference::InferenceContext* c);
253
254 // Shape function for FusedBatchNormGrad and FusedBatchNormGradV2 operations.
255 Status FusedBatchNormGradShape(shape_inference::InferenceContext* c);
256
257 // Shape function for MaxPool-like operations.
258 Status MaxPoolShape(shape_inference::InferenceContext* c);
259
260 // Shape function for MaxPoolV2-like operations.
261 Status MaxPoolV2Shape(shape_inference::InferenceContext* c, int num_inputs);
262
263 // Shape function for 3D Pooling operations.
264 Status Pool3DShape(shape_inference::InferenceContext* c);
265
266 // Shape function for use with ops whose output shapes are unknown.
267 Status UnknownShape(shape_inference::InferenceContext* c);
268
269 // Shape function for reduction operations.
270 Status ReductionShape(shape_inference::InferenceContext* c);
271
272 // Shape function for concat operations.
273 // <num_inputs_to_concat> is the number of inputs to concatenate and are taken
274 // from inputs
275 // [1,num_inputs_to_concat] of the op. Input 0 is the concat_dim input.
276 Status ConcatShape(shape_inference::InferenceContext* c,
277 int num_inputs_to_concat);
278
279 // Shape function for concat operations.
280 Status ConcatV2Shape(shape_inference::InferenceContext* c);
281
282 Status QuantizedConcatV2Shape(InferenceContext* c, int num_inputs_to_concat);
283
284 // Shape function for binary operators that broadcast their inputs
285 // and with output to output_index.
286 // Note: out cannot be NULL.
287 Status BroadcastBinaryOpOutputShapeFnHelper(InferenceContext* c,
288 ShapeHandle shape_x,
289 ShapeHandle shape_y,
290 ShapeHandle* out);
291
292 // Shape function for binary operators that broadcast their inputs
293 // and with output to output_index.
BroadcastBinaryOpOutputShapeFn(InferenceContext * c,int output_index)294 inline Status BroadcastBinaryOpOutputShapeFn(InferenceContext* c,
295 int output_index) {
296 ShapeHandle out;
297 TF_RETURN_IF_ERROR(
298 BroadcastBinaryOpOutputShapeFnHelper(c, c->input(0), c->input(1), &out));
299 c->set_output(output_index, out);
300 return Status::OK();
301 }
302
303 // Shape function for binary operators that broadcast their inputs.
304 // Tested by ops/math_ops_test.cc.
BroadcastBinaryOpShapeFn(InferenceContext * c)305 inline Status BroadcastBinaryOpShapeFn(InferenceContext* c) {
306 return BroadcastBinaryOpOutputShapeFn(c, 0);
307 }
308
309 // Shape function for random operations.
310 Status RandomShape(shape_inference::InferenceContext* c);
311
312 // Shape function for Slice opertaions.
313 Status SliceShape(shape_inference::InferenceContext* c);
314
315 // Validates the 3 component tensors of a sparse tensor have the proper
316 // shapes. This mimics SparseTensor.__init__ in python/framework/ops.py.
317 Status ValidateSparseTensor(InferenceContext* c, ShapeHandle indices_shape,
318 ShapeHandle values_shape, ShapeHandle shape_shape);
319
320 // Shape function for ScatterNd update/add/sub/... operations.
321 Status ScatterNdUpdateShape(InferenceContext* c);
322
323 // Shape function for ops with an explicit "shape" attribute.
324 Status ExplicitShape(InferenceContext* c);
325
326 // Shape function for multiple-output ops with an explicit "shapes" attribute.
327 Status ExplicitShapes(InferenceContext* c);
328
329 // Shape function for SparseReduceMax and SparseReduceSum.
330 Status SparseReduceShapeFn(InferenceContext* c);
331
332 } // namespace shape_inference
333
334 } // namespace tensorflow
335
336 #endif // TENSORFLOW_CORE_FRAMEWORK_COMMON_SHAPE_FNS_H_
337