• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_FRAMEWORK_COMMON_SHAPE_FNS_H_
16 #define TENSORFLOW_CORE_FRAMEWORK_COMMON_SHAPE_FNS_H_
17 
18 #include <array>
19 
20 #include "tensorflow/core/framework/shape_inference.h"
21 #include "tensorflow/core/util/padding.h"
22 #include "tensorflow/core/util/tensor_format.h"
23 
24 namespace tensorflow {
25 
26 // GetWindowedOutputSize(): Given an input tensor, kernel, stride and padding
27 // type, the function computes the output and padding dimensions.
28 //
29 // For example, ignoring batches or multiple features, a 1D convolution
30 // takes as input a 1D tensor of shape (H), and convolves it with a filter of
31 // shape (K).
32 //
33 // It also takes in a few additional parameters:
34 //
35 // Stride (S): the stride with which we apply the filters. This is the offset
36 // between locations where we apply the filters. A larger stride
37 // means that the output will be spatially smaller.
38 //
39 // Padding (P): the padding we apply to the input tensor along each
40 // dimension. This is usually used to make sure that the spatial dimensions
41 // do not shrink when we progress with convolutions. This function supports two
42 // types of padding.
43 //   SAME: the pad value is computed so that the output will have size H/S.
44 //   VALID: no padding is carried out.
45 // If you want to use EXPLICIT padding, GetWindowedOutputSizeVerbose must be
46 // called instead. Note the padded area is zero-filled.
47 //
48 // The output dimensions for convolution and many other operations, when given
49 // all the parameters above, are as follows:
50 // - When Padding = SAME: the output size is (H'), where
51 //     H' = ceil(float(H) / float(S))
52 //   where ceil is the ceiling function. The number of padded cells
53 //   is computed as:
54 //     Pc = ((H' - 1) * S + K - H) / 2
55 //   When the stride is 1, the expression simplifies to
56 //     H' = H, Pc = (K-1)/2.
57 //   This is where SAME comes from - the output has the same size as the input
58 //   has.
59 //
60 // - When Padding = VALID: the output size is computed as
61 //     H' = ceil(float(H - K + 1) / float(S))
62 //   and the number of padded cells is always zero.
63 //   When the stride is 1, the expression simplifies to
64 //     H' = H-K+1.
65 //
66 // For convolution, mathematically, the output value at location (r')
67 // is the inner product of two vectors: the chunk of input at
68 //    ((r'*S-Pr) : (r'*S-Pr+K)),
69 // and the filter.
70 //
71 // For 2D and 3D convolutions, the spatial dimensions are orthogonal, so the
72 // size and padding of each spatial dimension can be computed by calling
73 // GetWindowedOutputSize separately for each dimension.
74 //
75 Status GetWindowedOutputSize(int64 input_size, int64 filter_size, int64 stride,
76                              Padding padding_type, int64* output_size,
77                              int64* padding_size);
78 
79 // The V2 version computes the same outputs with arbitrary dilation_rate.
80 // The output dimensions are computed as follows:
81 // - When adding dilation_rate (D), we compute an effective filter size (K'):
82 //     K' = (K - 1) * D + 1
83 // - When Padding = SAME: the output size is (H'), where
84 //     H' = ceil(float(H) / float(S))
85 //   where ceil is the ceiling function. The number of padded cells
86 //   is computed as:
87 //     Pc = ((H' - 1) * S + K' - H) / 2
88 //   When the stride is 1, the expression simplifies to
89 //     H' = H, Pc = (K'-1)/2.
90 //   This is where SAME comes from - the output has the same size as the input
91 //   has.
92 //
93 // - When Padding = VALID: the output size is computed as
94 //     H' = ceil(float(H - K' + 1) / float(S))
95 //   and the number of padded cells is always zero.
96 //   When the stride is 1, the expression simplifies to
97 //     H' = H-K'+1.
98 //
99 // If you want to use EXPLICIT padding, GetWindowedOutputSizeVerboseV2 must be
100 // called instead
101 //
102 // TODO(b/67112639): Merge V2 versions and the original versions eventually.
103 Status GetWindowedOutputSizeV2(int64 input_size, int64 filter_size,
104                                int64 dilation_rate, int64 stride,
105                                Padding padding_type, int64* output_size,
106                                int64* padding_size);
107 
108 // Returns the same output dimensions as in GetWindowedOutputSize, but returns
109 // verbose padding dimensions (before/after), and EXPLICIT padding is supported.
110 // When padding_type is EXPLICIT, *padding_before and *padding_after must
111 // already point to initialized integers with the padding amounts. Otherwise,
112 // *padding_before and *padding_after are set by this function, and any
113 // excess padding (caused by an odd padding size value) is added to the
114 // 'padding_after' dimension.
115 Status GetWindowedOutputSizeVerbose(int64 input_size, int64 filter_size,
116                                     int64 stride, Padding padding_type,
117                                     int64* output_size, int64* padding_before,
118                                     int64* padding_after);
119 
120 // The V2 version computes the same outputs with arbitrary dilation_rate. For
121 // detailed equations, refer to the comments for GetWindowedOutputSizeV2().
122 Status GetWindowedOutputSizeVerboseV2(int64 input_size, int64 filter_size,
123                                       int64 dilation_rate, int64 stride,
124                                       Padding padding_type, int64* output_size,
125                                       int64* padding_before,
126                                       int64* padding_after);
127 
128 // Given an input tensor, kernel, stride and padding type, populates the 3D size
129 // of the output tensor and padding to be applied to the input tensor at the
130 // lower end of every dimension. Use for 3D convolutions, where the input data
131 // is padded with zeros, as well as for 3D avg/max pooling, where the input data
132 // is padded with invalid values that are not considered for pooling. EXPLICIT
133 // padding is not supported.
134 Status Get3dOutputSize(const std::array<int64, 3>& input,
135                        const std::array<int64, 3>& window,
136                        const std::array<int64, 3>& strides,
137                        Padding padding_type, std::array<int64, 3>* output_ptr,
138                        std::array<int64, 3>* padding_ptr);
139 
140 // The V2 version computes the same outputs with arbitrary dilation_rate. For
141 // detailed equations, refer to the comments for GetWindowedOutputSizeV2().
142 Status Get3dOutputSizeV2(const std::array<int64, 3>& input,
143                          const std::array<int64, 3>& window,
144                          const std::array<int64, 3>& dilations,
145                          const std::array<int64, 3>& strides,
146                          Padding padding_type, std::array<int64, 3>* output_ptr,
147                          std::array<int64, 3>* padding_ptr);
148 
149 namespace shape_inference {
150 
151 // Like GetWindowedOutputSize, but deals with DimensionHandles. Does not support
152 // EXPLICIT padding.
153 Status GetWindowedOutputSizeFromDims(InferenceContext* c,
154                                      DimensionHandle input_size,
155                                      DimensionOrConstant filter_size,
156                                      int64 stride, Padding padding_type,
157                                      DimensionHandle* output_size);
158 
159 // The V2 version computes the same outputs with arbitrary dilation_rate, and
160 // supports EXPLICIT padding. For detailed equations, refer to the comments
161 // for GetWindowedOutputSizeV2(). The 'padding_before' and 'padding_after'
162 // parameters are only used if padding_type == EXPLICIT.
163 Status GetWindowedOutputSizeFromDimsV2(
164     InferenceContext* c, DimensionHandle input_size,
165     DimensionOrConstant filter_size, int64 dilation_rate, int64 stride,
166     Padding padding_type, int64 padding_before, int64 padding_after,
167     DimensionHandle* output_size);
168 
169 // Transfers shape of input(0) to output(0).
170 Status UnchangedShape(shape_inference::InferenceContext* c);
171 
172 // Transfers shape of input(0) to output(0), after asserting its rank is <rank>.
UnchangedShapeWithRank(shape_inference::InferenceContext * c,int32 rank)173 inline Status UnchangedShapeWithRank(shape_inference::InferenceContext* c,
174                                      int32 rank) {
175   ShapeHandle out;
176   TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &out));
177   c->set_output(0, out);
178   return Status::OK();
179 }
180 
181 // Transfers shape of input(0) to output(0), after asserting its rank >= <rank>.
UnchangedShapeWithRankAtLeast(shape_inference::InferenceContext * c,int32 rank)182 inline Status UnchangedShapeWithRankAtLeast(
183     shape_inference::InferenceContext* c, int32 rank) {
184   ShapeHandle out;
185   TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), rank, &out));
186   c->set_output(0, out);
187   return Status::OK();
188 }
189 
190 // Transfers shape of input(0) to output(0), after asserting its rank <= <rank>.
UnchangedShapeWithRankAtMost(shape_inference::InferenceContext * c,int32 rank)191 inline Status UnchangedShapeWithRankAtMost(shape_inference::InferenceContext* c,
192                                            int32 rank) {
193   ShapeHandle out;
194   TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), rank, &out));
195   c->set_output(0, out);
196   return Status::OK();
197 }
198 
199 // Shape function for use with ops no outputs.
NoOutputs(shape_inference::InferenceContext * c)200 inline Status NoOutputs(shape_inference::InferenceContext* c) {
201   return Status::OK();
202 }
203 
204 // Shape function for ops that output a single scalar value.
ScalarShape(shape_inference::InferenceContext * c)205 inline Status ScalarShape(shape_inference::InferenceContext* c) {
206   c->set_output(0, c->Scalar());
207   return Status::OK();
208 }
209 
210 // Shape function for binary ops where both inputs and the output match.
MergeBothInputsShapeFn(InferenceContext * c)211 inline Status MergeBothInputsShapeFn(InferenceContext* c) {
212   ShapeHandle out;
213   TF_RETURN_IF_ERROR(c->Merge(c->input(0), c->input(1), &out));
214   c->set_output(0, out);
215   return Status::OK();
216 }
217 
218 // Returns a new shape with the specified dims arranged in the specified
219 // format. The returned value is owned by this context.
220 // Note: if format = "FORMAT_NCHW_VECT_C" then C represents the outer_depth.
221 Status MakeShapeFromFormat(TensorFormat format, DimensionOrConstant N,
222                            const std::vector<DimensionOrConstant>& spatial,
223                            DimensionOrConstant C, ShapeHandle* out,
224                            shape_inference::InferenceContext* context);
225 
226 // Shape function for MatMul-like operations.
227 Status MatMulShape(shape_inference::InferenceContext* c);
228 
229 // Shape function for BiasAdd-like operations.
230 Status BiasAddShape(shape_inference::InferenceContext* c);
231 
232 // Shape function for BiasAddGrad-like operations.
233 Status BiasAddGradShape(shape_inference::InferenceContext* c);
234 
235 // Shape function for Conv2D-like operations that support explicit padding.
236 Status Conv2DShapeWithExplicitPadding(shape_inference::InferenceContext* c);
237 
238 // Shape function for Conv2D-like operations that do not support explicit
239 // padding.
240 Status Conv2DShape(shape_inference::InferenceContext* c);
241 
242 // Shape function for Conv3D-like operations.
243 Status Conv3DShape(shape_inference::InferenceContext* c);
244 
245 // Shape function for DepthwiseConv2D-like operations.
246 Status DepthwiseConv2DNativeShape(shape_inference::InferenceContext* c);
247 
248 // Shape function for AvgPool-like operations.
249 Status AvgPoolShape(shape_inference::InferenceContext* c);
250 
251 // Shape function for FusedBatchNorm and FusedBatchNormV2 operations.
252 Status FusedBatchNormShape(shape_inference::InferenceContext* c);
253 
254 // Shape function for FusedBatchNormGrad and FusedBatchNormGradV2 operations.
255 Status FusedBatchNormGradShape(shape_inference::InferenceContext* c);
256 
257 // Shape function for MaxPool-like operations.
258 Status MaxPoolShape(shape_inference::InferenceContext* c);
259 
260 // Shape function for MaxPoolV2-like operations.
261 Status MaxPoolV2Shape(shape_inference::InferenceContext* c, int num_inputs);
262 
263 // Shape function for 3D Pooling operations.
264 Status Pool3DShape(shape_inference::InferenceContext* c);
265 
266 // Shape function for use with ops whose output shapes are unknown.
267 Status UnknownShape(shape_inference::InferenceContext* c);
268 
269 // Shape function for reduction operations.
270 Status ReductionShape(shape_inference::InferenceContext* c);
271 
272 // Shape function for concat operations.
273 // <num_inputs_to_concat> is the number of inputs to concatenate and are taken
274 // from inputs
275 // [1,num_inputs_to_concat] of the op.  Input 0 is the concat_dim input.
276 Status ConcatShape(shape_inference::InferenceContext* c,
277                    int num_inputs_to_concat);
278 
279 // Shape function for concat operations.
280 Status ConcatV2Shape(shape_inference::InferenceContext* c);
281 
282 Status QuantizedConcatV2Shape(InferenceContext* c, int num_inputs_to_concat);
283 
284 // Shape function for binary operators that broadcast their inputs
285 // and with output to output_index.
286 // Note: out cannot be NULL.
287 Status BroadcastBinaryOpOutputShapeFnHelper(InferenceContext* c,
288                                             ShapeHandle shape_x,
289                                             ShapeHandle shape_y,
290                                             ShapeHandle* out);
291 
292 // Shape function for binary operators that broadcast their inputs
293 // and with output to output_index.
BroadcastBinaryOpOutputShapeFn(InferenceContext * c,int output_index)294 inline Status BroadcastBinaryOpOutputShapeFn(InferenceContext* c,
295                                              int output_index) {
296   ShapeHandle out;
297   TF_RETURN_IF_ERROR(
298       BroadcastBinaryOpOutputShapeFnHelper(c, c->input(0), c->input(1), &out));
299   c->set_output(output_index, out);
300   return Status::OK();
301 }
302 
303 // Shape function for binary operators that broadcast their inputs.
304 // Tested by ops/math_ops_test.cc.
BroadcastBinaryOpShapeFn(InferenceContext * c)305 inline Status BroadcastBinaryOpShapeFn(InferenceContext* c) {
306   return BroadcastBinaryOpOutputShapeFn(c, 0);
307 }
308 
309 // Shape function for random operations.
310 Status RandomShape(shape_inference::InferenceContext* c);
311 
312 // Shape function for Slice opertaions.
313 Status SliceShape(shape_inference::InferenceContext* c);
314 
315 // Validates the 3 component tensors of a sparse tensor have the proper
316 // shapes. This mimics SparseTensor.__init__ in python/framework/ops.py.
317 Status ValidateSparseTensor(InferenceContext* c, ShapeHandle indices_shape,
318                             ShapeHandle values_shape, ShapeHandle shape_shape);
319 
320 // Shape function for ScatterNd update/add/sub/... operations.
321 Status ScatterNdUpdateShape(InferenceContext* c);
322 
323 // Shape function for ops with an explicit "shape" attribute.
324 Status ExplicitShape(InferenceContext* c);
325 
326 // Shape function for multiple-output ops with an explicit "shapes" attribute.
327 Status ExplicitShapes(InferenceContext* c);
328 
329 // Shape function for SparseReduceMax and SparseReduceSum.
330 Status SparseReduceShapeFn(InferenceContext* c);
331 
332 }  // namespace shape_inference
333 
334 }  // namespace tensorflow
335 
336 #endif  // TENSORFLOW_CORE_FRAMEWORK_COMMON_SHAPE_FNS_H_
337