1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_FRAMEWORK_COMMON_SHAPE_FNS_H_
16 #define TENSORFLOW_CORE_FRAMEWORK_COMMON_SHAPE_FNS_H_
17
18 #include <array>
19
20 #include "tensorflow/core/framework/shape_inference.h"
21 #include "tensorflow/core/util/padding.h"
22 #include "tensorflow/core/util/tensor_format.h"
23
24 namespace tensorflow {
25
26 namespace shape_inference {
27
28 // Like GetWindowedOutputSize, but deals with DimensionHandles. Does not support
29 // EXPLICIT padding.
30 Status GetWindowedOutputSizeFromDims(InferenceContext* c,
31 DimensionHandle input_size,
32 DimensionOrConstant filter_size,
33 int64_t stride, Padding padding_type,
34 DimensionHandle* output_size);
35
36 // The V2 version computes the same outputs with arbitrary dilation_rate, and
37 // supports EXPLICIT padding. For detailed equations, refer to the comments
38 // for GetWindowedOutputSizeV2(). The 'padding_before' and 'padding_after'
39 // parameters are only used if padding_type == EXPLICIT.
40 Status GetWindowedOutputSizeFromDimsV2(
41 InferenceContext* c, DimensionHandle input_size,
42 DimensionOrConstant filter_size, int64_t dilation_rate, int64_t stride,
43 Padding padding_type, int64_t padding_before, int64_t padding_after,
44 DimensionHandle* output_size);
45
46 // Transfers shape of input(0) to output(0).
47 Status UnchangedShape(shape_inference::InferenceContext* c);
48
49 // Transfers shape of input(0) to output(0), after asserting its rank is <rank>.
UnchangedShapeWithRank(shape_inference::InferenceContext * c,int32_t rank)50 inline Status UnchangedShapeWithRank(shape_inference::InferenceContext* c,
51 int32_t rank) {
52 ShapeHandle out;
53 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &out));
54 c->set_output(0, out);
55 return Status::OK();
56 }
57
58 // Transfers shape of input(0) to output(0), after asserting its rank >= <rank>.
UnchangedShapeWithRankAtLeast(shape_inference::InferenceContext * c,int32_t rank)59 inline Status UnchangedShapeWithRankAtLeast(
60 shape_inference::InferenceContext* c, int32_t rank) {
61 ShapeHandle out;
62 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), rank, &out));
63 c->set_output(0, out);
64 return Status::OK();
65 }
66
67 // Transfers shape of input(0) to output(0), after asserting its rank <= <rank>.
UnchangedShapeWithRankAtMost(shape_inference::InferenceContext * c,int32_t rank)68 inline Status UnchangedShapeWithRankAtMost(shape_inference::InferenceContext* c,
69 int32_t rank) {
70 ShapeHandle out;
71 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), rank, &out));
72 c->set_output(0, out);
73 return Status::OK();
74 }
75
76 // Shape function for use with ops no outputs.
NoOutputs(shape_inference::InferenceContext * c)77 inline Status NoOutputs(shape_inference::InferenceContext* c) {
78 return Status::OK();
79 }
80
81 // Shape function for ops that output a single scalar value.
ScalarShape(shape_inference::InferenceContext * c)82 inline Status ScalarShape(shape_inference::InferenceContext* c) {
83 c->set_output(0, c->Scalar());
84 return Status::OK();
85 }
86
87 // Shape function for binary ops where both inputs and the output match.
MergeBothInputsShapeFn(InferenceContext * c)88 inline Status MergeBothInputsShapeFn(InferenceContext* c) {
89 ShapeHandle out;
90 TF_RETURN_IF_ERROR(c->Merge(c->input(0), c->input(1), &out));
91 c->set_output(0, out);
92 return Status::OK();
93 }
94
95 // Shape function for dataset iterators.
96 Status DatasetIteratorShape(shape_inference::InferenceContext* c);
97
98 // Returns a new shape with the specified dims arranged in the specified
99 // format. The returned value is owned by this context.
100 // Note: if format = "FORMAT_NCHW_VECT_C" then C represents the outer_depth.
101 Status MakeShapeFromFormat(TensorFormat format, DimensionOrConstant N,
102 const std::vector<DimensionOrConstant>& spatial,
103 DimensionOrConstant C, ShapeHandle* out,
104 shape_inference::InferenceContext* context);
105
106 // Shape function for MatMul-like operations.
107 Status MatMulShape(shape_inference::InferenceContext* c);
108
109 // Shape function for Batched MatMul-like operations with broadcasting across
110 // batch dimensions.
111 Status BatchMatMulV2Shape(shape_inference::InferenceContext* c);
112
113 // Shape function for BatchMatMul-like operations
114 Status BatchMatMulShape(shape_inference::InferenceContext* c);
115
116 // Shape function for Einsum.
117 Status EinsumShape(shape_inference::InferenceContext* c);
118
119 // Shape function for BiasAdd-like operations.
120 Status BiasAddShape(shape_inference::InferenceContext* c);
121
122 // Shape function for BiasAddGrad-like operations.
123 Status BiasAddGradShape(shape_inference::InferenceContext* c);
124
125 // Shape function for Conv2D-like operations that support explicit padding.
126 Status Conv2DShapeWithExplicitPadding(shape_inference::InferenceContext* c);
127
128 // Shape function for Conv2D-like operations that do not support explicit
129 // padding.
130 Status Conv2DShape(shape_inference::InferenceContext* c);
131
132 // Shape function for Conv3D-like operations.
133 Status Conv3DShape(shape_inference::InferenceContext* c);
134
135 // Shape function for DepthwiseConv2D-like operations that support explicit
136 // padding.
137 Status DepthwiseConv2DNativeShapeWithExplicitPadding(
138 shape_inference::InferenceContext* c);
139
140 // Shape function for DepthwiseConv2D-like operations that do not support
141 // explicit padding.
142 Status DepthwiseConv2DNativeShape(shape_inference::InferenceContext* c);
143
144 // Shape function for Conv2DBackpropInput.
145 Status Conv2DBackpropInputShape(shape_inference::InferenceContext* c);
146
147 // Shape function for Conv2DBackpropFilterWithBias.
148 Status Conv2DBackpropFilterWithBiasShape(shape_inference::InferenceContext* c);
149
150 // Shape function for AvgPool-like operations.
151 Status AvgPoolShape(shape_inference::InferenceContext* c);
152
153 // Shape function for AvgPoolGrad-like operations.
154 Status AvgPoolGradShape(shape_inference::InferenceContext* c);
155
156 // Shape function for FusedBatchNorm and FusedBatchNormV2 operations.
157 Status FusedBatchNormShape(shape_inference::InferenceContext* c);
158
159 // Shape function for FusedBatchNormV3 operations.
160 Status FusedBatchNormV3Shape(shape_inference::InferenceContext* c);
161
162 // Shape function for _FusedBatchNormEx operations.
163 Status FusedBatchNormExShape(shape_inference::InferenceContext* c);
164
165 // Shape function for FusedBatchNormGrad and FusedBatchNormGradV2 operations.
166 Status FusedBatchNormGradShape(shape_inference::InferenceContext* c);
167
168 // Shape function for MatrixDiagPartV2 and MatrixDiagPartV3 operations.
169 Status MatrixDiagPartV2Shape(shape_inference::InferenceContext* c);
170
171 // Shape function for MatrixDiagV2 and MatrixDiagV3 operations.
172 Status MatrixDiagV2Shape(shape_inference::InferenceContext* c);
173
174 // Shape function for MatrixSetDiagV2 and MatrixSetDiagV3 operations.
175 Status MatrixSetDiagV2Shape(shape_inference::InferenceContext* c);
176
177 // Shape function for MaxPool-like operations that support explicit padding.
178 Status MaxPoolShapeWithExplicitPadding(shape_inference::InferenceContext* c);
179
180 // Shape function for MaxPool-like operations that do not support explicit
181 // padding.
182 Status MaxPoolShape(shape_inference::InferenceContext* c);
183
184 // Shape function for MaxPoolV2-like operations.
185 Status MaxPoolV2Shape(shape_inference::InferenceContext* c, int num_inputs);
186
187 // Shape function for MaxPoolGrad-like operations.
188 Status MaxPoolGradShape(shape_inference::InferenceContext* c);
189
190 // Shape function for 3D Pooling operations.
191 Status Pool3DShape(shape_inference::InferenceContext* c);
192
193 // Shape function for MaxPool3DGrad-like operations.
194 Status MaxPool3DGradShape(shape_inference::InferenceContext* c);
195
196 // Shape function for AvgPool3DGrad-like operations.
197 Status AvgPool3DGradShape(shape_inference::InferenceContext* c);
198
199 // Shape function for use with ops whose output shapes are unknown.
200 Status UnknownShape(shape_inference::InferenceContext* c);
201
202 // Shape function for reduction operations.
203 Status ReductionShape(shape_inference::InferenceContext* c);
204
205 // Shape function for unsorted segment operations.
206 Status UnsortedSegmentReductionShapeFn(InferenceContext* c);
207
208 // Shape function for concat operations.
209 // <num_inputs_to_concat> is the number of inputs to concatenate and are taken
210 // from inputs
211 // [1,num_inputs_to_concat] of the op. Input 0 is the concat_dim input.
212 Status ConcatShape(shape_inference::InferenceContext* c,
213 int num_inputs_to_concat);
214
215 // Shape function for concat operations.
216 Status ConcatV2Shape(shape_inference::InferenceContext* c);
217
218 Status QuantizedConcatV2Shape(InferenceContext* c, int num_inputs_to_concat);
219
220 // Shape function for binary operators that broadcast their inputs
221 // and with output to output_index.
222 // Note: out cannot be NULL.
223 Status BroadcastBinaryOpOutputShapeFnHelper(InferenceContext* c,
224 ShapeHandle shape_x,
225 ShapeHandle shape_y,
226 bool incompatible_shape_error,
227 ShapeHandle* out);
228
229 // Shape function for binary operators that broadcast their inputs
230 // and with output to output_index.
BroadcastBinaryOpOutputShapeFn(InferenceContext * c,int output_index)231 inline Status BroadcastBinaryOpOutputShapeFn(InferenceContext* c,
232 int output_index) {
233 ShapeHandle out;
234 TF_RETURN_IF_ERROR(BroadcastBinaryOpOutputShapeFnHelper(
235 c, c->input(0), c->input(1), true, &out));
236 c->set_output(output_index, out);
237 return Status::OK();
238 }
239
240 // Shape function for binary operators that broadcast their inputs.
241 // Tested by ops/math_ops_test.cc.
BroadcastBinaryOpShapeFn(InferenceContext * c)242 inline Status BroadcastBinaryOpShapeFn(InferenceContext* c) {
243 return BroadcastBinaryOpOutputShapeFn(c, 0);
244 }
245
246 // Shape function for random operations.
247 Status RandomShape(shape_inference::InferenceContext* c);
248
249 // Shape function for Slice operations.
250 Status SliceShape(shape_inference::InferenceContext* c);
251
252 // Validates the 3 component tensors of a sparse tensor have the proper
253 // shapes. This mimics SparseTensor.__init__ in python/framework/ops.py.
254 Status ValidateSparseTensor(InferenceContext* c, ShapeHandle indices_shape,
255 ShapeHandle values_shape, ShapeHandle shape_shape);
256
257 Status ValidateVariableResourceHandle(
258 InferenceContext* c, std::vector<ShapeAndType>* shape_and_type);
259
260 // Shape function for GatherNd operations.
261 Status GatherNdShape(InferenceContext* c);
262
263 // Helper shape function for ScatterNd.../TensorScatter... operations.
264 Status ScatterNdShapeHelper(InferenceContext* c, ShapeHandle indices_shape,
265 ShapeHandle updates_shape, ShapeHandle input_shape);
266
267 // Shape function for ops with an explicit "shape" attribute.
268 Status ExplicitShape(InferenceContext* c);
269
270 // Shape function for multiple-output ops with an explicit "shapes" attribute.
271 Status ExplicitShapes(InferenceContext* c);
272
273 // Shape function for SparseReduceMax and SparseReduceSum.
274 Status SparseReduceShapeFn(InferenceContext* c);
275
276 // Shape function for QuantizedConv2D op.
277 Status QuantizedConv2DShape(InferenceContext* c);
278
279 // Shape function for QuantizedAvgPool op
280 Status QuantizedAvgPoolShape(InferenceContext* c);
281
282 // Shape function for QuantizeV2 op
283 Status QuantizeV2Shape(InferenceContext* c);
284
285 } // namespace shape_inference
286
287 } // namespace tensorflow
288
289 #endif // TENSORFLOW_CORE_FRAMEWORK_COMMON_SHAPE_FNS_H_
290