1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/framework/common_shape_fns.h"
16
17 #include "absl/container/flat_hash_map.h"
18 #include "absl/container/flat_hash_set.h"
19 #include "absl/strings/match.h"
20 #include "absl/strings/str_split.h"
21 #include "absl/strings/string_view.h"
22 #include "tensorflow/core/framework/attr_value.pb.h"
23 #include "tensorflow/core/framework/shape_inference.h"
24 #include "tensorflow/core/lib/core/errors.h"
25 #include "tensorflow/core/lib/gtl/inlined_vector.h"
26 #include "tensorflow/core/util/einsum_op_util.h"
27
28 namespace tensorflow {
29
30 namespace shape_inference {
31
32 // The V2 version computes windowed output size with arbitrary dilation_rate and
33 // explicit padding, while the original version only handles the cases where
34 // dilation_rates equal to 1 and the padding is SAME or VALID.
GetWindowedOutputSizeFromDimsV2(shape_inference::InferenceContext * c,shape_inference::DimensionHandle input_size,shape_inference::DimensionOrConstant filter_size,int64 dilation_rate,int64 stride,Padding padding_type,int64 padding_before,int64 padding_after,shape_inference::DimensionHandle * output_size)35 Status GetWindowedOutputSizeFromDimsV2(
36 shape_inference::InferenceContext* c,
37 shape_inference::DimensionHandle input_size,
38 shape_inference::DimensionOrConstant filter_size, int64 dilation_rate,
39 int64 stride, Padding padding_type, int64 padding_before,
40 int64 padding_after, shape_inference::DimensionHandle* output_size) {
41 if (stride <= 0) {
42 return errors::InvalidArgument("Stride must be > 0, but got ", stride);
43 }
44
45 if (dilation_rate < 1) {
46 return errors::InvalidArgument("Dilation rate must be >= 1, but got ",
47 dilation_rate);
48 }
49
50 // See also the parallel implementation in GetWindowedOutputSizeVerbose.
51 switch (padding_type) {
52 case Padding::VALID:
53 padding_before = padding_after = 0;
54 TF_FALLTHROUGH_INTENDED;
55 case Padding::EXPLICIT:
56 TF_RETURN_IF_ERROR(
57 c->Add(input_size, padding_before + padding_after, &input_size));
58 if (dilation_rate > 1) {
59 DimensionHandle window_size;
60 TF_RETURN_IF_ERROR(
61 c->Subtract(c->MakeDim(filter_size), 1, &window_size));
62 TF_RETURN_IF_ERROR(
63 c->Multiply(window_size, dilation_rate, &window_size));
64 TF_RETURN_IF_ERROR(c->Add(window_size, 1, &window_size));
65 TF_RETURN_IF_ERROR(c->Subtract(input_size, window_size, output_size));
66 } else {
67 TF_RETURN_IF_ERROR(c->Subtract(input_size, filter_size, output_size));
68 }
69 TF_RETURN_IF_ERROR(c->Add(*output_size, stride, output_size));
70 TF_RETURN_IF_ERROR(c->Divide(*output_size, stride,
71 /*evenly_divisible=*/false, output_size));
72 break;
73 case Padding::SAME:
74 TF_RETURN_IF_ERROR(c->Add(input_size, stride - 1, output_size));
75 TF_RETURN_IF_ERROR(c->Divide(*output_size, stride,
76 /*evenly_divisible=*/false, output_size));
77 break;
78 }
79 return Status::OK();
80 }
81
GetWindowedOutputSizeFromDims(shape_inference::InferenceContext * c,shape_inference::DimensionHandle input_size,shape_inference::DimensionOrConstant filter_size,int64 stride,Padding padding_type,shape_inference::DimensionHandle * output_size)82 Status GetWindowedOutputSizeFromDims(
83 shape_inference::InferenceContext* c,
84 shape_inference::DimensionHandle input_size,
85 shape_inference::DimensionOrConstant filter_size, int64 stride,
86 Padding padding_type, shape_inference::DimensionHandle* output_size) {
87 if (padding_type == Padding::EXPLICIT) {
88 return errors::Internal(
89 "GetWindowedOutputSizeFromDims does not handle EXPLICIT padding; call "
90 "GetWindowedOutputSizeFromDimsV2 instead");
91 }
92 return GetWindowedOutputSizeFromDimsV2(c, input_size, filter_size,
93 /*dilation_rate=*/1, stride,
94 padding_type,
95 // Give dummy values of -1 to
96 // padding_before and padding_after,
97 // since explicit padding is not used.
98 -1, -1, output_size);
99 }
100
UnchangedShape(shape_inference::InferenceContext * c)101 Status UnchangedShape(shape_inference::InferenceContext* c) {
102 c->set_output(0, c->input(0));
103 auto* handle_data = c->input_handle_shapes_and_types(0);
104 if (handle_data != nullptr) {
105 c->set_output_handle_shapes_and_types(0, *handle_data);
106 }
107 return Status::OK();
108 }
109
MatMulShape(shape_inference::InferenceContext * c)110 Status MatMulShape(shape_inference::InferenceContext* c) {
111 ShapeHandle a;
112 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &a));
113
114 ShapeHandle b;
115 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &b));
116
117 bool transpose_a, transpose_b;
118 TF_RETURN_IF_ERROR(c->GetAttr("transpose_a", &transpose_a));
119 TF_RETURN_IF_ERROR(c->GetAttr("transpose_b", &transpose_b));
120 DimensionHandle output_rows = transpose_a ? c->Dim(a, 1) : c->Dim(a, 0);
121 DimensionHandle output_cols = transpose_b ? c->Dim(b, 0) : c->Dim(b, 1);
122
123 // Validate that the inner shapes are compatible.
124 DimensionHandle inner_a = transpose_a ? c->Dim(a, 0) : c->Dim(a, 1);
125 DimensionHandle inner_b = transpose_b ? c->Dim(b, 1) : c->Dim(b, 0);
126 DimensionHandle merged;
127 TF_RETURN_IF_ERROR(c->Merge(inner_a, inner_b, &merged));
128
129 c->set_output(0, c->Matrix(output_rows, output_cols));
130 return Status::OK();
131 }
132
133 namespace {
134
135 // Validate that an Einsum subscript contains exactly one or zero ellipsis; and
136 // that periods (.) occur only within an ellipses (...).
ValidateEinsumEllipsis(absl::string_view subscript,bool * found_ellipsis)137 Status ValidateEinsumEllipsis(absl::string_view subscript,
138 bool* found_ellipsis) {
139 const int num_periods = absl::c_count(subscript, '.');
140 if (num_periods != 0 && num_periods != 3) {
141 return errors::InvalidArgument(
142 "Expected at most one ellipsis (...), but found ", num_periods,
143 " periods (.) in the input subscript: ", subscript);
144 }
145 if (num_periods == 3 && !absl::StrContains(subscript, "...")) {
146 return errors::InvalidArgument(
147 "Periods found outside of ellipsis in subscript: ", subscript);
148 }
149 *found_ellipsis = num_periods > 0;
150 return Status::OK();
151 }
152
153 } // namespace
154
EinsumShape(shape_inference::InferenceContext * c)155 Status EinsumShape(shape_inference::InferenceContext* c) {
156 // We assume that the equation has a valid format. Either (x),(y)->(z)
157 // or (x)->(z), where each of (x), (y) and (z) are concatenation of zero or
158 // more latin alphabets and contains at most one ellipsis ('...').
159 string equation;
160 TF_RETURN_IF_ERROR(c->GetAttr("equation", &equation));
161 gtl::InlinedVector<string, 2> input_labels;
162 string output_labels;
163 TF_RETURN_IF_ERROR(
164 ParseEinsumEquation(equation, &input_labels, &output_labels));
165
166 if (c->num_inputs() == 0 || c->num_inputs() > 2) {
167 return errors::InvalidArgument("Expected either 1 or 2 inputs but got: ",
168 c->num_inputs());
169 }
170 const int input_labels_size = input_labels.size();
171 if (c->num_inputs() != input_labels_size) {
172 return errors::InvalidArgument("Expected ", input_labels.size(),
173 " inputs for equation ", equation,
174 " but got: ", c->num_inputs());
175 }
176
177 // Validate input subscripts, build the label to dimension mapping and obtain
178 // the broadcast shapes that map to ellipsis.
179 absl::flat_hash_map<char, DimensionHandle> label_to_dimension;
180 gtl::InlinedVector<ShapeHandle, 2> input_bcast_shapes(c->num_inputs());
181 for (int i = 0, end = c->num_inputs(); i < end; ++i) {
182 bool has_ellipsis = false;
183 TF_RETURN_IF_ERROR(ValidateEinsumEllipsis(input_labels[i], &has_ellipsis));
184 ShapeHandle input_shape = c->input(i);
185 // Validate that the input rank is sufficient for the given number of named
186 // labels.
187 if (c->RankKnown(input_shape)) {
188 if (has_ellipsis) {
189 const int num_named_labels =
190 static_cast<int>(input_labels[i].size()) - 3;
191 TF_RETURN_WITH_CONTEXT_IF_ERROR(
192 c->WithRankAtLeast(input_shape, num_named_labels, &input_shape),
193 " for ", i, "th input and equation: ", equation);
194 } else {
195 const int num_named_labels = static_cast<int>(input_labels[i].size());
196 TF_RETURN_WITH_CONTEXT_IF_ERROR(
197 c->WithRank(input_shape, num_named_labels, &input_shape), " for ",
198 i, "th input and equation: ", equation);
199 }
200 }
201
202 bool seen_ellipsis = false;
203 input_bcast_shapes[i] = c->Scalar();
204 // Run through the input labels; populate label_to_dimension mapping and
205 // compute the broadcast shapes corresponding to the ellipsis (if present).
206 for (int label_idx = 0, end = input_labels[i].size(); label_idx < end;
207 ++label_idx) {
208 const char label = input_labels[i][label_idx];
209 // Calculate the input axis that the current label is referring to. After
210 // the ellipsis, the axis may be found by using negative indices; i.e the
211 // (rank - k)th dimension corresponds to the (num_labels - k)th label.
212 const int64 axis_before_ellipsis = label_idx;
213 const int64 axis_after_ellipsis =
214 c->RankKnown(input_shape)
215 ? label_idx + c->Rank(input_shape) - input_labels[i].size()
216 : -1;
217
218 // Populate the input broadcast shape when we encounter an ellipsis (...).
219 if (label == '.') {
220 if (!c->RankKnown(input_shape)) {
221 input_bcast_shapes[i] = c->UnknownShape();
222 } else {
223 // The broadcast shape runs till the named label right after the
224 // ellipsis, the label with index (label_idx + 3).
225 TF_RETURN_IF_ERROR(c->Subshape(input_shape, axis_before_ellipsis,
226 axis_after_ellipsis + 3,
227 &input_bcast_shapes[i]));
228 }
229 label_idx += 2; // Skip the rest of the ellipsis.
230 seen_ellipsis = true;
231 continue;
232 }
233 // Obtain the dimension that the current label corresponds to.
234 int64 axis = seen_ellipsis ? axis_after_ellipsis : axis_before_ellipsis;
235 DimensionHandle new_dim = c->RankKnown(input_shape)
236 ? c->Dim(input_shape, axis)
237 : c->UnknownDim();
238 // If we've seen this label before, make sure previous and current
239 // dimensions are compatible.
240 if (label_to_dimension.contains(label)) {
241 DimensionHandle merged;
242 TF_RETURN_IF_ERROR(
243 c->Merge(label_to_dimension[label], new_dim, &merged));
244 label_to_dimension[label] = merged;
245 } else {
246 label_to_dimension[label] = new_dim;
247 }
248 }
249 }
250
251 // For two inputs, broadcast the two input broadcast shapes to create the
252 // output broadcast shape. For one input, just copy the single broadcast
253 // shape.
254 ShapeHandle output_bcast_shape;
255 if (input_bcast_shapes.size() == 1) {
256 output_bcast_shape = input_bcast_shapes[0];
257 } else if (input_bcast_shapes.size() == 2) {
258 TF_RETURN_IF_ERROR(BroadcastBinaryOpOutputShapeFnHelper(
259 c, input_bcast_shapes[0], input_bcast_shapes[1], true,
260 &output_bcast_shape));
261 }
262
263 bool output_has_ellipsis = false;
264 TF_RETURN_IF_ERROR(
265 ValidateEinsumEllipsis(output_labels, &output_has_ellipsis));
266 if (output_has_ellipsis) {
267 // If the output subscript has ellipsis and the output broadcast rank is
268 // unknown, then the output shape should have unknown rank.
269 if (!c->RankKnown(output_bcast_shape)) {
270 c->set_output(0, c->UnknownShape());
271 return Status::OK();
272 }
273 } else {
274 // If the output subscripts don't have ellipsis then make sure the output
275 // broadcasting shape is empty.
276 TF_RETURN_WITH_CONTEXT_IF_ERROR(
277 c->WithRankAtMost(output_bcast_shape, 0, &output_bcast_shape),
278 " for einsum equation '", equation,
279 "' without ellipsis (...) in the output subscripts where input(s) have "
280 "non-empty broadcasting shape");
281 output_bcast_shape = c->Scalar();
282 }
283
284 // Create the output shape from output labels and label_to_dimension mapping.
285 std::vector<DimensionHandle> output_dims;
286 for (int label_idx = 0, end = output_labels.size(); label_idx < end;
287 ++label_idx) {
288 const char label = output_labels[label_idx];
289 // Append the output_bcast_shape when the ellipsis is encountered.
290 if (label == '.') {
291 for (int k = 0; k < c->Rank(output_bcast_shape); ++k) {
292 output_dims.push_back(c->Dim(output_bcast_shape, k));
293 }
294 label_idx += 2; // Skip the rest of the ellipsis.
295 continue;
296 }
297 auto dimension_it = label_to_dimension.find(label);
298 if (dimension_it == label_to_dimension.end()) {
299 return errors::InvalidArgument(
300 "Einsum output subscripts for equation '", equation, "' has label '",
301 label, "' which is not present in the input subscripts");
302 }
303 output_dims.push_back(dimension_it->second);
304 }
305 c->set_output(0, c->MakeShape(output_dims));
306 return Status::OK();
307 }
308
BatchMatMulV2Shape(shape_inference::InferenceContext * c)309 Status BatchMatMulV2Shape(shape_inference::InferenceContext* c) {
310 ShapeHandle a_shape;
311 ShapeHandle b_shape;
312 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &a_shape));
313 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 2, &b_shape));
314
315 // Determine output rows and columns.
316 bool adj_x;
317 bool adj_y;
318 TF_RETURN_IF_ERROR(c->GetAttr("adj_x", &adj_x));
319 TF_RETURN_IF_ERROR(c->GetAttr("adj_y", &adj_y));
320 DimensionHandle output_rows = c->Dim(a_shape, adj_x ? -1 : -2);
321 DimensionHandle output_cols = c->Dim(b_shape, adj_y ? -2 : -1);
322
323 // Inner dimensions should be compatible.
324 DimensionHandle inner_merged;
325 TF_RETURN_IF_ERROR(c->Merge(c->Dim(a_shape, adj_x ? -2 : -1),
326 c->Dim(b_shape, adj_y ? -1 : -2), &inner_merged));
327
328 // Batch dimensions should broadcast with each other.
329 ShapeHandle a_batch_shape;
330 ShapeHandle b_batch_shape;
331 ShapeHandle output_batch_shape;
332 TF_RETURN_IF_ERROR(c->Subshape(a_shape, 0, -2, &a_batch_shape));
333 TF_RETURN_IF_ERROR(c->Subshape(b_shape, 0, -2, &b_batch_shape));
334
335 TF_RETURN_IF_ERROR(BroadcastBinaryOpOutputShapeFnHelper(
336 c, a_batch_shape, b_batch_shape, true, &output_batch_shape));
337
338 ShapeHandle output_shape;
339 TF_RETURN_IF_ERROR(c->Concatenate(
340 output_batch_shape, c->Matrix(output_rows, output_cols), &output_shape));
341
342 c->set_output(0, output_shape);
343 return Status::OK();
344 }
345
BatchMatMulShape(shape_inference::InferenceContext * c)346 Status BatchMatMulShape(shape_inference::InferenceContext* c) {
347 ShapeHandle a_shape;
348 ShapeHandle b_shape;
349 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &a_shape));
350 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 2, &b_shape));
351
352 // Determine output rows and cols.
353 bool adj_x;
354 bool adj_y;
355 TF_RETURN_IF_ERROR(c->GetAttr("adj_x", &adj_x));
356 TF_RETURN_IF_ERROR(c->GetAttr("adj_y", &adj_y));
357 DimensionHandle output_rows = c->Dim(a_shape, adj_x ? -1 : -2);
358 DimensionHandle output_cols = c->Dim(b_shape, adj_y ? -2 : -1);
359
360 // Batch dims match between inputs.
361 ShapeHandle a_batch_dims;
362 ShapeHandle b_batch_dims;
363 ShapeHandle batch_dims;
364 TF_RETURN_IF_ERROR(c->Subshape(a_shape, 0, -2, &a_batch_dims));
365 TF_RETURN_IF_ERROR(c->Subshape(b_shape, 0, -2, &b_batch_dims));
366 TF_RETURN_IF_ERROR(c->Merge(a_batch_dims, b_batch_dims, &batch_dims));
367
368 // Assert inner dims match.
369 DimensionHandle unused;
370 TF_RETURN_IF_ERROR(c->Merge(c->Dim(a_shape, adj_x ? -2 : -1),
371 c->Dim(b_shape, adj_y ? -1 : -2), &unused));
372
373 ShapeHandle out;
374 TF_RETURN_IF_ERROR(
375 c->Concatenate(batch_dims, c->Matrix(output_rows, output_cols), &out));
376 c->set_output(0, out);
377 return Status::OK();
378 }
379
380 // --------------------------------------------------------------------------
381
BiasAddShape(shape_inference::InferenceContext * c)382 Status BiasAddShape(shape_inference::InferenceContext* c) {
383 ShapeHandle input_shape;
384
385 // Fetch the data_format attribute, which may not exist.
386 string data_format;
387 Status s = c->GetAttr("data_format", &data_format);
388
389 if (s.ok() && data_format == "NCHW") {
390 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 3, &input_shape));
391 } else {
392 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input_shape));
393 }
394
395 ShapeHandle bias_shape;
396 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &bias_shape));
397 DimensionHandle bias_dim = c->Dim(bias_shape, 0);
398
399 // If rank unknown, return unknown shape.
400 if (!c->RankKnown(input_shape)) {
401 c->set_output(0, c->UnknownShape());
402 return Status::OK();
403 }
404
405 // Output has the same shape as the input, and matches the length of
406 // the bias in its bias dimension.
407 ShapeHandle output_shape;
408 if (s.ok() && data_format == "NCHW") {
409 // Merge the length of bias_shape into the third to last dimension
410 ShapeHandle first;
411 TF_RETURN_IF_ERROR(c->Subshape(input_shape, 0, 1, &first));
412
413 ShapeHandle last;
414 TF_RETURN_IF_ERROR(c->Subshape(input_shape, 2, &last));
415
416 DimensionHandle input_bias_dim = c->Dim(input_shape, 1);
417 DimensionHandle merged_bias_dim;
418 TF_RETURN_IF_ERROR(c->Merge(input_bias_dim, bias_dim, &merged_bias_dim));
419 ShapeHandle merged_bias = c->Vector(merged_bias_dim);
420
421 ShapeHandle temp;
422 TF_RETURN_IF_ERROR(c->Concatenate(first, merged_bias, &temp));
423 TF_RETURN_IF_ERROR(c->Concatenate(temp, last, &output_shape));
424 } else {
425 ShapeHandle all_but_bias;
426 TF_RETURN_IF_ERROR(c->Subshape(input_shape, 0, -1, &all_but_bias));
427
428 DimensionHandle input_bias_dim = c->Dim(input_shape, -1);
429 DimensionHandle merged_bias_dim;
430 TF_RETURN_IF_ERROR(c->Merge(input_bias_dim, bias_dim, &merged_bias_dim));
431
432 ShapeHandle merged_bias = c->Vector(merged_bias_dim);
433 TF_RETURN_IF_ERROR(
434 c->Concatenate(all_but_bias, merged_bias, &output_shape));
435 }
436
437 c->set_output(0, output_shape);
438 return Status::OK();
439 }
440
BiasAddGradShape(shape_inference::InferenceContext * c)441 Status BiasAddGradShape(shape_inference::InferenceContext* c) {
442 ShapeHandle input_shape;
443 // Fetch the data_format attribute, which may not exist.
444 string data_format;
445 Status s = c->GetAttr("data_format", &data_format);
446
447 if (s.ok() && data_format == "NCHW") {
448 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 3, &input_shape));
449 c->set_output(0, c->Vector(c->Dim(input_shape, 1)));
450 } else {
451 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input_shape));
452 c->set_output(0, c->Vector(c->Dim(input_shape, -1)));
453 }
454
455 return Status::OK();
456 }
457
CheckFormatConstraintsOnShape(const TensorFormat tensor_format,const ShapeHandle shape_handle,const string & tensor_name,shape_inference::InferenceContext * c)458 Status CheckFormatConstraintsOnShape(const TensorFormat tensor_format,
459 const ShapeHandle shape_handle,
460 const string& tensor_name,
461 shape_inference::InferenceContext* c) {
462 if (tensor_format == FORMAT_NCHW_VECT_C) {
463 // Check that the vect dim has size 4.
464 const int num_dims = c->Rank(shape_handle);
465 DimensionHandle vect_dim = c->Dim(
466 shape_handle, GetTensorInnerFeatureDimIndex(num_dims, tensor_format));
467 DimensionHandle unused_vect_dim;
468 TF_RETURN_IF_ERROR(c->WithValue(vect_dim, 4, &unused_vect_dim));
469 }
470
471 return Status::OK();
472 }
473
DatasetIteratorShape(shape_inference::InferenceContext * c)474 Status DatasetIteratorShape(shape_inference::InferenceContext* c) {
475 shape_inference::ShapeHandle unused;
476 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
477 std::vector<PartialTensorShape> output_shapes;
478 TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes));
479 const int output_shapes_size = output_shapes.size();
480 if (output_shapes_size != c->num_outputs()) {
481 return errors::InvalidArgument(
482 "`output_shapes` must be the same length as `output_types` (",
483 output_shapes.size(), " vs. ", c->num_outputs());
484 }
485 for (size_t i = 0; i < output_shapes.size(); ++i) {
486 shape_inference::ShapeHandle output_shape_handle;
487 TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(
488 output_shapes[i], &output_shape_handle));
489 c->set_output(static_cast<int>(i), output_shape_handle);
490 }
491 return Status::OK();
492 }
493
MakeShapeFromFormat(TensorFormat format,DimensionOrConstant N,const std::vector<DimensionOrConstant> & spatial,DimensionOrConstant C,ShapeHandle * out,shape_inference::InferenceContext * context)494 Status MakeShapeFromFormat(TensorFormat format, DimensionOrConstant N,
495 const std::vector<DimensionOrConstant>& spatial,
496 DimensionOrConstant C, ShapeHandle* out,
497 shape_inference::InferenceContext* context) {
498 const int num_dims = GetTensorDimsFromSpatialDims(spatial.size(), format);
499 std::vector<DimensionHandle> dims_actual(num_dims);
500 dims_actual[GetTensorBatchDimIndex(num_dims, format)] = context->MakeDim(N);
501 int outer_c_index = GetTensorFeatureDimIndex(num_dims, format);
502 dims_actual[outer_c_index] = context->MakeDim(C);
503 if (format == FORMAT_NCHW_VECT_C) {
504 dims_actual[GetTensorInnerFeatureDimIndex(num_dims, format)] =
505 context->MakeDim(4);
506 } else if (format == FORMAT_NHWC_VECT_W) {
507 dims_actual[GetTensorInnerWidthDimIndex(num_dims, format)] =
508 context->MakeDim(4);
509 }
510 for (int spatial_dim = 0, end = spatial.size(); spatial_dim < end;
511 spatial_dim++) {
512 dims_actual[GetTensorSpatialDimIndex(num_dims, format, spatial_dim)] =
513 context->MakeDim(spatial[spatial_dim]);
514 }
515 *out = context->MakeShape(dims_actual);
516 return Status::OK();
517 }
518
DimensionsFromShape(ShapeHandle shape,TensorFormat format,DimensionHandle * batch_dim,gtl::MutableArraySlice<DimensionHandle> spatial_dims,DimensionHandle * filter_dim,InferenceContext * context)519 Status DimensionsFromShape(ShapeHandle shape, TensorFormat format,
520 DimensionHandle* batch_dim,
521 gtl::MutableArraySlice<DimensionHandle> spatial_dims,
522 DimensionHandle* filter_dim,
523 InferenceContext* context) {
524 const int32 rank = GetTensorDimsFromSpatialDims(spatial_dims.size(), format);
525 // Batch.
526 *batch_dim = context->Dim(shape, GetTensorBatchDimIndex(rank, format));
527 // Spatial.
528 for (int spatial_dim_index = 0, end = spatial_dims.size();
529 spatial_dim_index < end; ++spatial_dim_index) {
530 spatial_dims[spatial_dim_index] = context->Dim(
531 shape, GetTensorSpatialDimIndex(rank, format, spatial_dim_index));
532 }
533 // Channel.
534 *filter_dim = context->Dim(shape, GetTensorFeatureDimIndex(rank, format));
535 if (format == FORMAT_NCHW_VECT_C) {
536 TF_RETURN_IF_ERROR(context->Multiply(
537 *filter_dim,
538 context->Dim(shape, GetTensorInnerFeatureDimIndex(rank, format)),
539 filter_dim));
540 }
541 return Status::OK();
542 }
543
ShapeFromDimensions(DimensionHandle batch_dim,gtl::ArraySlice<DimensionHandle> spatial_dims,DimensionHandle filter_dim,TensorFormat format,InferenceContext * context,ShapeHandle * shape)544 Status ShapeFromDimensions(DimensionHandle batch_dim,
545 gtl::ArraySlice<DimensionHandle> spatial_dims,
546 DimensionHandle filter_dim, TensorFormat format,
547 InferenceContext* context, ShapeHandle* shape) {
548 const int32 rank = GetTensorDimsFromSpatialDims(spatial_dims.size(), format);
549 std::vector<DimensionHandle> out_dims(rank);
550
551 // Batch.
552 out_dims[tensorflow::GetTensorBatchDimIndex(rank, format)] = batch_dim;
553 // Spatial.
554 for (int spatial_dim_index = 0, end = spatial_dims.size();
555 spatial_dim_index < end; ++spatial_dim_index) {
556 out_dims[tensorflow::GetTensorSpatialDimIndex(
557 rank, format, spatial_dim_index)] = spatial_dims[spatial_dim_index];
558 }
559 // Channel.
560 if (format == tensorflow::FORMAT_NCHW_VECT_C) {
561 // When format is NCHW_VECT_C, factor the feature map count
562 // into the outer feature count and the inner feature count (=4).
563 TF_RETURN_IF_ERROR(context->Divide(
564 filter_dim, 4, /*evenly_divisible=*/true,
565 &out_dims[tensorflow::GetTensorFeatureDimIndex(rank, format)]));
566 out_dims[GetTensorInnerFeatureDimIndex(rank, format)] = context->MakeDim(4);
567 } else {
568 out_dims[tensorflow::GetTensorFeatureDimIndex(rank, format)] = filter_dim;
569 }
570
571 *shape = context->MakeShape(out_dims);
572 return tensorflow::Status::OK();
573 }
574
575 namespace {
576
Conv2DShapeImpl(shape_inference::InferenceContext * c,bool supports_explicit_padding)577 Status Conv2DShapeImpl(shape_inference::InferenceContext* c,
578 bool supports_explicit_padding) {
579 string data_format_str, filter_format_str;
580 if (!c->GetAttr("data_format", &data_format_str).ok()) {
581 data_format_str = "NHWC";
582 }
583 if (!c->GetAttr("filter_format", &filter_format_str).ok()) {
584 filter_format_str = "HWIO";
585 }
586
587 TensorFormat data_format;
588 if (!FormatFromString(data_format_str, &data_format)) {
589 return errors::InvalidArgument("Invalid data format string: ",
590 data_format_str);
591 }
592 FilterTensorFormat filter_format;
593 if (!FilterFormatFromString(filter_format_str, &filter_format)) {
594 return errors::InvalidArgument("Invalid filter format string: ",
595 filter_format_str);
596 }
597
598 constexpr int num_spatial_dims = 2;
599 const int rank = GetTensorDimsFromSpatialDims(num_spatial_dims, data_format);
600 ShapeHandle conv_input_shape;
601 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &conv_input_shape));
602 TF_RETURN_IF_ERROR(CheckFormatConstraintsOnShape(
603 data_format, conv_input_shape, "conv_input", c));
604
605 // The filter rank should match the input (4 for NCHW, 5 for NCHW_VECT_C).
606 ShapeHandle filter_shape;
607 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), rank, &filter_shape));
608 TF_RETURN_IF_ERROR(
609 CheckFormatConstraintsOnShape(data_format, filter_shape, "filter", c));
610
611 std::vector<int32> dilations;
612 TF_RETURN_IF_ERROR(c->GetAttr("dilations", &dilations));
613
614 if (dilations.size() != 4) {
615 return errors::InvalidArgument(
616 "Conv2D requires the dilation attribute to contain 4 values, but got: ",
617 dilations.size());
618 }
619
620 std::vector<int32> strides;
621 TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
622
623 // strides.size() should be 4 (NCHW) even if the input is 5 (NCHW_VECT_C).
624 if (strides.size() != 4) {
625 return errors::InvalidArgument("Conv2D on data format ", data_format_str,
626 " requires the stride attribute to contain"
627 " 4 values, but got: ",
628 strides.size());
629 }
630
631 const int32 stride_rows = GetTensorDim(strides, data_format, 'H');
632 const int32 stride_cols = GetTensorDim(strides, data_format, 'W');
633 const int32 dilation_rows = GetTensorDim(dilations, data_format, 'H');
634 const int32 dilation_cols = GetTensorDim(dilations, data_format, 'W');
635
636 DimensionHandle batch_size_dim;
637 DimensionHandle input_depth_dim;
638 gtl::InlinedVector<DimensionHandle, 2> input_spatial_dims(2);
639 TF_RETURN_IF_ERROR(DimensionsFromShape(
640 conv_input_shape, data_format, &batch_size_dim,
641 absl::MakeSpan(input_spatial_dims), &input_depth_dim, c));
642
643 DimensionHandle output_depth_dim = c->Dim(
644 filter_shape, GetFilterDimIndex<num_spatial_dims>(filter_format, 'O'));
645 DimensionHandle filter_rows_dim = c->Dim(
646 filter_shape, GetFilterDimIndex<num_spatial_dims>(filter_format, 'H'));
647 DimensionHandle filter_cols_dim = c->Dim(
648 filter_shape, GetFilterDimIndex<num_spatial_dims>(filter_format, 'W'));
649 DimensionHandle filter_input_depth_dim;
650 if (filter_format == FORMAT_OIHW_VECT_I) {
651 TF_RETURN_IF_ERROR(c->Multiply(
652 c->Dim(filter_shape,
653 GetFilterDimIndex<num_spatial_dims>(filter_format, 'I')),
654 c->Dim(filter_shape,
655 GetFilterTensorInnerInputChannelsDimIndex(rank, filter_format)),
656 &filter_input_depth_dim));
657 } else {
658 filter_input_depth_dim = c->Dim(
659 filter_shape, GetFilterDimIndex<num_spatial_dims>(filter_format, 'I'));
660 }
661
662 // Check that the input tensor and the filter tensor agree on the channel
663 // count.
664 if (c->ValueKnown(input_depth_dim) && c->ValueKnown(filter_input_depth_dim)) {
665 int64 input_depth_value = c->Value(input_depth_dim),
666 filter_input_depth_value = c->Value(filter_input_depth_dim);
667 if (input_depth_value % filter_input_depth_value != 0)
668 return errors::InvalidArgument(
669 "Depth of input (", input_depth_value,
670 ") is not a multiple of input depth of filter (",
671 filter_input_depth_value, ")");
672 if (input_depth_value != filter_input_depth_value) {
673 int64 num_groups = input_depth_value / filter_input_depth_value;
674 if (c->ValueKnown(output_depth_dim)) {
675 int64 output_depth_value = c->Value(output_depth_dim);
676 if (output_depth_value % num_groups != 0)
677 return errors::InvalidArgument(
678 "Depth of output (", output_depth_value,
679 ") is not a multiple of the number of groups (", num_groups, ")");
680 }
681 }
682 }
683
684 Padding padding;
685 TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
686
687 std::vector<int64> explicit_paddings;
688 if (supports_explicit_padding) {
689 Status s = c->GetAttr("explicit_paddings", &explicit_paddings);
690 // Use the default value, which is an empty list, if the attribute is not
691 // found. Otherwise return the error to the caller.
692 if (!s.ok() && !errors::IsNotFound(s)) {
693 return s;
694 }
695 TF_RETURN_IF_ERROR(CheckValidPadding(padding, explicit_paddings,
696 /*num_dims=*/4, data_format));
697 } else {
698 CHECK(padding != Padding::EXPLICIT); // Crash ok.
699 }
700
701 DimensionHandle output_rows, output_cols;
702 int64 pad_rows_before = -1, pad_rows_after = -1;
703 int64 pad_cols_before = -1, pad_cols_after = -1;
704 if (padding == Padding::EXPLICIT) {
705 GetExplicitPaddingForDim(explicit_paddings, data_format, 'H',
706 &pad_rows_before, &pad_rows_after);
707 GetExplicitPaddingForDim(explicit_paddings, data_format, 'W',
708 &pad_cols_before, &pad_cols_after);
709 }
710 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDimsV2(
711 c, input_spatial_dims[0], filter_rows_dim, dilation_rows, stride_rows,
712 padding, pad_rows_before, pad_rows_after, &output_rows));
713 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDimsV2(
714 c, input_spatial_dims[1], filter_cols_dim, dilation_cols, stride_cols,
715 padding, pad_cols_before, pad_cols_after, &output_cols));
716
717 ShapeHandle output_shape;
718 TF_RETURN_IF_ERROR(
719 ShapeFromDimensions(batch_size_dim, {output_rows, output_cols},
720 output_depth_dim, data_format, c, &output_shape));
721 c->set_output(0, output_shape);
722 return Status::OK();
723 }
724
725 } // namespace
726
727 // Shape function for Conv2D-like operations that support explicit padding.
Conv2DShapeWithExplicitPadding(shape_inference::InferenceContext * c)728 Status Conv2DShapeWithExplicitPadding(shape_inference::InferenceContext* c) {
729 return Conv2DShapeImpl(c, true);
730 }
731
732 // Shape function for Conv2D-like operations that do not support explicit
733 // padding.
Conv2DShape(shape_inference::InferenceContext * c)734 Status Conv2DShape(shape_inference::InferenceContext* c) {
735 return Conv2DShapeImpl(c, false);
736 }
737
738 // TODO(mjanusz): Unify all conv/pooling shape functions.
Conv3DShape(shape_inference::InferenceContext * c)739 Status Conv3DShape(shape_inference::InferenceContext* c) {
740 ShapeHandle input_shape;
741 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 5, &input_shape));
742 ShapeHandle filter_shape;
743 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 5, &filter_shape));
744
745 string data_format;
746 Status s = c->GetAttr("data_format", &data_format);
747
748 std::vector<int32> dilations;
749 TF_RETURN_IF_ERROR(c->GetAttr("dilations", &dilations));
750
751 if (dilations.size() != 5) {
752 return errors::InvalidArgument(
753 "Conv3D requires the dilation attribute to contain 5 values, but got: ",
754 dilations.size());
755 }
756
757 std::vector<int32> strides;
758 TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
759 if (strides.size() != 5) {
760 return errors::InvalidArgument(
761 "Conv3D requires the stride attribute to contain 5 values, but got: ",
762 strides.size());
763 }
764
765 int32 stride_planes, stride_rows, stride_cols;
766 int32 dilation_planes, dilation_rows, dilation_cols;
767 if (s.ok() && data_format == "NCDHW") {
768 // Convert input_shape to NDHWC.
769 auto dim = [&](char dimension) {
770 return c->Dim(input_shape, GetTensorDimIndex<3>(FORMAT_NCHW, dimension));
771 };
772 input_shape =
773 c->MakeShape({{dim('N'), dim('0'), dim('1'), dim('2'), dim('C')}});
774 stride_planes = strides[2];
775 stride_rows = strides[3];
776 stride_cols = strides[4];
777 dilation_planes = dilations[2];
778 dilation_cols = dilations[3];
779 dilation_rows = dilations[4];
780 } else {
781 stride_planes = strides[1];
782 stride_rows = strides[2];
783 stride_cols = strides[3];
784 dilation_planes = dilations[1];
785 dilation_cols = dilations[2];
786 dilation_rows = dilations[3];
787 }
788
789 DimensionHandle batch_size_dim = c->Dim(input_shape, 0);
790 DimensionHandle in_planes_dim = c->Dim(input_shape, 1);
791 DimensionHandle in_rows_dim = c->Dim(input_shape, 2);
792 DimensionHandle in_cols_dim = c->Dim(input_shape, 3);
793 DimensionHandle input_depth_dim = c->Dim(input_shape, 4);
794
795 DimensionHandle filter_planes_dim = c->Dim(filter_shape, 0);
796 DimensionHandle filter_rows_dim = c->Dim(filter_shape, 1);
797 DimensionHandle filter_cols_dim = c->Dim(filter_shape, 2);
798 DimensionHandle filter_input_depth_dim = c->Dim(filter_shape, 3);
799 DimensionHandle output_depth_dim = c->Dim(filter_shape, 4);
800
801 // Check that the input tensor and the filter tensor agree on the channel
802 // count.
803 if (c->ValueKnown(input_depth_dim) && c->ValueKnown(filter_input_depth_dim)) {
804 int64 input_depth_value = c->Value(input_depth_dim),
805 filter_input_depth_value = c->Value(filter_input_depth_dim);
806 if (input_depth_value % filter_input_depth_value != 0)
807 return errors::InvalidArgument(
808 "Depth of input (", input_depth_value,
809 ") is not a multiple of input depth of filter (",
810 filter_input_depth_value, ")");
811 if (input_depth_value != filter_input_depth_value) {
812 int64 num_groups = input_depth_value / filter_input_depth_value;
813 if (c->ValueKnown(output_depth_dim)) {
814 int64 output_depth_value = c->Value(output_depth_dim);
815 if (output_depth_value % num_groups != 0)
816 return errors::InvalidArgument(
817 "Depth of output (", output_depth_value,
818 ") is not a multiple of the number of groups (", num_groups, ")");
819 }
820 }
821 }
822
823 Padding padding;
824 TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
825 DimensionHandle output_planes, output_rows, output_cols;
826
827 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDimsV2(
828 c, in_planes_dim, filter_planes_dim, dilation_planes, stride_planes,
829 padding, -1, -1, &output_planes));
830 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDimsV2(
831 c, in_rows_dim, filter_rows_dim, dilation_rows, stride_rows, padding, -1,
832 -1, &output_rows));
833 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDimsV2(
834 c, in_cols_dim, filter_cols_dim, dilation_cols, stride_cols, padding, -1,
835 -1, &output_cols));
836
837 ShapeHandle output_shape;
838 if (data_format == "NCDHW") {
839 output_shape = c->MakeShape({batch_size_dim, output_depth_dim,
840 output_planes, output_rows, output_cols});
841 } else {
842 output_shape = c->MakeShape({batch_size_dim, output_planes, output_rows,
843 output_cols, output_depth_dim});
844 }
845 c->set_output(0, output_shape);
846 return Status::OK();
847 }
848
Conv2DBackpropInputShape(shape_inference::InferenceContext * c)849 Status Conv2DBackpropInputShape(shape_inference::InferenceContext* c) {
850 string data_format_str;
851 if (!c->GetAttr("data_format", &data_format_str).ok()) {
852 data_format_str = "NHWC";
853 }
854 TensorFormat data_format;
855 if (!FormatFromString(data_format_str, &data_format)) {
856 return errors::InvalidArgument("Invalid data format string: ",
857 data_format_str);
858 }
859
860 // For the rest of this function, output_grad_* describes out_backprop and
861 // input_grad_* describes in_backprop.
862 ShapeHandle output_grad_shape = c->input(2);
863 TF_RETURN_IF_ERROR(c->WithRank(output_grad_shape, 4, &output_grad_shape));
864 ShapeHandle filter_shape = c->input(1);
865 TF_RETURN_IF_ERROR(c->WithRank(filter_shape, 4, &filter_shape));
866
867 DimensionHandle batch_size_dim;
868 DimensionHandle output_grad_depth_dim;
869 gtl::InlinedVector<DimensionHandle, 2> output_grad_spatial_dims(2);
870 TF_RETURN_IF_ERROR(DimensionsFromShape(
871 output_grad_shape, data_format, &batch_size_dim,
872 absl::MakeSpan(output_grad_spatial_dims), &output_grad_depth_dim, c));
873 DimensionHandle unused;
874 TF_RETURN_IF_ERROR(
875 c->Merge(output_grad_depth_dim, c->Dim(filter_shape, 3), &unused));
876
877 ShapeHandle specified_input_grad_shape;
878 TF_RETURN_IF_ERROR(
879 c->MakeShapeFromShapeTensor(0, &specified_input_grad_shape));
880 if (c->Rank(specified_input_grad_shape) == InferenceContext::kUnknownRank) {
881 TF_RETURN_IF_ERROR(c->WithRank(specified_input_grad_shape, 4,
882 &specified_input_grad_shape));
883 }
884
885 // input_grad_depth_dim doesn't equal c->Dim(filter_shape,2) when the number
886 // of groups is larger than 1. If input_sizes is a 4D shape, we collect
887 // input_grad_depth_dim from input_sizes; otherwise we compute it as
888 // c->Dim(filter_shape,2).
889 DimensionHandle input_grad_depth_dim;
890 gtl::InlinedVector<DimensionHandle, 2> specified_input_grad_spatial_dims(2);
891 int specified_input_grad_rank = c->Rank(specified_input_grad_shape);
892 if (specified_input_grad_rank == 4) {
893 DimensionHandle specified_batch_size_dim;
894 TF_RETURN_IF_ERROR(DimensionsFromShape(
895 specified_input_grad_shape, data_format, &specified_batch_size_dim,
896 absl::MakeSpan(specified_input_grad_spatial_dims),
897 &input_grad_depth_dim, c));
898 TF_RETURN_IF_ERROR(
899 c->Merge(specified_batch_size_dim, batch_size_dim, &unused));
900 } else if (specified_input_grad_rank == 2) {
901 specified_input_grad_spatial_dims[0] =
902 c->Dim(specified_input_grad_shape, 0);
903 specified_input_grad_spatial_dims[1] =
904 c->Dim(specified_input_grad_shape, 1);
905 input_grad_depth_dim = c->Dim(filter_shape, 2);
906 } else {
907 return errors::InvalidArgument(
908 "Conv2DBackpropInput requires input_sizes to contain 4 values or 2 "
909 "values, but got: ",
910 specified_input_grad_rank);
911 }
912
913 ShapeHandle input_grad_shape;
914 TF_RETURN_IF_ERROR(ShapeFromDimensions(
915 batch_size_dim, specified_input_grad_spatial_dims, input_grad_depth_dim,
916 data_format, c, &input_grad_shape));
917 c->set_output(0, input_grad_shape);
918 return Status::OK();
919 }
920
921 namespace {
922
DepthwiseConv2DNativeShapeImpl(shape_inference::InferenceContext * c,bool supports_explicit_padding)923 Status DepthwiseConv2DNativeShapeImpl(shape_inference::InferenceContext* c,
924 bool supports_explicit_padding) {
925 ShapeHandle input_shape;
926 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_shape));
927 ShapeHandle filter_shape;
928 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 4, &filter_shape));
929
930 std::vector<int32> strides;
931 TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
932
933 if (strides.size() != 4) {
934 return errors::InvalidArgument(
935 "DepthwiseConv2D requires the stride attribute to contain 4 values, "
936 "but got: ",
937 strides.size());
938 }
939
940 std::vector<int32> dilations;
941 if (!c->GetAttr("dilations", &dilations).ok()) {
942 dilations.resize(4, 1);
943 }
944
945 if (dilations.size() != 4) {
946 return errors::InvalidArgument(
947 "DepthwiseConv2D requires the dilations attribute to contain 4 values, "
948 "but got: ",
949 dilations.size());
950 }
951
952 string data_format_str;
953 Status s = c->GetAttr("data_format", &data_format_str);
954 TensorFormat data_format;
955 if (!s.ok() || !FormatFromString(data_format_str, &data_format)) {
956 data_format = FORMAT_NHWC;
957 }
958 int32 stride_rows;
959 int32 stride_cols;
960 int32 dilation_rows;
961 int32 dilation_cols;
962 if (data_format == FORMAT_NCHW) {
963 // Canonicalize input shape to NHWC so the shape inference code below can
964 // process it.
965 input_shape =
966 c->MakeShape({{c->Dim(input_shape, 0), c->Dim(input_shape, 2),
967 c->Dim(input_shape, 3), c->Dim(input_shape, 1)}});
968 stride_rows = strides[2];
969 stride_cols = strides[3];
970 dilation_rows = dilations[2];
971 dilation_cols = dilations[3];
972 } else {
973 stride_rows = strides[1];
974 stride_cols = strides[2];
975 dilation_rows = dilations[1];
976 dilation_cols = dilations[2];
977 }
978
979 DimensionHandle batch_size_dim = c->Dim(input_shape, 0);
980 DimensionHandle in_rows_dim = c->Dim(input_shape, 1);
981 DimensionHandle in_cols_dim = c->Dim(input_shape, 2);
982
983 DimensionHandle filter_rows_dim = c->Dim(filter_shape, 0);
984 DimensionHandle filter_cols_dim = c->Dim(filter_shape, 1);
985 DimensionHandle input_depth = c->Dim(filter_shape, 2);
986 DimensionHandle depth_multiplier = c->Dim(filter_shape, 3);
987
988 // Check that the input depths are compatible.
989 TF_RETURN_IF_ERROR(
990 c->Merge(c->Dim(input_shape, 3), input_depth, &input_depth));
991
992 DimensionHandle output_depth;
993 TF_RETURN_IF_ERROR(c->Multiply(input_depth, depth_multiplier, &output_depth));
994
995 Padding padding;
996 TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
997
998 std::vector<int64> explicit_paddings;
999 if (supports_explicit_padding) {
1000 Status status = c->GetAttr("explicit_paddings", &explicit_paddings);
1001 // Use the default value, which is an empty list, if the attribute is not
1002 // found. Otherwise return the error to the caller.
1003 if (!status.ok() && !errors::IsNotFound(status)) {
1004 return status;
1005 }
1006 TF_RETURN_IF_ERROR(CheckValidPadding(padding, explicit_paddings,
1007 /*num_dims=*/4, data_format));
1008 } else {
1009 DCHECK(padding != Padding::EXPLICIT);
1010 }
1011
1012 // TODO(mrry,shlens): Raise an error if the stride would cause
1013 // information in the input to be ignored. This will require a change
1014 // in the kernel implementation.
1015 DimensionHandle output_rows, output_cols;
1016 int64 pad_rows_before = -1, pad_rows_after = -1;
1017 int64 pad_cols_before = -1, pad_cols_after = -1;
1018 if (padding == Padding::EXPLICIT) {
1019 GetExplicitPaddingForDim(explicit_paddings, data_format, 'H',
1020 &pad_rows_before, &pad_rows_after);
1021 GetExplicitPaddingForDim(explicit_paddings, data_format, 'W',
1022 &pad_cols_before, &pad_cols_after);
1023 }
1024 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDimsV2(
1025 c, in_rows_dim, filter_rows_dim, dilation_rows, stride_rows, padding,
1026 pad_rows_before, pad_rows_after, &output_rows));
1027 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDimsV2(
1028 c, in_cols_dim, filter_cols_dim, dilation_cols, stride_cols, padding,
1029 pad_cols_before, pad_cols_after, &output_cols));
1030
1031 ShapeHandle output_shape;
1032 if (data_format == FORMAT_NCHW) {
1033 output_shape =
1034 c->MakeShape({batch_size_dim, output_depth, output_rows, output_cols});
1035 } else {
1036 output_shape =
1037 c->MakeShape({batch_size_dim, output_rows, output_cols, output_depth});
1038 }
1039 c->set_output(0, output_shape);
1040 return Status::OK();
1041 }
1042
1043 }; // namespace
1044
DepthwiseConv2DNativeShape(shape_inference::InferenceContext * c)1045 Status DepthwiseConv2DNativeShape(shape_inference::InferenceContext* c) {
1046 return DepthwiseConv2DNativeShapeImpl(c, false);
1047 }
1048
DepthwiseConv2DNativeShapeWithExplicitPadding(shape_inference::InferenceContext * c)1049 Status DepthwiseConv2DNativeShapeWithExplicitPadding(
1050 shape_inference::InferenceContext* c) {
1051 return DepthwiseConv2DNativeShapeImpl(c, true);
1052 }
1053
AvgPoolShape(shape_inference::InferenceContext * c)1054 Status AvgPoolShape(shape_inference::InferenceContext* c) {
1055 string data_format_str;
1056 TensorFormat data_format;
1057 Status s = c->GetAttr("data_format", &data_format_str);
1058 if (s.ok()) {
1059 FormatFromString(data_format_str, &data_format);
1060 } else {
1061 data_format = FORMAT_NHWC;
1062 }
1063
1064 const int rank = (data_format == FORMAT_NCHW_VECT_C) ? 5 : 4;
1065 ShapeHandle input_shape;
1066 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &input_shape));
1067
1068 TF_RETURN_IF_ERROR(
1069 CheckFormatConstraintsOnShape(data_format, input_shape, "input", c));
1070
1071 std::vector<int32> strides;
1072 TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
1073 if (strides.size() != 4) {
1074 return errors::InvalidArgument(
1075 "AvgPool requires the stride attribute to contain 4 values, but got: ",
1076 strides.size());
1077 }
1078
1079 std::vector<int32> kernel_sizes;
1080 TF_RETURN_IF_ERROR(c->GetAttr("ksize", &kernel_sizes));
1081 if (kernel_sizes.size() != 4) {
1082 return errors::InvalidArgument(
1083 "AvgPool requires the ksize attribute to contain 4 values, but got: ",
1084 kernel_sizes.size());
1085 }
1086
1087 int32 stride_rows = GetTensorDim(strides, data_format, 'H');
1088 int32 stride_cols = GetTensorDim(strides, data_format, 'W');
1089 int32 kernel_rows = GetTensorDim(kernel_sizes, data_format, 'H');
1090 int32 kernel_cols = GetTensorDim(kernel_sizes, data_format, 'W');
1091
1092 constexpr int num_spatial_dims = 2;
1093 DimensionHandle batch_size_dim = c->Dim(
1094 input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'N'));
1095 DimensionHandle in_rows_dim = c->Dim(
1096 input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'H'));
1097 DimensionHandle in_cols_dim = c->Dim(
1098 input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'W'));
1099 DimensionHandle depth_dim = c->Dim(
1100 input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'C'));
1101
1102 Padding padding;
1103 TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
1104
1105 // TODO(mrry,shlens): Raise an error if the stride would cause
1106 // information in the input to be ignored. This will require a change
1107 // in the kernel implementation.
1108
1109 DimensionHandle output_rows, output_cols;
1110 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
1111 c, in_rows_dim, kernel_rows, stride_rows, padding, &output_rows));
1112 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
1113 c, in_cols_dim, kernel_cols, stride_cols, padding, &output_cols));
1114
1115 ShapeHandle output_shape;
1116 TF_RETURN_IF_ERROR(MakeShapeFromFormat(data_format, batch_size_dim,
1117 {output_rows, output_cols}, depth_dim,
1118 &output_shape, c));
1119 c->set_output(0, output_shape);
1120 return Status::OK();
1121 }
1122
AvgPoolGradShape(shape_inference::InferenceContext * c)1123 Status AvgPoolGradShape(shape_inference::InferenceContext* c) {
1124 ShapeHandle s;
1125 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s));
1126 TF_RETURN_IF_ERROR(c->WithRank(s, 4, &s));
1127 c->set_output(0, s);
1128 return Status::OK();
1129 }
1130
FusedBatchNormShape(shape_inference::InferenceContext * c)1131 Status FusedBatchNormShape(shape_inference::InferenceContext* c) {
1132 string data_format_str;
1133 TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format_str));
1134 TensorFormat data_format;
1135 if (!FormatFromString(data_format_str, &data_format)) {
1136 return errors::InvalidArgument("Invalid data format string: ",
1137 data_format_str);
1138 }
1139 const int rank =
1140 (data_format_str == "NDHWC" || data_format_str == "NCDHW") ? 5 : 4;
1141 ShapeHandle x;
1142 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &x));
1143
1144 bool is_training;
1145 TF_RETURN_IF_ERROR(c->GetAttr("is_training", &is_training));
1146 float exponential_avg_factor;
1147 if (!c->GetAttr("exponential_avg_factor", &exponential_avg_factor).ok()) {
1148 exponential_avg_factor = 1.0f; // default value
1149 }
1150 int number_inputs = (is_training && exponential_avg_factor == 1.0f) ? 3 : 5;
1151
1152 int channel_dim_index = GetTensorFeatureDimIndex(rank, data_format);
1153 DimensionHandle channel_dim = c->Dim(x, channel_dim_index);
1154
1155 // covers scale, offset, and if is_training is false, mean, variance
1156 for (int i = 1; i < number_inputs; ++i) {
1157 ShapeHandle vec;
1158 TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &vec));
1159 TF_RETURN_IF_ERROR(c->Merge(channel_dim, c->Dim(vec, 0), &channel_dim));
1160 }
1161
1162 ShapeHandle y;
1163 TF_RETURN_IF_ERROR(c->ReplaceDim(x, channel_dim_index, channel_dim, &y));
1164 c->set_output(0, y);
1165 ShapeHandle vector_shape = c->Vector(channel_dim);
1166 c->set_output(1, vector_shape);
1167 c->set_output(2, vector_shape);
1168 c->set_output(3, vector_shape);
1169 c->set_output(4, vector_shape);
1170 return Status::OK();
1171 }
1172
FusedBatchNormV3Shape(shape_inference::InferenceContext * c)1173 Status FusedBatchNormV3Shape(shape_inference::InferenceContext* c) {
1174 TF_RETURN_IF_ERROR(FusedBatchNormShape(c));
1175 c->set_output(5, c->UnknownShape());
1176 return Status::OK();
1177 }
1178
FusedBatchNormExShape(shape_inference::InferenceContext * c)1179 Status FusedBatchNormExShape(shape_inference::InferenceContext* c) {
1180 TF_RETURN_IF_ERROR(FusedBatchNormV3Shape(c));
1181
1182 string data_format_str;
1183 TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format_str));
1184 TensorFormat data_format;
1185 if (!FormatFromString(data_format_str, &data_format)) {
1186 return errors::InvalidArgument("Invalid data format string: ",
1187 data_format_str);
1188 }
1189 ShapeHandle x;
1190 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &x));
1191
1192 int channel_dim_index = GetTensorFeatureDimIndex(4, data_format);
1193 DimensionHandle channel_dim = c->Dim(x, channel_dim_index);
1194
1195 // This is a cuDNN implementation constraint.
1196 if (c->ValueKnown(channel_dim) && c->Value(channel_dim) % 4 != 0) {
1197 return errors::InvalidArgument(
1198 "_FusedBatchNormEx channel dimension must be divisible by 4.");
1199 }
1200
1201 return Status::OK();
1202 }
1203
FusedBatchNormGradShape(shape_inference::InferenceContext * c)1204 Status FusedBatchNormGradShape(shape_inference::InferenceContext* c) {
1205 string data_format_str;
1206 TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format_str));
1207 TensorFormat data_format;
1208 if (!FormatFromString(data_format_str, &data_format)) {
1209 return errors::InvalidArgument("Invalid data format string: ",
1210 data_format_str);
1211 }
1212 const int rank =
1213 (data_format_str == "NDHWC" || data_format_str == "NCDHW") ? 5 : 4;
1214 ShapeHandle y_backprop;
1215 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &y_backprop));
1216 ShapeHandle x;
1217 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), rank, &x));
1218
1219 bool is_training;
1220 TF_RETURN_IF_ERROR(c->GetAttr("is_training", &is_training));
1221
1222 int channel_dim_index = GetTensorFeatureDimIndex(rank, data_format);
1223 DimensionHandle channel_dim = c->Dim(y_backprop, channel_dim_index);
1224 TF_RETURN_IF_ERROR(
1225 c->Merge(channel_dim, c->Dim(x, channel_dim_index), &channel_dim));
1226
1227 // covers scale, mean (reserve_space_1), variance (reserve_space_2)
1228 for (int i = 2; i < 5; ++i) {
1229 ShapeHandle vec;
1230 TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &vec));
1231 TF_RETURN_IF_ERROR(c->Merge(channel_dim, c->Dim(vec, 0), &channel_dim));
1232 }
1233
1234 ShapeHandle x_backprop;
1235 TF_RETURN_IF_ERROR(
1236 c->ReplaceDim(y_backprop, channel_dim_index, channel_dim, &x_backprop));
1237 c->set_output(0, x_backprop);
1238 c->set_output(1, c->Vector(channel_dim));
1239 c->set_output(2, c->Vector(channel_dim));
1240 c->set_output(3, c->Vector(0));
1241 c->set_output(4, c->Vector(0));
1242 return Status::OK();
1243 }
1244
ReadDiagIndex(InferenceContext * c,const Tensor * diag_index_tensor,int32 * lower_diag_index,int32 * upper_diag_index)1245 Status ReadDiagIndex(InferenceContext* c, const Tensor* diag_index_tensor,
1246 int32* lower_diag_index, int32* upper_diag_index) {
1247 // This function assumes that the shape of diag_index_tensor is fully defined.
1248 if (diag_index_tensor->dims() == 0) {
1249 *lower_diag_index = diag_index_tensor->scalar<int32>()();
1250 *upper_diag_index = *lower_diag_index;
1251 } else {
1252 int32 num_elements = diag_index_tensor->dim_size(0);
1253 if (num_elements == 1) {
1254 *lower_diag_index = diag_index_tensor->vec<int32>()(0);
1255 *upper_diag_index = *lower_diag_index;
1256 } else if (num_elements == 2) {
1257 *lower_diag_index = diag_index_tensor->vec<int32>()(0);
1258 *upper_diag_index = diag_index_tensor->vec<int32>()(1);
1259 } else {
1260 return errors::InvalidArgument(
1261 "diag_index must be a vector with one or two elements. It has ",
1262 num_elements, " elements.");
1263 }
1264 }
1265 return Status::OK();
1266 }
1267
MatrixDiagPartV2Shape(shape_inference::InferenceContext * c)1268 Status MatrixDiagPartV2Shape(shape_inference::InferenceContext* c) {
1269 ShapeHandle input_shape, diag_index_shape, unused_shape;
1270 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input_shape));
1271 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(1), 1, &diag_index_shape));
1272 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused_shape));
1273
1274 const Tensor* diag_index_tensor = c->input_tensor(1);
1275 if (!c->RankKnown(input_shape) || !c->FullyDefined(diag_index_shape) ||
1276 diag_index_tensor == nullptr) {
1277 c->set_output(0, c->UnknownShape());
1278 return Status::OK();
1279 }
1280 int32 lower_diag_index = 0;
1281 int32 upper_diag_index = 0;
1282 TF_RETURN_IF_ERROR(ReadDiagIndex(c, diag_index_tensor, &lower_diag_index,
1283 &upper_diag_index));
1284 if (lower_diag_index > upper_diag_index) {
1285 return errors::InvalidArgument(
1286 "lower_diag_index is greater than upper_diag_index");
1287 }
1288
1289 // Validates lower_diag_index and upper_diag_index.
1290 const int32 input_rank = c->Rank(input_shape);
1291 const int32 num_rows = c->Value(c->Dim(input_shape, input_rank - 2));
1292 const int32 num_cols = c->Value(c->Dim(input_shape, input_rank - 1));
1293 int32 max_diag_len = InferenceContext::kUnknownDim;
1294 if (num_rows != InferenceContext::kUnknownDim &&
1295 num_cols != InferenceContext::kUnknownDim) {
1296 if (lower_diag_index != 0 && // For when num_rows or num_cols == 0.
1297 (-num_rows >= lower_diag_index || lower_diag_index >= num_cols)) {
1298 return errors::InvalidArgument("lower_diag_index is out of bound.");
1299 }
1300 if (upper_diag_index != 0 && // For when num_rows or num_cols == 0.
1301 (-num_rows >= upper_diag_index || upper_diag_index >= num_cols)) {
1302 return errors::InvalidArgument("upper_diag_index is out of bound.");
1303 }
1304 max_diag_len = std::min(num_rows + std::min(upper_diag_index, 0),
1305 num_cols - std::max(lower_diag_index, 0));
1306 }
1307
1308 std::vector<DimensionHandle> dims;
1309 dims.reserve(input_rank - 2);
1310 for (int i = 0; i < input_rank - 2; ++i) {
1311 dims.push_back(c->Dim(input_shape, i));
1312 }
1313 if (lower_diag_index < upper_diag_index) {
1314 dims.push_back(c->MakeDim(upper_diag_index - lower_diag_index + 1));
1315 }
1316 dims.push_back(c->MakeDim(max_diag_len));
1317 c->set_output(0, c->MakeShape(dims));
1318 return Status::OK();
1319 }
1320
MatrixDiagV2Shape(shape_inference::InferenceContext * c)1321 Status MatrixDiagV2Shape(shape_inference::InferenceContext* c) {
1322 // Checks input ranks.
1323 ShapeHandle input_shape, diag_index_shape, unused_shape;
1324 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &input_shape));
1325 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(1), 1, &diag_index_shape));
1326 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused_shape));
1327 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused_shape));
1328 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused_shape));
1329
1330 // Reads the diagonal indices.
1331 const Tensor* diag_index_tensor = c->input_tensor(1);
1332 if (!c->RankKnown(input_shape) || !c->FullyDefined(diag_index_shape) ||
1333 diag_index_tensor == nullptr) {
1334 c->set_output(0, c->UnknownShape());
1335 return Status::OK();
1336 }
1337 int32 lower_diag_index = 0;
1338 int32 upper_diag_index = 0;
1339 TF_RETURN_IF_ERROR(ReadDiagIndex(c, diag_index_tensor, &lower_diag_index,
1340 &upper_diag_index));
1341 if (lower_diag_index > upper_diag_index) {
1342 return errors::InvalidArgument(
1343 "lower_diag_index is greater than upper_diag_index");
1344 }
1345
1346 // Checks if the number of diagonals provided matches what we imply from
1347 // lower_diag_index and upper_diag_index.
1348 const int32 input_rank = c->Rank(input_shape);
1349 if (lower_diag_index < upper_diag_index) {
1350 const int32 num_diags = c->Value(c->Dim(input_shape, input_rank - 2));
1351 const int32 other_dim = c->Value(c->Dim(input_shape, input_rank - 1));
1352
1353 if (num_diags != (upper_diag_index - lower_diag_index + 1)) {
1354 return errors::InvalidArgument(
1355 "The number of rows of `diagonal` doesn't match the number of "
1356 "diagonals implied from `d_lower` and `d_upper`.\n",
1357 "num_diags = ", num_diags, ", d_lower = ", lower_diag_index,
1358 ", d_upper = ", upper_diag_index, " ", input_rank, " ", other_dim);
1359 }
1360 }
1361
1362 // Reads num_rows and num_cols.
1363 const Tensor* num_rows_tensor = c->input_tensor(2);
1364 const Tensor* num_cols_tensor = c->input_tensor(3);
1365 int64 num_rows = -1;
1366 int64 num_cols = -1;
1367 if (num_rows_tensor != nullptr) {
1368 TF_RETURN_IF_ERROR(c->GetScalarFromTensor(num_rows_tensor, &num_rows));
1369 }
1370 if (num_cols_tensor != nullptr) {
1371 TF_RETURN_IF_ERROR(c->GetScalarFromTensor(num_cols_tensor, &num_cols));
1372 }
1373
1374 // Infers the missing num_rows or num_cols: If both are missing, assume
1375 // output is square. Otherwise, use the smallest possible value. Also
1376 // validates the provided values.
1377 const int32 max_diag_len = c->Value(c->Dim(input_shape, input_rank - 1));
1378 const int32 min_num_rows = max_diag_len - std::min(upper_diag_index, 0);
1379 const int32 min_num_cols = max_diag_len + std::max(lower_diag_index, 0);
1380 if (num_rows == -1 && num_cols == -1) { // Special case.
1381 num_rows = std::max(min_num_rows, min_num_cols);
1382 num_cols = num_rows;
1383 }
1384 if (num_rows == -1) {
1385 num_rows = min_num_rows;
1386 } else if (num_rows < min_num_rows) {
1387 return errors::InvalidArgument("num_rows is too small");
1388 }
1389 if (num_cols == -1) {
1390 num_cols = min_num_cols;
1391 } else if (num_cols < min_num_cols) {
1392 return errors::InvalidArgument("num_cols is too small.");
1393 }
1394 // At least one of them must match the minimum length.
1395 if (num_rows != min_num_rows && num_cols != min_num_cols) {
1396 return errors::InvalidArgument(
1397 "num_rows and num_cols are not consistent with lower_diag_index, "
1398 "upper_diag_index, and the length of the given diagonals.\n",
1399 "num_rows = ", num_rows, " != min_num_rows = ", min_num_rows,
1400 ", num_cols = ", num_cols, " != min_num_cols = ", min_num_cols);
1401 }
1402
1403 // Sets output shape.
1404 ShapeHandle output_shape;
1405 const DimensionHandle output_row_dim = c->MakeDim(num_rows);
1406 const DimensionHandle output_col_dim = c->MakeDim(num_cols);
1407 if (lower_diag_index == upper_diag_index) {
1408 TF_RETURN_IF_ERROR(c->ReplaceDim(input_shape, input_rank - 1,
1409 output_row_dim, &output_shape));
1410 TF_RETURN_IF_ERROR(
1411 c->Concatenate(output_shape, c->Vector(output_col_dim), &output_shape));
1412 } else {
1413 TF_RETURN_IF_ERROR(c->ReplaceDim(input_shape, input_rank - 2,
1414 output_row_dim, &output_shape));
1415 TF_RETURN_IF_ERROR(c->ReplaceDim(output_shape, input_rank - 1,
1416 output_col_dim, &output_shape));
1417 }
1418 c->set_output(0, output_shape);
1419 return Status::OK();
1420 }
1421
MatrixSetDiagV2Shape(shape_inference::InferenceContext * c)1422 Status MatrixSetDiagV2Shape(shape_inference::InferenceContext* c) {
1423 ShapeHandle input_shape, diag_shape, diag_index_shape;
1424 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input_shape));
1425 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &diag_shape));
1426 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(2), 1, &diag_index_shape));
1427
1428 int32 lower_diag_index = 0;
1429 int32 upper_diag_index = 0;
1430 bool diag_index_known = false;
1431 const Tensor* diag_index_tensor = c->input_tensor(2);
1432 if (diag_index_tensor != nullptr && c->FullyDefined(diag_index_shape)) {
1433 diag_index_known = true;
1434 TF_RETURN_IF_ERROR(ReadDiagIndex(c, diag_index_tensor, &lower_diag_index,
1435 &upper_diag_index));
1436 if (lower_diag_index > upper_diag_index) {
1437 return errors::InvalidArgument(
1438 "lower_diag_index is greater than upper_diag_index");
1439 }
1440 }
1441
1442 // Do more checks when input rank is known.
1443 if (c->RankKnown(input_shape)) {
1444 int32 input_rank = c->Rank(input_shape);
1445
1446 // If diag_index is set, we know the exact rank of diagonal.
1447 if (diag_index_known) {
1448 TF_RETURN_IF_ERROR(c->WithRank(
1449 c->input(1),
1450 (lower_diag_index == upper_diag_index) ? input_rank - 1 : input_rank,
1451 &diag_shape));
1452 } else {
1453 TF_RETURN_IF_ERROR(
1454 c->WithRankAtLeast(c->input(1), input_rank - 1, &diag_shape));
1455 TF_RETURN_IF_ERROR(
1456 c->WithRankAtMost(c->input(1), input_rank, &diag_shape));
1457 }
1458
1459 // Validates lower_diag_index and upper_diag_index.
1460 const int32 num_rows = c->Value(c->Dim(input_shape, input_rank - 2));
1461 const int32 num_cols = c->Value(c->Dim(input_shape, input_rank - 1));
1462 if (num_rows != InferenceContext::kUnknownDim &&
1463 num_cols != InferenceContext::kUnknownDim) {
1464 if (lower_diag_index != 0 && // For when num_rows or num_cols == 0.
1465 (-num_rows >= lower_diag_index || lower_diag_index >= num_cols)) {
1466 return errors::InvalidArgument("lower_diag_index is out of bound.");
1467 }
1468 if (upper_diag_index != 0 && // For when num_rows or num_cols == 0.
1469 (-num_rows >= upper_diag_index || upper_diag_index >= num_cols)) {
1470 return errors::InvalidArgument("upper_diag_index is out of bound.");
1471 }
1472 }
1473 }
1474
1475 ShapeHandle output_shape = input_shape;
1476 if (c->RankKnown(diag_shape) && !c->FullyDefined(input_shape)) {
1477 // Try to infer parts of shape from diag.
1478 ShapeHandle diag_prefix;
1479 TF_RETURN_IF_ERROR(c->Subshape(
1480 diag_shape, 0, (lower_diag_index == upper_diag_index) ? -1 : -2,
1481 &diag_prefix));
1482
1483 // The inner matrices can be rectangular, so we can't pinpoint their
1484 // exact height and width by just lower_diag_index, upper_diag_index,
1485 // and the longest length of given diagonals.
1486 TF_RETURN_IF_ERROR(
1487 c->Concatenate(diag_prefix, c->UnknownShapeOfRank(2), &diag_shape));
1488 TF_RETURN_IF_ERROR(c->Merge(input_shape, diag_shape, &output_shape));
1489 }
1490 c->set_output(0, output_shape);
1491 return Status::OK();
1492 }
1493
MaxPoolShapeImpl(shape_inference::InferenceContext * c,bool supports_explicit_padding)1494 Status MaxPoolShapeImpl(shape_inference::InferenceContext* c,
1495 bool supports_explicit_padding) {
1496 string data_format_str;
1497 TensorFormat data_format;
1498 Status s = c->GetAttr("data_format", &data_format_str);
1499 if (s.ok()) {
1500 FormatFromString(data_format_str, &data_format);
1501 } else {
1502 data_format = FORMAT_NHWC;
1503 }
1504
1505 const int rank = (data_format == FORMAT_NCHW_VECT_C) ? 5 : 4;
1506 ShapeHandle input_shape;
1507 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &input_shape));
1508
1509 TF_RETURN_IF_ERROR(
1510 CheckFormatConstraintsOnShape(data_format, input_shape, "input", c));
1511
1512 std::vector<int32> strides;
1513 TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
1514 if (strides.size() != 4) {
1515 return errors::InvalidArgument(
1516 "MaxPool requires the stride attribute to contain 4 values, but got: ",
1517 strides.size());
1518 }
1519
1520 std::vector<int32> kernel_sizes;
1521 TF_RETURN_IF_ERROR(c->GetAttr("ksize", &kernel_sizes));
1522 if (kernel_sizes.size() != 4) {
1523 return errors::InvalidArgument(
1524 "MaxPool requires the ksize attribute to contain 4 values, but got: ",
1525 kernel_sizes.size());
1526 }
1527
1528 int32 stride_depth = GetTensorDim(strides, data_format, 'C');
1529 int32 stride_rows = GetTensorDim(strides, data_format, 'H');
1530 int32 stride_cols = GetTensorDim(strides, data_format, 'W');
1531 int32 kernel_depth = GetTensorDim(kernel_sizes, data_format, 'C');
1532 int32 kernel_rows = GetTensorDim(kernel_sizes, data_format, 'H');
1533 int32 kernel_cols = GetTensorDim(kernel_sizes, data_format, 'W');
1534
1535 constexpr int num_spatial_dims = 2;
1536 DimensionHandle batch_size_dim = c->Dim(
1537 input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'N'));
1538 DimensionHandle in_rows_dim = c->Dim(
1539 input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'H'));
1540 DimensionHandle in_cols_dim = c->Dim(
1541 input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'W'));
1542 DimensionHandle in_depth_dim = c->Dim(
1543 input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'C'));
1544
1545 Padding padding;
1546 TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
1547
1548 std::vector<int64> explicit_paddings;
1549 if (supports_explicit_padding) {
1550 Status status = c->GetAttr("explicit_paddings", &explicit_paddings);
1551 // Use the default value, which is an empty list, if the attribute is not
1552 // found. Otherwise return the error to the caller.
1553 if (!status.ok() && !errors::IsNotFound(status)) {
1554 return status;
1555 }
1556 TF_RETURN_IF_ERROR(CheckValidPadding(padding, explicit_paddings,
1557 /*num_dims=*/4, data_format));
1558 } else {
1559 DCHECK(padding != Padding::EXPLICIT);
1560 }
1561
1562 ShapeHandle output_shape;
1563 DimensionHandle output_rows, output_cols, output_depth;
1564 int64 pad_rows_before = -1, pad_rows_after = -1;
1565 int64 pad_cols_before = -1, pad_cols_after = -1;
1566 if (padding == Padding::EXPLICIT) {
1567 GetExplicitPaddingForDim(explicit_paddings, data_format, 'H',
1568 &pad_rows_before, &pad_rows_after);
1569 GetExplicitPaddingForDim(explicit_paddings, data_format, 'W',
1570 &pad_cols_before, &pad_cols_after);
1571 }
1572 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDimsV2(
1573 c, in_rows_dim, kernel_rows, /*dilation_rate=*/1, stride_rows, padding,
1574 pad_rows_before, pad_rows_after, &output_rows));
1575 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDimsV2(
1576 c, in_cols_dim, kernel_cols, /*dilation_rate=*/1, stride_cols, padding,
1577 pad_cols_before, pad_cols_after, &output_cols));
1578 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDimsV2(
1579 c, in_depth_dim, kernel_depth, /*dilation_rate=*/1, stride_depth, padding,
1580 /*pad_before*/ 0, /*pad_after*/ 0, &output_depth));
1581
1582 TF_RETURN_IF_ERROR(MakeShapeFromFormat(data_format, batch_size_dim,
1583 {output_rows, output_cols},
1584 output_depth, &output_shape, c));
1585
1586 c->set_output(0, output_shape);
1587 return Status::OK();
1588 }
1589
MaxPoolShape(shape_inference::InferenceContext * c)1590 Status MaxPoolShape(shape_inference::InferenceContext* c) {
1591 return MaxPoolShapeImpl(c, /*supports_explicit_padding=*/false);
1592 }
1593
MaxPoolGradShape(shape_inference::InferenceContext * c)1594 Status MaxPoolGradShape(shape_inference::InferenceContext* c) {
1595 return UnchangedShapeWithRank(c, 4);
1596 }
1597
MaxPoolShapeWithExplicitPadding(shape_inference::InferenceContext * c)1598 Status MaxPoolShapeWithExplicitPadding(shape_inference::InferenceContext* c) {
1599 return MaxPoolShapeImpl(c, /*supports_explicit_padding=*/true);
1600 }
1601
MaxPoolV2Shape(shape_inference::InferenceContext * c,int num_inputs)1602 Status MaxPoolV2Shape(shape_inference::InferenceContext* c, int num_inputs) {
1603 string data_format_str;
1604 TensorFormat data_format;
1605 Status s = c->GetAttr("data_format", &data_format_str);
1606 if (s.ok()) {
1607 FormatFromString(data_format_str, &data_format);
1608 } else {
1609 data_format = FORMAT_NHWC;
1610 }
1611
1612 const int rank = (data_format == FORMAT_NCHW_VECT_C) ? 5 : 4;
1613 ShapeHandle input_shape;
1614 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &input_shape));
1615
1616 TF_RETURN_IF_ERROR(
1617 CheckFormatConstraintsOnShape(data_format, input_shape, "input", c));
1618
1619 std::vector<int32> kernel_sizes;
1620 std::vector<int32> strides;
1621
1622 if (c->num_inputs() + 2 == num_inputs) {
1623 TF_RETURN_IF_ERROR(c->GetAttr("ksize", &kernel_sizes));
1624
1625 TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
1626 } else {
1627 // Verify shape of ksize and strides input.
1628 ShapeHandle size;
1629 DimensionHandle unused;
1630 TF_RETURN_IF_ERROR(c->WithRank(c->input(c->num_inputs() - 2), 1, &size));
1631 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(size, 0), 4, &unused));
1632 TF_RETURN_IF_ERROR(c->WithRank(c->input(c->num_inputs() - 1), 1, &size));
1633 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(size, 0), 4, &unused));
1634
1635 const Tensor* kernel_sizes_tensor = c->input_tensor(c->num_inputs() - 2);
1636 if (kernel_sizes_tensor == nullptr) {
1637 c->set_output(0, c->UnknownShape());
1638 return Status::OK();
1639 }
1640 kernel_sizes.resize(kernel_sizes_tensor->shape().num_elements());
1641 auto kernel_sizes_vec = kernel_sizes_tensor->flat<int32>();
1642 std::copy_n(&kernel_sizes_vec(0), kernel_sizes.size(),
1643 kernel_sizes.begin());
1644
1645 const Tensor* strides_tensor = c->input_tensor(c->num_inputs() - 1);
1646 if (strides_tensor == nullptr) {
1647 c->set_output(0, c->UnknownShape());
1648 return Status::OK();
1649 }
1650 strides.resize(strides_tensor->shape().num_elements());
1651 auto strides_vec = strides_tensor->flat<int32>();
1652 std::copy_n(&strides_vec(0), strides.size(), strides.begin());
1653 }
1654
1655 if (strides.size() != 4) {
1656 return errors::InvalidArgument(
1657 "MaxPool requires the stride attribute to contain 4 values, but "
1658 "got: ",
1659 strides.size());
1660 }
1661 if (kernel_sizes.size() != 4) {
1662 return errors::InvalidArgument(
1663 "MaxPool requires the ksize attribute to contain 4 values, but got: ",
1664 kernel_sizes.size());
1665 }
1666
1667 int32 stride_depth = GetTensorDim(strides, data_format, 'C');
1668 int32 stride_rows = GetTensorDim(strides, data_format, 'H');
1669 int32 stride_cols = GetTensorDim(strides, data_format, 'W');
1670 int32 kernel_depth = GetTensorDim(kernel_sizes, data_format, 'C');
1671 int32 kernel_rows = GetTensorDim(kernel_sizes, data_format, 'H');
1672 int32 kernel_cols = GetTensorDim(kernel_sizes, data_format, 'W');
1673
1674 constexpr int num_spatial_dims = 2;
1675 DimensionHandle batch_size_dim = c->Dim(
1676 input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'N'));
1677 DimensionHandle in_rows_dim = c->Dim(
1678 input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'H'));
1679 DimensionHandle in_cols_dim = c->Dim(
1680 input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'W'));
1681 DimensionHandle in_depth_dim = c->Dim(
1682 input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'C'));
1683
1684 Padding padding;
1685 TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
1686
1687 ShapeHandle output_shape;
1688 DimensionHandle output_rows, output_cols, output_depth;
1689 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
1690 c, in_rows_dim, kernel_rows, stride_rows, padding, &output_rows));
1691 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
1692 c, in_cols_dim, kernel_cols, stride_cols, padding, &output_cols));
1693 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
1694 c, in_depth_dim, kernel_depth, stride_depth, padding, &output_depth));
1695
1696 TF_RETURN_IF_ERROR(MakeShapeFromFormat(data_format, batch_size_dim,
1697 {output_rows, output_cols},
1698 output_depth, &output_shape, c));
1699
1700 c->set_output(0, output_shape);
1701 return Status::OK();
1702 }
1703
Pool3DShape(shape_inference::InferenceContext * c)1704 Status Pool3DShape(shape_inference::InferenceContext* c) {
1705 ShapeHandle input_shape;
1706 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 5, &input_shape));
1707
1708 string data_format;
1709 Status s = c->GetAttr("data_format", &data_format);
1710
1711 std::vector<int32> strides;
1712 TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
1713 if (strides.size() != 5) {
1714 return errors::InvalidArgument(
1715 "Pool3D ops require the stride attribute to contain 5 values, but "
1716 "got: ",
1717 strides.size());
1718 }
1719
1720 std::vector<int32> kernel_sizes;
1721 TF_RETURN_IF_ERROR(c->GetAttr("ksize", &kernel_sizes));
1722 if (kernel_sizes.size() != 5) {
1723 return errors::InvalidArgument(
1724 "Pool3D requires the ksize attribute to contain 5 values, but got: ",
1725 kernel_sizes.size());
1726 }
1727
1728 int32 stride_planes, stride_rows, stride_cols;
1729 int32 kernel_planes, kernel_rows, kernel_cols;
1730
1731 if (s.ok() && data_format == "NCDHW") {
1732 // Convert input_shape to NDHWC.
1733 auto dim = [&](char dimension) {
1734 return c->Dim(input_shape, GetTensorDimIndex<3>(FORMAT_NCHW, dimension));
1735 };
1736 input_shape =
1737 c->MakeShape({{dim('N'), dim('0'), dim('1'), dim('2'), dim('C')}});
1738 stride_planes = strides[2];
1739 stride_rows = strides[3];
1740 stride_cols = strides[4];
1741 kernel_planes = kernel_sizes[2];
1742 kernel_rows = kernel_sizes[3];
1743 kernel_cols = kernel_sizes[4];
1744 } else {
1745 stride_planes = strides[1];
1746 stride_rows = strides[2];
1747 stride_cols = strides[3];
1748 kernel_planes = kernel_sizes[1];
1749 kernel_rows = kernel_sizes[2];
1750 kernel_cols = kernel_sizes[3];
1751 }
1752
1753 DimensionHandle batch_size_dim = c->Dim(input_shape, 0);
1754 DimensionHandle in_planes_dim = c->Dim(input_shape, 1);
1755 DimensionHandle in_rows_dim = c->Dim(input_shape, 2);
1756 DimensionHandle in_cols_dim = c->Dim(input_shape, 3);
1757 DimensionHandle output_depth_dim = c->Dim(input_shape, 4);
1758
1759 Padding padding;
1760 TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
1761
1762 // TODO(mrry,shlens): Raise an error if the stride would cause
1763 // information in the input to be ignored. This will require a change
1764 // in the kernel implementation.
1765 DimensionHandle output_planes, output_rows, output_cols;
1766 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
1767 c, in_planes_dim, kernel_planes, stride_planes, padding, &output_planes));
1768 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
1769 c, in_rows_dim, kernel_rows, stride_rows, padding, &output_rows));
1770 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
1771 c, in_cols_dim, kernel_cols, stride_cols, padding, &output_cols));
1772
1773 ShapeHandle output_shape;
1774 if (data_format == "NCDHW") {
1775 output_shape = c->MakeShape({batch_size_dim, output_depth_dim,
1776 output_planes, output_rows, output_cols});
1777 } else {
1778 output_shape = c->MakeShape({batch_size_dim, output_planes, output_rows,
1779 output_cols, output_depth_dim});
1780 }
1781
1782 c->set_output(0, output_shape);
1783 return Status::OK();
1784 }
1785
MaxPool3DGradShape(shape_inference::InferenceContext * c)1786 Status MaxPool3DGradShape(shape_inference::InferenceContext* c) {
1787 return UnchangedShapeWithRank(c, 5);
1788 }
1789
AvgPool3DGradShape(shape_inference::InferenceContext * c)1790 Status AvgPool3DGradShape(shape_inference::InferenceContext* c) {
1791 ShapeHandle s;
1792 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s));
1793 TF_RETURN_IF_ERROR(c->WithRank(s, 5, &s));
1794 c->set_output(0, s);
1795 return Status::OK();
1796 }
1797
UnknownShape(shape_inference::InferenceContext * c)1798 Status UnknownShape(shape_inference::InferenceContext* c) {
1799 for (int i = 0; i < c->num_outputs(); ++i) {
1800 c->set_output(i, c->UnknownShape());
1801 }
1802 return Status::OK();
1803 }
1804
1805 template <typename T>
ReductionShapeHelper(const Tensor * reduction_indices_t,const int32 input_rank,std::set<int64> * true_indices)1806 Status ReductionShapeHelper(const Tensor* reduction_indices_t,
1807 const int32 input_rank,
1808 std::set<int64>* true_indices) {
1809 auto reduction_indices = reduction_indices_t->flat<T>();
1810 for (int i = 0; i < reduction_indices_t->NumElements(); ++i) {
1811 const T reduction_index = reduction_indices(i);
1812 if (reduction_index < -input_rank || reduction_index >= input_rank) {
1813 return errors::InvalidArgument("Invalid reduction dimension ",
1814 reduction_index, " for input with ",
1815 input_rank, " dimensions.");
1816 }
1817
1818 auto wrapped_index = reduction_index;
1819 if (wrapped_index < 0) {
1820 wrapped_index += input_rank;
1821 }
1822
1823 true_indices->insert(wrapped_index);
1824 }
1825 return Status::OK();
1826 }
1827
ReductionShape(InferenceContext * c)1828 Status ReductionShape(InferenceContext* c) {
1829 ShapeHandle input = c->input(0);
1830
1831 ShapeHandle indices;
1832 // Older versions of TensorFlow accidentally allowed higher rank tensors like
1833 // [[1,2]] or [[1],[2]] to represent axis=[1,2].
1834 if (c->graph_def_version() < 21) {
1835 indices = c->input(1);
1836 } else {
1837 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(1), 1, &indices));
1838 }
1839
1840 bool keep_dims;
1841 TF_RETURN_IF_ERROR(c->GetAttr("keep_dims", &keep_dims));
1842
1843 const Tensor* reduction_indices_t = c->input_tensor(1);
1844 if (reduction_indices_t == nullptr || !c->RankKnown(input)) {
1845 // If we do not have the reduction values at runtime, or the
1846 // rank of the input, we don't know the output shape.
1847
1848 if (keep_dims && c->RankKnown(input)) {
1849 // output rank matches input input if <keep_dims>.
1850 c->set_output(0, c->UnknownShapeOfRank(c->Rank(input)));
1851 return Status::OK();
1852 } else {
1853 return shape_inference::UnknownShape(c);
1854 }
1855 }
1856
1857 const int32 input_rank = c->Rank(input);
1858 std::set<int64> true_indices;
1859 if (reduction_indices_t->dtype() == DataType::DT_INT32) {
1860 TF_RETURN_IF_ERROR(ReductionShapeHelper<int32>(reduction_indices_t,
1861 input_rank, &true_indices));
1862 } else if (reduction_indices_t->dtype() == DataType::DT_INT64) {
1863 TF_RETURN_IF_ERROR(ReductionShapeHelper<int64>(reduction_indices_t,
1864 input_rank, &true_indices));
1865 } else {
1866 return errors::InvalidArgument(
1867 "reduction_indices can only be int32 or int64");
1868 }
1869
1870 std::vector<DimensionHandle> dims;
1871 for (int i = 0; i < input_rank; ++i) {
1872 if (true_indices.count(i) > 0) {
1873 if (keep_dims) {
1874 dims.emplace_back(c->MakeDim(1));
1875 }
1876 } else {
1877 dims.emplace_back(c->Dim(input, i));
1878 }
1879 }
1880
1881 c->set_output(0, c->MakeShape(dims));
1882 return Status::OK();
1883 }
1884
ConcatShapeHelper(InferenceContext * c,int start_value_index,int end_value_index,int dim_index)1885 Status ConcatShapeHelper(InferenceContext* c, int start_value_index,
1886 int end_value_index, int dim_index) {
1887 ShapeHandle unused;
1888 TF_RETURN_IF_ERROR(c->WithRank(c->input(dim_index), 0, &unused));
1889 const Tensor* concat_dim_t = c->input_tensor(dim_index);
1890 if (concat_dim_t == nullptr) {
1891 // Return an unknown shape with same rank as inputs, or an unknown rank
1892 // if no input's rank is known.
1893
1894 // Find rank.
1895 int32 rank = InferenceContext::kUnknownRank;
1896 for (int i = start_value_index; i < end_value_index; ++i) {
1897 if (rank == InferenceContext::kUnknownRank) rank = c->Rank(c->input(i));
1898 if (rank != InferenceContext::kUnknownRank) {
1899 break;
1900 }
1901 }
1902 if (rank == InferenceContext::kUnknownRank) {
1903 c->set_output(0, c->UnknownShape());
1904 return Status::OK();
1905 } else if (rank == 0) {
1906 return errors::InvalidArgument(
1907 "Can't concatenate scalars (use tf.stack instead)");
1908 } else {
1909 for (int i = start_value_index; i < end_value_index; ++i) {
1910 // Check that all the inputs are of the correct rank.
1911 TF_RETURN_IF_ERROR(c->WithRank(c->input(i), rank, &unused));
1912 }
1913 }
1914 // Build result of <rank> different unknown dims.
1915 std::vector<DimensionHandle> dims;
1916 dims.reserve(rank);
1917 for (int i = 0; i < rank; ++i) dims.push_back(c->UnknownDim());
1918 c->set_output(0, c->MakeShape(dims));
1919 return Status::OK();
1920 }
1921
1922 // Merge all the non-concat dims, and sum the concat dim to make an output
1923 // shape.
1924 int64 concat_dim;
1925 if (concat_dim_t->dtype() == DT_INT32) {
1926 concat_dim = static_cast<int64>(concat_dim_t->flat<int32>()(0));
1927 } else {
1928 concat_dim = concat_dim_t->flat<int64>()(0);
1929 }
1930
1931 // Minimum required number of dimensions.
1932 const int min_rank = concat_dim < 0 ? -concat_dim : concat_dim + 1;
1933
1934 ShapeHandle output_before;
1935 ShapeHandle output_after;
1936
1937 ShapeHandle input = c->input(end_value_index - 1);
1938 TF_RETURN_IF_ERROR(c->WithRankAtLeast(input, min_rank, &input));
1939 TF_RETURN_IF_ERROR(c->Subshape(input, 0, concat_dim, &output_before));
1940 DimensionHandle output_middle = c->Dim(input, concat_dim);
1941 if (concat_dim == -1) {
1942 output_after = c->Scalar(); // no dimensions.
1943 } else {
1944 TF_RETURN_IF_ERROR(c->Subshape(input, concat_dim + 1, &output_after));
1945 }
1946
1947 for (int i = end_value_index - 2; i >= start_value_index; --i) {
1948 ShapeHandle before;
1949 ShapeHandle after;
1950 input = c->input(i);
1951 TF_RETURN_IF_ERROR(c->WithRankAtLeast(input, min_rank, &input));
1952 TF_RETURN_IF_ERROR(c->Subshape(input, 0, concat_dim, &before));
1953 DimensionHandle middle = c->Dim(input, concat_dim);
1954 if (concat_dim == -1) {
1955 after = c->Scalar();
1956 } else {
1957 TF_RETURN_IF_ERROR(c->Subshape(input, concat_dim + 1, &after));
1958 }
1959
1960 TF_RETURN_IF_ERROR(c->Merge(before, output_before, &output_before));
1961 TF_RETURN_IF_ERROR(c->Add(output_middle, middle, &output_middle));
1962 TF_RETURN_IF_ERROR(c->Merge(after, output_after, &output_after));
1963 }
1964
1965 ShapeHandle s;
1966 TF_RETURN_IF_ERROR(
1967 c->Concatenate(output_before, c->Vector(output_middle), &s));
1968 TF_RETURN_IF_ERROR(c->Concatenate(s, output_after, &s));
1969 c->set_output(0, s);
1970 return Status::OK();
1971 }
1972
ConcatShape(InferenceContext * c,int num_inputs_to_concat)1973 Status ConcatShape(InferenceContext* c, int num_inputs_to_concat) {
1974 return ConcatShapeHelper(c, 1 /* start_value_index */,
1975 1 + num_inputs_to_concat /* end_value_index */,
1976 0 /* dim_index */);
1977 }
1978
ConcatV2Shape(InferenceContext * c)1979 Status ConcatV2Shape(InferenceContext* c) {
1980 return ConcatShapeHelper(c, 0 /* start_value_index */,
1981 c->num_inputs() - 1 /* end_value_index */,
1982 c->num_inputs() - 1 /* dim_index */);
1983 }
1984
QuantizedConcatV2Shape(InferenceContext * c,int num_inputs_to_concat)1985 Status QuantizedConcatV2Shape(InferenceContext* c, int num_inputs_to_concat) {
1986 return ConcatShapeHelper(c, 0 /* start_value_index */,
1987 num_inputs_to_concat /* end_value_index */,
1988 num_inputs_to_concat /* dim_index */);
1989 }
1990
BroadcastBinaryOpOutputShapeFnHelper(InferenceContext * c,ShapeHandle shape_x,ShapeHandle shape_y,bool incompatible_shape_error,ShapeHandle * out)1991 Status BroadcastBinaryOpOutputShapeFnHelper(InferenceContext* c,
1992 ShapeHandle shape_x,
1993 ShapeHandle shape_y,
1994 bool incompatible_shape_error,
1995 ShapeHandle* out) {
1996 CHECK_NOTNULL(out);
1997 if (!c->RankKnown(shape_x) || !c->RankKnown(shape_y)) {
1998 *out = c->UnknownShape();
1999 return Status::OK();
2000 }
2001 const int32 rank_x = c->Rank(shape_x);
2002 const int32 rank_y = c->Rank(shape_y);
2003 const int32 rank_out = std::max(rank_x, rank_y);
2004
2005 // To compute the broadcast dimensions, we zip together shape_x and shape_y
2006 // and
2007 // pad with 1 to make them the same length.
2008 std::vector<DimensionHandle> dims;
2009 DimensionHandle dim_one;
2010 if (rank_x != rank_y) dim_one = c->MakeDim(1);
2011 for (int i = 0; i < rank_out; ++i) {
2012 const auto dim_x = i < (rank_out - rank_x)
2013 ? dim_one
2014 : c->Dim(shape_x, i - (rank_out - rank_x));
2015 const bool dim_y_is_one = (i < (rank_out - rank_y));
2016 const auto dim_y =
2017 dim_y_is_one ? dim_one : c->Dim(shape_y, i - (rank_out - rank_y));
2018 if (!c->ValueKnown(dim_x) || !c->ValueKnown(dim_y)) {
2019 // One or both dimensions is unknown.
2020 //
2021 // - If either dimension is greater than 1, we assume that the program is
2022 // correct, and the other dimension will be broadcast to match it.
2023 // TODO(cwhipkey): For shape inference, if we eliminate the shape checks
2024 // in C++ op code, we must still assert that the unknown dim is either 1
2025 // or the same as the known dim.
2026 // - If either dimension is 1, the other dimension is the output.
2027 // - If both are unknown then dimension is unknown
2028 if (c->Value(dim_x) > 1) {
2029 if (!incompatible_shape_error) {
2030 *out = c->UnknownShape();
2031 return Status::OK();
2032 }
2033 dims.push_back(dim_x);
2034 } else if (c->Value(dim_y) > 1) {
2035 if (!incompatible_shape_error) {
2036 *out = c->UnknownShape();
2037 return Status::OK();
2038 }
2039 dims.push_back(dim_y);
2040 } else if (c->Value(dim_x) == 1) {
2041 dims.push_back(dim_y);
2042 } else if (c->Value(dim_y) == 1) {
2043 dims.push_back(dim_x);
2044 } else if (dim_y.SameHandle(dim_x)) {
2045 dims.push_back(dim_x);
2046 } else if (!c->ValueKnown(dim_x) && !c->ValueKnown(dim_y)) {
2047 dims.push_back(c->UnknownDim());
2048 } else {
2049 if (!incompatible_shape_error) {
2050 *out = c->UnknownShape();
2051 return Status::OK();
2052 }
2053 dims.push_back(c->UnknownDim());
2054 }
2055 } else if (c->Value(dim_x) == 1 || c->Value(dim_y) == 1) {
2056 if (c->Value(dim_x) == 1 && !dim_y_is_one) {
2057 // We will broadcast dim_x to dim_y.
2058 dims.push_back(dim_y);
2059 } else {
2060 DCHECK_EQ(c->Value(dim_y), 1);
2061 // We will broadcast dim_y to dim_x.
2062 dims.push_back(dim_x);
2063 }
2064 } else {
2065 DimensionHandle dim;
2066 Status s = c->Merge(dim_x, dim_y, &dim);
2067 if (!s.ok()) {
2068 if (!incompatible_shape_error) {
2069 *out = c->MakeShape({});
2070 return Status::OK();
2071 }
2072 return s;
2073 }
2074 dims.push_back(dim);
2075 }
2076 }
2077
2078 *out = c->MakeShape(dims);
2079 return Status::OK();
2080 }
2081
RandomShape(shape_inference::InferenceContext * c)2082 Status RandomShape(shape_inference::InferenceContext* c) {
2083 shape_inference::ShapeHandle out;
2084 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out));
2085 c->set_output(0, out);
2086 return Status::OK();
2087 }
2088
UnsortedSegmentReductionShapeFn(InferenceContext * c)2089 Status UnsortedSegmentReductionShapeFn(InferenceContext* c) {
2090 ShapeHandle s_data = c->input(0);
2091 ShapeHandle s_segment_ids = c->input(1);
2092 ShapeHandle s_num_segments = c->input(2);
2093 TF_RETURN_IF_ERROR(c->WithRank(s_num_segments, 0, &s_num_segments));
2094
2095 ShapeHandle out;
2096
2097 // Leading dimensions of data must be compatible with dimensions of
2098 // <s_segment_ids>.
2099 if (c->RankKnown(s_segment_ids)) {
2100 TF_RETURN_IF_ERROR(
2101 c->MergePrefix(s_data, s_segment_ids, &s_data, &s_segment_ids));
2102
2103 // Get the value of the num_segments input tensor.
2104 DimensionHandle num_segments_dim;
2105 TF_RETURN_IF_ERROR(c->MakeDimForScalarInput(2, &num_segments_dim));
2106
2107 // Output is {segment_id_rank} + s_data[segment_id_rank:].
2108 ShapeHandle s_data_suffix;
2109 TF_RETURN_IF_ERROR(
2110 c->Subshape(s_data, c->Rank(s_segment_ids), &s_data_suffix));
2111 TF_RETURN_IF_ERROR(
2112 c->Concatenate(c->Vector(num_segments_dim), s_data_suffix, &out));
2113 } else {
2114 out = c->UnknownShape();
2115 }
2116 c->set_output(0, out);
2117 return Status::OK();
2118 }
2119
2120 namespace {
2121
2122 // This SliceHelper processes the output shape of the `slice`
2123 // when the tensor of `sizes` is available.
2124 template <typename T>
SliceHelper(InferenceContext * c,ShapeHandle begin_value,const Tensor * sizes_value,std::vector<DimensionHandle> * dims)2125 Status SliceHelper(InferenceContext* c, ShapeHandle begin_value,
2126 const Tensor* sizes_value,
2127 std::vector<DimensionHandle>* dims) {
2128 auto sizes_vec = sizes_value->vec<T>();
2129 for (int i = 0; i < sizes_value->NumElements(); ++i) {
2130 DimensionHandle dim = c->Dim(c->input(0), i);
2131 if (sizes_vec(i) != -1) {
2132 auto dim_val = c->Value(dim);
2133 if (sizes_vec(i) < 0) {
2134 return errors::InvalidArgument(
2135 "Out of bounds slicing on dimension ", i, " of length ", dim_val,
2136 ": sizes vector cannot be < -1, but was ", sizes_vec(i));
2137 }
2138
2139 dims->emplace_back(c->MakeDim(sizes_vec(i)));
2140 } else {
2141 DimensionHandle result;
2142 TF_RETURN_IF_ERROR(c->Subtract(dim, c->Dim(begin_value, i), &result));
2143 dims->emplace_back(result);
2144 }
2145 }
2146
2147 return Status::OK();
2148 }
2149 } // namespace
2150
SliceShape(InferenceContext * c)2151 Status SliceShape(InferenceContext* c) {
2152 ShapeHandle input = c->input(0);
2153 ShapeHandle begin_shape;
2154 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &begin_shape));
2155 ShapeHandle sizes_shape;
2156 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &sizes_shape));
2157
2158 // Merge to check compatibility of begin and sizes tensors.
2159 TF_RETURN_IF_ERROR(c->Merge(begin_shape, sizes_shape, &begin_shape));
2160
2161 DimensionHandle ndims = c->Dim(begin_shape, 0);
2162 if (c->ValueKnown(ndims)) {
2163 TF_RETURN_IF_ERROR(c->WithRank(input, c->Value(ndims), &input));
2164 }
2165
2166 // NOTE(mrry): Use MakeShapeFromShapeTensor to handle partially-known
2167 // values, even though the `begin` value does not represent a shape.
2168 ShapeHandle begin_value;
2169 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &begin_value));
2170
2171 // We check the tensor value here and will only use
2172 // `MakeShapeFromShapeTensor` when `sizes_value` is null.
2173 // The reason is that `sizes` might contain -1, which can't
2174 // be represented (-1 in the ShapeHandle would mean "unknown").
2175 const Tensor* sizes_value = c->input_tensor(2);
2176
2177 if (sizes_value != nullptr) {
2178 TF_RETURN_IF_ERROR(
2179 c->WithRank(begin_value, sizes_value->NumElements(), &begin_value));
2180 std::vector<DimensionHandle> dims;
2181 // If the begin and sizes tensors are available, then
2182 // we can be precise about the shape of the output.
2183 if (sizes_value->dtype() == DT_INT64) {
2184 TF_RETURN_IF_ERROR(
2185 SliceHelper<int64>(c, begin_value, sizes_value, &dims));
2186 } else {
2187 TF_RETURN_IF_ERROR(
2188 SliceHelper<int32>(c, begin_value, sizes_value, &dims));
2189 }
2190 c->set_output(0, c->MakeShape(dims));
2191 return Status::OK();
2192 } else {
2193 // In case `sizes` is not available (`sizes_value` is null),
2194 // we could try to use `MakeShapeFromShapeTensor` here.
2195 // If sizes contain -1, we will simply consider it as `Unknown`.
2196 // This is less than ideal but still an improvement of shape inference.
2197 // The following is an example that returns [None, 1, None] with this
2198 // code path:
2199 // z = tf.zeros((1, 2, 3))
2200 // m = tf.slice(z, [0, 0, 0], [tf.constant(1) + 0, 1, -1])
2201 // m.get_shape().as_list()
2202 ShapeHandle sizes_value;
2203 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(2, &sizes_value));
2204 if (c->RankKnown(sizes_value)) {
2205 TF_RETURN_IF_ERROR(
2206 c->WithRank(begin_value, c->Rank(sizes_value), &begin_value));
2207 std::vector<DimensionHandle> dims;
2208 dims.reserve(c->Rank(sizes_value));
2209 for (int i = 0; i < c->Rank(sizes_value); ++i) {
2210 dims.emplace_back(c->Dim(sizes_value, i));
2211 }
2212 c->set_output(0, c->MakeShape(dims));
2213 return Status::OK();
2214 }
2215 // We might know the rank of the input.
2216 if (c->RankKnown(input)) {
2217 c->set_output(0, c->UnknownShapeOfRank(c->Rank(input)));
2218 return Status::OK();
2219 } else {
2220 return shape_inference::UnknownShape(c);
2221 }
2222 }
2223
2224 return Status::OK();
2225 }
2226
ValidateSparseTensor(InferenceContext * c,ShapeHandle indices_shape,ShapeHandle values_shape,ShapeHandle shape_shape)2227 Status ValidateSparseTensor(InferenceContext* c, ShapeHandle indices_shape,
2228 ShapeHandle values_shape, ShapeHandle shape_shape) {
2229 // Validate ranks.
2230 ShapeHandle unused_shape;
2231 TF_RETURN_IF_ERROR(c->WithRank(indices_shape, 2, &unused_shape));
2232 TF_RETURN_IF_ERROR(c->WithRank(values_shape, 1, &unused_shape));
2233 TF_RETURN_IF_ERROR(c->WithRank(shape_shape, 1, &unused_shape));
2234
2235 // Number of elements in indices and values must match.
2236 DimensionHandle num_index_elements_dim = c->Dim(indices_shape, 0);
2237 if (c->ValueKnown(num_index_elements_dim)) {
2238 DimensionHandle num_values_elements_dim = c->Dim(values_shape, 0);
2239 if (c->ValueKnown(num_values_elements_dim)) {
2240 int64 num_index_elements = c->Value(num_index_elements_dim);
2241 int64 num_values_elements = c->Value(num_values_elements_dim);
2242 if (num_index_elements != num_values_elements) {
2243 return errors::InvalidArgument("Number of elements in index (",
2244 num_index_elements, ") and values (",
2245 num_values_elements, ") do not match.");
2246 }
2247 }
2248 }
2249
2250 // Rank embedded in indices must match shape.
2251 DimensionHandle index_rank_dim = c->Dim(indices_shape, 1);
2252 if (c->ValueKnown(index_rank_dim)) {
2253 DimensionHandle shape_rank_dim = c->Dim(shape_shape, 0);
2254 if (c->ValueKnown(shape_rank_dim)) {
2255 int64 index_rank = c->Value(index_rank_dim);
2256 int32 shape_rank = c->Value(shape_rank_dim);
2257 if (index_rank != shape_rank) {
2258 return errors::InvalidArgument("Index rank (", index_rank,
2259 ") and shape rank (", shape_rank,
2260 ") do not match.");
2261 }
2262 }
2263 }
2264
2265 return Status::OK();
2266 }
2267
ValidateVariableResourceHandle(InferenceContext * c,std::vector<ShapeAndType> * shape_and_type)2268 Status ValidateVariableResourceHandle(
2269 InferenceContext* c, std::vector<ShapeAndType>* shape_and_type) {
2270 auto* handle_data = c->input_handle_shapes_and_types(0);
2271 if (handle_data == nullptr || handle_data->empty()) {
2272 shape_and_type->emplace_back(c->UnknownShape(), DT_INVALID);
2273 } else {
2274 *shape_and_type = *handle_data;
2275 DataType value_dtype;
2276 TF_RETURN_IF_ERROR(c->GetAttr("dtype", &value_dtype));
2277 if (shape_and_type->at(0).dtype != value_dtype) {
2278 return errors::InvalidArgument(
2279 "Trying to read variable with wrong dtype. "
2280 "Expected ",
2281 DataTypeString(shape_and_type->at(0).dtype), " got ",
2282 DataTypeString(value_dtype));
2283 }
2284 }
2285 return Status::OK();
2286 }
2287
GatherNdShape(InferenceContext * c)2288 Status GatherNdShape(InferenceContext* c) {
2289 ShapeHandle params;
2290 std::vector<ShapeAndType> handle_shape_and_type;
2291 if (c->input_handle_shapes_and_types(0) != nullptr) {
2292 TF_RETURN_IF_ERROR(
2293 ValidateVariableResourceHandle(c, &handle_shape_and_type));
2294 params = handle_shape_and_type[0].shape;
2295 } else {
2296 params = c->input(0);
2297 }
2298 ShapeHandle indices;
2299 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &indices));
2300 DimensionHandle r_dim = c->Dim(indices, -1);
2301
2302 if (!c->RankKnown(params) || !c->ValueKnown(r_dim)) {
2303 c->set_output(0, c->UnknownShape());
2304 return Status::OK();
2305 }
2306
2307 if (c->Value(r_dim) > c->Rank(params)) {
2308 return errors::InvalidArgument(
2309 "indices.shape[-1] must be <= params.rank, but saw indices shape: ",
2310 c->DebugString(indices), " and params shape: ", c->DebugString(params));
2311 }
2312
2313 // Remove r_dim from indices to get output.
2314 ShapeHandle indices_slice;
2315 ShapeHandle params_slice;
2316 TF_RETURN_IF_ERROR(c->Subshape(indices, 0, -1, &indices_slice));
2317 TF_RETURN_IF_ERROR(c->Subshape(params, c->Value(r_dim), ¶ms_slice));
2318 ShapeHandle out;
2319 TF_RETURN_IF_ERROR(c->Concatenate(indices_slice, params_slice, &out));
2320 c->set_output(0, out);
2321 return Status::OK();
2322 }
2323
ScatterNdShapeHelper(InferenceContext * c,ShapeHandle indices_shape,ShapeHandle updates_shape,ShapeHandle input_shape)2324 Status ScatterNdShapeHelper(InferenceContext* c, ShapeHandle indices_shape,
2325 ShapeHandle updates_shape,
2326 ShapeHandle input_shape) {
2327 if (c->Value(c->NumElements(input_shape)) == 0 &&
2328 (c->Value(c->NumElements(indices_shape)) > 0 ||
2329 c->Value(c->NumElements(updates_shape)) > 0)) {
2330 return errors::InvalidArgument(
2331 "Indices and updates specified for empty input");
2332 }
2333
2334 if (c->RankKnown(indices_shape) && c->RankKnown(updates_shape)) {
2335 const int64 outer_dims = c->Rank(indices_shape) - 1;
2336 const DimensionHandle ixdim = c->Dim(indices_shape, -1);
2337
2338 // We can only do more validation if the last dimension of indices
2339 // is a known value.
2340 if (c->ValueKnown(ixdim)) {
2341 int64 ix = c->Value(ixdim);
2342 ShapeHandle unused;
2343 ShapeHandle prefix_indices;
2344 TF_RETURN_IF_ERROR(
2345 c->Subshape(indices_shape, 0, outer_dims, &prefix_indices));
2346 ShapeHandle prefix_updates;
2347 TF_RETURN_IF_ERROR(
2348 c->Subshape(updates_shape, 0, outer_dims, &prefix_updates));
2349
2350 Status s = c->Merge(prefix_indices, prefix_updates, &unused);
2351 if (!s.ok()) {
2352 return errors::InvalidArgument(
2353 "Dimensions [0,", outer_dims,
2354 ") of indices[shape=", c->DebugString(indices_shape),
2355 "] = ", c->DebugString(prefix_indices),
2356 " must match dimensions [0,", outer_dims,
2357 ") of updates[shape=", c->DebugString(updates_shape),
2358 "] = ", c->DebugString(prefix_updates), ": ", s.error_message());
2359 }
2360
2361 ShapeHandle suffix_output;
2362 TF_RETURN_IF_ERROR(c->Subshape(input_shape, ix, &suffix_output));
2363 ShapeHandle suffix_updates;
2364 TF_RETURN_IF_ERROR(
2365 c->Subshape(updates_shape, outer_dims, &suffix_updates));
2366 s = c->Merge(suffix_output, suffix_updates, &unused);
2367 if (!s.ok()) {
2368 return errors::InvalidArgument(
2369 "Dimensions [", ix, ",", c->Rank(input_shape),
2370 ") of input[shape=", c->DebugString(input_shape),
2371 "] = ", c->DebugString(suffix_output), " must match dimensions [",
2372 outer_dims, ",", c->Rank(updates_shape),
2373 ") of updates[shape=", c->DebugString(updates_shape),
2374 "] = ", c->DebugString(suffix_updates), ": ", s.error_message());
2375 }
2376 }
2377 }
2378
2379 if (c->input_handle_shapes_and_types(0) == nullptr && c->num_outputs() > 0) {
2380 // This is called for tf.scatter_nd; output is a tensor with this shape.
2381 c->set_output(0, input_shape);
2382 }
2383 return Status::OK();
2384 }
2385
ExplicitShape(InferenceContext * c)2386 Status ExplicitShape(InferenceContext* c) {
2387 PartialTensorShape shape;
2388 TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape));
2389 ShapeHandle output_shape;
2390 TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(shape, &output_shape));
2391 c->set_output(0, output_shape);
2392 return Status::OK();
2393 }
2394
ExplicitShapes(InferenceContext * c)2395 Status ExplicitShapes(InferenceContext* c) {
2396 std::vector<PartialTensorShape> shapes;
2397 TF_RETURN_IF_ERROR(c->GetAttr("shapes", &shapes));
2398 if (shapes.empty()) {
2399 return errors::Internal("shapes attribute is empty");
2400 }
2401 for (int i = 0, end = shapes.size(); i < end; ++i) {
2402 ShapeHandle output_shape;
2403 TF_RETURN_IF_ERROR(
2404 c->MakeShapeFromPartialTensorShape(shapes[i], &output_shape));
2405 c->set_output(i, output_shape);
2406 }
2407 return Status::OK();
2408 }
2409
SparseReduceShapeFn(InferenceContext * c)2410 Status SparseReduceShapeFn(InferenceContext* c) {
2411 // Input 0: input_indices
2412 // Input 1: input_values
2413 // Input 2: input_shape
2414 // Input 3: reduction_axes
2415 // Attr: keep_dims
2416 bool keep_dims = false;
2417 TF_RETURN_IF_ERROR(c->GetAttr("keep_dims", &keep_dims));
2418
2419 const Tensor* shape_tensor = c->input_tensor(2);
2420 const Tensor* axes_tensor = c->input_tensor(3);
2421 if (shape_tensor != nullptr && axes_tensor != nullptr) {
2422 auto shape_vec = shape_tensor->flat<int64>();
2423 auto axes_vec = axes_tensor->flat<int32>();
2424
2425 int64 ndims = shape_vec.size();
2426 absl::flat_hash_set<int64> axes;
2427 for (int i = 0; i < axes_vec.size(); i++) {
2428 axes.insert((axes_vec(i) + ndims) % ndims);
2429 }
2430
2431 std::vector<DimensionHandle> dims;
2432 if (keep_dims) {
2433 dims.reserve(ndims);
2434 for (int d = 0; d < ndims; ++d) {
2435 if (axes.find(d) == axes.end()) {
2436 dims.push_back(c->MakeDim(shape_vec(d)));
2437 } else {
2438 dims.push_back(c->MakeDim(1));
2439 }
2440 }
2441 } else {
2442 for (int d = 0; d < ndims; ++d) {
2443 if (axes.find(d) == axes.end()) {
2444 dims.push_back(c->MakeDim(shape_vec(d)));
2445 }
2446 }
2447 }
2448
2449 c->set_output(0, c->MakeShape(dims));
2450 return Status::OK();
2451 }
2452 return UnknownShape(c);
2453 }
2454
2455 } // namespace shape_inference
2456
2457 } // namespace tensorflow
2458