1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/framework/common_shape_fns.h"
16 #include "tensorflow/core/framework/attr_value.pb.h"
17 #include "tensorflow/core/lib/core/errors.h"
18
19 namespace tensorflow {
20
GetWindowedOutputSizeVerboseV2(int64 input_size,int64 filter_size,int64 dilation_rate,int64 stride,Padding padding_type,int64 * output_size,int64 * padding_before,int64 * padding_after)21 Status GetWindowedOutputSizeVerboseV2(int64 input_size, int64 filter_size,
22 int64 dilation_rate, int64 stride,
23 Padding padding_type, int64* output_size,
24 int64* padding_before,
25 int64* padding_after) {
26 if (stride <= 0) {
27 return errors::InvalidArgument("Stride must be > 0, but got ", stride);
28 }
29 if (dilation_rate < 1) {
30 return errors::InvalidArgument("Dilation rate must be >= 1, but got ",
31 dilation_rate);
32 }
33
34 // See also the parallel implementation in GetWindowedOutputSizeFromDimsV2.
35 int64 effective_filter_size = (filter_size - 1) * dilation_rate + 1;
36 switch (padding_type) {
37 case Padding::VALID:
38 *output_size = (input_size - effective_filter_size + stride) / stride;
39 *padding_before = *padding_after = 0;
40 break;
41 case Padding::EXPLICIT:
42 *output_size = (input_size + *padding_before + *padding_after -
43 effective_filter_size + stride) /
44 stride;
45 break;
46 case Padding::SAME:
47 *output_size = (input_size + stride - 1) / stride;
48 const int64 padding_needed =
49 std::max(int64{0}, (*output_size - 1) * stride +
50 effective_filter_size - input_size);
51 // For odd values of total padding, add more padding at the 'right'
52 // side of the given dimension.
53 *padding_before = padding_needed / 2;
54 *padding_after = padding_needed - *padding_before;
55 break;
56 }
57 if (*output_size < 0) {
58 return errors::InvalidArgument(
59 "Computed output size would be negative: ", *output_size,
60 " [input_size: ", input_size,
61 ", effective_filter_size: ", effective_filter_size,
62 ", stride: ", stride, "]");
63 }
64 return Status::OK();
65 }
66
GetWindowedOutputSizeVerbose(int64 input_size,int64 filter_size,int64 stride,Padding padding_type,int64 * output_size,int64 * padding_before,int64 * padding_after)67 Status GetWindowedOutputSizeVerbose(int64 input_size, int64 filter_size,
68 int64 stride, Padding padding_type,
69 int64* output_size, int64* padding_before,
70 int64* padding_after) {
71 return GetWindowedOutputSizeVerboseV2(input_size, filter_size,
72 /*dilation_rate=*/1, stride,
73 padding_type, output_size,
74 padding_before, padding_after);
75 }
76
GetWindowedOutputSize(int64 input_size,int64 filter_size,int64 stride,Padding padding_type,int64 * output_size,int64 * padding_size)77 Status GetWindowedOutputSize(int64 input_size, int64 filter_size, int64 stride,
78 Padding padding_type, int64* output_size,
79 int64* padding_size) {
80 if (padding_type == Padding::EXPLICIT) {
81 return errors::Internal(
82 "GetWindowedOutputSize does not handle EXPLICIT padding; call "
83 "GetWindowedOutputSizeVerbose instead");
84 }
85 int64 padding_after_unused;
86 return GetWindowedOutputSizeVerbose(input_size, filter_size, stride,
87 padding_type, output_size, padding_size,
88 &padding_after_unused);
89 }
90
GetWindowedOutputSizeV2(int64 input_size,int64 filter_size,int64 dilation_rate,int64 stride,Padding padding_type,int64 * output_size,int64 * padding_size)91 Status GetWindowedOutputSizeV2(int64 input_size, int64 filter_size,
92 int64 dilation_rate, int64 stride,
93 Padding padding_type, int64* output_size,
94 int64* padding_size) {
95 if (padding_type == Padding::EXPLICIT) {
96 return errors::Internal(
97 "GetWindowedOutputSizeV2 does not handle EXPLICIT padding; call "
98 "GetWindowedOutputSizeVerboseV2 instead");
99 }
100 int64 padding_after_unused;
101 return GetWindowedOutputSizeVerboseV2(input_size, filter_size, dilation_rate,
102 stride, padding_type, output_size,
103 padding_size, &padding_after_unused);
104 }
105
Get3dOutputSize(const std::array<int64,3> & input,const std::array<int64,3> & window,const std::array<int64,3> & strides,Padding padding_type,std::array<int64,3> * output_ptr,std::array<int64,3> * padding_ptr)106 Status Get3dOutputSize(const std::array<int64, 3>& input,
107 const std::array<int64, 3>& window,
108 const std::array<int64, 3>& strides,
109 Padding padding_type, std::array<int64, 3>* output_ptr,
110 std::array<int64, 3>* padding_ptr) {
111 for (size_t i = 0; i < input.size(); ++i) {
112 TF_RETURN_IF_ERROR(GetWindowedOutputSize(input[i], window[i], strides[i],
113 padding_type, &(*output_ptr)[i],
114 &(*padding_ptr)[i]));
115 }
116 return Status::OK();
117 }
118
Get3dOutputSizeV2(const std::array<int64,3> & input,const std::array<int64,3> & window,const std::array<int64,3> & dilations,const std::array<int64,3> & strides,Padding padding_type,std::array<int64,3> * output_ptr,std::array<int64,3> * padding_ptr)119 Status Get3dOutputSizeV2(const std::array<int64, 3>& input,
120 const std::array<int64, 3>& window,
121 const std::array<int64, 3>& dilations,
122 const std::array<int64, 3>& strides,
123 Padding padding_type, std::array<int64, 3>* output_ptr,
124 std::array<int64, 3>* padding_ptr) {
125 for (size_t i = 0; i < input.size(); ++i) {
126 TF_RETURN_IF_ERROR(GetWindowedOutputSizeV2(
127 input[i], window[i], dilations[i], strides[i], padding_type,
128 &(*output_ptr)[i], &(*padding_ptr)[i]));
129 }
130 return Status::OK();
131 }
132
133 namespace shape_inference {
134
135 // The V2 version computes windowed output size with arbitrary dilation_rate,
136 // while the original version only handles the cases where dilation_rates equal
137 // to 1.
GetWindowedOutputSizeFromDimsV2(shape_inference::InferenceContext * c,shape_inference::DimensionHandle input_size,shape_inference::DimensionOrConstant filter_size,int64 dilation_rate,int64 stride,Padding padding_type,int64 padding_before,int64 padding_after,shape_inference::DimensionHandle * output_size)138 Status GetWindowedOutputSizeFromDimsV2(
139 shape_inference::InferenceContext* c,
140 shape_inference::DimensionHandle input_size,
141 shape_inference::DimensionOrConstant filter_size, int64 dilation_rate,
142 int64 stride, Padding padding_type, int64 padding_before,
143 int64 padding_after, shape_inference::DimensionHandle* output_size) {
144 if (stride <= 0) {
145 return errors::InvalidArgument("Stride must be > 0, but got ", stride);
146 }
147
148 if (dilation_rate < 1) {
149 return errors::InvalidArgument("Dilation rate must be >= 1, but got ",
150 dilation_rate);
151 }
152
153 // See also the parallel implementation in GetWindowedOutputSizeVerbose.
154 switch (padding_type) {
155 case Padding::VALID:
156 padding_before = padding_after = 0;
157 TF_FALLTHROUGH_INTENDED;
158 case Padding::EXPLICIT:
159 TF_RETURN_IF_ERROR(
160 c->Add(input_size, padding_before + padding_after, &input_size));
161 if (dilation_rate > 1) {
162 DimensionHandle window_size;
163 TF_RETURN_IF_ERROR(
164 c->Subtract(c->MakeDim(filter_size), 1, &window_size));
165 TF_RETURN_IF_ERROR(
166 c->Multiply(window_size, dilation_rate, &window_size));
167 TF_RETURN_IF_ERROR(c->Add(window_size, 1, &window_size));
168 TF_RETURN_IF_ERROR(c->Subtract(input_size, window_size, output_size));
169 } else {
170 TF_RETURN_IF_ERROR(c->Subtract(input_size, filter_size, output_size));
171 }
172 TF_RETURN_IF_ERROR(c->Add(*output_size, stride, output_size));
173 TF_RETURN_IF_ERROR(c->Divide(*output_size, stride,
174 /*evenly_divisible=*/false, output_size));
175 break;
176 case Padding::SAME:
177 TF_RETURN_IF_ERROR(c->Add(input_size, stride - 1, output_size));
178 TF_RETURN_IF_ERROR(c->Divide(*output_size, stride,
179 /*evenly_divisible=*/false, output_size));
180 break;
181 }
182 return Status::OK();
183 }
184
GetWindowedOutputSizeFromDims(shape_inference::InferenceContext * c,shape_inference::DimensionHandle input_size,shape_inference::DimensionOrConstant filter_size,int64 stride,Padding padding_type,shape_inference::DimensionHandle * output_size)185 Status GetWindowedOutputSizeFromDims(
186 shape_inference::InferenceContext* c,
187 shape_inference::DimensionHandle input_size,
188 shape_inference::DimensionOrConstant filter_size, int64 stride,
189 Padding padding_type, shape_inference::DimensionHandle* output_size) {
190 if (padding_type == Padding::EXPLICIT) {
191 return errors::Internal(
192 "GetWindowedOutputSizeFromDims does not handle EXPLICIT padding; call "
193 "GetWindowedOutputSizeFromDimsV2 instead");
194 }
195 return GetWindowedOutputSizeFromDimsV2(c, input_size, filter_size,
196 /*dilation_rate=*/1, stride,
197 padding_type,
198 // Give dummy values of -1 to
199 // padding_before and padding_after,
200 // since explicit padding is not used.
201 -1, -1, output_size);
202 }
203
UnchangedShape(shape_inference::InferenceContext * c)204 Status UnchangedShape(shape_inference::InferenceContext* c) {
205 c->set_output(0, c->input(0));
206 auto* handle_data = c->input_handle_shapes_and_types(0);
207 if (handle_data != nullptr) {
208 c->set_output_handle_shapes_and_types(0, *handle_data);
209 }
210 return Status::OK();
211 }
212
MatMulShape(shape_inference::InferenceContext * c)213 Status MatMulShape(shape_inference::InferenceContext* c) {
214 ShapeHandle a;
215 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &a));
216
217 ShapeHandle b;
218 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &b));
219
220 bool transpose_a, transpose_b;
221 TF_RETURN_IF_ERROR(c->GetAttr("transpose_a", &transpose_a));
222 TF_RETURN_IF_ERROR(c->GetAttr("transpose_b", &transpose_b));
223 DimensionHandle output_rows = transpose_a ? c->Dim(a, 1) : c->Dim(a, 0);
224 DimensionHandle output_cols = transpose_b ? c->Dim(b, 0) : c->Dim(b, 1);
225
226 // Validate that the inner shapes are compatible.
227 DimensionHandle inner_a = transpose_a ? c->Dim(a, 0) : c->Dim(a, 1);
228 DimensionHandle inner_b = transpose_b ? c->Dim(b, 1) : c->Dim(b, 0);
229 DimensionHandle merged;
230 TF_RETURN_IF_ERROR(c->Merge(inner_a, inner_b, &merged));
231
232 c->set_output(0, c->Matrix(output_rows, output_cols));
233 return Status::OK();
234 }
235
BiasAddShape(shape_inference::InferenceContext * c)236 Status BiasAddShape(shape_inference::InferenceContext* c) {
237 ShapeHandle input_shape;
238
239 // Fetch the data_format attribute, which may not exist.
240 string data_format;
241 Status s = c->GetAttr("data_format", &data_format);
242
243 if (s.ok() && data_format == "NCHW") {
244 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 3, &input_shape));
245 } else {
246 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input_shape));
247 }
248
249 ShapeHandle bias_shape;
250 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &bias_shape));
251 DimensionHandle bias_dim = c->Dim(bias_shape, 0);
252
253 // If rank unknown, return unknown shape.
254 if (!c->RankKnown(input_shape)) {
255 c->set_output(0, c->UnknownShape());
256 return Status::OK();
257 }
258
259 // Output has the same shape as the input, and matches the length of
260 // the bias in its bias dimension.
261 ShapeHandle output_shape;
262 if (s.ok() && data_format == "NCHW") {
263 // Merge the length of bias_shape into the third to last dimension
264 ShapeHandle first;
265 TF_RETURN_IF_ERROR(c->Subshape(input_shape, 0, 1, &first));
266
267 ShapeHandle last;
268 TF_RETURN_IF_ERROR(c->Subshape(input_shape, 2, &last));
269
270 DimensionHandle input_bias_dim = c->Dim(input_shape, 1);
271 DimensionHandle merged_bias_dim;
272 TF_RETURN_IF_ERROR(c->Merge(input_bias_dim, bias_dim, &merged_bias_dim));
273 ShapeHandle merged_bias = c->Vector(merged_bias_dim);
274
275 ShapeHandle temp;
276 TF_RETURN_IF_ERROR(c->Concatenate(first, merged_bias, &temp));
277 TF_RETURN_IF_ERROR(c->Concatenate(temp, last, &output_shape));
278 } else {
279 ShapeHandle all_but_bias;
280 TF_RETURN_IF_ERROR(c->Subshape(input_shape, 0, -1, &all_but_bias));
281
282 DimensionHandle input_bias_dim = c->Dim(input_shape, -1);
283 DimensionHandle merged_bias_dim;
284 TF_RETURN_IF_ERROR(c->Merge(input_bias_dim, bias_dim, &merged_bias_dim));
285
286 ShapeHandle merged_bias = c->Vector(merged_bias_dim);
287 TF_RETURN_IF_ERROR(
288 c->Concatenate(all_but_bias, merged_bias, &output_shape));
289 }
290
291 c->set_output(0, output_shape);
292 return Status::OK();
293 }
294
BiasAddGradShape(shape_inference::InferenceContext * c)295 Status BiasAddGradShape(shape_inference::InferenceContext* c) {
296 ShapeHandle input_shape;
297 // Fetch the data_format attribute, which may not exist.
298 string data_format;
299 Status s = c->GetAttr("data_format", &data_format);
300
301 if (s.ok() && data_format == "NCHW") {
302 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 3, &input_shape));
303 c->set_output(0, c->Vector(c->Dim(input_shape, 1)));
304 } else {
305 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input_shape));
306 c->set_output(0, c->Vector(c->Dim(input_shape, -1)));
307 }
308
309 return Status::OK();
310 }
311
CheckFormatConstraintsOnShape(const TensorFormat tensor_format,const ShapeHandle shape_handle,const string & tensor_name,shape_inference::InferenceContext * c)312 Status CheckFormatConstraintsOnShape(const TensorFormat tensor_format,
313 const ShapeHandle shape_handle,
314 const string& tensor_name,
315 shape_inference::InferenceContext* c) {
316 if (tensor_format == FORMAT_NCHW_VECT_C) {
317 // Check that the vect dim has size 4.
318 const int num_dims = c->Rank(shape_handle);
319 DimensionHandle vect_dim = c->Dim(
320 shape_handle, GetTensorInnerFeatureDimIndex(num_dims, tensor_format));
321 DimensionHandle unused_vect_dim;
322 TF_RETURN_IF_ERROR(c->WithValue(vect_dim, 4, &unused_vect_dim));
323 }
324
325 return Status::OK();
326 }
327
MakeShapeFromFormat(TensorFormat format,DimensionOrConstant N,const std::vector<DimensionOrConstant> & spatial,DimensionOrConstant C,ShapeHandle * out,shape_inference::InferenceContext * context)328 Status MakeShapeFromFormat(TensorFormat format, DimensionOrConstant N,
329 const std::vector<DimensionOrConstant>& spatial,
330 DimensionOrConstant C, ShapeHandle* out,
331 shape_inference::InferenceContext* context) {
332 const int num_dims = GetTensorDimsFromSpatialDims(spatial.size(), format);
333 std::vector<DimensionHandle> dims_actual(num_dims);
334 dims_actual[GetTensorBatchDimIndex(num_dims, format)] = context->MakeDim(N);
335 int outer_c_index = GetTensorFeatureDimIndex(num_dims, format);
336 dims_actual[outer_c_index] = context->MakeDim(C);
337 if (format == FORMAT_NCHW_VECT_C) {
338 dims_actual[GetTensorInnerFeatureDimIndex(num_dims, format)] =
339 context->MakeDim(4);
340 } else if (format == FORMAT_NHWC_VECT_W) {
341 dims_actual[GetTensorInnerWidthDimIndex(num_dims, format)] =
342 context->MakeDim(4);
343 }
344 for (int spatial_dim = 0; spatial_dim < spatial.size(); spatial_dim++) {
345 dims_actual[GetTensorSpatialDimIndex(num_dims, format, spatial_dim)] =
346 context->MakeDim(spatial[spatial_dim]);
347 }
348 *out = context->MakeShape(dims_actual);
349 return Status::OK();
350 }
351
DimensionsFromShape(ShapeHandle shape,TensorFormat format,DimensionHandle * batch_dim,gtl::MutableArraySlice<DimensionHandle> spatial_dims,DimensionHandle * filter_dim,InferenceContext * context)352 Status DimensionsFromShape(ShapeHandle shape, TensorFormat format,
353 DimensionHandle* batch_dim,
354 gtl::MutableArraySlice<DimensionHandle> spatial_dims,
355 DimensionHandle* filter_dim,
356 InferenceContext* context) {
357 const int32 rank = GetTensorDimsFromSpatialDims(spatial_dims.size(), format);
358 // Batch.
359 *batch_dim = context->Dim(shape, GetTensorBatchDimIndex(rank, format));
360 // Spatial.
361 for (int spatial_dim_index = 0; spatial_dim_index < spatial_dims.size();
362 ++spatial_dim_index) {
363 spatial_dims[spatial_dim_index] = context->Dim(
364 shape, GetTensorSpatialDimIndex(rank, format, spatial_dim_index));
365 }
366 // Channel.
367 *filter_dim = context->Dim(shape, GetTensorFeatureDimIndex(rank, format));
368 if (format == FORMAT_NCHW_VECT_C) {
369 TF_RETURN_IF_ERROR(context->Multiply(
370 *filter_dim,
371 context->Dim(shape, GetTensorInnerFeatureDimIndex(rank, format)),
372 filter_dim));
373 }
374 return Status::OK();
375 }
376
ShapeFromDimensions(DimensionHandle batch_dim,gtl::ArraySlice<DimensionHandle> spatial_dims,DimensionHandle filter_dim,TensorFormat format,InferenceContext * context,ShapeHandle * shape)377 Status ShapeFromDimensions(DimensionHandle batch_dim,
378 gtl::ArraySlice<DimensionHandle> spatial_dims,
379 DimensionHandle filter_dim, TensorFormat format,
380 InferenceContext* context, ShapeHandle* shape) {
381 const int32 rank = GetTensorDimsFromSpatialDims(spatial_dims.size(), format);
382 std::vector<DimensionHandle> out_dims(rank);
383
384 // Batch.
385 out_dims[tensorflow::GetTensorBatchDimIndex(rank, format)] = batch_dim;
386 // Spatial.
387 for (int spatial_dim_index = 0; spatial_dim_index < spatial_dims.size();
388 ++spatial_dim_index) {
389 out_dims[tensorflow::GetTensorSpatialDimIndex(
390 rank, format, spatial_dim_index)] = spatial_dims[spatial_dim_index];
391 }
392 // Channel.
393 if (format == tensorflow::FORMAT_NCHW_VECT_C) {
394 // When format is NCHW_VECT_C, factor the feature map count
395 // into the outer feature count and the inner feature count (=4).
396 TF_RETURN_IF_ERROR(context->Divide(
397 filter_dim, 4, /*evenly_divisible=*/true,
398 &out_dims[tensorflow::GetTensorFeatureDimIndex(rank, format)]));
399 out_dims[GetTensorInnerFeatureDimIndex(rank, format)] = context->MakeDim(4);
400 } else {
401 out_dims[tensorflow::GetTensorFeatureDimIndex(rank, format)] = filter_dim;
402 }
403
404 *shape = context->MakeShape(out_dims);
405 return tensorflow::Status::OK();
406 }
407
408 namespace {
409
Conv2DShapeImpl(shape_inference::InferenceContext * c,bool supports_explicit_padding)410 Status Conv2DShapeImpl(shape_inference::InferenceContext* c,
411 bool supports_explicit_padding) {
412 string data_format_str, filter_format_str;
413 if (!c->GetAttr("data_format", &data_format_str).ok()) {
414 data_format_str = "NHWC";
415 }
416 if (!c->GetAttr("filter_format", &filter_format_str).ok()) {
417 filter_format_str = "HWIO";
418 }
419
420 TensorFormat data_format;
421 if (!FormatFromString(data_format_str, &data_format)) {
422 return errors::InvalidArgument("Invalid data format string: ",
423 data_format_str);
424 }
425 FilterTensorFormat filter_format;
426 if (!FilterFormatFromString(filter_format_str, &filter_format)) {
427 return errors::InvalidArgument("Invalid filter format string: ",
428 filter_format_str);
429 }
430
431 constexpr int num_spatial_dims = 2;
432 const int rank = GetTensorDimsFromSpatialDims(num_spatial_dims, data_format);
433 ShapeHandle conv_input_shape;
434 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &conv_input_shape));
435 TF_RETURN_IF_ERROR(CheckFormatConstraintsOnShape(
436 data_format, conv_input_shape, "conv_input", c));
437
438 // The filter rank should match the input (4 for NCHW, 5 for NCHW_VECT_C).
439 ShapeHandle filter_shape;
440 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), rank, &filter_shape));
441 TF_RETURN_IF_ERROR(
442 CheckFormatConstraintsOnShape(data_format, filter_shape, "filter", c));
443
444 std::vector<int32> dilations;
445 TF_RETURN_IF_ERROR(c->GetAttr("dilations", &dilations));
446
447 if (dilations.size() != 4) {
448 return errors::InvalidArgument(
449 "Conv2D requires the dilation attribute to contain 4 values, but got: ",
450 dilations.size());
451 }
452
453 std::vector<int32> strides;
454 TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
455
456 // strides.size() should be 4 (NCHW) even if the input is 5 (NCHW_VECT_C).
457 if (strides.size() != 4) {
458 return errors::InvalidArgument("Conv2D on data format ", data_format_str,
459 " requires the stride attribute to contain"
460 " 4 values, but got: ",
461 strides.size());
462 }
463
464 const int32 stride_rows = GetTensorDim(strides, data_format, 'H');
465 const int32 stride_cols = GetTensorDim(strides, data_format, 'W');
466 const int32 dilation_rows = GetTensorDim(dilations, data_format, 'H');
467 const int32 dilation_cols = GetTensorDim(dilations, data_format, 'W');
468
469 DimensionHandle batch_size_dim;
470 DimensionHandle input_depth_dim;
471 gtl::InlinedVector<DimensionHandle, 2> input_spatial_dims(2);
472 TF_RETURN_IF_ERROR(DimensionsFromShape(
473 conv_input_shape, data_format, &batch_size_dim,
474 absl::MakeSpan(input_spatial_dims), &input_depth_dim, c));
475
476 DimensionHandle output_depth_dim = c->Dim(
477 filter_shape, GetFilterDimIndex<num_spatial_dims>(filter_format, 'O'));
478 DimensionHandle filter_rows_dim = c->Dim(
479 filter_shape, GetFilterDimIndex<num_spatial_dims>(filter_format, 'H'));
480 DimensionHandle filter_cols_dim = c->Dim(
481 filter_shape, GetFilterDimIndex<num_spatial_dims>(filter_format, 'W'));
482 DimensionHandle filter_input_depth_dim;
483 if (filter_format == FORMAT_OIHW_VECT_I) {
484 TF_RETURN_IF_ERROR(c->Multiply(
485 c->Dim(filter_shape,
486 GetFilterDimIndex<num_spatial_dims>(filter_format, 'I')),
487 c->Dim(filter_shape,
488 GetFilterTensorInnerInputChannelsDimIndex(rank, filter_format)),
489 &filter_input_depth_dim));
490 } else {
491 filter_input_depth_dim = c->Dim(
492 filter_shape, GetFilterDimIndex<num_spatial_dims>(filter_format, 'I'));
493 }
494
495 // Check that the input tensor and the filter tensor agree on the input
496 // channel count.
497 DimensionHandle unused;
498 TF_RETURN_IF_ERROR(
499 c->Merge(input_depth_dim, filter_input_depth_dim, &unused));
500
501 Padding padding;
502 TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
503
504 std::vector<int64> explicit_paddings;
505 if (supports_explicit_padding) {
506 Status s = c->GetAttr("explicit_paddings", &explicit_paddings);
507 // Use the default value, which is an empty list, if the attribute is not
508 // found. Otherwise return the error to the caller.
509 if (!s.ok() && !errors::IsNotFound(s)) {
510 return s;
511 }
512 TF_RETURN_IF_ERROR(CheckValidPadding(padding, explicit_paddings,
513 /*num_dims=*/4, data_format));
514 } else {
515 DCHECK(padding != Padding::EXPLICIT);
516 }
517
518 DimensionHandle output_rows, output_cols;
519 int64 pad_rows_before = -1, pad_rows_after = -1;
520 int64 pad_cols_before = -1, pad_cols_after = -1;
521 if (padding == Padding::EXPLICIT) {
522 GetExplicitPaddingForDim(explicit_paddings, data_format, 'H',
523 &pad_rows_before, &pad_rows_after);
524 GetExplicitPaddingForDim(explicit_paddings, data_format, 'W',
525 &pad_cols_before, &pad_cols_after);
526 }
527 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDimsV2(
528 c, input_spatial_dims[0], filter_rows_dim, dilation_rows, stride_rows,
529 padding, pad_rows_before, pad_rows_after, &output_rows));
530 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDimsV2(
531 c, input_spatial_dims[1], filter_cols_dim, dilation_cols, stride_cols,
532 padding, pad_cols_before, pad_cols_after, &output_cols));
533
534 ShapeHandle output_shape;
535 TF_RETURN_IF_ERROR(
536 ShapeFromDimensions(batch_size_dim, {output_rows, output_cols},
537 output_depth_dim, data_format, c, &output_shape));
538 c->set_output(0, output_shape);
539 return Status::OK();
540 }
541
542 } // namespace
543
544 // Shape function for Conv2D-like operations that support explicit padding.
Conv2DShapeWithExplicitPadding(shape_inference::InferenceContext * c)545 Status Conv2DShapeWithExplicitPadding(shape_inference::InferenceContext* c) {
546 return Conv2DShapeImpl(c, true);
547 }
548
549 // Shape function for Conv2D-like operations that do not support explicit
550 // padding.
Conv2DShape(shape_inference::InferenceContext * c)551 Status Conv2DShape(shape_inference::InferenceContext* c) {
552 return Conv2DShapeImpl(c, false);
553 }
554
555 // TODO(mjanusz): Unify all conv/pooling shape functions.
Conv3DShape(shape_inference::InferenceContext * c)556 Status Conv3DShape(shape_inference::InferenceContext* c) {
557 ShapeHandle input_shape;
558 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 5, &input_shape));
559 ShapeHandle filter_shape;
560 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 5, &filter_shape));
561
562 string data_format;
563 Status s = c->GetAttr("data_format", &data_format);
564
565 std::vector<int32> dilations;
566 TF_RETURN_IF_ERROR(c->GetAttr("dilations", &dilations));
567
568 if (dilations.size() != 5) {
569 return errors::InvalidArgument(
570 "Conv3D requires the dilation attribute to contain 5 values, but got: ",
571 dilations.size());
572 }
573
574 std::vector<int32> strides;
575 TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
576 if (strides.size() != 5) {
577 return errors::InvalidArgument(
578 "Conv3D requires the stride attribute to contain 5 values, but got: ",
579 strides.size());
580 }
581
582 int32 stride_planes, stride_rows, stride_cols;
583 int32 dilation_planes, dilation_rows, dilation_cols;
584 if (s.ok() && data_format == "NCDHW") {
585 // Convert input_shape to NDHWC.
586 auto dim = [&](char dimension) {
587 return c->Dim(input_shape, GetTensorDimIndex<3>(FORMAT_NCHW, dimension));
588 };
589 input_shape =
590 c->MakeShape({{dim('N'), dim('0'), dim('1'), dim('2'), dim('C')}});
591 stride_planes = strides[2];
592 stride_rows = strides[3];
593 stride_cols = strides[4];
594 dilation_planes = dilations[2];
595 dilation_cols = dilations[3];
596 dilation_rows = dilations[4];
597 } else {
598 stride_planes = strides[1];
599 stride_rows = strides[2];
600 stride_cols = strides[3];
601 dilation_planes = dilations[1];
602 dilation_cols = dilations[2];
603 dilation_rows = dilations[3];
604 }
605
606 DimensionHandle batch_size_dim = c->Dim(input_shape, 0);
607 DimensionHandle in_planes_dim = c->Dim(input_shape, 1);
608 DimensionHandle in_rows_dim = c->Dim(input_shape, 2);
609 DimensionHandle in_cols_dim = c->Dim(input_shape, 3);
610
611 DimensionHandle filter_planes_dim = c->Dim(filter_shape, 0);
612 DimensionHandle filter_rows_dim = c->Dim(filter_shape, 1);
613 DimensionHandle filter_cols_dim = c->Dim(filter_shape, 2);
614 DimensionHandle output_depth_dim = c->Dim(filter_shape, 4);
615
616 DimensionHandle unused;
617 TF_RETURN_IF_ERROR(
618 c->Merge(c->Dim(input_shape, 4), c->Dim(filter_shape, 3), &unused));
619
620 Padding padding;
621 TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
622 DimensionHandle output_planes, output_rows, output_cols;
623
624 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDimsV2(
625 c, in_planes_dim, filter_planes_dim, dilation_planes, stride_planes,
626 padding, -1, -1, &output_planes));
627 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDimsV2(
628 c, in_rows_dim, filter_rows_dim, dilation_rows, stride_rows, padding, -1,
629 -1, &output_rows));
630 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDimsV2(
631 c, in_cols_dim, filter_cols_dim, dilation_cols, stride_cols, padding, -1,
632 -1, &output_cols));
633
634 ShapeHandle output_shape;
635 if (data_format == "NCDHW") {
636 output_shape = c->MakeShape({batch_size_dim, output_depth_dim,
637 output_planes, output_rows, output_cols});
638 } else {
639 output_shape = c->MakeShape({batch_size_dim, output_planes, output_rows,
640 output_cols, output_depth_dim});
641 }
642 c->set_output(0, output_shape);
643 return Status::OK();
644 }
645
DepthwiseConv2DNativeShape(shape_inference::InferenceContext * c)646 Status DepthwiseConv2DNativeShape(shape_inference::InferenceContext* c) {
647 ShapeHandle input_shape;
648 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_shape));
649 ShapeHandle filter_shape;
650 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 4, &filter_shape));
651
652 std::vector<int32> strides;
653 TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
654
655 if (strides.size() != 4) {
656 return errors::InvalidArgument(
657 "DepthwiseConv2D requires the stride attribute to contain 4 values, "
658 "but got: ",
659 strides.size());
660 }
661
662 string data_format;
663 Status s = c->GetAttr("data_format", &data_format);
664 int32 stride_rows;
665 int32 stride_cols;
666 if (s.ok() && data_format == "NCHW") {
667 // Canonicalize input shape to NHWC so the shape inference code below can
668 // process it.
669 input_shape =
670 c->MakeShape({{c->Dim(input_shape, 0), c->Dim(input_shape, 2),
671 c->Dim(input_shape, 3), c->Dim(input_shape, 1)}});
672 stride_rows = strides[2];
673 stride_cols = strides[3];
674 } else {
675 stride_rows = strides[1];
676 stride_cols = strides[2];
677 }
678
679 DimensionHandle batch_size_dim = c->Dim(input_shape, 0);
680 DimensionHandle in_rows_dim = c->Dim(input_shape, 1);
681 DimensionHandle in_cols_dim = c->Dim(input_shape, 2);
682
683 DimensionHandle filter_rows_dim = c->Dim(filter_shape, 0);
684 DimensionHandle filter_cols_dim = c->Dim(filter_shape, 1);
685 DimensionHandle input_depth = c->Dim(filter_shape, 2);
686 DimensionHandle depth_multiplier = c->Dim(filter_shape, 3);
687
688 // Check that the input depths are compatible.
689 TF_RETURN_IF_ERROR(
690 c->Merge(c->Dim(input_shape, 3), input_depth, &input_depth));
691
692 DimensionHandle output_depth;
693 TF_RETURN_IF_ERROR(c->Multiply(input_depth, depth_multiplier, &output_depth));
694
695 Padding padding;
696 TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
697
698 // TODO(mrry,shlens): Raise an error if the stride would cause
699 // information in the input to be ignored. This will require a change
700 // in the kernel implementation.
701 DimensionHandle output_rows, output_cols;
702
703 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
704 c, in_rows_dim, filter_rows_dim, stride_rows, padding, &output_rows));
705 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
706 c, in_cols_dim, filter_cols_dim, stride_cols, padding, &output_cols));
707
708 ShapeHandle output_shape;
709 if (data_format == "NCHW") {
710 output_shape =
711 c->MakeShape({batch_size_dim, output_depth, output_rows, output_cols});
712 } else {
713 output_shape =
714 c->MakeShape({batch_size_dim, output_rows, output_cols, output_depth});
715 }
716 c->set_output(0, output_shape);
717 return Status::OK();
718 }
719
AvgPoolShape(shape_inference::InferenceContext * c)720 Status AvgPoolShape(shape_inference::InferenceContext* c) {
721 string data_format_str;
722 TensorFormat data_format;
723 Status s = c->GetAttr("data_format", &data_format_str);
724 if (s.ok()) {
725 FormatFromString(data_format_str, &data_format);
726 } else {
727 data_format = FORMAT_NHWC;
728 }
729
730 const int rank = (data_format == FORMAT_NCHW_VECT_C) ? 5 : 4;
731 ShapeHandle input_shape;
732 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &input_shape));
733
734 TF_RETURN_IF_ERROR(
735 CheckFormatConstraintsOnShape(data_format, input_shape, "input", c));
736
737 std::vector<int32> strides;
738 TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
739 if (strides.size() != 4) {
740 return errors::InvalidArgument(
741 "AvgPool requires the stride attribute to contain 4 values, but got: ",
742 strides.size());
743 }
744
745 std::vector<int32> kernel_sizes;
746 TF_RETURN_IF_ERROR(c->GetAttr("ksize", &kernel_sizes));
747 if (kernel_sizes.size() != 4) {
748 return errors::InvalidArgument(
749 "AvgPool requires the ksize attribute to contain 4 values, but got: ",
750 kernel_sizes.size());
751 }
752
753 int32 stride_rows = GetTensorDim(strides, data_format, 'H');
754 int32 stride_cols = GetTensorDim(strides, data_format, 'W');
755 int32 kernel_rows = GetTensorDim(kernel_sizes, data_format, 'H');
756 int32 kernel_cols = GetTensorDim(kernel_sizes, data_format, 'W');
757
758 constexpr int num_spatial_dims = 2;
759 DimensionHandle batch_size_dim = c->Dim(
760 input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'N'));
761 DimensionHandle in_rows_dim = c->Dim(
762 input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'H'));
763 DimensionHandle in_cols_dim = c->Dim(
764 input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'W'));
765 DimensionHandle depth_dim = c->Dim(
766 input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'C'));
767
768 Padding padding;
769 TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
770
771 // TODO(mrry,shlens): Raise an error if the stride would cause
772 // information in the input to be ignored. This will require a change
773 // in the kernel implementation.
774
775 DimensionHandle output_rows, output_cols;
776 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
777 c, in_rows_dim, kernel_rows, stride_rows, padding, &output_rows));
778 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
779 c, in_cols_dim, kernel_cols, stride_cols, padding, &output_cols));
780
781 ShapeHandle output_shape;
782 TF_RETURN_IF_ERROR(MakeShapeFromFormat(data_format, batch_size_dim,
783 {output_rows, output_cols}, depth_dim,
784 &output_shape, c));
785 c->set_output(0, output_shape);
786 return Status::OK();
787 }
788
FusedBatchNormShape(shape_inference::InferenceContext * c)789 Status FusedBatchNormShape(shape_inference::InferenceContext* c) {
790 ShapeHandle x;
791 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &x));
792
793 bool is_training;
794 TF_RETURN_IF_ERROR(c->GetAttr("is_training", &is_training));
795 int number_inputs = (is_training) ? 3 : 5;
796 string data_format_str;
797 TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format_str));
798 TensorFormat data_format;
799 if (!FormatFromString(data_format_str, &data_format)) {
800 return errors::InvalidArgument("Invalid data format string: ",
801 data_format_str);
802 }
803 int channel_dim_index = GetTensorFeatureDimIndex(4, data_format);
804 DimensionHandle channel_dim = c->Dim(x, channel_dim_index);
805
806 // covers scale, offset, and if is_training is false, mean, variance
807 for (int i = 1; i < number_inputs; ++i) {
808 ShapeHandle vec;
809 TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &vec));
810 TF_RETURN_IF_ERROR(c->Merge(channel_dim, c->Dim(vec, 0), &channel_dim));
811 }
812
813 ShapeHandle y;
814 TF_RETURN_IF_ERROR(c->ReplaceDim(x, channel_dim_index, channel_dim, &y));
815 c->set_output(0, y);
816 ShapeHandle vector_shape = c->Vector(channel_dim);
817 c->set_output(1, vector_shape);
818 c->set_output(2, vector_shape);
819 c->set_output(3, vector_shape);
820 c->set_output(4, vector_shape);
821 return Status::OK();
822 }
823
FusedBatchNormGradShape(shape_inference::InferenceContext * c)824 Status FusedBatchNormGradShape(shape_inference::InferenceContext* c) {
825 ShapeHandle y_backprop;
826 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &y_backprop));
827 ShapeHandle x;
828 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 4, &x));
829
830 bool is_training;
831 TF_RETURN_IF_ERROR(c->GetAttr("is_training", &is_training));
832 string data_format_str;
833 TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format_str));
834 TensorFormat data_format;
835 if (!FormatFromString(data_format_str, &data_format)) {
836 return errors::InvalidArgument("Invalid data format string: ",
837 data_format_str);
838 }
839 int channel_dim_index = GetTensorFeatureDimIndex(4, data_format);
840 DimensionHandle channel_dim = c->Dim(y_backprop, channel_dim_index);
841 TF_RETURN_IF_ERROR(
842 c->Merge(channel_dim, c->Dim(x, channel_dim_index), &channel_dim));
843
844 // covers scale, mean (reserve_space_1), variance (reserve_space_2)
845 for (int i = 2; i < 5; ++i) {
846 ShapeHandle vec;
847 TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &vec));
848 TF_RETURN_IF_ERROR(c->Merge(channel_dim, c->Dim(vec, 0), &channel_dim));
849 }
850
851 ShapeHandle x_backprop;
852 TF_RETURN_IF_ERROR(
853 c->ReplaceDim(y_backprop, channel_dim_index, channel_dim, &x_backprop));
854 c->set_output(0, x_backprop);
855 c->set_output(1, c->Vector(channel_dim));
856 c->set_output(2, c->Vector(channel_dim));
857 // Set the correct shapes for reserve_spaces
858 // so that gradients can be performed when
859 // the op is in a symbolic condition.
860 if (is_training) {
861 c->set_output(3, c->Vector(0));
862 c->set_output(4, c->Vector(0));
863 } else {
864 c->set_output(3, c->Vector(channel_dim));
865 c->set_output(4, c->Vector(channel_dim));
866 }
867 return Status::OK();
868 }
869
MaxPoolShape(shape_inference::InferenceContext * c)870 Status MaxPoolShape(shape_inference::InferenceContext* c) {
871 string data_format_str;
872 TensorFormat data_format;
873 Status s = c->GetAttr("data_format", &data_format_str);
874 if (s.ok()) {
875 FormatFromString(data_format_str, &data_format);
876 } else {
877 data_format = FORMAT_NHWC;
878 }
879
880 const int rank = (data_format == FORMAT_NCHW_VECT_C) ? 5 : 4;
881 ShapeHandle input_shape;
882 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &input_shape));
883
884 TF_RETURN_IF_ERROR(
885 CheckFormatConstraintsOnShape(data_format, input_shape, "input", c));
886
887 std::vector<int32> strides;
888 TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
889 if (strides.size() != 4) {
890 return errors::InvalidArgument(
891 "MaxPool requires the stride attribute to contain 4 values, but got: ",
892 strides.size());
893 }
894
895 std::vector<int32> kernel_sizes;
896 TF_RETURN_IF_ERROR(c->GetAttr("ksize", &kernel_sizes));
897 if (kernel_sizes.size() != 4) {
898 return errors::InvalidArgument(
899 "MaxPool requires the ksize attribute to contain 4 values, but got: ",
900 kernel_sizes.size());
901 }
902
903 int32 stride_depth = GetTensorDim(strides, data_format, 'C');
904 int32 stride_rows = GetTensorDim(strides, data_format, 'H');
905 int32 stride_cols = GetTensorDim(strides, data_format, 'W');
906 int32 kernel_depth = GetTensorDim(kernel_sizes, data_format, 'C');
907 int32 kernel_rows = GetTensorDim(kernel_sizes, data_format, 'H');
908 int32 kernel_cols = GetTensorDim(kernel_sizes, data_format, 'W');
909
910 constexpr int num_spatial_dims = 2;
911 DimensionHandle batch_size_dim = c->Dim(
912 input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'N'));
913 DimensionHandle in_rows_dim = c->Dim(
914 input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'H'));
915 DimensionHandle in_cols_dim = c->Dim(
916 input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'W'));
917 DimensionHandle in_depth_dim = c->Dim(
918 input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'C'));
919
920 Padding padding;
921 TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
922
923 ShapeHandle output_shape;
924 DimensionHandle output_rows, output_cols, output_depth;
925 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
926 c, in_rows_dim, kernel_rows, stride_rows, padding, &output_rows));
927 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
928 c, in_cols_dim, kernel_cols, stride_cols, padding, &output_cols));
929 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
930 c, in_depth_dim, kernel_depth, stride_depth, padding, &output_depth));
931
932 TF_RETURN_IF_ERROR(MakeShapeFromFormat(data_format, batch_size_dim,
933 {output_rows, output_cols},
934 output_depth, &output_shape, c));
935
936 c->set_output(0, output_shape);
937 return Status::OK();
938 }
939
MaxPoolV2Shape(shape_inference::InferenceContext * c,int num_inputs)940 Status MaxPoolV2Shape(shape_inference::InferenceContext* c, int num_inputs) {
941 string data_format_str;
942 TensorFormat data_format;
943 Status s = c->GetAttr("data_format", &data_format_str);
944 if (s.ok()) {
945 FormatFromString(data_format_str, &data_format);
946 } else {
947 data_format = FORMAT_NHWC;
948 }
949
950 const int rank = (data_format == FORMAT_NCHW_VECT_C) ? 5 : 4;
951 ShapeHandle input_shape;
952 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &input_shape));
953
954 TF_RETURN_IF_ERROR(
955 CheckFormatConstraintsOnShape(data_format, input_shape, "input", c));
956
957 std::vector<int32> kernel_sizes;
958 std::vector<int32> strides;
959
960 if (c->num_inputs() + 2 == num_inputs) {
961 TF_RETURN_IF_ERROR(c->GetAttr("ksize", &kernel_sizes));
962
963 TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
964 } else {
965 // Verify shape of ksize and strides input.
966 ShapeHandle size;
967 DimensionHandle unused;
968 TF_RETURN_IF_ERROR(c->WithRank(c->input(c->num_inputs() - 2), 1, &size));
969 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(size, 0), 4, &unused));
970 TF_RETURN_IF_ERROR(c->WithRank(c->input(c->num_inputs() - 1), 1, &size));
971 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(size, 0), 4, &unused));
972
973 const Tensor* kernel_sizes_tensor = c->input_tensor(c->num_inputs() - 2);
974 if (kernel_sizes_tensor == nullptr) {
975 c->set_output(0, c->UnknownShape());
976 return Status::OK();
977 }
978 kernel_sizes.resize(kernel_sizes_tensor->shape().num_elements());
979 auto kernel_sizes_vec = kernel_sizes_tensor->flat<int32>();
980 std::copy_n(&kernel_sizes_vec(0), kernel_sizes.size(),
981 kernel_sizes.begin());
982
983 const Tensor* strides_tensor = c->input_tensor(c->num_inputs() - 1);
984 if (strides_tensor == nullptr) {
985 c->set_output(0, c->UnknownShape());
986 return Status::OK();
987 }
988 strides.resize(strides_tensor->shape().num_elements());
989 auto strides_vec = strides_tensor->flat<int32>();
990 std::copy_n(&strides_vec(0), strides.size(), strides.begin());
991 }
992
993 if (strides.size() != 4) {
994 return errors::InvalidArgument(
995 "MaxPool requires the stride attribute to contain 4 values, but "
996 "got: ",
997 strides.size());
998 }
999 if (kernel_sizes.size() != 4) {
1000 return errors::InvalidArgument(
1001 "MaxPool requires the ksize attribute to contain 4 values, but got: ",
1002 kernel_sizes.size());
1003 }
1004
1005 int32 stride_depth = GetTensorDim(strides, data_format, 'C');
1006 int32 stride_rows = GetTensorDim(strides, data_format, 'H');
1007 int32 stride_cols = GetTensorDim(strides, data_format, 'W');
1008 int32 kernel_depth = GetTensorDim(kernel_sizes, data_format, 'C');
1009 int32 kernel_rows = GetTensorDim(kernel_sizes, data_format, 'H');
1010 int32 kernel_cols = GetTensorDim(kernel_sizes, data_format, 'W');
1011
1012 constexpr int num_spatial_dims = 2;
1013 DimensionHandle batch_size_dim = c->Dim(
1014 input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'N'));
1015 DimensionHandle in_rows_dim = c->Dim(
1016 input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'H'));
1017 DimensionHandle in_cols_dim = c->Dim(
1018 input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'W'));
1019 DimensionHandle in_depth_dim = c->Dim(
1020 input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'C'));
1021
1022 Padding padding;
1023 TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
1024
1025 ShapeHandle output_shape;
1026 DimensionHandle output_rows, output_cols, output_depth;
1027 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
1028 c, in_rows_dim, kernel_rows, stride_rows, padding, &output_rows));
1029 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
1030 c, in_cols_dim, kernel_cols, stride_cols, padding, &output_cols));
1031 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
1032 c, in_depth_dim, kernel_depth, stride_depth, padding, &output_depth));
1033
1034 TF_RETURN_IF_ERROR(MakeShapeFromFormat(data_format, batch_size_dim,
1035 {output_rows, output_cols},
1036 output_depth, &output_shape, c));
1037
1038 c->set_output(0, output_shape);
1039 return Status::OK();
1040 }
1041
Pool3DShape(shape_inference::InferenceContext * c)1042 Status Pool3DShape(shape_inference::InferenceContext* c) {
1043 ShapeHandle input_shape;
1044 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 5, &input_shape));
1045
1046 string data_format;
1047 Status s = c->GetAttr("data_format", &data_format);
1048
1049 std::vector<int32> strides;
1050 TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
1051 if (strides.size() != 5) {
1052 return errors::InvalidArgument(
1053 "Pool3D ops require the stride attribute to contain 5 values, but "
1054 "got: ",
1055 strides.size());
1056 }
1057
1058 std::vector<int32> kernel_sizes;
1059 TF_RETURN_IF_ERROR(c->GetAttr("ksize", &kernel_sizes));
1060 if (kernel_sizes.size() != 5) {
1061 return errors::InvalidArgument(
1062 "Pool3D requires the ksize attribute to contain 5 values, but got: ",
1063 kernel_sizes.size());
1064 }
1065
1066 int32 stride_planes, stride_rows, stride_cols;
1067 int32 kernel_planes, kernel_rows, kernel_cols;
1068
1069 if (s.ok() && data_format == "NCDHW") {
1070 // Convert input_shape to NDHWC.
1071 auto dim = [&](char dimension) {
1072 return c->Dim(input_shape, GetTensorDimIndex<3>(FORMAT_NCHW, dimension));
1073 };
1074 input_shape =
1075 c->MakeShape({{dim('N'), dim('0'), dim('1'), dim('2'), dim('C')}});
1076 stride_planes = strides[2];
1077 stride_rows = strides[3];
1078 stride_cols = strides[4];
1079 kernel_planes = kernel_sizes[2];
1080 kernel_rows = kernel_sizes[3];
1081 kernel_cols = kernel_sizes[4];
1082 } else {
1083 stride_planes = strides[1];
1084 stride_rows = strides[2];
1085 stride_cols = strides[3];
1086 kernel_planes = kernel_sizes[1];
1087 kernel_rows = kernel_sizes[2];
1088 kernel_cols = kernel_sizes[3];
1089 }
1090
1091 DimensionHandle batch_size_dim = c->Dim(input_shape, 0);
1092 DimensionHandle in_planes_dim = c->Dim(input_shape, 1);
1093 DimensionHandle in_rows_dim = c->Dim(input_shape, 2);
1094 DimensionHandle in_cols_dim = c->Dim(input_shape, 3);
1095 DimensionHandle output_depth_dim = c->Dim(input_shape, 4);
1096
1097 Padding padding;
1098 TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
1099
1100 // TODO(mrry,shlens): Raise an error if the stride would cause
1101 // information in the input to be ignored. This will require a change
1102 // in the kernel implementation.
1103 DimensionHandle output_planes, output_rows, output_cols;
1104 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
1105 c, in_planes_dim, kernel_planes, stride_planes, padding, &output_planes));
1106 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
1107 c, in_rows_dim, kernel_rows, stride_rows, padding, &output_rows));
1108 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
1109 c, in_cols_dim, kernel_cols, stride_cols, padding, &output_cols));
1110
1111 ShapeHandle output_shape;
1112 if (data_format == "NCDHW") {
1113 output_shape = c->MakeShape({batch_size_dim, output_depth_dim,
1114 output_planes, output_rows, output_cols});
1115 } else {
1116 output_shape = c->MakeShape({batch_size_dim, output_planes, output_rows,
1117 output_cols, output_depth_dim});
1118 }
1119
1120 c->set_output(0, output_shape);
1121 return Status::OK();
1122 }
1123
UnknownShape(shape_inference::InferenceContext * c)1124 Status UnknownShape(shape_inference::InferenceContext* c) {
1125 for (int i = 0; i < c->num_outputs(); ++i) {
1126 c->set_output(i, c->UnknownShape());
1127 }
1128 return Status::OK();
1129 }
1130
1131 template <typename T>
ReductionShapeHelper(const Tensor * reduction_indices_t,const int32 input_rank,std::set<int64> * true_indices)1132 Status ReductionShapeHelper(const Tensor* reduction_indices_t,
1133 const int32 input_rank,
1134 std::set<int64>* true_indices) {
1135 auto reduction_indices = reduction_indices_t->flat<T>();
1136 for (int i = 0; i < reduction_indices_t->NumElements(); ++i) {
1137 const T reduction_index = reduction_indices(i);
1138 if (reduction_index < -input_rank || reduction_index >= input_rank) {
1139 return errors::InvalidArgument("Invalid reduction dimension ",
1140 reduction_index, " for input with ",
1141 input_rank, " dimensions.");
1142 }
1143
1144 auto wrapped_index = reduction_index;
1145 if (wrapped_index < 0) {
1146 wrapped_index += input_rank;
1147 }
1148
1149 true_indices->insert(wrapped_index);
1150 }
1151 return Status::OK();
1152 }
1153
ReductionShape(InferenceContext * c)1154 Status ReductionShape(InferenceContext* c) {
1155 ShapeHandle input = c->input(0);
1156
1157 ShapeHandle indices;
1158 // Older versions of TensorFlow accidentally allowed higher rank tensors like
1159 // [[1,2]] or [[1],[2]] to represent axis=[1,2].
1160 if (c->graph_def_version() < 21) {
1161 indices = c->input(1);
1162 } else {
1163 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(1), 1, &indices));
1164 }
1165
1166 bool keep_dims;
1167 TF_RETURN_IF_ERROR(c->GetAttr("keep_dims", &keep_dims));
1168
1169 const Tensor* reduction_indices_t = c->input_tensor(1);
1170 if (reduction_indices_t == nullptr || !c->RankKnown(input)) {
1171 // If we do not have the reduction values at runtime, or the
1172 // rank of the input, we don't know the output shape.
1173
1174 if (keep_dims && c->RankKnown(input)) {
1175 // output rank matches input input if <keep_dims>.
1176 c->set_output(0, c->UnknownShapeOfRank(c->Rank(input)));
1177 return Status::OK();
1178 } else {
1179 return shape_inference::UnknownShape(c);
1180 }
1181 }
1182
1183 const int32 input_rank = c->Rank(input);
1184 std::set<int64> true_indices;
1185 if (reduction_indices_t->dtype() == DataType::DT_INT32) {
1186 TF_RETURN_IF_ERROR(ReductionShapeHelper<int32>(reduction_indices_t,
1187 input_rank, &true_indices));
1188 } else if (reduction_indices_t->dtype() == DataType::DT_INT64) {
1189 TF_RETURN_IF_ERROR(ReductionShapeHelper<int64>(reduction_indices_t,
1190 input_rank, &true_indices));
1191 } else {
1192 return errors::InvalidArgument(
1193 "reduction_indices can only be int32 or int64");
1194 }
1195
1196 std::vector<DimensionHandle> dims;
1197 for (int i = 0; i < input_rank; ++i) {
1198 if (true_indices.count(i) > 0) {
1199 if (keep_dims) {
1200 dims.emplace_back(c->MakeDim(1));
1201 }
1202 } else {
1203 dims.emplace_back(c->Dim(input, i));
1204 }
1205 }
1206
1207 c->set_output(0, c->MakeShape(dims));
1208 return Status::OK();
1209 }
1210
ConcatShapeHelper(InferenceContext * c,int start_value_index,int end_value_index,int dim_index)1211 Status ConcatShapeHelper(InferenceContext* c, int start_value_index,
1212 int end_value_index, int dim_index) {
1213 ShapeHandle unused;
1214 TF_RETURN_IF_ERROR(c->WithRank(c->input(dim_index), 0, &unused));
1215 const Tensor* concat_dim_t = c->input_tensor(dim_index);
1216 if (concat_dim_t == nullptr) {
1217 // Return an unknown shape with same rank as inputs, or an unknown rank
1218 // if no input's rank is known.
1219
1220 // Find rank.
1221 int32 rank = InferenceContext::kUnknownRank;
1222 for (int i = start_value_index; i < end_value_index; ++i) {
1223 if (rank == InferenceContext::kUnknownRank) rank = c->Rank(c->input(i));
1224 if (rank != InferenceContext::kUnknownRank) {
1225 break;
1226 }
1227 }
1228 if (rank == InferenceContext::kUnknownRank) {
1229 c->set_output(0, c->UnknownShape());
1230 return Status::OK();
1231 } else if (rank == 0) {
1232 return errors::InvalidArgument(
1233 "Can't concatenate scalars (use tf.stack instead)");
1234 } else {
1235 for (int i = start_value_index; i < end_value_index; ++i) {
1236 // Check that all the inputs are of the correct rank.
1237 TF_RETURN_IF_ERROR(c->WithRank(c->input(i), rank, &unused));
1238 }
1239 }
1240 // Build result of <rank> different unknown dims.
1241 std::vector<DimensionHandle> dims;
1242 dims.reserve(rank);
1243 for (int i = 0; i < rank; ++i) dims.push_back(c->UnknownDim());
1244 c->set_output(0, c->MakeShape(dims));
1245 return Status::OK();
1246 }
1247
1248 // Merge all the non-concat dims, and sum the concat dim to make an output
1249 // shape.
1250 const int32 concat_dim = concat_dim_t->scalar<int32>()();
1251
1252 // Minimum required number of dimensions.
1253 const int min_rank = concat_dim < 0 ? -concat_dim : concat_dim + 1;
1254
1255 ShapeHandle output_before;
1256 ShapeHandle output_after;
1257
1258 ShapeHandle input = c->input(end_value_index - 1);
1259 TF_RETURN_IF_ERROR(c->WithRankAtLeast(input, min_rank, &input));
1260 TF_RETURN_IF_ERROR(c->Subshape(input, 0, concat_dim, &output_before));
1261 DimensionHandle output_middle = c->Dim(input, concat_dim);
1262 if (concat_dim == -1) {
1263 output_after = c->Scalar(); // no dimensions.
1264 } else {
1265 TF_RETURN_IF_ERROR(c->Subshape(input, concat_dim + 1, &output_after));
1266 }
1267
1268 for (int i = end_value_index - 2; i >= start_value_index; --i) {
1269 ShapeHandle before;
1270 ShapeHandle after;
1271 input = c->input(i);
1272 TF_RETURN_IF_ERROR(c->WithRankAtLeast(input, min_rank, &input));
1273 TF_RETURN_IF_ERROR(c->Subshape(input, 0, concat_dim, &before));
1274 DimensionHandle middle = c->Dim(input, concat_dim);
1275 if (concat_dim == -1) {
1276 after = c->Scalar();
1277 } else {
1278 TF_RETURN_IF_ERROR(c->Subshape(input, concat_dim + 1, &after));
1279 }
1280
1281 TF_RETURN_IF_ERROR(c->Merge(before, output_before, &output_before));
1282 TF_RETURN_IF_ERROR(c->Add(output_middle, middle, &output_middle));
1283 TF_RETURN_IF_ERROR(c->Merge(after, output_after, &output_after));
1284 }
1285
1286 ShapeHandle s;
1287 TF_RETURN_IF_ERROR(
1288 c->Concatenate(output_before, c->Vector(output_middle), &s));
1289 TF_RETURN_IF_ERROR(c->Concatenate(s, output_after, &s));
1290 c->set_output(0, s);
1291 return Status::OK();
1292 }
1293
ConcatShape(InferenceContext * c,int num_inputs_to_concat)1294 Status ConcatShape(InferenceContext* c, int num_inputs_to_concat) {
1295 return ConcatShapeHelper(c, 1 /* start_value_index */,
1296 1 + num_inputs_to_concat /* end_value_index */,
1297 0 /* dim_index */);
1298 }
1299
ConcatV2Shape(InferenceContext * c)1300 Status ConcatV2Shape(InferenceContext* c) {
1301 return ConcatShapeHelper(c, 0 /* start_value_index */,
1302 c->num_inputs() - 1 /* end_value_index */,
1303 c->num_inputs() - 1 /* dim_index */);
1304 }
1305
QuantizedConcatV2Shape(InferenceContext * c,int num_inputs_to_concat)1306 Status QuantizedConcatV2Shape(InferenceContext* c, int num_inputs_to_concat) {
1307 return ConcatShapeHelper(c, 0 /* start_value_index */,
1308 num_inputs_to_concat /* end_value_index */,
1309 num_inputs_to_concat /* dim_index */);
1310 }
1311
BroadcastBinaryOpOutputShapeFnHelper(InferenceContext * c,ShapeHandle shape_x,ShapeHandle shape_y,ShapeHandle * out)1312 Status BroadcastBinaryOpOutputShapeFnHelper(InferenceContext* c,
1313 ShapeHandle shape_x,
1314 ShapeHandle shape_y,
1315 ShapeHandle* out) {
1316 CHECK_NOTNULL(out);
1317 if (!c->RankKnown(shape_x) || !c->RankKnown(shape_y)) {
1318 *out = c->UnknownShape();
1319 return Status::OK();
1320 }
1321 const int32 rank_x = c->Rank(shape_x);
1322 const int32 rank_y = c->Rank(shape_y);
1323 const int32 rank_out = std::max(rank_x, rank_y);
1324
1325 // To compute the broadcast dimensions, we zip together shape_x and shape_y
1326 // and
1327 // pad with 1 to make them the same length.
1328 std::vector<DimensionHandle> dims;
1329 DimensionHandle dim_one;
1330 if (rank_x != rank_y) dim_one = c->MakeDim(1);
1331 for (int i = 0; i < rank_out; ++i) {
1332 const auto dim_x = i < (rank_out - rank_x)
1333 ? dim_one
1334 : c->Dim(shape_x, i - (rank_out - rank_x));
1335 const bool dim_y_is_one = (i < (rank_out - rank_y));
1336 const auto dim_y =
1337 dim_y_is_one ? dim_one : c->Dim(shape_y, i - (rank_out - rank_y));
1338 if (!c->ValueKnown(dim_x) || !c->ValueKnown(dim_y)) {
1339 // One or both dimensions is unknown.
1340 //
1341 // - If either dimension is greater than 1, we assume that the program is
1342 // correct, and the other dimension will be broadcast to match it.
1343 // TODO(cwhipkey): For shape inference, if we eliminate the shape checks
1344 // in C++ op code, we must still assert that the unknown dim is either 1
1345 // or the same as the known dim.
1346 // - If either dimension is 1, the other dimension is the output.
1347 if (c->Value(dim_x) > 1) {
1348 dims.push_back(dim_x);
1349 } else if (c->Value(dim_y) > 1) {
1350 dims.push_back(dim_y);
1351 } else if (c->Value(dim_x) == 1) {
1352 dims.push_back(dim_y);
1353 } else if (c->Value(dim_y) == 1) {
1354 dims.push_back(dim_x);
1355 } else if (dim_y.SameHandle(dim_x)) {
1356 dims.push_back(dim_x);
1357 } else {
1358 dims.push_back(c->UnknownDim());
1359 }
1360 } else if (c->Value(dim_x) == 1 || c->Value(dim_y) == 1) {
1361 if (c->Value(dim_x) == 1 && !dim_y_is_one) {
1362 // We will broadcast dim_x to dim_y.
1363 dims.push_back(dim_y);
1364 } else {
1365 DCHECK_EQ(c->Value(dim_y), 1);
1366 // We will broadcast dim_y to dim_x.
1367 dims.push_back(dim_x);
1368 }
1369 } else {
1370 DimensionHandle dim;
1371 TF_RETURN_IF_ERROR(c->Merge(dim_x, dim_y, &dim));
1372 dims.push_back(dim);
1373 }
1374 }
1375
1376 *out = c->MakeShape(dims);
1377 return Status::OK();
1378 }
1379
RandomShape(shape_inference::InferenceContext * c)1380 Status RandomShape(shape_inference::InferenceContext* c) {
1381 shape_inference::ShapeHandle out;
1382 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out));
1383 c->set_output(0, out);
1384 return Status::OK();
1385 }
1386
1387 namespace {
1388
1389 // This SliceHelper processes the output shape of the `slice`
1390 // when the tensor of `sizes` is available.
1391 template <typename T>
SliceHelper(InferenceContext * c,ShapeHandle begin_value,const Tensor * sizes_value,std::vector<DimensionHandle> * dims)1392 Status SliceHelper(InferenceContext* c, ShapeHandle begin_value,
1393 const Tensor* sizes_value,
1394 std::vector<DimensionHandle>* dims) {
1395 auto sizes_vec = sizes_value->vec<T>();
1396 for (int i = 0; i < sizes_value->NumElements(); ++i) {
1397 DimensionHandle dim = c->Dim(c->input(0), i);
1398 if (sizes_vec(i) != -1) {
1399 auto dim_val = c->Value(dim);
1400 if (sizes_vec(i) < 0) {
1401 return errors::InvalidArgument(
1402 "Out of bounds slicing on dimension ", i, " of length ", dim_val,
1403 ": sizes vector cannot be < -1, but was ", sizes_vec(i));
1404 }
1405
1406 dims->emplace_back(c->MakeDim(sizes_vec(i)));
1407 } else {
1408 DimensionHandle result;
1409 TF_RETURN_IF_ERROR(c->Subtract(dim, c->Dim(begin_value, i), &result));
1410 dims->emplace_back(result);
1411 }
1412 }
1413
1414 return Status::OK();
1415 }
1416 } // namespace
1417
SliceShape(InferenceContext * c)1418 Status SliceShape(InferenceContext* c) {
1419 ShapeHandle input = c->input(0);
1420 ShapeHandle begin_shape;
1421 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &begin_shape));
1422 ShapeHandle sizes_shape;
1423 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &sizes_shape));
1424
1425 // Merge to check compatibility of begin and sizes tensors.
1426 TF_RETURN_IF_ERROR(c->Merge(begin_shape, sizes_shape, &begin_shape));
1427
1428 DimensionHandle ndims = c->Dim(begin_shape, 0);
1429 if (c->ValueKnown(ndims)) {
1430 TF_RETURN_IF_ERROR(c->WithRank(input, c->Value(ndims), &input));
1431 }
1432
1433 // NOTE(mrry): Use MakeShapeFromShapeTensor to handle partially-known
1434 // values, even though the `begin` value does not represent a shape.
1435 ShapeHandle begin_value;
1436 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &begin_value));
1437
1438 // We check the tensor value here and will only use
1439 // `MakeShapeFromShapeTensor` when `sizes_value` is null.
1440 // The reason is that `sizes` might contain -1, which can't
1441 // be represented (-1 in the ShapeHandle would mean "unknown").
1442 const Tensor* sizes_value = c->input_tensor(2);
1443
1444 if (sizes_value != nullptr) {
1445 TF_RETURN_IF_ERROR(
1446 c->WithRank(begin_value, sizes_value->NumElements(), &begin_value));
1447 std::vector<DimensionHandle> dims;
1448 // If the begin and sizes tensors are available, then
1449 // we can be precise about the shape of the output.
1450 if (sizes_value->dtype() == DT_INT64) {
1451 TF_RETURN_IF_ERROR(
1452 SliceHelper<int64>(c, begin_value, sizes_value, &dims));
1453 } else {
1454 TF_RETURN_IF_ERROR(
1455 SliceHelper<int32>(c, begin_value, sizes_value, &dims));
1456 }
1457 c->set_output(0, c->MakeShape(dims));
1458 return Status::OK();
1459 } else {
1460 // In case `sizes` is not available (`sizes_value` is null),
1461 // we could try to use `MakeShapeFromShapeTensor` here.
1462 // If sizes contain -1, we will simply consider it as `Unknown`.
1463 // This is less than ideal but still an improvement of shape inference.
1464 // The following is an example that returns [None, 1, None] with this
1465 // code path:
1466 // z = tf.zeros((1, 2, 3))
1467 // m = tf.slice(z, [0, 0, 0], [tf.constant(1) + 0, 1, -1])
1468 // m.get_shape().as_list()
1469 ShapeHandle sizes_value;
1470 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(2, &sizes_value));
1471 if (c->RankKnown(sizes_value)) {
1472 TF_RETURN_IF_ERROR(
1473 c->WithRank(begin_value, c->Rank(sizes_value), &begin_value));
1474 std::vector<DimensionHandle> dims;
1475 dims.reserve(c->Rank(sizes_value));
1476 for (int i = 0; i < c->Rank(sizes_value); ++i) {
1477 dims.emplace_back(c->Dim(sizes_value, i));
1478 }
1479 c->set_output(0, c->MakeShape(dims));
1480 return Status::OK();
1481 }
1482 // We might know the rank of the input.
1483 if (c->RankKnown(input)) {
1484 c->set_output(0, c->UnknownShapeOfRank(c->Rank(input)));
1485 return Status::OK();
1486 } else {
1487 return shape_inference::UnknownShape(c);
1488 }
1489 }
1490
1491 return Status::OK();
1492 }
1493
ValidateSparseTensor(InferenceContext * c,ShapeHandle indices_shape,ShapeHandle values_shape,ShapeHandle shape_shape)1494 Status ValidateSparseTensor(InferenceContext* c, ShapeHandle indices_shape,
1495 ShapeHandle values_shape, ShapeHandle shape_shape) {
1496 // Validate ranks.
1497 ShapeHandle unused_shape;
1498 TF_RETURN_IF_ERROR(c->WithRank(indices_shape, 2, &unused_shape));
1499 TF_RETURN_IF_ERROR(c->WithRank(values_shape, 1, &unused_shape));
1500 TF_RETURN_IF_ERROR(c->WithRank(shape_shape, 1, &unused_shape));
1501
1502 // Number of elements in indices and values must match.
1503 DimensionHandle num_index_elements_dim = c->Dim(indices_shape, 0);
1504 if (c->ValueKnown(num_index_elements_dim)) {
1505 DimensionHandle num_values_elements_dim = c->Dim(values_shape, 0);
1506 if (c->ValueKnown(num_values_elements_dim)) {
1507 int64 num_index_elements = c->Value(num_index_elements_dim);
1508 int64 num_values_elements = c->Value(num_values_elements_dim);
1509 if (num_index_elements != num_values_elements) {
1510 return errors::InvalidArgument("Number of elements in index (",
1511 num_index_elements, ") and values (",
1512 num_values_elements, ") do not match.");
1513 }
1514 }
1515 }
1516
1517 // Rank embedded in indices must match shape.
1518 DimensionHandle index_rank_dim = c->Dim(indices_shape, 1);
1519 if (c->ValueKnown(index_rank_dim)) {
1520 DimensionHandle shape_rank_dim = c->Dim(shape_shape, 0);
1521 if (c->ValueKnown(shape_rank_dim)) {
1522 int64 index_rank = c->Value(index_rank_dim);
1523 int32 shape_rank = c->Value(shape_rank_dim);
1524 if (index_rank != shape_rank) {
1525 return errors::InvalidArgument("Index rank (", index_rank,
1526 ") and shape rank (", shape_rank,
1527 ") do not match.");
1528 }
1529 }
1530 }
1531
1532 return Status::OK();
1533 }
1534
ScatterNdUpdateShape(InferenceContext * c)1535 Status ScatterNdUpdateShape(InferenceContext* c) {
1536 ShapeHandle input_shape = c->input(0);
1537 if (c->input_handle_shapes_and_types(0) != nullptr) {
1538 // This is called for tf.scatter_nd_update; input is a Variable handle.
1539 const auto& shape_and_type = *(c->input_handle_shapes_and_types(0));
1540 if (shape_and_type.size() == 1) {
1541 input_shape = shape_and_type[0].shape;
1542 }
1543 }
1544 ShapeHandle indices_shape;
1545 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &indices_shape));
1546 ShapeHandle updates_shape;
1547 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(2), 1, &updates_shape));
1548
1549 if (c->Value(c->NumElements(input_shape)) == 0 &&
1550 (c->Value(c->NumElements(indices_shape)) > 0 ||
1551 c->Value(c->NumElements(updates_shape)) > 0)) {
1552 return errors::InvalidArgument(
1553 "Indices and updates specified for empty output shape");
1554 }
1555
1556 if (c->RankKnown(indices_shape) && c->RankKnown(updates_shape)) {
1557 const int64 num_outer_dims = c->Rank(indices_shape) - 1;
1558 const DimensionHandle index_size = c->Dim(indices_shape, -1);
1559
1560 // We can only do more validation if the last dimension of indices
1561 // is a known value.
1562 if (c->ValueKnown(index_size)) {
1563 const int64 ix = c->Value(index_size);
1564 ShapeHandle unused;
1565 ShapeHandle prefix_indices;
1566 TF_RETURN_IF_ERROR(
1567 c->Subshape(indices_shape, 0, num_outer_dims, &prefix_indices));
1568 ShapeHandle prefix_updates;
1569 TF_RETURN_IF_ERROR(
1570 c->Subshape(updates_shape, 0, num_outer_dims, &prefix_updates));
1571
1572 Status s = c->Merge(prefix_indices, prefix_updates, &unused);
1573 if (!s.ok()) {
1574 return errors::InvalidArgument(
1575 "The outer ", num_outer_dims,
1576 " dimensions of indices.shape=", c->DebugString(indices_shape),
1577 " must match the outer ", num_outer_dims,
1578 " dimensions of updates.shape=", c->DebugString(updates_shape),
1579 ": ", s.error_message());
1580 }
1581
1582 ShapeHandle input_suffix;
1583 TF_RETURN_IF_ERROR(c->Subshape(input_shape, ix, &input_suffix));
1584 ShapeHandle suffix_updates;
1585 TF_RETURN_IF_ERROR(
1586 c->Subshape(updates_shape, num_outer_dims, &suffix_updates));
1587 s = c->Merge(input_suffix, suffix_updates, &unused);
1588 if (!s.ok()) {
1589 return errors::InvalidArgument(
1590 "The inner ", c->Rank(input_shape) - ix,
1591 " dimensions of input.shape=", c->DebugString(input_shape),
1592 " must match the inner ", c->Rank(updates_shape) - num_outer_dims,
1593 " dimensions of updates.shape=", c->DebugString(updates_shape),
1594 ": ", s.error_message());
1595 }
1596 }
1597 }
1598
1599 if (c->input_handle_shapes_and_types(0) == nullptr && c->num_outputs() > 0) {
1600 // This is called for tf.scatter_nd; output is a tensor with this shape.
1601 c->set_output(0, input_shape);
1602 }
1603 return Status::OK();
1604 }
1605
ExplicitShape(InferenceContext * c)1606 Status ExplicitShape(InferenceContext* c) {
1607 PartialTensorShape shape;
1608 TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape));
1609 ShapeHandle output_shape;
1610 TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(shape, &output_shape));
1611 c->set_output(0, output_shape);
1612 return Status::OK();
1613 }
1614
ExplicitShapes(InferenceContext * c)1615 Status ExplicitShapes(InferenceContext* c) {
1616 std::vector<PartialTensorShape> shapes;
1617 TF_RETURN_IF_ERROR(c->GetAttr("shapes", &shapes));
1618 if (shapes.empty()) {
1619 return errors::Internal("shapes attribute is empty");
1620 }
1621 for (int i = 0; i < shapes.size(); ++i) {
1622 ShapeHandle output_shape;
1623 TF_RETURN_IF_ERROR(
1624 c->MakeShapeFromPartialTensorShape(shapes[i], &output_shape));
1625 c->set_output(i, output_shape);
1626 }
1627 return Status::OK();
1628 }
1629
SparseReduceShapeFn(InferenceContext * c)1630 Status SparseReduceShapeFn(InferenceContext* c) {
1631 // Input 0: input_indices
1632 // Input 1: input_values
1633 // Input 2: input_shape
1634 // Input 3: reduction_axes
1635 // Attr: keep_dims
1636 bool keep_dims = false;
1637 TF_RETURN_IF_ERROR(c->GetAttr("keep_dims", &keep_dims));
1638
1639 const Tensor* shape_tensor = c->input_tensor(2);
1640 const Tensor* axes_tensor = c->input_tensor(3);
1641 if (shape_tensor != nullptr && axes_tensor != nullptr) {
1642 auto shape_vec = shape_tensor->flat<int64>();
1643 auto axes_vec = axes_tensor->flat<int32>();
1644
1645 int64 ndims = shape_vec.size();
1646 std::unordered_set<int64> axes;
1647 for (int i = 0; i < axes_vec.size(); i++) {
1648 axes.insert((axes_vec(i) + ndims) % ndims);
1649 }
1650
1651 std::vector<DimensionHandle> dims;
1652 if (keep_dims) {
1653 dims.reserve(ndims);
1654 for (int d = 0; d < ndims; ++d) {
1655 if (axes.find(d) == axes.end()) {
1656 dims.push_back(c->MakeDim(shape_vec(d)));
1657 } else {
1658 dims.push_back(c->MakeDim(1));
1659 }
1660 }
1661 } else {
1662 for (int d = 0; d < ndims; ++d) {
1663 if (axes.find(d) == axes.end()) {
1664 dims.push_back(c->MakeDim(shape_vec(d)));
1665 }
1666 }
1667 }
1668
1669 c->set_output(0, c->MakeShape(dims));
1670 return Status::OK();
1671 }
1672 return UnknownShape(c);
1673 }
1674
1675 } // namespace shape_inference
1676
1677 } // namespace tensorflow
1678