1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/framework/common_shape_fns.h"
16 #include "tensorflow/core/framework/attr_value.pb.h"
17
18 namespace tensorflow {
19
GetWindowedOutputSizeVerboseV2(int64 input_size,int64 filter_size,int64 dilation_rate,int64 stride,Padding padding_type,int64 * output_size,int64 * padding_before,int64 * padding_after)20 Status GetWindowedOutputSizeVerboseV2(int64 input_size, int64 filter_size,
21 int64 dilation_rate, int64 stride,
22 Padding padding_type, int64* output_size,
23 int64* padding_before,
24 int64* padding_after) {
25 if (stride <= 0) {
26 return errors::InvalidArgument("Stride must be > 0, but got ", stride);
27 }
28 if (dilation_rate < 1) {
29 return errors::InvalidArgument("Dilation rate must be >= 1, but got ",
30 dilation_rate);
31 }
32
33 // See also the parallel implementation in GetWindowedOutputSizeFromDimsV2.
34 int64 effective_filter_size = (filter_size - 1) * dilation_rate + 1;
35 switch (padding_type) {
36 case Padding::VALID:
37 *output_size = (input_size - effective_filter_size + stride) / stride;
38 *padding_before = *padding_after = 0;
39 break;
40 case Padding::SAME:
41 *output_size = (input_size + stride - 1) / stride;
42 const int64 padding_needed =
43 std::max(0LL, (*output_size - 1) * stride + effective_filter_size -
44 input_size);
45 // For odd values of total padding, add more padding at the 'right'
46 // side of the given dimension.
47 *padding_before = padding_needed / 2;
48 *padding_after = padding_needed - *padding_before;
49 break;
50 }
51 if (*output_size < 0) {
52 return errors::InvalidArgument(
53 "Computed output size would be negative: ", *output_size,
54 " [input_size: ", input_size,
55 ", effective_filter_size: ", effective_filter_size,
56 ", stride: ", stride, "]");
57 }
58 return Status::OK();
59 }
60
GetWindowedOutputSizeVerbose(int64 input_size,int64 filter_size,int64 stride,Padding padding_type,int64 * output_size,int64 * padding_before,int64 * padding_after)61 Status GetWindowedOutputSizeVerbose(int64 input_size, int64 filter_size,
62 int64 stride, Padding padding_type,
63 int64* output_size, int64* padding_before,
64 int64* padding_after) {
65 return GetWindowedOutputSizeVerboseV2(input_size, filter_size,
66 /*dilation_rate=*/1, stride,
67 padding_type, output_size,
68 padding_before, padding_after);
69 }
70
GetWindowedOutputSize(int64 input_size,int64 filter_size,int64 stride,Padding padding_type,int64 * output_size,int64 * padding_size)71 Status GetWindowedOutputSize(int64 input_size, int64 filter_size, int64 stride,
72 Padding padding_type, int64* output_size,
73 int64* padding_size) {
74 int64 padding_after_unused;
75 return GetWindowedOutputSizeVerbose(input_size, filter_size, stride,
76 padding_type, output_size, padding_size,
77 &padding_after_unused);
78 }
79
GetWindowedOutputSizeV2(int64 input_size,int64 filter_size,int64 dilation_rate,int64 stride,Padding padding_type,int64 * output_size,int64 * padding_size)80 Status GetWindowedOutputSizeV2(int64 input_size, int64 filter_size,
81 int64 dilation_rate, int64 stride,
82 Padding padding_type, int64* output_size,
83 int64* padding_size) {
84 int64 padding_after_unused;
85 return GetWindowedOutputSizeVerboseV2(input_size, filter_size, dilation_rate,
86 stride, padding_type, output_size,
87 padding_size, &padding_after_unused);
88 }
89
Get3dOutputSize(const std::array<int64,3> & input,const std::array<int64,3> & window,const std::array<int64,3> & strides,Padding padding_type,std::array<int64,3> * output_ptr,std::array<int64,3> * padding_ptr)90 Status Get3dOutputSize(const std::array<int64, 3>& input,
91 const std::array<int64, 3>& window,
92 const std::array<int64, 3>& strides,
93 Padding padding_type, std::array<int64, 3>* output_ptr,
94 std::array<int64, 3>* padding_ptr) {
95 for (size_t i = 0; i < input.size(); ++i) {
96 TF_RETURN_IF_ERROR(GetWindowedOutputSize(input[i], window[i], strides[i],
97 padding_type, &(*output_ptr)[i],
98 &(*padding_ptr)[i]));
99 }
100 return Status::OK();
101 }
102
Get3dOutputSizeV2(const std::array<int64,3> & input,const std::array<int64,3> & window,const std::array<int64,3> & dilations,const std::array<int64,3> & strides,Padding padding_type,std::array<int64,3> * output_ptr,std::array<int64,3> * padding_ptr)103 Status Get3dOutputSizeV2(const std::array<int64, 3>& input,
104 const std::array<int64, 3>& window,
105 const std::array<int64, 3>& dilations,
106 const std::array<int64, 3>& strides,
107 Padding padding_type, std::array<int64, 3>* output_ptr,
108 std::array<int64, 3>* padding_ptr) {
109 for (size_t i = 0; i < input.size(); ++i) {
110 TF_RETURN_IF_ERROR(GetWindowedOutputSizeV2(
111 input[i], window[i], dilations[i], strides[i], padding_type,
112 &(*output_ptr)[i], &(*padding_ptr)[i]));
113 }
114 return Status::OK();
115 }
116
117 namespace shape_inference {
118
119 // The V2 version computes windowed output size with arbitrary dilation_rate,
120 // while the original version only handles the cases where dilation_rates equal
121 // to 1.
GetWindowedOutputSizeFromDimsV2(shape_inference::InferenceContext * c,shape_inference::DimensionHandle input_size,shape_inference::DimensionOrConstant filter_size,int64 dilation_rate,int64 stride,Padding padding_type,shape_inference::DimensionHandle * output_size)122 Status GetWindowedOutputSizeFromDimsV2(
123 shape_inference::InferenceContext* c,
124 shape_inference::DimensionHandle input_size,
125 shape_inference::DimensionOrConstant filter_size, int64 dilation_rate,
126 int64 stride, Padding padding_type,
127 shape_inference::DimensionHandle* output_size) {
128 if (stride <= 0) {
129 return errors::InvalidArgument("Stride must be > 0, but got ", stride);
130 }
131
132 if (dilation_rate < 1) {
133 return errors::InvalidArgument("Dilation rate must be >= 1, but got ",
134 dilation_rate);
135 }
136
137 // See also the parallel implementation in GetWindowedOutputSizeVerbose.
138 switch (padding_type) {
139 case Padding::VALID:
140 if (dilation_rate > 1) {
141 DimensionHandle window_size;
142 TF_RETURN_IF_ERROR(
143 c->Subtract(c->MakeDim(filter_size), 1, &window_size));
144 TF_RETURN_IF_ERROR(
145 c->Multiply(window_size, dilation_rate, &window_size));
146 TF_RETURN_IF_ERROR(c->Add(window_size, 1, &window_size));
147 TF_RETURN_IF_ERROR(c->Subtract(input_size, window_size, output_size));
148 } else {
149 TF_RETURN_IF_ERROR(c->Subtract(input_size, filter_size, output_size));
150 }
151 TF_RETURN_IF_ERROR(c->Add(*output_size, stride, output_size));
152 TF_RETURN_IF_ERROR(c->Divide(*output_size, stride,
153 /*evenly_divisible=*/false, output_size));
154 break;
155 case Padding::SAME:
156 TF_RETURN_IF_ERROR(c->Add(input_size, stride - 1, output_size));
157 TF_RETURN_IF_ERROR(c->Divide(*output_size, stride,
158 /*evenly_divisible=*/false, output_size));
159 break;
160 }
161 return Status::OK();
162 }
163
GetWindowedOutputSizeFromDims(shape_inference::InferenceContext * c,shape_inference::DimensionHandle input_size,shape_inference::DimensionOrConstant filter_size,int64 stride,Padding padding_type,shape_inference::DimensionHandle * output_size)164 Status GetWindowedOutputSizeFromDims(
165 shape_inference::InferenceContext* c,
166 shape_inference::DimensionHandle input_size,
167 shape_inference::DimensionOrConstant filter_size, int64 stride,
168 Padding padding_type, shape_inference::DimensionHandle* output_size) {
169 return GetWindowedOutputSizeFromDimsV2(c, input_size, filter_size,
170 /*dilation_rate=*/1, stride,
171 padding_type, output_size);
172 }
173
UnchangedShape(shape_inference::InferenceContext * c)174 Status UnchangedShape(shape_inference::InferenceContext* c) {
175 c->set_output(0, c->input(0));
176 return Status::OK();
177 }
178
MatMulShape(shape_inference::InferenceContext * c)179 Status MatMulShape(shape_inference::InferenceContext* c) {
180 ShapeHandle a;
181 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &a));
182
183 ShapeHandle b;
184 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &b));
185
186 bool transpose_a, transpose_b;
187 TF_RETURN_IF_ERROR(c->GetAttr("transpose_a", &transpose_a));
188 TF_RETURN_IF_ERROR(c->GetAttr("transpose_b", &transpose_b));
189 DimensionHandle output_rows = transpose_a ? c->Dim(a, 1) : c->Dim(a, 0);
190 DimensionHandle output_cols = transpose_b ? c->Dim(b, 0) : c->Dim(b, 1);
191
192 // Validate that the inner shapes are compatible.
193 DimensionHandle inner_a = transpose_a ? c->Dim(a, 0) : c->Dim(a, 1);
194 DimensionHandle inner_b = transpose_b ? c->Dim(b, 1) : c->Dim(b, 0);
195 DimensionHandle merged;
196 TF_RETURN_IF_ERROR(c->Merge(inner_a, inner_b, &merged));
197
198 c->set_output(0, c->Matrix(output_rows, output_cols));
199 return Status::OK();
200 }
201
BiasAddShape(shape_inference::InferenceContext * c)202 Status BiasAddShape(shape_inference::InferenceContext* c) {
203 ShapeHandle input_shape;
204
205 // Fetch the data_format attribute, which may not exist.
206 string data_format;
207 Status s = c->GetAttr("data_format", &data_format);
208
209 if (s.ok() && data_format == "NCHW") {
210 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 3, &input_shape));
211 } else {
212 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input_shape));
213 }
214
215 ShapeHandle bias_shape;
216 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &bias_shape));
217 DimensionHandle bias_dim = c->Dim(bias_shape, 0);
218
219 // If rank unknown, return unknown shape.
220 if (!c->RankKnown(input_shape)) {
221 c->set_output(0, c->UnknownShape());
222 return Status::OK();
223 }
224
225 // Output has the same shape as the input, and matches the length of
226 // the bias in its bias dimension.
227 ShapeHandle output_shape;
228 if (s.ok() && data_format == "NCHW") {
229 // Merge the length of bias_shape into the third to last dimension
230 ShapeHandle first;
231 TF_RETURN_IF_ERROR(c->Subshape(input_shape, 0, -3, &first));
232
233 ShapeHandle last;
234 TF_RETURN_IF_ERROR(c->Subshape(input_shape, -2, &last));
235
236 DimensionHandle input_bias_dim = c->Dim(input_shape, -3);
237 DimensionHandle merged_bias_dim;
238 TF_RETURN_IF_ERROR(c->Merge(input_bias_dim, bias_dim, &merged_bias_dim));
239 ShapeHandle merged_bias = c->Vector(merged_bias_dim);
240
241 ShapeHandle temp;
242 TF_RETURN_IF_ERROR(c->Concatenate(first, merged_bias, &temp));
243 TF_RETURN_IF_ERROR(c->Concatenate(temp, last, &output_shape));
244 } else {
245 ShapeHandle all_but_bias;
246 TF_RETURN_IF_ERROR(c->Subshape(input_shape, 0, -1, &all_but_bias));
247
248 DimensionHandle input_bias_dim = c->Dim(input_shape, -1);
249 DimensionHandle merged_bias_dim;
250 TF_RETURN_IF_ERROR(c->Merge(input_bias_dim, bias_dim, &merged_bias_dim));
251
252 ShapeHandle merged_bias = c->Vector(merged_bias_dim);
253 TF_RETURN_IF_ERROR(
254 c->Concatenate(all_but_bias, merged_bias, &output_shape));
255 }
256
257 c->set_output(0, output_shape);
258 return Status::OK();
259 }
260
BiasAddGradShape(shape_inference::InferenceContext * c)261 Status BiasAddGradShape(shape_inference::InferenceContext* c) {
262 ShapeHandle input_shape;
263 // Fetch the data_format attribute, which may not exist.
264 string data_format;
265 Status s = c->GetAttr("data_format", &data_format);
266
267 if (s.ok() && data_format == "NCHW") {
268 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 3, &input_shape));
269 c->set_output(0, c->Vector(c->Dim(input_shape, -3)));
270 } else {
271 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input_shape));
272 c->set_output(0, c->Vector(c->Dim(input_shape, -1)));
273 }
274
275 return Status::OK();
276 }
277
CheckFormatConstraintsOnShape(const TensorFormat tensor_format,const ShapeHandle shape_handle,const string & tensor_name,shape_inference::InferenceContext * c)278 Status CheckFormatConstraintsOnShape(const TensorFormat tensor_format,
279 const ShapeHandle shape_handle,
280 const string& tensor_name,
281 shape_inference::InferenceContext* c) {
282 if (tensor_format == FORMAT_NCHW_VECT_C) {
283 // Check that the vect dim has size 4.
284 const int num_dims = c->Rank(shape_handle);
285 DimensionHandle vect_dim = c->Dim(
286 shape_handle, GetTensorInnerFeatureDimIndex(num_dims, tensor_format));
287 DimensionHandle unused_vect_dim;
288 TF_RETURN_IF_ERROR(c->WithValue(vect_dim, 4, &unused_vect_dim));
289 }
290
291 return Status::OK();
292 }
293
MakeShapeFromFormat(TensorFormat format,DimensionOrConstant N,const std::vector<DimensionOrConstant> & spatial,DimensionOrConstant C,ShapeHandle * out,shape_inference::InferenceContext * context)294 Status MakeShapeFromFormat(TensorFormat format, DimensionOrConstant N,
295 const std::vector<DimensionOrConstant>& spatial,
296 DimensionOrConstant C, ShapeHandle* out,
297 shape_inference::InferenceContext* context) {
298 const int num_dims = GetTensorDimsFromSpatialDims(spatial.size(), format);
299 std::vector<DimensionHandle> dims_actual(num_dims);
300 dims_actual[GetTensorBatchDimIndex(num_dims, format)] = context->MakeDim(N);
301 int outer_c_index = GetTensorFeatureDimIndex(num_dims, format);
302 dims_actual[outer_c_index] = context->MakeDim(C);
303 if (format == FORMAT_NCHW_VECT_C) {
304 dims_actual[GetTensorInnerFeatureDimIndex(num_dims, format)] =
305 context->MakeDim(4);
306 }
307 for (int spatial_dim = 0; spatial_dim < spatial.size(); spatial_dim++) {
308 dims_actual[GetTensorSpatialDimIndex(num_dims, format, spatial_dim)] =
309 context->MakeDim(spatial[spatial_dim]);
310 }
311 *out = context->MakeShape(dims_actual);
312 return Status::OK();
313 }
314
DimensionsFromShape(ShapeHandle shape,TensorFormat format,DimensionHandle * batch_dim,gtl::MutableArraySlice<DimensionHandle> spatial_dims,DimensionHandle * filter_dim,InferenceContext * context)315 Status DimensionsFromShape(ShapeHandle shape, TensorFormat format,
316 DimensionHandle* batch_dim,
317 gtl::MutableArraySlice<DimensionHandle> spatial_dims,
318 DimensionHandle* filter_dim,
319 InferenceContext* context) {
320 const int32 rank = GetTensorDimsFromSpatialDims(spatial_dims.size(), format);
321 // Batch.
322 *batch_dim = context->Dim(shape, GetTensorBatchDimIndex(rank, format));
323 // Spatial.
324 for (int spatial_dim_index = 0; spatial_dim_index < spatial_dims.size();
325 ++spatial_dim_index) {
326 spatial_dims[spatial_dim_index] = context->Dim(
327 shape, GetTensorSpatialDimIndex(rank, format, spatial_dim_index));
328 }
329 // Channel.
330 *filter_dim = context->Dim(shape, GetTensorFeatureDimIndex(rank, format));
331 if (format == FORMAT_NCHW_VECT_C) {
332 TF_RETURN_IF_ERROR(context->Multiply(
333 *filter_dim,
334 context->Dim(shape, GetTensorInnerFeatureDimIndex(rank, format)),
335 filter_dim));
336 }
337 return Status::OK();
338 }
339
ShapeFromDimensions(DimensionHandle batch_dim,gtl::ArraySlice<DimensionHandle> spatial_dims,DimensionHandle filter_dim,TensorFormat format,InferenceContext * context,ShapeHandle * shape)340 Status ShapeFromDimensions(DimensionHandle batch_dim,
341 gtl::ArraySlice<DimensionHandle> spatial_dims,
342 DimensionHandle filter_dim, TensorFormat format,
343 InferenceContext* context, ShapeHandle* shape) {
344 const int32 rank = GetTensorDimsFromSpatialDims(spatial_dims.size(), format);
345 std::vector<DimensionHandle> out_dims(rank);
346
347 // Batch.
348 out_dims[tensorflow::GetTensorBatchDimIndex(rank, format)] = batch_dim;
349 // Spatial.
350 for (int spatial_dim_index = 0; spatial_dim_index < spatial_dims.size();
351 ++spatial_dim_index) {
352 out_dims[tensorflow::GetTensorSpatialDimIndex(
353 rank, format, spatial_dim_index)] = spatial_dims[spatial_dim_index];
354 }
355 // Channel.
356 if (format == tensorflow::FORMAT_NCHW_VECT_C) {
357 // When format is NCHW_VECT_C, factor the feature map count
358 // into the outer feature count and the inner feature count (=4).
359 TF_RETURN_IF_ERROR(context->Divide(
360 filter_dim, 4, /*evenly_divisible=*/true,
361 &out_dims[tensorflow::GetTensorFeatureDimIndex(rank, format)]));
362 out_dims[GetTensorInnerFeatureDimIndex(rank, format)] = context->MakeDim(4);
363 } else {
364 out_dims[tensorflow::GetTensorFeatureDimIndex(rank, format)] = filter_dim;
365 }
366
367 *shape = context->MakeShape(out_dims);
368 return tensorflow::Status::OK();
369 }
370
Conv2DShape(shape_inference::InferenceContext * c)371 Status Conv2DShape(shape_inference::InferenceContext* c) {
372 string data_format_str, filter_format_str;
373 if (!c->GetAttr("data_format", &data_format_str).ok()) {
374 data_format_str = "NHWC";
375 }
376 if (!c->GetAttr("filter_format", &filter_format_str).ok()) {
377 filter_format_str = "HWIO";
378 }
379
380 TensorFormat data_format;
381 if (!FormatFromString(data_format_str, &data_format)) {
382 return errors::InvalidArgument("Invalid data format string: ",
383 data_format_str);
384 }
385 FilterTensorFormat filter_format;
386 if (!FilterFormatFromString(filter_format_str, &filter_format)) {
387 return errors::InvalidArgument("Invalid filter format string: ",
388 filter_format_str);
389 }
390
391 constexpr int num_spatial_dims = 2;
392 const int rank = GetTensorDimsFromSpatialDims(num_spatial_dims, data_format);
393 ShapeHandle conv_input_shape;
394 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &conv_input_shape));
395 TF_RETURN_IF_ERROR(CheckFormatConstraintsOnShape(
396 data_format, conv_input_shape, "conv_input", c));
397
398 // The filter rank should match the input (4 for NCHW, 5 for NCHW_VECT_C).
399 ShapeHandle filter_shape;
400 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), rank, &filter_shape));
401 TF_RETURN_IF_ERROR(
402 CheckFormatConstraintsOnShape(data_format, filter_shape, "filter", c));
403
404 std::vector<int32> dilations;
405 TF_RETURN_IF_ERROR(c->GetAttr("dilations", &dilations));
406
407 if (dilations.size() != 4) {
408 return errors::InvalidArgument(
409 "Conv2D requires the dilation attribute to contain 4 values, but got: ",
410 dilations.size());
411 }
412
413 std::vector<int32> strides;
414 TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
415
416 // strides.size() should be 4 (NCHW) even if the input is 5 (NCHW_VECT_C).
417 if (strides.size() != 4) {
418 return errors::InvalidArgument("Conv2D on data format ", data_format_str,
419 " requires the stride attribute to contain"
420 " 4 values, but got: ",
421 strides.size());
422 }
423
424 const int32 stride_rows = GetTensorDim(strides, data_format, 'H');
425 const int32 stride_cols = GetTensorDim(strides, data_format, 'W');
426 const int32 dilation_rows = GetTensorDim(dilations, data_format, 'H');
427 const int32 dilation_cols = GetTensorDim(dilations, data_format, 'W');
428
429 DimensionHandle batch_size_dim;
430 DimensionHandle input_depth_dim;
431 gtl::InlinedVector<DimensionHandle, 2> input_spatial_dims(2);
432 TF_RETURN_IF_ERROR(DimensionsFromShape(conv_input_shape, data_format,
433 &batch_size_dim, &input_spatial_dims,
434 &input_depth_dim, c));
435
436 DimensionHandle output_depth_dim = c->Dim(
437 filter_shape, GetFilterDimIndex<num_spatial_dims>(filter_format, 'O'));
438 DimensionHandle filter_rows_dim = c->Dim(
439 filter_shape, GetFilterDimIndex<num_spatial_dims>(filter_format, 'H'));
440 DimensionHandle filter_cols_dim = c->Dim(
441 filter_shape, GetFilterDimIndex<num_spatial_dims>(filter_format, 'W'));
442 DimensionHandle filter_input_depth_dim;
443 if (filter_format == FORMAT_OIHW_VECT_I) {
444 TF_RETURN_IF_ERROR(c->Multiply(
445 c->Dim(filter_shape,
446 GetFilterDimIndex<num_spatial_dims>(filter_format, 'I')),
447 c->Dim(filter_shape,
448 GetFilterTensorInnerInputChannelsDimIndex(rank, filter_format)),
449 &filter_input_depth_dim));
450 } else {
451 filter_input_depth_dim = c->Dim(
452 filter_shape, GetFilterDimIndex<num_spatial_dims>(filter_format, 'I'));
453 }
454
455 // Check that the input tensor and the filter tensor agree on the input
456 // channel count.
457 DimensionHandle unused;
458 TF_RETURN_IF_ERROR(
459 c->Merge(input_depth_dim, filter_input_depth_dim, &unused));
460
461 Padding padding;
462 TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
463
464 DimensionHandle output_rows, output_cols;
465 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDimsV2(
466 c, input_spatial_dims[0], filter_rows_dim, dilation_rows, stride_rows,
467 padding, &output_rows));
468 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDimsV2(
469 c, input_spatial_dims[1], filter_cols_dim, dilation_cols, stride_cols,
470 padding, &output_cols));
471
472 ShapeHandle output_shape;
473 TF_RETURN_IF_ERROR(
474 ShapeFromDimensions(batch_size_dim, {output_rows, output_cols},
475 output_depth_dim, data_format, c, &output_shape));
476 c->set_output(0, output_shape);
477 return Status::OK();
478 }
479
480 // TODO(mjanusz): Unify all conv/pooling shape functions.
Conv3DShape(shape_inference::InferenceContext * c)481 Status Conv3DShape(shape_inference::InferenceContext* c) {
482 ShapeHandle input_shape;
483 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 5, &input_shape));
484 ShapeHandle filter_shape;
485 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 5, &filter_shape));
486
487 string data_format;
488 Status s = c->GetAttr("data_format", &data_format);
489
490 std::vector<int32> strides;
491 TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
492 if (strides.size() != 5) {
493 return errors::InvalidArgument(
494 "Conv3D requires the stride attribute to contain 5 values, but got: ",
495 strides.size());
496 }
497
498 int32 stride_planes, stride_rows, stride_cols;
499 if (s.ok() && data_format == "NCDHW") {
500 // Convert input_shape to NDHWC.
501 auto dim = [&](char dimension) {
502 return c->Dim(input_shape, GetTensorDimIndex<3>(FORMAT_NCHW, dimension));
503 };
504 input_shape =
505 c->MakeShape({{dim('N'), dim('0'), dim('1'), dim('2'), dim('C')}});
506 stride_planes = strides[2];
507 stride_cols = strides[3];
508 stride_rows = strides[4];
509 } else {
510 stride_planes = strides[1];
511 stride_rows = strides[2];
512 stride_cols = strides[3];
513 }
514
515 DimensionHandle batch_size_dim = c->Dim(input_shape, 0);
516 DimensionHandle in_planes_dim = c->Dim(input_shape, 1);
517 DimensionHandle in_rows_dim = c->Dim(input_shape, 2);
518 DimensionHandle in_cols_dim = c->Dim(input_shape, 3);
519
520 DimensionHandle filter_planes_dim = c->Dim(filter_shape, 0);
521 DimensionHandle filter_rows_dim = c->Dim(filter_shape, 1);
522 DimensionHandle filter_cols_dim = c->Dim(filter_shape, 2);
523 DimensionHandle output_depth_dim = c->Dim(filter_shape, 4);
524
525 DimensionHandle unused;
526 TF_RETURN_IF_ERROR(
527 c->Merge(c->Dim(input_shape, 4), c->Dim(filter_shape, 3), &unused));
528
529 Padding padding;
530 TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
531 DimensionHandle output_planes, output_rows, output_cols;
532
533 TF_RETURN_IF_ERROR(
534 GetWindowedOutputSizeFromDims(c, in_planes_dim, filter_planes_dim,
535 stride_planes, padding, &output_planes));
536 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
537 c, in_rows_dim, filter_rows_dim, stride_rows, padding, &output_rows));
538 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
539 c, in_cols_dim, filter_cols_dim, stride_cols, padding, &output_cols));
540
541 ShapeHandle output_shape;
542 if (data_format == "NCDHW") {
543 output_shape = c->MakeShape({batch_size_dim, output_depth_dim,
544 output_planes, output_rows, output_cols});
545 } else {
546 output_shape = c->MakeShape({batch_size_dim, output_planes, output_rows,
547 output_cols, output_depth_dim});
548 }
549 c->set_output(0, output_shape);
550 return Status::OK();
551 }
552
DepthwiseConv2DNativeShape(shape_inference::InferenceContext * c)553 Status DepthwiseConv2DNativeShape(shape_inference::InferenceContext* c) {
554 ShapeHandle input_shape;
555 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_shape));
556 ShapeHandle filter_shape;
557 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 4, &filter_shape));
558
559 std::vector<int32> strides;
560 TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
561
562 if (strides.size() != 4) {
563 return errors::InvalidArgument(
564 "DepthwiseConv2D requires the stride attribute to contain 4 values, "
565 "but got: ",
566 strides.size());
567 }
568
569 string data_format;
570 Status s = c->GetAttr("data_format", &data_format);
571 int32 stride_rows;
572 int32 stride_cols;
573 if (s.ok() && data_format == "NCHW") {
574 // Canonicalize input shape to NHWC so the shape inference code below can
575 // process it.
576 input_shape =
577 c->MakeShape({{c->Dim(input_shape, 0), c->Dim(input_shape, 2),
578 c->Dim(input_shape, 3), c->Dim(input_shape, 1)}});
579 stride_rows = strides[2];
580 stride_cols = strides[3];
581 } else {
582 stride_rows = strides[1];
583 stride_cols = strides[2];
584 }
585
586 DimensionHandle batch_size_dim = c->Dim(input_shape, 0);
587 DimensionHandle in_rows_dim = c->Dim(input_shape, 1);
588 DimensionHandle in_cols_dim = c->Dim(input_shape, 2);
589
590 DimensionHandle filter_rows_dim = c->Dim(filter_shape, 0);
591 DimensionHandle filter_cols_dim = c->Dim(filter_shape, 1);
592 DimensionHandle input_depth = c->Dim(filter_shape, 2);
593 DimensionHandle depth_multiplier = c->Dim(filter_shape, 3);
594
595 // Check that the input depths are compatible.
596 TF_RETURN_IF_ERROR(
597 c->Merge(c->Dim(input_shape, 3), input_depth, &input_depth));
598
599 DimensionHandle output_depth;
600 TF_RETURN_IF_ERROR(c->Multiply(input_depth, depth_multiplier, &output_depth));
601
602 Padding padding;
603 TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
604
605 // TODO(mrry,shlens): Raise an error if the stride would cause
606 // information in the input to be ignored. This will require a change
607 // in the kernel implementation.
608 DimensionHandle output_rows, output_cols;
609
610 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
611 c, in_rows_dim, filter_rows_dim, stride_rows, padding, &output_rows));
612 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
613 c, in_cols_dim, filter_cols_dim, stride_cols, padding, &output_cols));
614
615 ShapeHandle output_shape;
616 if (data_format == "NCHW") {
617 output_shape =
618 c->MakeShape({batch_size_dim, output_depth, output_rows, output_cols});
619 } else {
620 output_shape =
621 c->MakeShape({batch_size_dim, output_rows, output_cols, output_depth});
622 }
623 c->set_output(0, output_shape);
624 return Status::OK();
625 }
626
AvgPoolShape(shape_inference::InferenceContext * c)627 Status AvgPoolShape(shape_inference::InferenceContext* c) {
628 string data_format_str;
629 TensorFormat data_format;
630 Status s = c->GetAttr("data_format", &data_format_str);
631 if (s.ok()) {
632 FormatFromString(data_format_str, &data_format);
633 } else {
634 data_format = FORMAT_NHWC;
635 }
636
637 const int rank = (data_format == FORMAT_NCHW_VECT_C) ? 5 : 4;
638 ShapeHandle input_shape;
639 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &input_shape));
640
641 TF_RETURN_IF_ERROR(
642 CheckFormatConstraintsOnShape(data_format, input_shape, "input", c));
643
644 std::vector<int32> strides;
645 TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
646 if (strides.size() != 4) {
647 return errors::InvalidArgument(
648 "AvgPool requires the stride attribute to contain 4 values, but got: ",
649 strides.size());
650 }
651
652 std::vector<int32> kernel_sizes;
653 TF_RETURN_IF_ERROR(c->GetAttr("ksize", &kernel_sizes));
654 if (kernel_sizes.size() != 4) {
655 return errors::InvalidArgument(
656 "AvgPool requires the ksize attribute to contain 4 values, but got: ",
657 kernel_sizes.size());
658 }
659
660 int32 stride_rows = GetTensorDim(strides, data_format, 'H');
661 int32 stride_cols = GetTensorDim(strides, data_format, 'W');
662 int32 kernel_rows = GetTensorDim(kernel_sizes, data_format, 'H');
663 int32 kernel_cols = GetTensorDim(kernel_sizes, data_format, 'W');
664
665 constexpr int num_spatial_dims = 2;
666 DimensionHandle batch_size_dim = c->Dim(
667 input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'N'));
668 DimensionHandle in_rows_dim = c->Dim(
669 input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'H'));
670 DimensionHandle in_cols_dim = c->Dim(
671 input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'W'));
672 DimensionHandle depth_dim = c->Dim(
673 input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'C'));
674
675 Padding padding;
676 TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
677
678 // TODO(mrry,shlens): Raise an error if the stride would cause
679 // information in the input to be ignored. This will require a change
680 // in the kernel implementation.
681
682 DimensionHandle output_rows, output_cols;
683 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
684 c, in_rows_dim, kernel_rows, stride_rows, padding, &output_rows));
685 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
686 c, in_cols_dim, kernel_cols, stride_cols, padding, &output_cols));
687
688 ShapeHandle output_shape;
689 TF_RETURN_IF_ERROR(MakeShapeFromFormat(data_format, batch_size_dim,
690 {output_rows, output_cols}, depth_dim,
691 &output_shape, c));
692 c->set_output(0, output_shape);
693 return Status::OK();
694 }
695
FusedBatchNormShape(shape_inference::InferenceContext * c)696 Status FusedBatchNormShape(shape_inference::InferenceContext* c) {
697 ShapeHandle x;
698 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &x));
699
700 bool is_training;
701 TF_RETURN_IF_ERROR(c->GetAttr("is_training", &is_training));
702 int number_inputs = (is_training) ? 3 : 5;
703 string data_format;
704 TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format));
705 DimensionHandle channel_dim =
706 (data_format == "NHWC") ? c->Dim(x, 3) : c->Dim(x, 1);
707
708 // covers scale, offset, and if is_training is false, mean, variance
709 for (int i = 1; i < number_inputs; ++i) {
710 ShapeHandle vec;
711 TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &vec));
712 TF_RETURN_IF_ERROR(c->Merge(channel_dim, c->Dim(vec, 0), &channel_dim));
713 }
714
715 ShapeHandle y;
716 if (data_format == "NHWC") {
717 TF_RETURN_IF_ERROR(c->ReplaceDim(x, 3, channel_dim, &y));
718 } else {
719 TF_RETURN_IF_ERROR(c->ReplaceDim(x, 1, channel_dim, &y));
720 }
721 c->set_output(0, y);
722 ShapeHandle vector_shape = c->Vector(channel_dim);
723 c->set_output(1, vector_shape);
724 c->set_output(2, vector_shape);
725 c->set_output(3, vector_shape);
726 c->set_output(4, vector_shape);
727 return Status::OK();
728 }
729
FusedBatchNormGradShape(shape_inference::InferenceContext * c)730 Status FusedBatchNormGradShape(shape_inference::InferenceContext* c) {
731 ShapeHandle y_backprop;
732 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &y_backprop));
733 ShapeHandle x;
734 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 4, &x));
735
736 bool is_training;
737 string data_format;
738 TF_RETURN_IF_ERROR(c->GetAttr("is_training", &is_training));
739 TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format));
740 DimensionHandle channel_dim =
741 (data_format == "NHWC") ? c->Dim(y_backprop, 3) : c->Dim(y_backprop, 1);
742 if (data_format == "NHWC") {
743 TF_RETURN_IF_ERROR(c->Merge(channel_dim, c->Dim(x, 3), &channel_dim));
744 } else {
745 TF_RETURN_IF_ERROR(c->Merge(channel_dim, c->Dim(x, 1), &channel_dim));
746 }
747
748 // covers scale, mean (reserve_space_1), variance (reserve_space_2)
749 for (int i = 2; i < 5; ++i) {
750 ShapeHandle vec;
751 TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &vec));
752 TF_RETURN_IF_ERROR(c->Merge(channel_dim, c->Dim(vec, 0), &channel_dim));
753 }
754
755 ShapeHandle x_backprop;
756 if (data_format == "NHWC") {
757 TF_RETURN_IF_ERROR(c->ReplaceDim(y_backprop, 3, channel_dim, &x_backprop));
758 } else {
759 TF_RETURN_IF_ERROR(c->ReplaceDim(y_backprop, 1, channel_dim, &x_backprop));
760 }
761 c->set_output(0, x_backprop);
762 c->set_output(1, c->Vector(channel_dim));
763 c->set_output(2, c->Vector(channel_dim));
764 // Set the correct shapes for reserve_spaces
765 // so that gradients can be performed when
766 // the op is in a symbolic condition.
767 if (is_training) {
768 c->set_output(3, c->Vector(0));
769 c->set_output(4, c->Vector(0));
770 } else {
771 c->set_output(3, c->Vector(channel_dim));
772 c->set_output(4, c->Vector(channel_dim));
773 }
774 return Status::OK();
775 }
776
MaxPoolShape(shape_inference::InferenceContext * c)777 Status MaxPoolShape(shape_inference::InferenceContext* c) {
778 string data_format_str;
779 TensorFormat data_format;
780 Status s = c->GetAttr("data_format", &data_format_str);
781 if (s.ok()) {
782 FormatFromString(data_format_str, &data_format);
783 } else {
784 data_format = FORMAT_NHWC;
785 }
786
787 const int rank = (data_format == FORMAT_NCHW_VECT_C) ? 5 : 4;
788 ShapeHandle input_shape;
789 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &input_shape));
790
791 TF_RETURN_IF_ERROR(
792 CheckFormatConstraintsOnShape(data_format, input_shape, "input", c));
793
794 std::vector<int32> strides;
795 TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
796 if (strides.size() != 4) {
797 return errors::InvalidArgument(
798 "MaxPool requires the stride attribute to contain 4 values, but got: ",
799 strides.size());
800 }
801
802 std::vector<int32> kernel_sizes;
803 TF_RETURN_IF_ERROR(c->GetAttr("ksize", &kernel_sizes));
804 if (kernel_sizes.size() != 4) {
805 return errors::InvalidArgument(
806 "MaxPool requires the ksize attribute to contain 4 values, but got: ",
807 kernel_sizes.size());
808 }
809
810 int32 stride_depth = GetTensorDim(strides, data_format, 'C');
811 int32 stride_rows = GetTensorDim(strides, data_format, 'H');
812 int32 stride_cols = GetTensorDim(strides, data_format, 'W');
813 int32 kernel_depth = GetTensorDim(kernel_sizes, data_format, 'C');
814 int32 kernel_rows = GetTensorDim(kernel_sizes, data_format, 'H');
815 int32 kernel_cols = GetTensorDim(kernel_sizes, data_format, 'W');
816
817 constexpr int num_spatial_dims = 2;
818 DimensionHandle batch_size_dim = c->Dim(
819 input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'N'));
820 DimensionHandle in_rows_dim = c->Dim(
821 input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'H'));
822 DimensionHandle in_cols_dim = c->Dim(
823 input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'W'));
824 DimensionHandle in_depth_dim = c->Dim(
825 input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'C'));
826
827 Padding padding;
828 TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
829
830 ShapeHandle output_shape;
831 DimensionHandle output_rows, output_cols, output_depth;
832 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
833 c, in_rows_dim, kernel_rows, stride_rows, padding, &output_rows));
834 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
835 c, in_cols_dim, kernel_cols, stride_cols, padding, &output_cols));
836 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
837 c, in_depth_dim, kernel_depth, stride_depth, padding, &output_depth));
838
839 TF_RETURN_IF_ERROR(MakeShapeFromFormat(data_format, batch_size_dim,
840 {output_rows, output_cols},
841 output_depth, &output_shape, c));
842
843 c->set_output(0, output_shape);
844 return Status::OK();
845 }
846
MaxPoolV2Shape(shape_inference::InferenceContext * c,int num_inputs)847 Status MaxPoolV2Shape(shape_inference::InferenceContext* c, int num_inputs) {
848 string data_format_str;
849 TensorFormat data_format;
850 Status s = c->GetAttr("data_format", &data_format_str);
851 if (s.ok()) {
852 FormatFromString(data_format_str, &data_format);
853 } else {
854 data_format = FORMAT_NHWC;
855 }
856
857 const int rank = (data_format == FORMAT_NCHW_VECT_C) ? 5 : 4;
858 ShapeHandle input_shape;
859 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &input_shape));
860
861 TF_RETURN_IF_ERROR(
862 CheckFormatConstraintsOnShape(data_format, input_shape, "input", c));
863
864 std::vector<int32> kernel_sizes;
865 std::vector<int32> strides;
866
867 if (c->num_inputs() + 2 == num_inputs) {
868 TF_RETURN_IF_ERROR(c->GetAttr("ksize", &kernel_sizes));
869
870 TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
871 } else {
872 // Verify shape of ksize and strides input.
873 ShapeHandle size;
874 DimensionHandle unused;
875 TF_RETURN_IF_ERROR(c->WithRank(c->input(c->num_inputs() - 2), 1, &size));
876 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(size, 0), 4, &unused));
877 TF_RETURN_IF_ERROR(c->WithRank(c->input(c->num_inputs() - 1), 1, &size));
878 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(size, 0), 4, &unused));
879
880 const Tensor* kernel_sizes_tensor = c->input_tensor(c->num_inputs() - 2);
881 if (kernel_sizes_tensor == nullptr) {
882 c->set_output(0, c->UnknownShape());
883 return Status::OK();
884 }
885 kernel_sizes.resize(kernel_sizes_tensor->shape().num_elements());
886 auto kernel_sizes_vec = kernel_sizes_tensor->flat<int32>();
887 std::copy_n(&kernel_sizes_vec(0), kernel_sizes.size(),
888 kernel_sizes.begin());
889
890 const Tensor* strides_tensor = c->input_tensor(c->num_inputs() - 1);
891 if (strides_tensor == nullptr) {
892 c->set_output(0, c->UnknownShape());
893 return Status::OK();
894 }
895 strides.resize(strides_tensor->shape().num_elements());
896 auto strides_vec = strides_tensor->flat<int32>();
897 std::copy_n(&strides_vec(0), strides.size(), strides.begin());
898 }
899
900 if (strides.size() != 4) {
901 return errors::InvalidArgument(
902 "MaxPool requires the stride attribute to contain 4 values, but "
903 "got: ",
904 strides.size());
905 }
906 if (kernel_sizes.size() != 4) {
907 return errors::InvalidArgument(
908 "MaxPool requires the ksize attribute to contain 4 values, but got: ",
909 kernel_sizes.size());
910 }
911
912 int32 stride_depth = GetTensorDim(strides, data_format, 'C');
913 int32 stride_rows = GetTensorDim(strides, data_format, 'H');
914 int32 stride_cols = GetTensorDim(strides, data_format, 'W');
915 int32 kernel_depth = GetTensorDim(kernel_sizes, data_format, 'C');
916 int32 kernel_rows = GetTensorDim(kernel_sizes, data_format, 'H');
917 int32 kernel_cols = GetTensorDim(kernel_sizes, data_format, 'W');
918
919 constexpr int num_spatial_dims = 2;
920 DimensionHandle batch_size_dim = c->Dim(
921 input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'N'));
922 DimensionHandle in_rows_dim = c->Dim(
923 input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'H'));
924 DimensionHandle in_cols_dim = c->Dim(
925 input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'W'));
926 DimensionHandle in_depth_dim = c->Dim(
927 input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'C'));
928
929 Padding padding;
930 TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
931
932 ShapeHandle output_shape;
933 DimensionHandle output_rows, output_cols, output_depth;
934 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
935 c, in_rows_dim, kernel_rows, stride_rows, padding, &output_rows));
936 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
937 c, in_cols_dim, kernel_cols, stride_cols, padding, &output_cols));
938 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
939 c, in_depth_dim, kernel_depth, stride_depth, padding, &output_depth));
940
941 TF_RETURN_IF_ERROR(MakeShapeFromFormat(data_format, batch_size_dim,
942 {output_rows, output_cols},
943 output_depth, &output_shape, c));
944
945 c->set_output(0, output_shape);
946 return Status::OK();
947 }
948
Pool3DShape(shape_inference::InferenceContext * c)949 Status Pool3DShape(shape_inference::InferenceContext* c) {
950 ShapeHandle input_shape;
951 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 5, &input_shape));
952
953 string data_format;
954 Status s = c->GetAttr("data_format", &data_format);
955
956 std::vector<int32> strides;
957 TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
958 if (strides.size() != 5) {
959 return errors::InvalidArgument(
960 "Pool3D ops require the stride attribute to contain 5 values, but "
961 "got: ",
962 strides.size());
963 }
964
965 std::vector<int32> kernel_sizes;
966 TF_RETURN_IF_ERROR(c->GetAttr("ksize", &kernel_sizes));
967 if (kernel_sizes.size() != 5) {
968 return errors::InvalidArgument(
969 "Pool3D requires the ksize attribute to contain 5 values, but got: ",
970 kernel_sizes.size());
971 }
972
973 int32 stride_planes, stride_rows, stride_cols;
974 int32 kernel_planes, kernel_rows, kernel_cols;
975
976 if (s.ok() && data_format == "NCDHW") {
977 // Convert input_shape to NDHWC.
978 auto dim = [&](char dimension) {
979 return c->Dim(input_shape, GetTensorDimIndex<3>(FORMAT_NCHW, dimension));
980 };
981 input_shape =
982 c->MakeShape({{dim('N'), dim('0'), dim('1'), dim('2'), dim('C')}});
983 stride_planes = strides[2];
984 stride_rows = strides[3];
985 stride_cols = strides[4];
986 kernel_planes = kernel_sizes[2];
987 kernel_rows = kernel_sizes[3];
988 kernel_cols = kernel_sizes[4];
989 } else {
990 stride_planes = strides[1];
991 stride_rows = strides[2];
992 stride_cols = strides[3];
993 kernel_planes = kernel_sizes[1];
994 kernel_rows = kernel_sizes[2];
995 kernel_cols = kernel_sizes[3];
996 }
997
998 DimensionHandle batch_size_dim = c->Dim(input_shape, 0);
999 DimensionHandle in_planes_dim = c->Dim(input_shape, 1);
1000 DimensionHandle in_rows_dim = c->Dim(input_shape, 2);
1001 DimensionHandle in_cols_dim = c->Dim(input_shape, 3);
1002 DimensionHandle output_depth_dim = c->Dim(input_shape, 4);
1003
1004 Padding padding;
1005 TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
1006
1007 // TODO(mrry,shlens): Raise an error if the stride would cause
1008 // information in the input to be ignored. This will require a change
1009 // in the kernel implementation.
1010 DimensionHandle output_planes, output_rows, output_cols;
1011 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
1012 c, in_planes_dim, kernel_planes, stride_planes, padding, &output_planes));
1013 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
1014 c, in_rows_dim, kernel_rows, stride_rows, padding, &output_rows));
1015 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
1016 c, in_cols_dim, kernel_cols, stride_cols, padding, &output_cols));
1017
1018 ShapeHandle output_shape;
1019 if (data_format == "NCDHW") {
1020 output_shape = c->MakeShape({batch_size_dim, output_depth_dim,
1021 output_planes, output_rows, output_cols});
1022 } else {
1023 output_shape = c->MakeShape({batch_size_dim, output_planes, output_rows,
1024 output_cols, output_depth_dim});
1025 }
1026
1027 c->set_output(0, output_shape);
1028 return Status::OK();
1029 }
1030
UnknownShape(shape_inference::InferenceContext * c)1031 Status UnknownShape(shape_inference::InferenceContext* c) {
1032 for (int i = 0; i < c->num_outputs(); ++i) {
1033 c->set_output(i, c->UnknownShape());
1034 }
1035 return Status::OK();
1036 }
1037
1038 template <typename T>
ReductionShapeHelper(const Tensor * reduction_indices_t,const int32 input_rank,std::set<int64> & true_indices)1039 Status ReductionShapeHelper(const Tensor* reduction_indices_t,
1040 const int32 input_rank,
1041 std::set<int64>& true_indices) {
1042 auto reduction_indices = reduction_indices_t->flat<T>();
1043 for (int i = 0; i < reduction_indices_t->NumElements(); ++i) {
1044 const T reduction_index = reduction_indices(i);
1045 if (reduction_index < -input_rank || reduction_index >= input_rank) {
1046 return errors::InvalidArgument("Invalid reduction dimension ",
1047 reduction_index, " for input with ",
1048 input_rank, " dimensions.");
1049 }
1050
1051 auto wrapped_index = reduction_index;
1052 if (wrapped_index < 0) {
1053 wrapped_index += input_rank;
1054 }
1055
1056 true_indices.insert(wrapped_index);
1057 }
1058 return Status::OK();
1059 }
1060
ReductionShape(InferenceContext * c)1061 Status ReductionShape(InferenceContext* c) {
1062 ShapeHandle input = c->input(0);
1063
1064 ShapeHandle indices;
1065 // Older versions of TensorFlow accidentally allowed higher rank tensors like
1066 // [[1,2]] or [[1],[2]] to represent axis=[1,2].
1067 if (c->graph_def_version() < 21) {
1068 indices = c->input(1);
1069 } else {
1070 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(1), 1, &indices));
1071 }
1072
1073 bool keep_dims;
1074 TF_RETURN_IF_ERROR(c->GetAttr("keep_dims", &keep_dims));
1075
1076 const Tensor* reduction_indices_t = c->input_tensor(1);
1077 if (reduction_indices_t == nullptr || !c->RankKnown(input)) {
1078 // If we do not have the reduction values at runtime, or the
1079 // rank of the input, we don't know the output shape.
1080
1081 if (keep_dims && c->RankKnown(input)) {
1082 // output rank matches input input if <keep_dims>.
1083 c->set_output(0, c->UnknownShapeOfRank(c->Rank(input)));
1084 return Status::OK();
1085 } else {
1086 return shape_inference::UnknownShape(c);
1087 }
1088 }
1089
1090 const int32 input_rank = c->Rank(input);
1091 std::set<int64> true_indices;
1092 if (reduction_indices_t->dtype() == DataType::DT_INT32) {
1093 TF_RETURN_IF_ERROR(ReductionShapeHelper<int32>(reduction_indices_t,
1094 input_rank, true_indices));
1095 } else if (reduction_indices_t->dtype() == DataType::DT_INT64) {
1096 TF_RETURN_IF_ERROR(ReductionShapeHelper<int64>(reduction_indices_t,
1097 input_rank, true_indices));
1098 } else {
1099 return errors::InvalidArgument(
1100 "reduction_indices can only be int32 or int64");
1101 }
1102
1103 std::vector<DimensionHandle> dims;
1104 for (int i = 0; i < input_rank; ++i) {
1105 if (true_indices.count(i) > 0) {
1106 if (keep_dims) {
1107 dims.emplace_back(c->MakeDim(1));
1108 }
1109 } else {
1110 dims.emplace_back(c->Dim(input, i));
1111 }
1112 }
1113
1114 c->set_output(0, c->MakeShape(dims));
1115 return Status::OK();
1116 }
1117
ConcatShapeHelper(InferenceContext * c,int start_value_index,int end_value_index,int dim_index)1118 Status ConcatShapeHelper(InferenceContext* c, int start_value_index,
1119 int end_value_index, int dim_index) {
1120 ShapeHandle unused;
1121 TF_RETURN_IF_ERROR(c->WithRank(c->input(dim_index), 0, &unused));
1122 const Tensor* concat_dim_t = c->input_tensor(dim_index);
1123 if (concat_dim_t == nullptr) {
1124 // Return an unknown shape with same rank as inputs, or an unknown rank
1125 // if no input's rank is known.
1126
1127 // Find rank.
1128 int32 rank = InferenceContext::kUnknownRank;
1129 for (int i = start_value_index; i < end_value_index; ++i) {
1130 if (rank == InferenceContext::kUnknownRank) rank = c->Rank(c->input(i));
1131 if (rank != InferenceContext::kUnknownRank) {
1132 break;
1133 }
1134 }
1135 if (rank == InferenceContext::kUnknownRank) {
1136 c->set_output(0, c->UnknownShape());
1137 return Status::OK();
1138 } else if (rank == 0) {
1139 return errors::InvalidArgument(
1140 "Can't concatenate scalars (use tf.stack instead)");
1141 } else {
1142 for (int i = start_value_index; i < end_value_index; ++i) {
1143 // Check that all the inputs are of the correct rank.
1144 TF_RETURN_IF_ERROR(c->WithRank(c->input(i), rank, &unused));
1145 }
1146 }
1147 // Build result of <rank> different unknown dims.
1148 std::vector<DimensionHandle> dims;
1149 dims.reserve(rank);
1150 for (int i = 0; i < rank; ++i) dims.push_back(c->UnknownDim());
1151 c->set_output(0, c->MakeShape(dims));
1152 return Status::OK();
1153 }
1154
1155 // Merge all the non-concat dims, and sum the concat dim to make an output
1156 // shape.
1157 const int32 concat_dim = concat_dim_t->scalar<int32>()();
1158
1159 // Minimum required number of dimensions.
1160 const int min_rank = concat_dim < 0 ? -concat_dim : concat_dim + 1;
1161
1162 ShapeHandle output_before;
1163 ShapeHandle output_after;
1164
1165 ShapeHandle input = c->input(end_value_index - 1);
1166 TF_RETURN_IF_ERROR(c->WithRankAtLeast(input, min_rank, &input));
1167 TF_RETURN_IF_ERROR(c->Subshape(input, 0, concat_dim, &output_before));
1168 DimensionHandle output_middle = c->Dim(input, concat_dim);
1169 if (concat_dim == -1) {
1170 output_after = c->Scalar(); // no dimensions.
1171 } else {
1172 TF_RETURN_IF_ERROR(c->Subshape(input, concat_dim + 1, &output_after));
1173 }
1174
1175 for (int i = end_value_index - 2; i >= start_value_index; --i) {
1176 ShapeHandle before;
1177 ShapeHandle after;
1178 input = c->input(i);
1179 TF_RETURN_IF_ERROR(c->WithRankAtLeast(input, min_rank, &input));
1180 TF_RETURN_IF_ERROR(c->Subshape(input, 0, concat_dim, &before));
1181 DimensionHandle middle = c->Dim(input, concat_dim);
1182 if (concat_dim == -1) {
1183 after = c->Scalar();
1184 } else {
1185 TF_RETURN_IF_ERROR(c->Subshape(input, concat_dim + 1, &after));
1186 }
1187
1188 TF_RETURN_IF_ERROR(c->Merge(before, output_before, &output_before));
1189 TF_RETURN_IF_ERROR(c->Add(output_middle, middle, &output_middle));
1190 TF_RETURN_IF_ERROR(c->Merge(after, output_after, &output_after));
1191 }
1192
1193 ShapeHandle s;
1194 TF_RETURN_IF_ERROR(
1195 c->Concatenate(output_before, c->Vector(output_middle), &s));
1196 TF_RETURN_IF_ERROR(c->Concatenate(s, output_after, &s));
1197 c->set_output(0, s);
1198 return Status::OK();
1199 }
1200
ConcatShape(InferenceContext * c,int num_inputs_to_concat)1201 Status ConcatShape(InferenceContext* c, int num_inputs_to_concat) {
1202 return ConcatShapeHelper(c, 1 /* start_value_index */,
1203 1 + num_inputs_to_concat /* end_value_index */,
1204 0 /* dim_index */);
1205 }
1206
ConcatV2Shape(InferenceContext * c)1207 Status ConcatV2Shape(InferenceContext* c) {
1208 return ConcatShapeHelper(c, 0 /* start_value_index */,
1209 c->num_inputs() - 1 /* end_value_index */,
1210 c->num_inputs() - 1 /* dim_index */);
1211 }
1212
BroadcastBinaryOpShapeFn(InferenceContext * c)1213 Status BroadcastBinaryOpShapeFn(InferenceContext* c) {
1214 ShapeHandle shape_x = c->input(0);
1215 ShapeHandle shape_y = c->input(1);
1216 if (!c->RankKnown(shape_x) || !c->RankKnown(shape_y)) {
1217 c->set_output(0, c->UnknownShape());
1218 return Status::OK();
1219 }
1220 const int32 rank_x = c->Rank(shape_x);
1221 const int32 rank_y = c->Rank(shape_y);
1222 const int32 rank_out = std::max(rank_x, rank_y);
1223
1224 // To compute the broadcast dimensions, we zip together shape_x and shape_y
1225 // and
1226 // pad with 1 to make them the same length.
1227 std::vector<DimensionHandle> dims;
1228 DimensionHandle dim_one;
1229 if (rank_x != rank_y) dim_one = c->MakeDim(1);
1230 for (int i = 0; i < rank_out; ++i) {
1231 const auto dim_x = i < (rank_out - rank_x)
1232 ? dim_one
1233 : c->Dim(shape_x, i - (rank_out - rank_x));
1234 const bool dim_y_is_one = (i < (rank_out - rank_y));
1235 const auto dim_y =
1236 dim_y_is_one ? dim_one : c->Dim(shape_y, i - (rank_out - rank_y));
1237 if (!c->ValueKnown(dim_x) || !c->ValueKnown(dim_y)) {
1238 // One or both dimensions is unknown.
1239 //
1240 // - If either dimension is greater than 1, we assume that the program is
1241 // correct, and the other dimension will be broadcast to match it.
1242 // TODO(cwhipkey): For shape inference, if we eliminate the shape checks
1243 // in C++ op code, we must still assert that the unknown dim is either 1
1244 // or the same as the known dim.
1245 // - If either dimension is 1, the other dimension is the output.
1246 if (c->Value(dim_x) > 1) {
1247 dims.push_back(dim_x);
1248 } else if (c->Value(dim_y) > 1) {
1249 dims.push_back(dim_y);
1250 } else if (c->Value(dim_x) == 1) {
1251 dims.push_back(dim_y);
1252 } else if (c->Value(dim_y) == 1) {
1253 dims.push_back(dim_x);
1254 } else if (dim_y.SameHandle(dim_x)) {
1255 dims.push_back(dim_x);
1256 } else {
1257 dims.push_back(c->UnknownDim());
1258 }
1259 } else if (c->Value(dim_x) == 1 || c->Value(dim_y) == 1) {
1260 if (c->Value(dim_x) == 1 && !dim_y_is_one) {
1261 // We will broadcast dim_x to dim_y.
1262 dims.push_back(dim_y);
1263 } else {
1264 DCHECK_EQ(c->Value(dim_y), 1);
1265 // We will broadcast dim_y to dim_x.
1266 dims.push_back(dim_x);
1267 }
1268 } else {
1269 DimensionHandle dim;
1270 TF_RETURN_IF_ERROR(c->Merge(dim_x, dim_y, &dim));
1271 dims.push_back(dim);
1272 }
1273 }
1274
1275 c->set_output(0, c->MakeShape(dims));
1276 return Status::OK();
1277 }
1278
RandomShape(shape_inference::InferenceContext * c)1279 Status RandomShape(shape_inference::InferenceContext* c) {
1280 shape_inference::ShapeHandle out;
1281 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out));
1282 c->set_output(0, out);
1283 return Status::OK();
1284 }
1285
ValidateSparseTensor(InferenceContext * c,ShapeHandle indices_shape,ShapeHandle values_shape,ShapeHandle shape_shape)1286 Status ValidateSparseTensor(InferenceContext* c, ShapeHandle indices_shape,
1287 ShapeHandle values_shape, ShapeHandle shape_shape) {
1288 // Validate ranks.
1289 ShapeHandle unused_shape;
1290 TF_RETURN_IF_ERROR(c->WithRank(indices_shape, 2, &unused_shape));
1291 TF_RETURN_IF_ERROR(c->WithRank(values_shape, 1, &unused_shape));
1292 TF_RETURN_IF_ERROR(c->WithRank(shape_shape, 1, &unused_shape));
1293
1294 // Number of elements in indices and values must match.
1295 DimensionHandle num_index_elements_dim = c->Dim(indices_shape, 0);
1296 if (c->ValueKnown(num_index_elements_dim)) {
1297 DimensionHandle num_values_elements_dim = c->Dim(values_shape, 0);
1298 if (c->ValueKnown(num_values_elements_dim)) {
1299 int64 num_index_elements = c->Value(num_index_elements_dim);
1300 int64 num_values_elements = c->Value(num_values_elements_dim);
1301 if (num_index_elements != num_values_elements) {
1302 return errors::InvalidArgument("Number of elements in index (",
1303 num_index_elements, ") and values (",
1304 num_values_elements, ") do not match.");
1305 }
1306 }
1307 }
1308
1309 // Rank embedded in indices must match shape.
1310 DimensionHandle index_rank_dim = c->Dim(indices_shape, 1);
1311 if (c->ValueKnown(index_rank_dim)) {
1312 DimensionHandle shape_rank_dim = c->Dim(shape_shape, 0);
1313 if (c->ValueKnown(shape_rank_dim)) {
1314 int64 index_rank = c->Value(index_rank_dim);
1315 int32 shape_rank = c->Value(shape_rank_dim);
1316 if (index_rank != shape_rank) {
1317 return errors::InvalidArgument("Index rank (", index_rank,
1318 ") and shape rank (", shape_rank,
1319 ") do not match.");
1320 }
1321 }
1322 }
1323
1324 return Status::OK();
1325 }
1326
ScatterNdUpdateShape(InferenceContext * c)1327 Status ScatterNdUpdateShape(InferenceContext* c) {
1328 ShapeHandle input_shape = c->input(0);
1329 if (c->input_handle_shapes_and_types(0) != nullptr) {
1330 input_shape = (*c->input_handle_shapes_and_types(0))[0].shape;
1331 }
1332 ShapeHandle indices_shape;
1333 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &indices_shape));
1334 ShapeHandle updates_shape;
1335 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(2), 1, &updates_shape));
1336
1337 if (c->Value(c->NumElements(input_shape)) == 0 &&
1338 (c->Value(c->NumElements(indices_shape)) > 0 ||
1339 c->Value(c->NumElements(updates_shape)) > 0)) {
1340 return errors::InvalidArgument(
1341 "Indices and updates specified for empty output shape");
1342 }
1343
1344 if (c->RankKnown(indices_shape) && c->RankKnown(updates_shape)) {
1345 const int64 num_outer_dims = c->Rank(indices_shape) - 1;
1346 const DimensionHandle index_size = c->Dim(indices_shape, -1);
1347
1348 // We can only do more validation if the last dimension of indices
1349 // is a known value.
1350 if (c->ValueKnown(index_size)) {
1351 const int64 ix = c->Value(index_size);
1352 ShapeHandle unused;
1353 ShapeHandle prefix_indices;
1354 TF_RETURN_IF_ERROR(
1355 c->Subshape(indices_shape, 0, num_outer_dims, &prefix_indices));
1356 ShapeHandle prefix_updates;
1357 TF_RETURN_IF_ERROR(
1358 c->Subshape(updates_shape, 0, num_outer_dims, &prefix_updates));
1359
1360 Status s = c->Merge(prefix_indices, prefix_updates, &unused);
1361 if (!s.ok()) {
1362 return errors::InvalidArgument(
1363 "The outer ", num_outer_dims,
1364 " dimensions of indices.shape=", c->DebugString(indices_shape),
1365 " must match the outer ", num_outer_dims,
1366 " dimensions of updates.shape=", c->DebugString(updates_shape),
1367 ": ", s.error_message());
1368 }
1369
1370 ShapeHandle input_suffix;
1371 TF_RETURN_IF_ERROR(c->Subshape(input_shape, ix, &input_suffix));
1372 ShapeHandle suffix_updates;
1373 TF_RETURN_IF_ERROR(
1374 c->Subshape(updates_shape, num_outer_dims, &suffix_updates));
1375 s = c->Merge(input_suffix, suffix_updates, &unused);
1376 if (!s.ok()) {
1377 return errors::InvalidArgument(
1378 "The inner ", c->Rank(input_shape) - ix,
1379 " dimensions of input.shape=", c->DebugString(input_shape),
1380 " must match the inner ", c->Rank(updates_shape) - num_outer_dims,
1381 " dimensions of updates.shape=", c->DebugString(updates_shape),
1382 ": ", s.error_message());
1383 }
1384 }
1385 }
1386
1387 if (c->input_handle_shapes_and_types(0) == nullptr) {
1388 c->set_output(0, input_shape);
1389 }
1390 return Status::OK();
1391 }
1392
ExplicitShape(InferenceContext * c)1393 Status ExplicitShape(InferenceContext* c) {
1394 PartialTensorShape shape;
1395 TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape));
1396 ShapeHandle output_shape;
1397 TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(shape, &output_shape));
1398 c->set_output(0, output_shape);
1399 return Status::OK();
1400 }
1401
1402 } // namespace shape_inference
1403
1404 } // namespace tensorflow
1405