1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/framework/common_shape_fns.h"
16
17 #include "absl/container/flat_hash_map.h"
18 #include "absl/container/flat_hash_set.h"
19 #include "absl/strings/match.h"
20 #include "absl/strings/str_split.h"
21 #include "absl/strings/string_view.h"
22 #include "tensorflow/core/framework/attr_value.pb.h"
23 #include "tensorflow/core/framework/shape_inference.h"
24 #include "tensorflow/core/lib/core/errors.h"
25 #include "tensorflow/core/lib/gtl/inlined_vector.h"
26 #include "tensorflow/core/util/einsum_op_util.h"
27
28 namespace tensorflow {
29
30 namespace shape_inference {
31
32 // The V2 version computes windowed output size with arbitrary dilation_rate and
33 // explicit padding, while the original version only handles the cases where
34 // dilation_rates equal to 1 and the padding is SAME or VALID.
GetWindowedOutputSizeFromDimsV2(shape_inference::InferenceContext * c,shape_inference::DimensionHandle input_size,shape_inference::DimensionOrConstant filter_size,int64_t dilation_rate,int64_t stride,Padding padding_type,int64_t padding_before,int64_t padding_after,shape_inference::DimensionHandle * output_size)35 Status GetWindowedOutputSizeFromDimsV2(
36 shape_inference::InferenceContext* c,
37 shape_inference::DimensionHandle input_size,
38 shape_inference::DimensionOrConstant filter_size, int64_t dilation_rate,
39 int64_t stride, Padding padding_type, int64_t padding_before,
40 int64_t padding_after, shape_inference::DimensionHandle* output_size) {
41 if (stride <= 0) {
42 return errors::InvalidArgument("Stride must be > 0, but got ", stride);
43 }
44
45 if (dilation_rate < 1) {
46 return errors::InvalidArgument("Dilation rate must be >= 1, but got ",
47 dilation_rate);
48 }
49
50 // See also the parallel implementation in GetWindowedOutputSizeVerbose.
51 switch (padding_type) {
52 case Padding::VALID:
53 padding_before = padding_after = 0;
54 TF_FALLTHROUGH_INTENDED;
55 case Padding::EXPLICIT:
56 TF_RETURN_IF_ERROR(
57 c->Add(input_size, padding_before + padding_after, &input_size));
58 if (dilation_rate > 1) {
59 DimensionHandle window_size;
60 TF_RETURN_IF_ERROR(
61 c->Subtract(c->MakeDim(filter_size), 1, &window_size));
62 TF_RETURN_IF_ERROR(
63 c->Multiply(window_size, dilation_rate, &window_size));
64 TF_RETURN_IF_ERROR(c->Add(window_size, 1, &window_size));
65 TF_RETURN_IF_ERROR(c->Subtract(input_size, window_size, output_size));
66 } else {
67 TF_RETURN_IF_ERROR(c->Subtract(input_size, filter_size, output_size));
68 }
69 TF_RETURN_IF_ERROR(c->Add(*output_size, stride, output_size));
70 TF_RETURN_IF_ERROR(c->Divide(*output_size, stride,
71 /*evenly_divisible=*/false, output_size));
72 break;
73 case Padding::SAME:
74 TF_RETURN_IF_ERROR(c->Add(input_size, stride - 1, output_size));
75 TF_RETURN_IF_ERROR(c->Divide(*output_size, stride,
76 /*evenly_divisible=*/false, output_size));
77 break;
78 }
79 return OkStatus();
80 }
81
GetWindowedOutputSizeFromDims(shape_inference::InferenceContext * c,shape_inference::DimensionHandle input_size,shape_inference::DimensionOrConstant filter_size,int64_t stride,Padding padding_type,shape_inference::DimensionHandle * output_size)82 Status GetWindowedOutputSizeFromDims(
83 shape_inference::InferenceContext* c,
84 shape_inference::DimensionHandle input_size,
85 shape_inference::DimensionOrConstant filter_size, int64_t stride,
86 Padding padding_type, shape_inference::DimensionHandle* output_size) {
87 if (padding_type == Padding::EXPLICIT) {
88 return errors::Internal(
89 "GetWindowedOutputSizeFromDims does not handle EXPLICIT padding; call "
90 "GetWindowedOutputSizeFromDimsV2 instead");
91 }
92 return GetWindowedOutputSizeFromDimsV2(c, input_size, filter_size,
93 /*dilation_rate=*/1, stride,
94 padding_type,
95 // Give dummy values of -1 to
96 // padding_before and padding_after,
97 // since explicit padding is not used.
98 -1, -1, output_size);
99 }
100
UnchangedShape(shape_inference::InferenceContext * c)101 Status UnchangedShape(shape_inference::InferenceContext* c) {
102 c->set_output(0, c->input(0));
103 auto* handle_data = c->input_handle_shapes_and_types(0);
104 if (handle_data != nullptr) {
105 c->set_output_handle_shapes_and_types(0, *handle_data);
106 }
107 return OkStatus();
108 }
109
MatMulShape(shape_inference::InferenceContext * c)110 Status MatMulShape(shape_inference::InferenceContext* c) {
111 ShapeHandle a;
112 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &a));
113
114 ShapeHandle b;
115 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &b));
116
117 bool transpose_a, transpose_b;
118 TF_RETURN_IF_ERROR(c->GetAttr("transpose_a", &transpose_a));
119 TF_RETURN_IF_ERROR(c->GetAttr("transpose_b", &transpose_b));
120 DimensionHandle output_rows = transpose_a ? c->Dim(a, 1) : c->Dim(a, 0);
121 DimensionHandle output_cols = transpose_b ? c->Dim(b, 0) : c->Dim(b, 1);
122
123 // Validate that the inner shapes are compatible.
124 DimensionHandle inner_a = transpose_a ? c->Dim(a, 0) : c->Dim(a, 1);
125 DimensionHandle inner_b = transpose_b ? c->Dim(b, 1) : c->Dim(b, 0);
126 DimensionHandle merged;
127 TF_RETURN_IF_ERROR(c->Merge(inner_a, inner_b, &merged));
128
129 c->set_output(0, c->Matrix(output_rows, output_cols));
130 return OkStatus();
131 }
132
133 namespace {
134
135 // Validate that an Einsum subscript contains exactly one or zero ellipsis; and
136 // that periods (.) occur only within an ellipses (...).
ValidateEinsumEllipsis(absl::string_view subscript,bool * found_ellipsis)137 Status ValidateEinsumEllipsis(absl::string_view subscript,
138 bool* found_ellipsis) {
139 const int num_periods = absl::c_count(subscript, '.');
140 if (num_periods != 0 && num_periods != 3) {
141 return errors::InvalidArgument(
142 "Expected at most one ellipsis (...), but found ", num_periods,
143 " periods (.) in the input subscript: ", subscript);
144 }
145 if (num_periods == 3 && !absl::StrContains(subscript, "...")) {
146 return errors::InvalidArgument(
147 "Periods found outside of ellipsis in subscript: ", subscript);
148 }
149 *found_ellipsis = num_periods > 0;
150 return OkStatus();
151 }
152
153 } // namespace
154
EinsumShape(shape_inference::InferenceContext * c)155 Status EinsumShape(shape_inference::InferenceContext* c) {
156 // We assume that the equation has a valid format. Either (x),(y)->(z)
157 // or (x)->(z), where each of (x), (y) and (z) are concatenation of zero or
158 // more latin alphabets and contains at most one ellipsis ('...').
159 string equation;
160 TF_RETURN_IF_ERROR(c->GetAttr("equation", &equation));
161 gtl::InlinedVector<string, 2> input_labels;
162 string output_labels;
163 TF_RETURN_IF_ERROR(
164 ValidateEinsumEquation(equation, &input_labels, &output_labels));
165
166 if (c->num_inputs() == 0 || c->num_inputs() > 2) {
167 return errors::InvalidArgument("Expected either 1 or 2 inputs but got: ",
168 c->num_inputs());
169 }
170 const int input_labels_size = input_labels.size();
171 if (c->num_inputs() != input_labels_size) {
172 return errors::InvalidArgument("Expected ", input_labels.size(),
173 " inputs for equation ", equation,
174 " but got: ", c->num_inputs());
175 }
176
177 // Validate input subscripts, build the label to dimension mapping and obtain
178 // the broadcast shapes that map to ellipsis.
179 absl::flat_hash_map<char, DimensionHandle> label_to_dimension;
180 gtl::InlinedVector<ShapeHandle, 2> input_bcast_shapes(c->num_inputs());
181 for (int i = 0, end = c->num_inputs(); i < end; ++i) {
182 bool has_ellipsis = false;
183 TF_RETURN_IF_ERROR(ValidateEinsumEllipsis(input_labels[i], &has_ellipsis));
184 ShapeHandle input_shape = c->input(i);
185 // Validate that the input rank is sufficient for the given number of named
186 // labels.
187 if (c->RankKnown(input_shape)) {
188 if (has_ellipsis) {
189 const int num_named_labels =
190 static_cast<int>(input_labels[i].size()) - 3;
191 TF_RETURN_WITH_CONTEXT_IF_ERROR(
192 c->WithRankAtLeast(input_shape, num_named_labels, &input_shape),
193 " for ", i, "th input and equation: ", equation);
194 } else {
195 const int num_named_labels = static_cast<int>(input_labels[i].size());
196 TF_RETURN_WITH_CONTEXT_IF_ERROR(
197 c->WithRank(input_shape, num_named_labels, &input_shape), " for ",
198 i, "th input and equation: ", equation);
199 }
200 }
201
202 bool seen_ellipsis = false;
203 input_bcast_shapes[i] = c->Scalar();
204 // Run through the input labels; populate label_to_dimension mapping and
205 // compute the broadcast shapes corresponding to the ellipsis (if present).
206 for (int label_idx = 0, end = input_labels[i].size(); label_idx < end;
207 ++label_idx) {
208 const char label = input_labels[i][label_idx];
209 // Calculate the input axis that the current label is referring to. After
210 // the ellipsis, the axis may be found by using negative indices; i.e the
211 // (rank - k)th dimension corresponds to the (num_labels - k)th label.
212 const int64_t axis_before_ellipsis = label_idx;
213 const int64_t axis_after_ellipsis =
214 c->RankKnown(input_shape)
215 ? label_idx + c->Rank(input_shape) - input_labels[i].size()
216 : -1;
217
218 // Populate the input broadcast shape when we encounter an ellipsis (...).
219 if (label == '.') {
220 if (!c->RankKnown(input_shape)) {
221 input_bcast_shapes[i] = c->UnknownShape();
222 } else {
223 // The broadcast shape runs till the named label right after the
224 // ellipsis, the label with index (label_idx + 3).
225 TF_RETURN_IF_ERROR(c->Subshape(input_shape, axis_before_ellipsis,
226 axis_after_ellipsis + 3,
227 &input_bcast_shapes[i]));
228 }
229 label_idx += 2; // Skip the rest of the ellipsis.
230 seen_ellipsis = true;
231 continue;
232 }
233 // Obtain the dimension that the current label corresponds to.
234 int64_t axis = seen_ellipsis ? axis_after_ellipsis : axis_before_ellipsis;
235 DimensionHandle new_dim = c->RankKnown(input_shape)
236 ? c->Dim(input_shape, axis)
237 : c->UnknownDim();
238 // If we've seen this label before, make sure previous and current
239 // dimensions are compatible.
240 if (label_to_dimension.contains(label)) {
241 DimensionHandle merged;
242 TF_RETURN_IF_ERROR(
243 c->Merge(label_to_dimension[label], new_dim, &merged));
244 label_to_dimension[label] = merged;
245 } else {
246 label_to_dimension[label] = new_dim;
247 }
248 }
249 }
250
251 // For two inputs, broadcast the two input broadcast shapes to create the
252 // output broadcast shape. For one input, just copy the single broadcast
253 // shape.
254 ShapeHandle output_bcast_shape;
255 if (input_bcast_shapes.size() == 1) {
256 output_bcast_shape = input_bcast_shapes[0];
257 } else if (input_bcast_shapes.size() == 2) {
258 TF_RETURN_IF_ERROR(BroadcastBinaryOpOutputShapeFnHelper(
259 c, input_bcast_shapes[0], input_bcast_shapes[1], true,
260 &output_bcast_shape));
261 }
262
263 bool output_has_ellipsis = false;
264 TF_RETURN_IF_ERROR(
265 ValidateEinsumEllipsis(output_labels, &output_has_ellipsis));
266 if (output_has_ellipsis) {
267 // If the output subscript has ellipsis and the output broadcast rank is
268 // unknown, then the output shape should have unknown rank.
269 if (!c->RankKnown(output_bcast_shape)) {
270 c->set_output(0, c->UnknownShape());
271 return OkStatus();
272 }
273 } else {
274 // If the output subscripts don't have ellipsis then make sure the output
275 // broadcasting shape is empty.
276 TF_RETURN_WITH_CONTEXT_IF_ERROR(
277 c->WithRankAtMost(output_bcast_shape, 0, &output_bcast_shape),
278 " for einsum equation '", equation,
279 "' without ellipsis (...) in the output subscripts where input(s) have "
280 "non-empty broadcasting shape");
281 output_bcast_shape = c->Scalar();
282 }
283
284 // Create the output shape from output labels and label_to_dimension mapping.
285 std::vector<DimensionHandle> output_dims;
286 for (int label_idx = 0, end = output_labels.size(); label_idx < end;
287 ++label_idx) {
288 const char label = output_labels[label_idx];
289 // Append the output_bcast_shape when the ellipsis is encountered.
290 if (label == '.') {
291 for (int k = 0; k < c->Rank(output_bcast_shape); ++k) {
292 output_dims.push_back(c->Dim(output_bcast_shape, k));
293 }
294 label_idx += 2; // Skip the rest of the ellipsis.
295 continue;
296 }
297 auto dimension_it = label_to_dimension.find(label);
298 if (dimension_it == label_to_dimension.end()) {
299 return errors::InvalidArgument(
300 "Einsum output subscripts for equation '", equation, "' has label '",
301 label, "' which is not present in the input subscripts");
302 }
303 output_dims.push_back(dimension_it->second);
304 }
305 c->set_output(0, c->MakeShape(output_dims));
306 return OkStatus();
307 }
308
BatchMatMulV2Shape(shape_inference::InferenceContext * c)309 Status BatchMatMulV2Shape(shape_inference::InferenceContext* c) {
310 ShapeHandle a_shape;
311 ShapeHandle b_shape;
312 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &a_shape));
313 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 2, &b_shape));
314
315 // Determine output rows and columns.
316 bool adj_x;
317 bool adj_y;
318 TF_RETURN_IF_ERROR(c->GetAttr("adj_x", &adj_x));
319 TF_RETURN_IF_ERROR(c->GetAttr("adj_y", &adj_y));
320 DimensionHandle output_rows = c->Dim(a_shape, adj_x ? -1 : -2);
321 DimensionHandle output_cols = c->Dim(b_shape, adj_y ? -2 : -1);
322
323 // Inner dimensions should be compatible.
324 DimensionHandle inner_merged;
325 TF_RETURN_IF_ERROR(c->Merge(c->Dim(a_shape, adj_x ? -2 : -1),
326 c->Dim(b_shape, adj_y ? -1 : -2), &inner_merged));
327
328 // Batch dimensions should broadcast with each other.
329 ShapeHandle a_batch_shape;
330 ShapeHandle b_batch_shape;
331 ShapeHandle output_batch_shape;
332 TF_RETURN_IF_ERROR(c->Subshape(a_shape, 0, -2, &a_batch_shape));
333 TF_RETURN_IF_ERROR(c->Subshape(b_shape, 0, -2, &b_batch_shape));
334
335 TF_RETURN_IF_ERROR(BroadcastBinaryOpOutputShapeFnHelper(
336 c, a_batch_shape, b_batch_shape, true, &output_batch_shape));
337
338 ShapeHandle output_shape;
339 TF_RETURN_IF_ERROR(c->Concatenate(
340 output_batch_shape, c->Matrix(output_rows, output_cols), &output_shape));
341
342 c->set_output(0, output_shape);
343 return OkStatus();
344 }
345
BatchMatMulShape(shape_inference::InferenceContext * c)346 Status BatchMatMulShape(shape_inference::InferenceContext* c) {
347 ShapeHandle a_shape;
348 ShapeHandle b_shape;
349 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &a_shape));
350 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 2, &b_shape));
351
352 // Determine output rows and cols.
353 bool adj_x;
354 bool adj_y;
355 TF_RETURN_IF_ERROR(c->GetAttr("adj_x", &adj_x));
356 TF_RETURN_IF_ERROR(c->GetAttr("adj_y", &adj_y));
357 DimensionHandle output_rows = c->Dim(a_shape, adj_x ? -1 : -2);
358 DimensionHandle output_cols = c->Dim(b_shape, adj_y ? -2 : -1);
359
360 // Batch dims match between inputs.
361 ShapeHandle a_batch_dims;
362 ShapeHandle b_batch_dims;
363 ShapeHandle batch_dims;
364 TF_RETURN_IF_ERROR(c->Subshape(a_shape, 0, -2, &a_batch_dims));
365 TF_RETURN_IF_ERROR(c->Subshape(b_shape, 0, -2, &b_batch_dims));
366 TF_RETURN_IF_ERROR(c->Merge(a_batch_dims, b_batch_dims, &batch_dims));
367
368 // Assert inner dims match.
369 DimensionHandle unused;
370 TF_RETURN_IF_ERROR(c->Merge(c->Dim(a_shape, adj_x ? -2 : -1),
371 c->Dim(b_shape, adj_y ? -1 : -2), &unused));
372
373 ShapeHandle out;
374 TF_RETURN_IF_ERROR(
375 c->Concatenate(batch_dims, c->Matrix(output_rows, output_cols), &out));
376 c->set_output(0, out);
377 return OkStatus();
378 }
379
380 // --------------------------------------------------------------------------
381
BiasAddShape(shape_inference::InferenceContext * c)382 Status BiasAddShape(shape_inference::InferenceContext* c) {
383 ShapeHandle input_shape;
384
385 // Fetch the data_format attribute, which may not exist.
386 string data_format;
387 Status s = c->GetAttr("data_format", &data_format);
388
389 if (s.ok() && data_format == "NCHW") {
390 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 3, &input_shape));
391 } else {
392 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input_shape));
393 }
394
395 ShapeHandle bias_shape;
396 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &bias_shape));
397 DimensionHandle bias_dim = c->Dim(bias_shape, 0);
398
399 // If rank unknown, return unknown shape.
400 if (!c->RankKnown(input_shape)) {
401 c->set_output(0, c->UnknownShape());
402 return OkStatus();
403 }
404
405 // Output has the same shape as the input, and matches the length of
406 // the bias in its bias dimension.
407 ShapeHandle output_shape;
408 if (s.ok() && data_format == "NCHW") {
409 // Merge the length of bias_shape into the third to last dimension
410 ShapeHandle first;
411 TF_RETURN_IF_ERROR(c->Subshape(input_shape, 0, 1, &first));
412
413 ShapeHandle last;
414 TF_RETURN_IF_ERROR(c->Subshape(input_shape, 2, &last));
415
416 DimensionHandle input_bias_dim = c->Dim(input_shape, 1);
417 DimensionHandle merged_bias_dim;
418 TF_RETURN_IF_ERROR(c->Merge(input_bias_dim, bias_dim, &merged_bias_dim));
419 ShapeHandle merged_bias = c->Vector(merged_bias_dim);
420
421 ShapeHandle temp;
422 TF_RETURN_IF_ERROR(c->Concatenate(first, merged_bias, &temp));
423 TF_RETURN_IF_ERROR(c->Concatenate(temp, last, &output_shape));
424 } else {
425 ShapeHandle all_but_bias;
426 TF_RETURN_IF_ERROR(c->Subshape(input_shape, 0, -1, &all_but_bias));
427
428 DimensionHandle input_bias_dim = c->Dim(input_shape, -1);
429 DimensionHandle merged_bias_dim;
430 TF_RETURN_IF_ERROR(c->Merge(input_bias_dim, bias_dim, &merged_bias_dim));
431
432 ShapeHandle merged_bias = c->Vector(merged_bias_dim);
433 TF_RETURN_IF_ERROR(
434 c->Concatenate(all_but_bias, merged_bias, &output_shape));
435 }
436
437 c->set_output(0, output_shape);
438 return OkStatus();
439 }
440
BiasAddGradShape(shape_inference::InferenceContext * c)441 Status BiasAddGradShape(shape_inference::InferenceContext* c) {
442 ShapeHandle input_shape;
443 // Fetch the data_format attribute, which may not exist.
444 string data_format;
445 Status s = c->GetAttr("data_format", &data_format);
446
447 if (s.ok() && data_format == "NCHW") {
448 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 3, &input_shape));
449 c->set_output(0, c->Vector(c->Dim(input_shape, 1)));
450 } else {
451 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input_shape));
452 c->set_output(0, c->Vector(c->Dim(input_shape, -1)));
453 }
454
455 return OkStatus();
456 }
457
CheckFormatConstraintsOnShape(const TensorFormat tensor_format,const ShapeHandle shape_handle,const string & tensor_name,shape_inference::InferenceContext * c)458 Status CheckFormatConstraintsOnShape(const TensorFormat tensor_format,
459 const ShapeHandle shape_handle,
460 const string& tensor_name,
461 shape_inference::InferenceContext* c) {
462 if (tensor_format == FORMAT_NCHW_VECT_C) {
463 // Check that the vect dim has size 4 or 32.
464 const int num_dims = c->Rank(shape_handle);
465 DimensionHandle vect_dim = c->Dim(
466 shape_handle, GetTensorInnerFeatureDimIndex(num_dims, tensor_format));
467 int64_t vect_dim_val = c->Value(vect_dim);
468 if (vect_dim_val != 4 && vect_dim_val != 32) {
469 return errors::InvalidArgument(
470 "VECT_C dimension must be 4 or 32, but is ", vect_dim_val);
471 }
472 }
473
474 return OkStatus();
475 }
476
DatasetIteratorShape(shape_inference::InferenceContext * c)477 Status DatasetIteratorShape(shape_inference::InferenceContext* c) {
478 shape_inference::ShapeHandle unused;
479 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
480 std::vector<PartialTensorShape> output_shapes;
481 TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes));
482 const int output_shapes_size = output_shapes.size();
483 if (output_shapes_size != c->num_outputs()) {
484 return errors::InvalidArgument(
485 "`output_shapes` must be the same length as `output_types` (",
486 output_shapes.size(), " vs. ", c->num_outputs());
487 }
488 for (size_t i = 0; i < output_shapes.size(); ++i) {
489 shape_inference::ShapeHandle output_shape_handle;
490 TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(
491 output_shapes[i], &output_shape_handle));
492 c->set_output(static_cast<int>(i), output_shape_handle);
493 }
494 return OkStatus();
495 }
496
MakeShapeFromFormat(TensorFormat format,DimensionOrConstant N,const std::vector<DimensionOrConstant> & spatial,DimensionOrConstant C,ShapeHandle * out,shape_inference::InferenceContext * context)497 Status MakeShapeFromFormat(TensorFormat format, DimensionOrConstant N,
498 const std::vector<DimensionOrConstant>& spatial,
499 DimensionOrConstant C, ShapeHandle* out,
500 shape_inference::InferenceContext* context) {
501 const int num_dims = GetTensorDimsFromSpatialDims(spatial.size(), format);
502 std::vector<DimensionHandle> dims_actual(num_dims);
503 dims_actual[GetTensorBatchDimIndex(num_dims, format)] = context->MakeDim(N);
504 int outer_c_index = GetTensorFeatureDimIndex(num_dims, format);
505 dims_actual[outer_c_index] = context->MakeDim(C);
506 if (format == FORMAT_NCHW_VECT_C) {
507 dims_actual[GetTensorInnerFeatureDimIndex(num_dims, format)] =
508 context->MakeDim(4);
509 } else if (format == FORMAT_NHWC_VECT_W) {
510 dims_actual[GetTensorInnerWidthDimIndex(num_dims, format)] =
511 context->MakeDim(4);
512 }
513 for (int spatial_dim = 0, end = spatial.size(); spatial_dim < end;
514 spatial_dim++) {
515 dims_actual[GetTensorSpatialDimIndex(num_dims, format, spatial_dim)] =
516 context->MakeDim(spatial[spatial_dim]);
517 }
518 *out = context->MakeShape(dims_actual);
519 return OkStatus();
520 }
521
DimensionsFromShape(ShapeHandle shape,TensorFormat format,DimensionHandle * batch_dim,gtl::MutableArraySlice<DimensionHandle> spatial_dims,DimensionHandle * filter_dim,InferenceContext * context)522 Status DimensionsFromShape(ShapeHandle shape, TensorFormat format,
523 DimensionHandle* batch_dim,
524 gtl::MutableArraySlice<DimensionHandle> spatial_dims,
525 DimensionHandle* filter_dim,
526 InferenceContext* context) {
527 const int32_t rank =
528 GetTensorDimsFromSpatialDims(spatial_dims.size(), format);
529 // Batch.
530 *batch_dim = context->Dim(shape, GetTensorBatchDimIndex(rank, format));
531 // Spatial.
532 for (int spatial_dim_index = 0, end = spatial_dims.size();
533 spatial_dim_index < end; ++spatial_dim_index) {
534 spatial_dims[spatial_dim_index] = context->Dim(
535 shape, GetTensorSpatialDimIndex(rank, format, spatial_dim_index));
536 }
537 // Channel.
538 *filter_dim = context->Dim(shape, GetTensorFeatureDimIndex(rank, format));
539 if (format == FORMAT_NCHW_VECT_C) {
540 TF_RETURN_IF_ERROR(context->Multiply(
541 *filter_dim,
542 context->Dim(shape, GetTensorInnerFeatureDimIndex(rank, format)),
543 filter_dim));
544 }
545 return OkStatus();
546 }
547
548 // vect_size must be provided if format is NCHW_VECT_C.
ShapeFromDimensions(DimensionHandle batch_dim,gtl::ArraySlice<DimensionHandle> spatial_dims,DimensionHandle filter_dim,TensorFormat format,absl::optional<DimensionHandle> vect_size,InferenceContext * context,ShapeHandle * shape)549 Status ShapeFromDimensions(DimensionHandle batch_dim,
550 gtl::ArraySlice<DimensionHandle> spatial_dims,
551 DimensionHandle filter_dim, TensorFormat format,
552 absl::optional<DimensionHandle> vect_size,
553 InferenceContext* context, ShapeHandle* shape) {
554 const int32_t rank =
555 GetTensorDimsFromSpatialDims(spatial_dims.size(), format);
556 std::vector<DimensionHandle> out_dims(rank);
557
558 // Batch.
559 out_dims[tensorflow::GetTensorBatchDimIndex(rank, format)] = batch_dim;
560 // Spatial.
561 for (int spatial_dim_index = 0, end = spatial_dims.size();
562 spatial_dim_index < end; ++spatial_dim_index) {
563 out_dims[tensorflow::GetTensorSpatialDimIndex(
564 rank, format, spatial_dim_index)] = spatial_dims[spatial_dim_index];
565 }
566 // Channel.
567 if (format == tensorflow::FORMAT_NCHW_VECT_C) {
568 // When format is NCHW_VECT_C, factor the feature map count into the outer
569 // feature count and the inner feature count (4 or 32).
570 CHECK(vect_size.has_value()); // Crash ok.
571 TF_RETURN_IF_ERROR(context->Divide(
572 filter_dim, *vect_size, /*evenly_divisible=*/true,
573 &out_dims[tensorflow::GetTensorFeatureDimIndex(rank, format)]));
574 out_dims[GetTensorInnerFeatureDimIndex(rank, format)] = *vect_size;
575 } else {
576 out_dims[tensorflow::GetTensorFeatureDimIndex(rank, format)] = filter_dim;
577 }
578
579 *shape = context->MakeShape(out_dims);
580 return OkStatus();
581 }
582
583 namespace {
584
Conv2DShapeImpl(shape_inference::InferenceContext * c,bool supports_explicit_padding)585 Status Conv2DShapeImpl(shape_inference::InferenceContext* c,
586 bool supports_explicit_padding) {
587 string data_format_str, filter_format_str;
588 if (!c->GetAttr("data_format", &data_format_str).ok()) {
589 data_format_str = "NHWC";
590 }
591 if (!c->GetAttr("filter_format", &filter_format_str).ok()) {
592 filter_format_str = "HWIO";
593 }
594
595 TensorFormat data_format;
596 if (!FormatFromString(data_format_str, &data_format)) {
597 return errors::InvalidArgument("Invalid data format string: ",
598 data_format_str);
599 }
600 FilterTensorFormat filter_format;
601 if (!FilterFormatFromString(filter_format_str, &filter_format)) {
602 return errors::InvalidArgument("Invalid filter format string: ",
603 filter_format_str);
604 }
605
606 constexpr int num_spatial_dims = 2;
607 const int rank = GetTensorDimsFromSpatialDims(num_spatial_dims, data_format);
608 ShapeHandle conv_input_shape;
609 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &conv_input_shape));
610 TF_RETURN_IF_ERROR(CheckFormatConstraintsOnShape(
611 data_format, conv_input_shape, "conv_input", c));
612
613 // The filter rank should match the input (4 for NCHW, 5 for NCHW_VECT_C).
614 ShapeHandle filter_shape;
615 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), rank, &filter_shape));
616 TF_RETURN_IF_ERROR(
617 CheckFormatConstraintsOnShape(data_format, filter_shape, "filter", c));
618
619 std::vector<int32> dilations;
620 TF_RETURN_IF_ERROR(c->GetAttr("dilations", &dilations));
621
622 if (dilations.size() != 4) {
623 return errors::InvalidArgument(
624 "Conv2D requires the dilation attribute to contain 4 values, but got: ",
625 dilations.size());
626 }
627
628 std::vector<int32> strides;
629 TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
630
631 // strides.size() should be 4 (NCHW) even if the input is 5 (NCHW_VECT_C).
632 if (strides.size() != 4) {
633 return errors::InvalidArgument("Conv2D on data format ", data_format_str,
634 " requires the stride attribute to contain"
635 " 4 values, but got: ",
636 strides.size());
637 }
638
639 const int32_t stride_rows = GetTensorDim(strides, data_format, 'H');
640 const int32_t stride_cols = GetTensorDim(strides, data_format, 'W');
641 const int32_t dilation_rows = GetTensorDim(dilations, data_format, 'H');
642 const int32_t dilation_cols = GetTensorDim(dilations, data_format, 'W');
643
644 DimensionHandle batch_size_dim;
645 DimensionHandle input_depth_dim;
646 gtl::InlinedVector<DimensionHandle, 2> input_spatial_dims(2);
647 TF_RETURN_IF_ERROR(DimensionsFromShape(
648 conv_input_shape, data_format, &batch_size_dim,
649 absl::MakeSpan(input_spatial_dims), &input_depth_dim, c));
650
651 DimensionHandle output_depth_dim = c->Dim(
652 filter_shape, GetFilterDimIndex<num_spatial_dims>(filter_format, 'O'));
653 DimensionHandle filter_rows_dim = c->Dim(
654 filter_shape, GetFilterDimIndex<num_spatial_dims>(filter_format, 'H'));
655 DimensionHandle filter_cols_dim = c->Dim(
656 filter_shape, GetFilterDimIndex<num_spatial_dims>(filter_format, 'W'));
657 DimensionHandle filter_input_depth_dim;
658 if (filter_format == FORMAT_OIHW_VECT_I) {
659 TF_RETURN_IF_ERROR(c->Multiply(
660 c->Dim(filter_shape,
661 GetFilterDimIndex<num_spatial_dims>(filter_format, 'I')),
662 c->Dim(filter_shape,
663 GetFilterTensorInnerInputChannelsDimIndex(rank, filter_format)),
664 &filter_input_depth_dim));
665 } else {
666 filter_input_depth_dim = c->Dim(
667 filter_shape, GetFilterDimIndex<num_spatial_dims>(filter_format, 'I'));
668 }
669
670 // Check that the input tensor and the filter tensor agree on the channel
671 // count.
672 if (c->ValueKnown(input_depth_dim) && c->ValueKnown(filter_input_depth_dim)) {
673 int64_t input_depth_value = c->Value(input_depth_dim),
674 filter_input_depth_value = c->Value(filter_input_depth_dim);
675 if (filter_input_depth_value == 0)
676 return errors::InvalidArgument("Depth of filter must not be 0");
677 if (input_depth_value % filter_input_depth_value != 0)
678 return errors::InvalidArgument(
679 "Depth of input (", input_depth_value,
680 ") is not a multiple of input depth of filter (",
681 filter_input_depth_value, ")");
682 if (input_depth_value != filter_input_depth_value) {
683 int64_t num_groups = input_depth_value / filter_input_depth_value;
684 if (c->ValueKnown(output_depth_dim)) {
685 int64_t output_depth_value = c->Value(output_depth_dim);
686 if (num_groups == 0)
687 return errors::InvalidArgument("Number of groups must not be 0");
688 if (output_depth_value % num_groups != 0)
689 return errors::InvalidArgument(
690 "Depth of output (", output_depth_value,
691 ") is not a multiple of the number of groups (", num_groups, ")");
692 }
693 }
694 }
695
696 Padding padding;
697 TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
698 std::vector<int64_t> explicit_paddings;
699 if (supports_explicit_padding) {
700 Status s = c->GetAttr("explicit_paddings", &explicit_paddings);
701 // Use the default value, which is an empty list, if the attribute is not
702 // found. Otherwise return the error to the caller.
703 if (!s.ok() && !errors::IsNotFound(s)) {
704 return s;
705 }
706 TF_RETURN_IF_ERROR(CheckValidPadding(padding, explicit_paddings,
707 /*num_dims=*/4, data_format));
708 } else {
709 if (padding == Padding::EXPLICIT) {
710 return errors::InvalidArgument(
711 "Expected non-explicit padding but got explicit padding");
712 }
713 std::vector<int64_t> p_list;
714 // `padding_list` attribute is used by Fused int8 convolutions to support
715 // explicit paddings.
716 Status s_p_list = c->GetAttr("padding_list", &p_list);
717 if (!s_p_list.ok() && !errors::IsNotFound(s_p_list)) {
718 return s_p_list;
719 }
720 if (s_p_list.ok() && !p_list.empty()) {
721 padding = Padding::EXPLICIT;
722 explicit_paddings = p_list;
723 TF_RETURN_IF_ERROR(CheckValidPadding(padding, explicit_paddings,
724 /*num_dims=*/4, data_format));
725 }
726 }
727
728 DimensionHandle output_rows, output_cols;
729 int64_t pad_rows_before = -1, pad_rows_after = -1;
730 int64_t pad_cols_before = -1, pad_cols_after = -1;
731 if (padding == Padding::EXPLICIT) {
732 GetExplicitPaddingForDim(explicit_paddings, data_format, 'H',
733 &pad_rows_before, &pad_rows_after);
734 GetExplicitPaddingForDim(explicit_paddings, data_format, 'W',
735 &pad_cols_before, &pad_cols_after);
736 }
737 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDimsV2(
738 c, input_spatial_dims[0], filter_rows_dim, dilation_rows, stride_rows,
739 padding, pad_rows_before, pad_rows_after, &output_rows));
740 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDimsV2(
741 c, input_spatial_dims[1], filter_cols_dim, dilation_cols, stride_cols,
742 padding, pad_cols_before, pad_cols_after, &output_cols));
743
744 absl::optional<DimensionHandle> vect_size;
745 if (data_format == FORMAT_NCHW_VECT_C) {
746 vect_size.emplace(c->Dim(conv_input_shape,
747 GetTensorInnerFeatureDimIndex(rank, data_format)));
748 }
749 ShapeHandle output_shape;
750 TF_RETURN_IF_ERROR(ShapeFromDimensions(
751 batch_size_dim, {output_rows, output_cols}, output_depth_dim, data_format,
752 vect_size, c, &output_shape));
753 c->set_output(0, output_shape);
754 return OkStatus();
755 }
756
757 } // namespace
758
759 // Shape function for Conv2D-like operations that support explicit padding.
Conv2DShapeWithExplicitPadding(shape_inference::InferenceContext * c)760 Status Conv2DShapeWithExplicitPadding(shape_inference::InferenceContext* c) {
761 return Conv2DShapeImpl(c, true);
762 }
763
764 // Shape function for Conv2D-like operations that do not support explicit
765 // padding.
Conv2DShape(shape_inference::InferenceContext * c)766 Status Conv2DShape(shape_inference::InferenceContext* c) {
767 return Conv2DShapeImpl(c, false);
768 }
769
770 // TODO(mjanusz): Unify all conv/pooling shape functions.
Conv3DShape(shape_inference::InferenceContext * c)771 Status Conv3DShape(shape_inference::InferenceContext* c) {
772 ShapeHandle input_shape;
773 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 5, &input_shape));
774 ShapeHandle filter_shape;
775 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 5, &filter_shape));
776
777 string data_format;
778 Status s = c->GetAttr("data_format", &data_format);
779
780 std::vector<int32> dilations;
781 TF_RETURN_IF_ERROR(c->GetAttr("dilations", &dilations));
782
783 if (dilations.size() != 5) {
784 return errors::InvalidArgument(
785 "Conv3D requires the dilation attribute to contain 5 values, but got: ",
786 dilations.size());
787 }
788
789 std::vector<int32> strides;
790 TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
791 if (strides.size() != 5) {
792 return errors::InvalidArgument(
793 "Conv3D requires the stride attribute to contain 5 values, but got: ",
794 strides.size());
795 }
796
797 int32_t stride_planes, stride_rows, stride_cols;
798 int32_t dilation_planes, dilation_rows, dilation_cols;
799 if (s.ok() && data_format == "NCDHW") {
800 // Convert input_shape to NDHWC.
801 auto dim = [&](char dimension) {
802 return c->Dim(input_shape, GetTensorDimIndex<3>(FORMAT_NCHW, dimension));
803 };
804 input_shape =
805 c->MakeShape({{dim('N'), dim('0'), dim('1'), dim('2'), dim('C')}});
806 stride_planes = strides[2];
807 stride_rows = strides[3];
808 stride_cols = strides[4];
809 dilation_planes = dilations[2];
810 dilation_cols = dilations[3];
811 dilation_rows = dilations[4];
812 } else {
813 stride_planes = strides[1];
814 stride_rows = strides[2];
815 stride_cols = strides[3];
816 dilation_planes = dilations[1];
817 dilation_cols = dilations[2];
818 dilation_rows = dilations[3];
819 }
820
821 DimensionHandle batch_size_dim = c->Dim(input_shape, 0);
822 DimensionHandle in_planes_dim = c->Dim(input_shape, 1);
823 DimensionHandle in_rows_dim = c->Dim(input_shape, 2);
824 DimensionHandle in_cols_dim = c->Dim(input_shape, 3);
825 DimensionHandle input_depth_dim = c->Dim(input_shape, 4);
826
827 DimensionHandle filter_planes_dim = c->Dim(filter_shape, 0);
828 DimensionHandle filter_rows_dim = c->Dim(filter_shape, 1);
829 DimensionHandle filter_cols_dim = c->Dim(filter_shape, 2);
830 DimensionHandle filter_input_depth_dim = c->Dim(filter_shape, 3);
831 DimensionHandle output_depth_dim = c->Dim(filter_shape, 4);
832
833 // Check that the input tensor and the filter tensor agree on the channel
834 // count.
835 if (c->ValueKnown(input_depth_dim) && c->ValueKnown(filter_input_depth_dim)) {
836 int64_t input_depth_value = c->Value(input_depth_dim),
837 filter_input_depth_value = c->Value(filter_input_depth_dim);
838 if (filter_input_depth_value == 0)
839 return errors::InvalidArgument("Depth of filter must not be 0");
840 if (input_depth_value % filter_input_depth_value != 0)
841 return errors::InvalidArgument(
842 "Depth of input (", input_depth_value,
843 ") is not a multiple of input depth of filter (",
844 filter_input_depth_value, ")");
845 if (input_depth_value != filter_input_depth_value) {
846 int64_t num_groups = input_depth_value / filter_input_depth_value;
847 if (c->ValueKnown(output_depth_dim)) {
848 int64_t output_depth_value = c->Value(output_depth_dim);
849 if (num_groups == 0)
850 return errors::InvalidArgument("Number of groups must not be 0");
851 if (output_depth_value % num_groups != 0)
852 return errors::InvalidArgument(
853 "Depth of output (", output_depth_value,
854 ") is not a multiple of the number of groups (", num_groups, ")");
855 }
856 }
857 }
858
859 Padding padding;
860 TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
861 DimensionHandle output_planes, output_rows, output_cols;
862
863 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDimsV2(
864 c, in_planes_dim, filter_planes_dim, dilation_planes, stride_planes,
865 padding, -1, -1, &output_planes));
866 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDimsV2(
867 c, in_rows_dim, filter_rows_dim, dilation_rows, stride_rows, padding, -1,
868 -1, &output_rows));
869 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDimsV2(
870 c, in_cols_dim, filter_cols_dim, dilation_cols, stride_cols, padding, -1,
871 -1, &output_cols));
872
873 ShapeHandle output_shape;
874 if (data_format == "NCDHW") {
875 output_shape = c->MakeShape({batch_size_dim, output_depth_dim,
876 output_planes, output_rows, output_cols});
877 } else {
878 output_shape = c->MakeShape({batch_size_dim, output_planes, output_rows,
879 output_cols, output_depth_dim});
880 }
881 c->set_output(0, output_shape);
882 return OkStatus();
883 }
884
Conv2DBackpropInputShape(shape_inference::InferenceContext * c)885 Status Conv2DBackpropInputShape(shape_inference::InferenceContext* c) {
886 string data_format_str;
887 if (!c->GetAttr("data_format", &data_format_str).ok()) {
888 data_format_str = "NHWC";
889 }
890 TensorFormat data_format;
891 if (!FormatFromString(data_format_str, &data_format)) {
892 return errors::InvalidArgument("Invalid data format string: ",
893 data_format_str);
894 }
895
896 // For the rest of this function, output_grad_* describes out_backprop and
897 // input_grad_* describes in_backprop.
898 ShapeHandle output_grad_shape = c->input(2);
899 TF_RETURN_IF_ERROR(c->WithRank(output_grad_shape, 4, &output_grad_shape));
900 ShapeHandle filter_shape = c->input(1);
901 TF_RETURN_IF_ERROR(c->WithRank(filter_shape, 4, &filter_shape));
902
903 DimensionHandle batch_size_dim;
904 DimensionHandle output_grad_depth_dim;
905 gtl::InlinedVector<DimensionHandle, 2> output_grad_spatial_dims(2);
906 TF_RETURN_IF_ERROR(DimensionsFromShape(
907 output_grad_shape, data_format, &batch_size_dim,
908 absl::MakeSpan(output_grad_spatial_dims), &output_grad_depth_dim, c));
909 DimensionHandle unused;
910 TF_RETURN_IF_ERROR(
911 c->Merge(output_grad_depth_dim, c->Dim(filter_shape, 3), &unused));
912
913 ShapeHandle specified_input_grad_shape;
914 TF_RETURN_IF_ERROR(
915 c->MakeShapeFromShapeTensor(0, &specified_input_grad_shape));
916 if (c->Rank(specified_input_grad_shape) == InferenceContext::kUnknownRank) {
917 TF_RETURN_IF_ERROR(c->WithRank(specified_input_grad_shape, 4,
918 &specified_input_grad_shape));
919 }
920
921 // input_grad_depth_dim doesn't equal c->Dim(filter_shape,2) when the number
922 // of groups is larger than 1. If input_sizes is a 4D shape, we collect
923 // input_grad_depth_dim from input_sizes; otherwise we compute it as
924 // c->Dim(filter_shape,2).
925 DimensionHandle input_grad_depth_dim;
926 gtl::InlinedVector<DimensionHandle, 2> specified_input_grad_spatial_dims(2);
927 int specified_input_grad_rank = c->Rank(specified_input_grad_shape);
928 if (specified_input_grad_rank == 4) {
929 DimensionHandle specified_batch_size_dim;
930 TF_RETURN_IF_ERROR(DimensionsFromShape(
931 specified_input_grad_shape, data_format, &specified_batch_size_dim,
932 absl::MakeSpan(specified_input_grad_spatial_dims),
933 &input_grad_depth_dim, c));
934 TF_RETURN_IF_ERROR(
935 c->Merge(specified_batch_size_dim, batch_size_dim, &unused));
936 } else if (specified_input_grad_rank == 2) {
937 specified_input_grad_spatial_dims[0] =
938 c->Dim(specified_input_grad_shape, 0);
939 specified_input_grad_spatial_dims[1] =
940 c->Dim(specified_input_grad_shape, 1);
941 input_grad_depth_dim = c->Dim(filter_shape, 2);
942 } else {
943 return errors::InvalidArgument(
944 "Conv2DBackpropInput requires input_sizes to contain 4 values or 2 "
945 "values, but got: ",
946 specified_input_grad_rank);
947 }
948
949 ShapeHandle input_grad_shape;
950 TF_RETURN_IF_ERROR(ShapeFromDimensions(
951 batch_size_dim, specified_input_grad_spatial_dims, input_grad_depth_dim,
952 data_format, /*vect_size=*/absl::nullopt, c, &input_grad_shape));
953 c->set_output(0, input_grad_shape);
954 return OkStatus();
955 }
956
Conv2DBackpropFilterWithBiasShape(shape_inference::InferenceContext * c)957 Status Conv2DBackpropFilterWithBiasShape(shape_inference::InferenceContext* c) {
958 ShapeHandle input_shape;
959 // Fetch the data_format attribute, which may not exist.
960 string data_format;
961 Status s = c->GetAttr("data_format", &data_format);
962
963 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_shape));
964 if (s.ok() && data_format == "NCHW") {
965 c->set_output(1, c->Vector(c->Dim(input_shape, -3)));
966 } else {
967 c->set_output(1, c->Vector(c->Dim(input_shape, -1)));
968 }
969 ShapeHandle sh;
970 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &sh));
971 TF_RETURN_IF_ERROR(c->WithRank(sh, 4, &sh));
972 c->set_output(0, sh);
973 return OkStatus();
974 }
975
976 namespace {
977
DepthwiseConv2DNativeShapeImpl(shape_inference::InferenceContext * c,bool supports_explicit_padding)978 Status DepthwiseConv2DNativeShapeImpl(shape_inference::InferenceContext* c,
979 bool supports_explicit_padding) {
980 ShapeHandle input_shape;
981 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_shape));
982 ShapeHandle filter_shape;
983 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 4, &filter_shape));
984
985 std::vector<int32> strides;
986 TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
987
988 if (strides.size() != 4) {
989 return errors::InvalidArgument(
990 "DepthwiseConv2D requires the stride attribute to contain 4 values, "
991 "but got: ",
992 strides.size());
993 }
994
995 std::vector<int32> dilations;
996 if (!c->GetAttr("dilations", &dilations).ok()) {
997 dilations.resize(4, 1);
998 }
999
1000 if (dilations.size() != 4) {
1001 return errors::InvalidArgument(
1002 "DepthwiseConv2D requires the dilations attribute to contain 4 values, "
1003 "but got: ",
1004 dilations.size());
1005 }
1006
1007 string data_format_str;
1008 Status s = c->GetAttr("data_format", &data_format_str);
1009 TensorFormat data_format;
1010 if (!s.ok() || !FormatFromString(data_format_str, &data_format)) {
1011 data_format = FORMAT_NHWC;
1012 }
1013 int32_t stride_rows;
1014 int32_t stride_cols;
1015 int32_t dilation_rows;
1016 int32_t dilation_cols;
1017 if (data_format == FORMAT_NCHW) {
1018 // Canonicalize input shape to NHWC so the shape inference code below can
1019 // process it.
1020 input_shape =
1021 c->MakeShape({{c->Dim(input_shape, 0), c->Dim(input_shape, 2),
1022 c->Dim(input_shape, 3), c->Dim(input_shape, 1)}});
1023 stride_rows = strides[2];
1024 stride_cols = strides[3];
1025 dilation_rows = dilations[2];
1026 dilation_cols = dilations[3];
1027 } else {
1028 stride_rows = strides[1];
1029 stride_cols = strides[2];
1030 dilation_rows = dilations[1];
1031 dilation_cols = dilations[2];
1032 }
1033
1034 DimensionHandle batch_size_dim = c->Dim(input_shape, 0);
1035 DimensionHandle in_rows_dim = c->Dim(input_shape, 1);
1036 DimensionHandle in_cols_dim = c->Dim(input_shape, 2);
1037
1038 DimensionHandle filter_rows_dim = c->Dim(filter_shape, 0);
1039 DimensionHandle filter_cols_dim = c->Dim(filter_shape, 1);
1040 DimensionHandle input_depth = c->Dim(filter_shape, 2);
1041 DimensionHandle depth_multiplier = c->Dim(filter_shape, 3);
1042
1043 // Check that the input depths are compatible.
1044 TF_RETURN_IF_ERROR(
1045 c->Merge(c->Dim(input_shape, 3), input_depth, &input_depth));
1046
1047 DimensionHandle output_depth;
1048 TF_RETURN_IF_ERROR(c->Multiply(input_depth, depth_multiplier, &output_depth));
1049
1050 Padding padding;
1051 TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
1052
1053 std::vector<int64_t> explicit_paddings;
1054 if (supports_explicit_padding) {
1055 Status status = c->GetAttr("explicit_paddings", &explicit_paddings);
1056 // Use the default value, which is an empty list, if the attribute is not
1057 // found. Otherwise return the error to the caller.
1058 if (!status.ok() && !errors::IsNotFound(status)) {
1059 return status;
1060 }
1061 TF_RETURN_IF_ERROR(CheckValidPadding(padding, explicit_paddings,
1062 /*num_dims=*/4, data_format));
1063 } else {
1064 DCHECK(padding != Padding::EXPLICIT);
1065 }
1066
1067 // TODO(mrry,shlens): Raise an error if the stride would cause
1068 // information in the input to be ignored. This will require a change
1069 // in the kernel implementation.
1070 DimensionHandle output_rows, output_cols;
1071 int64_t pad_rows_before = -1, pad_rows_after = -1;
1072 int64_t pad_cols_before = -1, pad_cols_after = -1;
1073 if (padding == Padding::EXPLICIT) {
1074 GetExplicitPaddingForDim(explicit_paddings, data_format, 'H',
1075 &pad_rows_before, &pad_rows_after);
1076 GetExplicitPaddingForDim(explicit_paddings, data_format, 'W',
1077 &pad_cols_before, &pad_cols_after);
1078 }
1079 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDimsV2(
1080 c, in_rows_dim, filter_rows_dim, dilation_rows, stride_rows, padding,
1081 pad_rows_before, pad_rows_after, &output_rows));
1082 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDimsV2(
1083 c, in_cols_dim, filter_cols_dim, dilation_cols, stride_cols, padding,
1084 pad_cols_before, pad_cols_after, &output_cols));
1085
1086 ShapeHandle output_shape;
1087 if (data_format == FORMAT_NCHW) {
1088 output_shape =
1089 c->MakeShape({batch_size_dim, output_depth, output_rows, output_cols});
1090 } else {
1091 output_shape =
1092 c->MakeShape({batch_size_dim, output_rows, output_cols, output_depth});
1093 }
1094 c->set_output(0, output_shape);
1095 return OkStatus();
1096 }
1097
1098 }; // namespace
1099
DepthwiseConv2DNativeShape(shape_inference::InferenceContext * c)1100 Status DepthwiseConv2DNativeShape(shape_inference::InferenceContext* c) {
1101 return DepthwiseConv2DNativeShapeImpl(c, false);
1102 }
1103
DepthwiseConv2DNativeShapeWithExplicitPadding(shape_inference::InferenceContext * c)1104 Status DepthwiseConv2DNativeShapeWithExplicitPadding(
1105 shape_inference::InferenceContext* c) {
1106 return DepthwiseConv2DNativeShapeImpl(c, true);
1107 }
1108
AvgPoolShape(shape_inference::InferenceContext * c)1109 Status AvgPoolShape(shape_inference::InferenceContext* c) {
1110 string data_format_str;
1111 TensorFormat data_format;
1112 Status s = c->GetAttr("data_format", &data_format_str);
1113 if (s.ok()) {
1114 FormatFromString(data_format_str, &data_format);
1115 } else {
1116 data_format = FORMAT_NHWC;
1117 }
1118
1119 const int rank = (data_format == FORMAT_NCHW_VECT_C) ? 5 : 4;
1120 ShapeHandle input_shape;
1121 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &input_shape));
1122
1123 TF_RETURN_IF_ERROR(
1124 CheckFormatConstraintsOnShape(data_format, input_shape, "input", c));
1125
1126 std::vector<int32> strides;
1127 TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
1128 if (strides.size() != 4) {
1129 return errors::InvalidArgument(
1130 "AvgPool requires the stride attribute to contain 4 values, but got: ",
1131 strides.size());
1132 }
1133
1134 std::vector<int32> kernel_sizes;
1135 TF_RETURN_IF_ERROR(c->GetAttr("ksize", &kernel_sizes));
1136 if (kernel_sizes.size() != 4) {
1137 return errors::InvalidArgument(
1138 "AvgPool requires the ksize attribute to contain 4 values, but got: ",
1139 kernel_sizes.size());
1140 }
1141
1142 int32_t stride_rows = GetTensorDim(strides, data_format, 'H');
1143 int32_t stride_cols = GetTensorDim(strides, data_format, 'W');
1144 int32_t kernel_rows = GetTensorDim(kernel_sizes, data_format, 'H');
1145 int32_t kernel_cols = GetTensorDim(kernel_sizes, data_format, 'W');
1146
1147 constexpr int num_spatial_dims = 2;
1148 DimensionHandle batch_size_dim = c->Dim(
1149 input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'N'));
1150 DimensionHandle in_rows_dim = c->Dim(
1151 input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'H'));
1152 DimensionHandle in_cols_dim = c->Dim(
1153 input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'W'));
1154 DimensionHandle depth_dim = c->Dim(
1155 input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'C'));
1156
1157 Padding padding;
1158 TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
1159
1160 // TODO(mrry,shlens): Raise an error if the stride would cause
1161 // information in the input to be ignored. This will require a change
1162 // in the kernel implementation.
1163
1164 DimensionHandle output_rows, output_cols;
1165 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
1166 c, in_rows_dim, kernel_rows, stride_rows, padding, &output_rows));
1167 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
1168 c, in_cols_dim, kernel_cols, stride_cols, padding, &output_cols));
1169
1170 ShapeHandle output_shape;
1171 TF_RETURN_IF_ERROR(MakeShapeFromFormat(data_format, batch_size_dim,
1172 {output_rows, output_cols}, depth_dim,
1173 &output_shape, c));
1174 c->set_output(0, output_shape);
1175 return OkStatus();
1176 }
1177
AvgPoolGradShape(shape_inference::InferenceContext * c)1178 Status AvgPoolGradShape(shape_inference::InferenceContext* c) {
1179 ShapeHandle s;
1180 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s));
1181 TF_RETURN_IF_ERROR(c->WithRank(s, 4, &s));
1182 c->set_output(0, s);
1183 return OkStatus();
1184 }
1185
FusedBatchNormShape(shape_inference::InferenceContext * c)1186 Status FusedBatchNormShape(shape_inference::InferenceContext* c) {
1187 string data_format_str;
1188 TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format_str));
1189 TensorFormat data_format;
1190 if (!FormatFromString(data_format_str, &data_format)) {
1191 return errors::InvalidArgument("Invalid data format string: ",
1192 data_format_str);
1193 }
1194 const int rank =
1195 (data_format_str == "NDHWC" || data_format_str == "NCDHW") ? 5 : 4;
1196 ShapeHandle x;
1197 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &x));
1198
1199 bool is_training;
1200 TF_RETURN_IF_ERROR(c->GetAttr("is_training", &is_training));
1201 float exponential_avg_factor;
1202 if (!c->GetAttr("exponential_avg_factor", &exponential_avg_factor).ok()) {
1203 exponential_avg_factor = 1.0f; // default value
1204 }
1205 int number_inputs = (is_training && exponential_avg_factor == 1.0f) ? 3 : 5;
1206
1207 int channel_dim_index = GetTensorFeatureDimIndex(rank, data_format);
1208 DimensionHandle channel_dim = c->Dim(x, channel_dim_index);
1209
1210 // covers scale, offset, and if is_training is false, mean, variance
1211 for (int i = 1; i < number_inputs; ++i) {
1212 ShapeHandle vec;
1213 TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &vec));
1214 TF_RETURN_IF_ERROR(c->Merge(channel_dim, c->Dim(vec, 0), &channel_dim));
1215 }
1216
1217 ShapeHandle y;
1218 TF_RETURN_IF_ERROR(c->ReplaceDim(x, channel_dim_index, channel_dim, &y));
1219 c->set_output(0, y);
1220 ShapeHandle vector_shape = c->Vector(channel_dim);
1221 c->set_output(1, vector_shape);
1222 c->set_output(2, vector_shape);
1223 c->set_output(3, vector_shape);
1224 c->set_output(4, vector_shape);
1225 return OkStatus();
1226 }
1227
FusedBatchNormV3Shape(shape_inference::InferenceContext * c)1228 Status FusedBatchNormV3Shape(shape_inference::InferenceContext* c) {
1229 TF_RETURN_IF_ERROR(FusedBatchNormShape(c));
1230 c->set_output(5, c->UnknownShape());
1231 return OkStatus();
1232 }
1233
FusedBatchNormExShape(shape_inference::InferenceContext * c)1234 Status FusedBatchNormExShape(shape_inference::InferenceContext* c) {
1235 TF_RETURN_IF_ERROR(FusedBatchNormV3Shape(c));
1236
1237 string data_format_str;
1238 TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format_str));
1239 TensorFormat data_format;
1240 if (!FormatFromString(data_format_str, &data_format)) {
1241 return errors::InvalidArgument("Invalid data format string: ",
1242 data_format_str);
1243 }
1244 ShapeHandle x;
1245 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &x));
1246
1247 int channel_dim_index = GetTensorFeatureDimIndex(4, data_format);
1248 DimensionHandle channel_dim = c->Dim(x, channel_dim_index);
1249
1250 // This is a cuDNN implementation constraint.
1251 if (c->ValueKnown(channel_dim) && c->Value(channel_dim) % 4 != 0) {
1252 return errors::InvalidArgument(
1253 "_FusedBatchNormEx channel dimension must be divisible by 4.");
1254 }
1255
1256 return OkStatus();
1257 }
1258
FusedBatchNormGradShape(shape_inference::InferenceContext * c)1259 Status FusedBatchNormGradShape(shape_inference::InferenceContext* c) {
1260 string data_format_str;
1261 TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format_str));
1262 TensorFormat data_format;
1263 if (!FormatFromString(data_format_str, &data_format)) {
1264 return errors::InvalidArgument("Invalid data format string: ",
1265 data_format_str);
1266 }
1267 const int rank =
1268 (data_format_str == "NDHWC" || data_format_str == "NCDHW") ? 5 : 4;
1269 ShapeHandle y_backprop;
1270 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &y_backprop));
1271 ShapeHandle x;
1272 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), rank, &x));
1273
1274 bool is_training;
1275 TF_RETURN_IF_ERROR(c->GetAttr("is_training", &is_training));
1276
1277 int channel_dim_index = GetTensorFeatureDimIndex(rank, data_format);
1278 DimensionHandle channel_dim = c->Dim(y_backprop, channel_dim_index);
1279 TF_RETURN_IF_ERROR(
1280 c->Merge(channel_dim, c->Dim(x, channel_dim_index), &channel_dim));
1281
1282 // covers scale, mean (reserve_space_1), variance (reserve_space_2)
1283 for (int i = 2; i < 5; ++i) {
1284 ShapeHandle vec;
1285 TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &vec));
1286 TF_RETURN_IF_ERROR(c->Merge(channel_dim, c->Dim(vec, 0), &channel_dim));
1287 }
1288
1289 ShapeHandle x_backprop;
1290 TF_RETURN_IF_ERROR(
1291 c->ReplaceDim(y_backprop, channel_dim_index, channel_dim, &x_backprop));
1292 c->set_output(0, x_backprop);
1293 c->set_output(1, c->Vector(channel_dim));
1294 c->set_output(2, c->Vector(channel_dim));
1295 c->set_output(3, c->Vector(0));
1296 c->set_output(4, c->Vector(0));
1297 return OkStatus();
1298 }
1299
FusedBatchNormGradExShape(shape_inference::InferenceContext * c)1300 Status FusedBatchNormGradExShape(shape_inference::InferenceContext* c) {
1301 TF_RETURN_IF_ERROR(FusedBatchNormGradShape(c));
1302
1303 int num_side_inputs;
1304 TF_RETURN_IF_ERROR(c->GetAttr("num_side_inputs", &num_side_inputs));
1305 if (num_side_inputs == 0) {
1306 return OkStatus();
1307 }
1308
1309 string data_format_str;
1310 TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format_str));
1311 TensorFormat data_format;
1312 if (!FormatFromString(data_format_str, &data_format)) {
1313 return errors::InvalidArgument("Invalid data format string: ",
1314 data_format_str);
1315 }
1316 const int rank =
1317 (data_format_str == "NDHWC" || data_format_str == "NCDHW") ? 5 : 4;
1318 ShapeHandle y_backprop;
1319 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &y_backprop));
1320 ShapeHandle x;
1321 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), rank, &x));
1322
1323 int channel_dim_index = GetTensorFeatureDimIndex(rank, data_format);
1324 DimensionHandle channel_dim = c->Dim(y_backprop, channel_dim_index);
1325 TF_RETURN_IF_ERROR(
1326 c->Merge(channel_dim, c->Dim(x, channel_dim_index), &channel_dim));
1327
1328 ShapeHandle side_input_backprop;
1329 TF_RETURN_IF_ERROR(c->ReplaceDim(y_backprop, channel_dim_index, channel_dim,
1330 &side_input_backprop));
1331
1332 c->set_output(5, side_input_backprop);
1333 return OkStatus();
1334 }
1335
ReadDiagIndex(InferenceContext * c,const Tensor * diag_index_tensor,int32 * lower_diag_index,int32 * upper_diag_index)1336 Status ReadDiagIndex(InferenceContext* c, const Tensor* diag_index_tensor,
1337 int32* lower_diag_index, int32* upper_diag_index) {
1338 // This function assumes that the shape of diag_index_tensor is fully defined.
1339 if (diag_index_tensor->dims() == 0) {
1340 *lower_diag_index = diag_index_tensor->scalar<int32>()();
1341 *upper_diag_index = *lower_diag_index;
1342 } else {
1343 int32_t num_elements = diag_index_tensor->dim_size(0);
1344 if (num_elements == 1) {
1345 *lower_diag_index = diag_index_tensor->vec<int32>()(0);
1346 *upper_diag_index = *lower_diag_index;
1347 } else if (num_elements == 2) {
1348 *lower_diag_index = diag_index_tensor->vec<int32>()(0);
1349 *upper_diag_index = diag_index_tensor->vec<int32>()(1);
1350 } else {
1351 return errors::InvalidArgument(
1352 "diag_index must be a vector with one or two elements. It has ",
1353 num_elements, " elements.");
1354 }
1355 }
1356 return OkStatus();
1357 }
1358
MatrixDiagPartV2Shape(shape_inference::InferenceContext * c)1359 Status MatrixDiagPartV2Shape(shape_inference::InferenceContext* c) {
1360 ShapeHandle input_shape, diag_index_shape, unused_shape;
1361 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input_shape));
1362 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(1), 1, &diag_index_shape));
1363 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused_shape));
1364
1365 const Tensor* diag_index_tensor = c->input_tensor(1);
1366 if (!c->RankKnown(input_shape) || !c->FullyDefined(diag_index_shape) ||
1367 diag_index_tensor == nullptr) {
1368 c->set_output(0, c->UnknownShape());
1369 return OkStatus();
1370 }
1371 int32_t lower_diag_index = 0;
1372 int32_t upper_diag_index = 0;
1373 TF_RETURN_IF_ERROR(ReadDiagIndex(c, diag_index_tensor, &lower_diag_index,
1374 &upper_diag_index));
1375 if (lower_diag_index > upper_diag_index) {
1376 return errors::InvalidArgument(
1377 "lower_diag_index is greater than upper_diag_index");
1378 }
1379
1380 // Validates lower_diag_index and upper_diag_index.
1381 const int32_t input_rank = c->Rank(input_shape);
1382 const int32_t num_rows = c->Value(c->Dim(input_shape, input_rank - 2));
1383 const int32_t num_cols = c->Value(c->Dim(input_shape, input_rank - 1));
1384 int32_t max_diag_len = InferenceContext::kUnknownDim;
1385 if (num_rows != InferenceContext::kUnknownDim &&
1386 num_cols != InferenceContext::kUnknownDim) {
1387 if (lower_diag_index != 0 && // For when num_rows or num_cols == 0.
1388 (-num_rows >= lower_diag_index || lower_diag_index >= num_cols)) {
1389 return errors::InvalidArgument("lower_diag_index is out of bound.");
1390 }
1391 if (upper_diag_index != 0 && // For when num_rows or num_cols == 0.
1392 (-num_rows >= upper_diag_index || upper_diag_index >= num_cols)) {
1393 return errors::InvalidArgument("upper_diag_index is out of bound.");
1394 }
1395 max_diag_len = std::min(num_rows + std::min(upper_diag_index, 0),
1396 num_cols - std::max(lower_diag_index, 0));
1397 }
1398
1399 std::vector<DimensionHandle> dims;
1400 dims.reserve(input_rank - 2);
1401 for (int i = 0; i < input_rank - 2; ++i) {
1402 dims.push_back(c->Dim(input_shape, i));
1403 }
1404 if (lower_diag_index < upper_diag_index) {
1405 dims.push_back(c->MakeDim(upper_diag_index - lower_diag_index + 1));
1406 }
1407 dims.push_back(c->MakeDim(max_diag_len));
1408 c->set_output(0, c->MakeShape(dims));
1409 return OkStatus();
1410 }
1411
MatrixDiagV2Shape(shape_inference::InferenceContext * c)1412 Status MatrixDiagV2Shape(shape_inference::InferenceContext* c) {
1413 // Checks input ranks.
1414 ShapeHandle input_shape, diag_index_shape, unused_shape;
1415 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &input_shape));
1416 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(1), 1, &diag_index_shape));
1417 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused_shape));
1418 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused_shape));
1419 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused_shape));
1420
1421 // Reads the diagonal indices.
1422 const Tensor* diag_index_tensor = c->input_tensor(1);
1423 if (!c->RankKnown(input_shape) || !c->FullyDefined(diag_index_shape) ||
1424 diag_index_tensor == nullptr) {
1425 c->set_output(0, c->UnknownShape());
1426 return OkStatus();
1427 }
1428 int32_t lower_diag_index = 0;
1429 int32_t upper_diag_index = 0;
1430 TF_RETURN_IF_ERROR(ReadDiagIndex(c, diag_index_tensor, &lower_diag_index,
1431 &upper_diag_index));
1432 if (lower_diag_index > upper_diag_index) {
1433 return errors::InvalidArgument(
1434 "lower_diag_index is greater than upper_diag_index");
1435 }
1436
1437 // Checks if the number of diagonals provided matches what we imply from
1438 // lower_diag_index and upper_diag_index.
1439 const int32_t input_rank = c->Rank(input_shape);
1440 if (lower_diag_index < upper_diag_index) {
1441 const int32_t num_diags = c->Value(c->Dim(input_shape, input_rank - 2));
1442 const int32_t other_dim = c->Value(c->Dim(input_shape, input_rank - 1));
1443
1444 if (num_diags != (upper_diag_index - lower_diag_index + 1)) {
1445 return errors::InvalidArgument(
1446 "The number of rows of `diagonal` doesn't match the number of "
1447 "diagonals implied from `d_lower` and `d_upper`.\n",
1448 "num_diags = ", num_diags, ", d_lower = ", lower_diag_index,
1449 ", d_upper = ", upper_diag_index, " ", input_rank, " ", other_dim);
1450 }
1451 }
1452
1453 // Reads num_rows and num_cols.
1454 const Tensor* num_rows_tensor = c->input_tensor(2);
1455 const Tensor* num_cols_tensor = c->input_tensor(3);
1456 int64_t num_rows = -1;
1457 int64_t num_cols = -1;
1458 if (num_rows_tensor != nullptr) {
1459 TF_RETURN_IF_ERROR(c->GetScalarFromTensor(num_rows_tensor, &num_rows));
1460 }
1461 if (num_cols_tensor != nullptr) {
1462 TF_RETURN_IF_ERROR(c->GetScalarFromTensor(num_cols_tensor, &num_cols));
1463 }
1464
1465 // Infers the missing num_rows or num_cols: If both are missing, assume
1466 // output is square. Otherwise, use the smallest possible value. Also
1467 // validates the provided values.
1468 const int32_t max_diag_len = c->Value(c->Dim(input_shape, input_rank - 1));
1469 const int32_t min_num_rows = max_diag_len - std::min(upper_diag_index, 0);
1470 const int32_t min_num_cols = max_diag_len + std::max(lower_diag_index, 0);
1471 if (num_rows == -1 && num_cols == -1) { // Special case.
1472 num_rows = std::max(min_num_rows, min_num_cols);
1473 num_cols = num_rows;
1474 }
1475 if (num_rows == -1) {
1476 num_rows = min_num_rows;
1477 } else if (num_rows < min_num_rows) {
1478 return errors::InvalidArgument("num_rows is too small");
1479 }
1480 if (num_cols == -1) {
1481 num_cols = min_num_cols;
1482 } else if (num_cols < min_num_cols) {
1483 return errors::InvalidArgument("num_cols is too small.");
1484 }
1485 // At least one of them must match the minimum length.
1486 if (num_rows != min_num_rows && num_cols != min_num_cols) {
1487 return errors::InvalidArgument(
1488 "num_rows and num_cols are not consistent with lower_diag_index, "
1489 "upper_diag_index, and the length of the given diagonals.\n",
1490 "num_rows = ", num_rows, " != min_num_rows = ", min_num_rows,
1491 ", num_cols = ", num_cols, " != min_num_cols = ", min_num_cols);
1492 }
1493
1494 // Sets output shape.
1495 ShapeHandle output_shape;
1496 const DimensionHandle output_row_dim = c->MakeDim(num_rows);
1497 const DimensionHandle output_col_dim = c->MakeDim(num_cols);
1498 if (lower_diag_index == upper_diag_index) {
1499 TF_RETURN_IF_ERROR(c->ReplaceDim(input_shape, input_rank - 1,
1500 output_row_dim, &output_shape));
1501 TF_RETURN_IF_ERROR(
1502 c->Concatenate(output_shape, c->Vector(output_col_dim), &output_shape));
1503 } else {
1504 TF_RETURN_IF_ERROR(c->ReplaceDim(input_shape, input_rank - 2,
1505 output_row_dim, &output_shape));
1506 TF_RETURN_IF_ERROR(c->ReplaceDim(output_shape, input_rank - 1,
1507 output_col_dim, &output_shape));
1508 }
1509 c->set_output(0, output_shape);
1510 return OkStatus();
1511 }
1512
MatrixSetDiagV2Shape(shape_inference::InferenceContext * c)1513 Status MatrixSetDiagV2Shape(shape_inference::InferenceContext* c) {
1514 ShapeHandle input_shape, diag_shape, diag_index_shape;
1515 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input_shape));
1516 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &diag_shape));
1517 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(2), 1, &diag_index_shape));
1518
1519 int32_t lower_diag_index = 0;
1520 int32_t upper_diag_index = 0;
1521 bool diag_index_known = false;
1522 const Tensor* diag_index_tensor = c->input_tensor(2);
1523 if (diag_index_tensor != nullptr && c->FullyDefined(diag_index_shape)) {
1524 diag_index_known = true;
1525 TF_RETURN_IF_ERROR(ReadDiagIndex(c, diag_index_tensor, &lower_diag_index,
1526 &upper_diag_index));
1527 if (lower_diag_index > upper_diag_index) {
1528 return errors::InvalidArgument(
1529 "lower_diag_index is greater than upper_diag_index");
1530 }
1531 }
1532
1533 // Do more checks when input rank is known.
1534 if (c->RankKnown(input_shape)) {
1535 int32_t input_rank = c->Rank(input_shape);
1536
1537 // If diag_index is set, we know the exact rank of diagonal.
1538 if (diag_index_known) {
1539 TF_RETURN_IF_ERROR(c->WithRank(
1540 c->input(1),
1541 (lower_diag_index == upper_diag_index) ? input_rank - 1 : input_rank,
1542 &diag_shape));
1543 } else {
1544 TF_RETURN_IF_ERROR(
1545 c->WithRankAtLeast(c->input(1), input_rank - 1, &diag_shape));
1546 TF_RETURN_IF_ERROR(
1547 c->WithRankAtMost(c->input(1), input_rank, &diag_shape));
1548 }
1549
1550 // Validates lower_diag_index and upper_diag_index.
1551 const int32_t num_rows = c->Value(c->Dim(input_shape, input_rank - 2));
1552 const int32_t num_cols = c->Value(c->Dim(input_shape, input_rank - 1));
1553 if (num_rows != InferenceContext::kUnknownDim &&
1554 num_cols != InferenceContext::kUnknownDim) {
1555 if (lower_diag_index != 0 && // For when num_rows or num_cols == 0.
1556 (-num_rows >= lower_diag_index || lower_diag_index >= num_cols)) {
1557 return errors::InvalidArgument("lower_diag_index is out of bound.");
1558 }
1559 if (upper_diag_index != 0 && // For when num_rows or num_cols == 0.
1560 (-num_rows >= upper_diag_index || upper_diag_index >= num_cols)) {
1561 return errors::InvalidArgument("upper_diag_index is out of bound.");
1562 }
1563 }
1564 }
1565
1566 ShapeHandle output_shape = input_shape;
1567 if (c->RankKnown(diag_shape) && !c->FullyDefined(input_shape)) {
1568 // Try to infer parts of shape from diag.
1569 ShapeHandle diag_prefix;
1570 TF_RETURN_IF_ERROR(c->Subshape(
1571 diag_shape, 0, (lower_diag_index == upper_diag_index) ? -1 : -2,
1572 &diag_prefix));
1573
1574 // The inner matrices can be rectangular, so we can't pinpoint their
1575 // exact height and width by just lower_diag_index, upper_diag_index,
1576 // and the longest length of given diagonals.
1577 TF_RETURN_IF_ERROR(
1578 c->Concatenate(diag_prefix, c->UnknownShapeOfRank(2), &diag_shape));
1579 TF_RETURN_IF_ERROR(c->Merge(input_shape, diag_shape, &output_shape));
1580 }
1581 c->set_output(0, output_shape);
1582 return OkStatus();
1583 }
1584
MaxPoolShapeImpl(shape_inference::InferenceContext * c,bool supports_explicit_padding)1585 Status MaxPoolShapeImpl(shape_inference::InferenceContext* c,
1586 bool supports_explicit_padding) {
1587 string data_format_str;
1588 TensorFormat data_format;
1589 Status s = c->GetAttr("data_format", &data_format_str);
1590 if (s.ok()) {
1591 FormatFromString(data_format_str, &data_format);
1592 } else {
1593 data_format = FORMAT_NHWC;
1594 }
1595
1596 const int rank = (data_format == FORMAT_NCHW_VECT_C) ? 5 : 4;
1597 ShapeHandle input_shape;
1598 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &input_shape));
1599
1600 TF_RETURN_IF_ERROR(
1601 CheckFormatConstraintsOnShape(data_format, input_shape, "input", c));
1602
1603 std::vector<int32> strides;
1604 TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
1605 if (strides.size() != 4) {
1606 return errors::InvalidArgument(
1607 "MaxPool requires the stride attribute to contain 4 values, but got: ",
1608 strides.size());
1609 }
1610
1611 std::vector<int32> kernel_sizes;
1612 TF_RETURN_IF_ERROR(c->GetAttr("ksize", &kernel_sizes));
1613 if (kernel_sizes.size() != 4) {
1614 return errors::InvalidArgument(
1615 "MaxPool requires the ksize attribute to contain 4 values, but got: ",
1616 kernel_sizes.size());
1617 }
1618
1619 int32_t stride_depth = GetTensorDim(strides, data_format, 'C');
1620 int32_t stride_rows = GetTensorDim(strides, data_format, 'H');
1621 int32_t stride_cols = GetTensorDim(strides, data_format, 'W');
1622 int32_t kernel_depth = GetTensorDim(kernel_sizes, data_format, 'C');
1623 int32_t kernel_rows = GetTensorDim(kernel_sizes, data_format, 'H');
1624 int32_t kernel_cols = GetTensorDim(kernel_sizes, data_format, 'W');
1625
1626 constexpr int num_spatial_dims = 2;
1627 DimensionHandle batch_size_dim = c->Dim(
1628 input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'N'));
1629 DimensionHandle in_rows_dim = c->Dim(
1630 input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'H'));
1631 DimensionHandle in_cols_dim = c->Dim(
1632 input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'W'));
1633 DimensionHandle in_depth_dim = c->Dim(
1634 input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'C'));
1635
1636 Padding padding;
1637 TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
1638
1639 std::vector<int64_t> explicit_paddings;
1640 if (supports_explicit_padding) {
1641 Status status = c->GetAttr("explicit_paddings", &explicit_paddings);
1642 // Use the default value, which is an empty list, if the attribute is not
1643 // found. Otherwise return the error to the caller.
1644 if (!status.ok() && !errors::IsNotFound(status)) {
1645 return status;
1646 }
1647 TF_RETURN_IF_ERROR(CheckValidPadding(padding, explicit_paddings,
1648 /*num_dims=*/4, data_format));
1649 } else {
1650 DCHECK(padding != Padding::EXPLICIT);
1651 }
1652
1653 ShapeHandle output_shape;
1654 DimensionHandle output_rows, output_cols, output_depth;
1655 int64_t pad_rows_before = -1, pad_rows_after = -1;
1656 int64_t pad_cols_before = -1, pad_cols_after = -1;
1657 if (padding == Padding::EXPLICIT) {
1658 GetExplicitPaddingForDim(explicit_paddings, data_format, 'H',
1659 &pad_rows_before, &pad_rows_after);
1660 GetExplicitPaddingForDim(explicit_paddings, data_format, 'W',
1661 &pad_cols_before, &pad_cols_after);
1662 }
1663 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDimsV2(
1664 c, in_rows_dim, kernel_rows, /*dilation_rate=*/1, stride_rows, padding,
1665 pad_rows_before, pad_rows_after, &output_rows));
1666 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDimsV2(
1667 c, in_cols_dim, kernel_cols, /*dilation_rate=*/1, stride_cols, padding,
1668 pad_cols_before, pad_cols_after, &output_cols));
1669 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDimsV2(
1670 c, in_depth_dim, kernel_depth, /*dilation_rate=*/1, stride_depth, padding,
1671 /*pad_before*/ 0, /*pad_after*/ 0, &output_depth));
1672
1673 TF_RETURN_IF_ERROR(MakeShapeFromFormat(data_format, batch_size_dim,
1674 {output_rows, output_cols},
1675 output_depth, &output_shape, c));
1676
1677 c->set_output(0, output_shape);
1678 return OkStatus();
1679 }
1680
MaxPoolShape(shape_inference::InferenceContext * c)1681 Status MaxPoolShape(shape_inference::InferenceContext* c) {
1682 return MaxPoolShapeImpl(c, /*supports_explicit_padding=*/false);
1683 }
1684
MaxPoolGradShape(shape_inference::InferenceContext * c)1685 Status MaxPoolGradShape(shape_inference::InferenceContext* c) {
1686 return UnchangedShapeWithRank(c, 4);
1687 }
1688
MaxPoolShapeWithExplicitPadding(shape_inference::InferenceContext * c)1689 Status MaxPoolShapeWithExplicitPadding(shape_inference::InferenceContext* c) {
1690 return MaxPoolShapeImpl(c, /*supports_explicit_padding=*/true);
1691 }
1692
MaxPoolV2Shape(shape_inference::InferenceContext * c,int num_inputs)1693 Status MaxPoolV2Shape(shape_inference::InferenceContext* c, int num_inputs) {
1694 string data_format_str;
1695 TensorFormat data_format;
1696 Status s = c->GetAttr("data_format", &data_format_str);
1697 if (s.ok()) {
1698 FormatFromString(data_format_str, &data_format);
1699 } else {
1700 data_format = FORMAT_NHWC;
1701 }
1702
1703 const int rank = (data_format == FORMAT_NCHW_VECT_C) ? 5 : 4;
1704 ShapeHandle input_shape;
1705 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &input_shape));
1706
1707 TF_RETURN_IF_ERROR(
1708 CheckFormatConstraintsOnShape(data_format, input_shape, "input", c));
1709
1710 std::vector<int32> kernel_sizes;
1711 std::vector<int32> strides;
1712
1713 if (c->num_inputs() + 2 == num_inputs) {
1714 TF_RETURN_IF_ERROR(c->GetAttr("ksize", &kernel_sizes));
1715
1716 TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
1717 } else {
1718 // Verify shape of ksize and strides input.
1719 ShapeHandle size;
1720 DimensionHandle unused;
1721 TF_RETURN_IF_ERROR(c->WithRank(c->input(c->num_inputs() - 2), 1, &size));
1722 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(size, 0), 4, &unused));
1723 TF_RETURN_IF_ERROR(c->WithRank(c->input(c->num_inputs() - 1), 1, &size));
1724 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(size, 0), 4, &unused));
1725
1726 const Tensor* kernel_sizes_tensor = c->input_tensor(c->num_inputs() - 2);
1727 if (kernel_sizes_tensor == nullptr) {
1728 c->set_output(0, c->UnknownShape());
1729 return OkStatus();
1730 }
1731 kernel_sizes.resize(kernel_sizes_tensor->shape().num_elements());
1732 auto kernel_sizes_vec = kernel_sizes_tensor->flat<int32>();
1733 std::copy_n(&kernel_sizes_vec(0), kernel_sizes.size(),
1734 kernel_sizes.begin());
1735
1736 const Tensor* strides_tensor = c->input_tensor(c->num_inputs() - 1);
1737 if (strides_tensor == nullptr) {
1738 c->set_output(0, c->UnknownShape());
1739 return OkStatus();
1740 }
1741 strides.resize(strides_tensor->shape().num_elements());
1742 auto strides_vec = strides_tensor->flat<int32>();
1743 std::copy_n(&strides_vec(0), strides.size(), strides.begin());
1744 }
1745
1746 if (strides.size() != 4) {
1747 return errors::InvalidArgument(
1748 "MaxPool requires the stride attribute to contain 4 values, but "
1749 "got: ",
1750 strides.size());
1751 }
1752 if (kernel_sizes.size() != 4) {
1753 return errors::InvalidArgument(
1754 "MaxPool requires the ksize attribute to contain 4 values, but got: ",
1755 kernel_sizes.size());
1756 }
1757
1758 int32_t stride_depth = GetTensorDim(strides, data_format, 'C');
1759 int32_t stride_rows = GetTensorDim(strides, data_format, 'H');
1760 int32_t stride_cols = GetTensorDim(strides, data_format, 'W');
1761 int32_t kernel_depth = GetTensorDim(kernel_sizes, data_format, 'C');
1762 int32_t kernel_rows = GetTensorDim(kernel_sizes, data_format, 'H');
1763 int32_t kernel_cols = GetTensorDim(kernel_sizes, data_format, 'W');
1764
1765 constexpr int num_spatial_dims = 2;
1766 DimensionHandle batch_size_dim = c->Dim(
1767 input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'N'));
1768 DimensionHandle in_rows_dim = c->Dim(
1769 input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'H'));
1770 DimensionHandle in_cols_dim = c->Dim(
1771 input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'W'));
1772 DimensionHandle in_depth_dim = c->Dim(
1773 input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'C'));
1774
1775 Padding padding;
1776 TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
1777
1778 ShapeHandle output_shape;
1779 DimensionHandle output_rows, output_cols, output_depth;
1780 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
1781 c, in_rows_dim, kernel_rows, stride_rows, padding, &output_rows));
1782 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
1783 c, in_cols_dim, kernel_cols, stride_cols, padding, &output_cols));
1784 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
1785 c, in_depth_dim, kernel_depth, stride_depth, padding, &output_depth));
1786
1787 TF_RETURN_IF_ERROR(MakeShapeFromFormat(data_format, batch_size_dim,
1788 {output_rows, output_cols},
1789 output_depth, &output_shape, c));
1790
1791 c->set_output(0, output_shape);
1792 return OkStatus();
1793 }
1794
Pool3DShape(shape_inference::InferenceContext * c)1795 Status Pool3DShape(shape_inference::InferenceContext* c) {
1796 ShapeHandle input_shape;
1797 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 5, &input_shape));
1798
1799 string data_format;
1800 Status s = c->GetAttr("data_format", &data_format);
1801
1802 std::vector<int32> strides;
1803 TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
1804 if (strides.size() != 5) {
1805 return errors::InvalidArgument(
1806 "Pool3D ops require the stride attribute to contain 5 values, but "
1807 "got: ",
1808 strides.size());
1809 }
1810
1811 std::vector<int32> kernel_sizes;
1812 TF_RETURN_IF_ERROR(c->GetAttr("ksize", &kernel_sizes));
1813 if (kernel_sizes.size() != 5) {
1814 return errors::InvalidArgument(
1815 "Pool3D requires the ksize attribute to contain 5 values, but got: ",
1816 kernel_sizes.size());
1817 }
1818
1819 int32_t stride_planes, stride_rows, stride_cols;
1820 int32_t kernel_planes, kernel_rows, kernel_cols;
1821
1822 if (s.ok() && data_format == "NCDHW") {
1823 // Convert input_shape to NDHWC.
1824 auto dim = [&](char dimension) {
1825 return c->Dim(input_shape, GetTensorDimIndex<3>(FORMAT_NCHW, dimension));
1826 };
1827 input_shape =
1828 c->MakeShape({{dim('N'), dim('0'), dim('1'), dim('2'), dim('C')}});
1829 stride_planes = strides[2];
1830 stride_rows = strides[3];
1831 stride_cols = strides[4];
1832 kernel_planes = kernel_sizes[2];
1833 kernel_rows = kernel_sizes[3];
1834 kernel_cols = kernel_sizes[4];
1835 } else {
1836 stride_planes = strides[1];
1837 stride_rows = strides[2];
1838 stride_cols = strides[3];
1839 kernel_planes = kernel_sizes[1];
1840 kernel_rows = kernel_sizes[2];
1841 kernel_cols = kernel_sizes[3];
1842 }
1843
1844 DimensionHandle batch_size_dim = c->Dim(input_shape, 0);
1845 DimensionHandle in_planes_dim = c->Dim(input_shape, 1);
1846 DimensionHandle in_rows_dim = c->Dim(input_shape, 2);
1847 DimensionHandle in_cols_dim = c->Dim(input_shape, 3);
1848 DimensionHandle output_depth_dim = c->Dim(input_shape, 4);
1849
1850 Padding padding;
1851 TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
1852
1853 // TODO(mrry,shlens): Raise an error if the stride would cause
1854 // information in the input to be ignored. This will require a change
1855 // in the kernel implementation.
1856 DimensionHandle output_planes, output_rows, output_cols;
1857 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
1858 c, in_planes_dim, kernel_planes, stride_planes, padding, &output_planes));
1859 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
1860 c, in_rows_dim, kernel_rows, stride_rows, padding, &output_rows));
1861 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
1862 c, in_cols_dim, kernel_cols, stride_cols, padding, &output_cols));
1863
1864 ShapeHandle output_shape;
1865 if (data_format == "NCDHW") {
1866 output_shape = c->MakeShape({batch_size_dim, output_depth_dim,
1867 output_planes, output_rows, output_cols});
1868 } else {
1869 output_shape = c->MakeShape({batch_size_dim, output_planes, output_rows,
1870 output_cols, output_depth_dim});
1871 }
1872
1873 c->set_output(0, output_shape);
1874 return OkStatus();
1875 }
1876
MaxPool3DGradShape(shape_inference::InferenceContext * c)1877 Status MaxPool3DGradShape(shape_inference::InferenceContext* c) {
1878 return UnchangedShapeWithRank(c, 5);
1879 }
1880
AvgPool3DGradShape(shape_inference::InferenceContext * c)1881 Status AvgPool3DGradShape(shape_inference::InferenceContext* c) {
1882 ShapeHandle s;
1883 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s));
1884 TF_RETURN_IF_ERROR(c->WithRank(s, 5, &s));
1885 c->set_output(0, s);
1886 return OkStatus();
1887 }
1888
UnknownShape(shape_inference::InferenceContext * c)1889 Status UnknownShape(shape_inference::InferenceContext* c) {
1890 for (int i = 0; i < c->num_outputs(); ++i) {
1891 c->set_output(i, c->UnknownShape());
1892 }
1893 return OkStatus();
1894 }
1895
1896 template <typename T>
ReductionShapeHelper(const Tensor * reduction_indices_t,const int32_t input_rank,std::set<int64_t> * true_indices)1897 Status ReductionShapeHelper(const Tensor* reduction_indices_t,
1898 const int32_t input_rank,
1899 std::set<int64_t>* true_indices) {
1900 auto reduction_indices = reduction_indices_t->flat<T>();
1901 for (int i = 0; i < reduction_indices_t->NumElements(); ++i) {
1902 const T reduction_index = reduction_indices(i);
1903 if (reduction_index < -input_rank || reduction_index >= input_rank) {
1904 return errors::InvalidArgument("Invalid reduction dimension ",
1905 reduction_index, " for input with ",
1906 input_rank, " dimensions.");
1907 }
1908
1909 auto wrapped_index = reduction_index;
1910 if (wrapped_index < 0) {
1911 wrapped_index += input_rank;
1912 }
1913
1914 true_indices->insert(wrapped_index);
1915 }
1916 return OkStatus();
1917 }
1918
ReductionShape(InferenceContext * c)1919 Status ReductionShape(InferenceContext* c) {
1920 ShapeHandle input = c->input(0);
1921
1922 ShapeHandle indices;
1923 // Older versions of TensorFlow accidentally allowed higher rank tensors like
1924 // [[1,2]] or [[1],[2]] to represent axis=[1,2].
1925 if (c->graph_def_version() < 21) {
1926 indices = c->input(1);
1927 } else {
1928 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(1), 1, &indices));
1929 }
1930
1931 bool keep_dims;
1932 TF_RETURN_IF_ERROR(c->GetAttr("keep_dims", &keep_dims));
1933
1934 const Tensor* reduction_indices_t = c->input_tensor(1);
1935 if (reduction_indices_t == nullptr || !c->RankKnown(input)) {
1936 // If we do not have the reduction values at runtime, or the
1937 // rank of the input, we don't know the output shape.
1938
1939 if (keep_dims && c->RankKnown(input)) {
1940 // output rank matches input input if <keep_dims>.
1941 c->set_output(0, c->UnknownShapeOfRank(c->Rank(input)));
1942 return OkStatus();
1943 } else {
1944 return shape_inference::UnknownShape(c);
1945 }
1946 }
1947
1948 const int32_t input_rank = c->Rank(input);
1949 std::set<int64_t> true_indices;
1950 if (reduction_indices_t->dtype() == DataType::DT_INT32) {
1951 TF_RETURN_IF_ERROR(ReductionShapeHelper<int32>(reduction_indices_t,
1952 input_rank, &true_indices));
1953 } else if (reduction_indices_t->dtype() == DataType::DT_INT64) {
1954 TF_RETURN_IF_ERROR(ReductionShapeHelper<int64_t>(
1955 reduction_indices_t, input_rank, &true_indices));
1956 } else {
1957 return errors::InvalidArgument(
1958 "reduction_indices can only be int32 or int64");
1959 }
1960
1961 std::vector<DimensionHandle> dims;
1962 for (int i = 0; i < input_rank; ++i) {
1963 if (true_indices.count(i) > 0) {
1964 if (keep_dims) {
1965 dims.emplace_back(c->MakeDim(1));
1966 }
1967 } else {
1968 dims.emplace_back(c->Dim(input, i));
1969 }
1970 }
1971
1972 c->set_output(0, c->MakeShape(dims));
1973 return OkStatus();
1974 }
1975
ConcatShapeHelper(InferenceContext * c,int start_value_index,int end_value_index,int dim_index)1976 Status ConcatShapeHelper(InferenceContext* c, int start_value_index,
1977 int end_value_index, int dim_index) {
1978 ShapeHandle unused;
1979 TF_RETURN_IF_ERROR(c->WithRank(c->input(dim_index), 0, &unused));
1980 const Tensor* concat_dim_t = c->input_tensor(dim_index);
1981 if (concat_dim_t == nullptr) {
1982 // Return an unknown shape with same rank as inputs, or an unknown rank
1983 // if no input's rank is known.
1984
1985 // Find rank.
1986 int32_t rank = InferenceContext::kUnknownRank;
1987 for (int i = start_value_index; i < end_value_index; ++i) {
1988 if (rank == InferenceContext::kUnknownRank) rank = c->Rank(c->input(i));
1989 if (rank != InferenceContext::kUnknownRank) {
1990 break;
1991 }
1992 }
1993 if (rank == InferenceContext::kUnknownRank) {
1994 c->set_output(0, c->UnknownShape());
1995 return OkStatus();
1996 } else if (rank == 0) {
1997 return errors::InvalidArgument(
1998 "Can't concatenate scalars (use tf.stack instead)");
1999 } else {
2000 for (int i = start_value_index; i < end_value_index; ++i) {
2001 // Check that all the inputs are of the correct rank.
2002 TF_RETURN_IF_ERROR(c->WithRank(c->input(i), rank, &unused));
2003 }
2004 }
2005 // Build result of <rank> different unknown dims.
2006 std::vector<DimensionHandle> dims;
2007 dims.reserve(rank);
2008 for (int i = 0; i < rank; ++i) dims.push_back(c->UnknownDim());
2009 c->set_output(0, c->MakeShape(dims));
2010 return OkStatus();
2011 }
2012
2013 // Merge all the non-concat dims, and sum the concat dim to make an output
2014 // shape.
2015 int64_t concat_dim;
2016 if (concat_dim_t->dtype() == DT_INT32) {
2017 concat_dim = static_cast<int64_t>(concat_dim_t->flat<int32>()(0));
2018 } else {
2019 concat_dim = concat_dim_t->flat<int64_t>()(0);
2020 }
2021
2022 // Minimum required number of dimensions.
2023 const int64 min_rank = concat_dim < 0 ? -concat_dim : concat_dim + 1;
2024
2025 ShapeHandle output_before;
2026 ShapeHandle output_after;
2027
2028 ShapeHandle input = c->input(end_value_index - 1);
2029 TF_RETURN_IF_ERROR(c->WithRankAtLeast(input, min_rank, &input));
2030 TF_RETURN_IF_ERROR(c->Subshape(input, 0, concat_dim, &output_before));
2031 DimensionHandle output_middle = c->Dim(input, concat_dim);
2032 if (concat_dim == -1) {
2033 output_after = c->Scalar(); // no dimensions.
2034 } else {
2035 TF_RETURN_IF_ERROR(c->Subshape(input, concat_dim + 1, &output_after));
2036 }
2037
2038 for (int i = end_value_index - 2; i >= start_value_index; --i) {
2039 ShapeHandle before;
2040 ShapeHandle after;
2041 input = c->input(i);
2042 TF_RETURN_IF_ERROR(c->WithRankAtLeast(input, min_rank, &input));
2043 TF_RETURN_IF_ERROR(c->Subshape(input, 0, concat_dim, &before));
2044 DimensionHandle middle = c->Dim(input, concat_dim);
2045 if (concat_dim == -1) {
2046 after = c->Scalar();
2047 } else {
2048 TF_RETURN_IF_ERROR(c->Subshape(input, concat_dim + 1, &after));
2049 }
2050
2051 TF_RETURN_IF_ERROR(c->Merge(before, output_before, &output_before));
2052 TF_RETURN_IF_ERROR(c->Add(output_middle, middle, &output_middle));
2053 TF_RETURN_IF_ERROR(c->Merge(after, output_after, &output_after));
2054 }
2055
2056 ShapeHandle s;
2057 TF_RETURN_IF_ERROR(
2058 c->Concatenate(output_before, c->Vector(output_middle), &s));
2059 TF_RETURN_IF_ERROR(c->Concatenate(s, output_after, &s));
2060 c->set_output(0, s);
2061 return OkStatus();
2062 }
2063
ConcatShape(InferenceContext * c,int num_inputs_to_concat)2064 Status ConcatShape(InferenceContext* c, int num_inputs_to_concat) {
2065 return ConcatShapeHelper(c, 1 /* start_value_index */,
2066 1 + num_inputs_to_concat /* end_value_index */,
2067 0 /* dim_index */);
2068 }
2069
ConcatV2Shape(InferenceContext * c)2070 Status ConcatV2Shape(InferenceContext* c) {
2071 return ConcatShapeHelper(c, 0 /* start_value_index */,
2072 c->num_inputs() - 1 /* end_value_index */,
2073 c->num_inputs() - 1 /* dim_index */);
2074 }
2075
QuantizedConcatV2Shape(InferenceContext * c,int num_inputs_to_concat)2076 Status QuantizedConcatV2Shape(InferenceContext* c, int num_inputs_to_concat) {
2077 return ConcatShapeHelper(c, 0 /* start_value_index */,
2078 num_inputs_to_concat /* end_value_index */,
2079 num_inputs_to_concat /* dim_index */);
2080 }
2081
BroadcastBinaryOpOutputShapeFnHelper(InferenceContext * c,ShapeHandle shape_x,ShapeHandle shape_y,bool incompatible_shape_error,ShapeHandle * out)2082 Status BroadcastBinaryOpOutputShapeFnHelper(InferenceContext* c,
2083 ShapeHandle shape_x,
2084 ShapeHandle shape_y,
2085 bool incompatible_shape_error,
2086 ShapeHandle* out) {
2087 CHECK_NOTNULL(out);
2088 if (!c->RankKnown(shape_x) || !c->RankKnown(shape_y)) {
2089 *out = c->UnknownShape();
2090 return OkStatus();
2091 }
2092 const int32_t rank_x = c->Rank(shape_x);
2093 const int32_t rank_y = c->Rank(shape_y);
2094 const int32_t rank_out = std::max(rank_x, rank_y);
2095
2096 // To compute the broadcast dimensions, we zip together shape_x and shape_y
2097 // and
2098 // pad with 1 to make them the same length.
2099 std::vector<DimensionHandle> dims;
2100 DimensionHandle dim_one;
2101 if (rank_x != rank_y) dim_one = c->MakeDim(1);
2102 for (int i = 0; i < rank_out; ++i) {
2103 const auto dim_x = i < (rank_out - rank_x)
2104 ? dim_one
2105 : c->Dim(shape_x, i - (rank_out - rank_x));
2106 const bool dim_y_is_one = (i < (rank_out - rank_y));
2107 const auto dim_y =
2108 dim_y_is_one ? dim_one : c->Dim(shape_y, i - (rank_out - rank_y));
2109 if (!c->ValueKnown(dim_x) || !c->ValueKnown(dim_y)) {
2110 // One or both dimensions is unknown.
2111 //
2112 // - If either dimension is greater than 1, we assume that the program is
2113 // correct, and the other dimension will be broadcast to match it.
2114 // TODO(cwhipkey): For shape inference, if we eliminate the shape checks
2115 // in C++ op code, we must still assert that the unknown dim is either 1
2116 // or the same as the known dim.
2117 // - If either dimension is 1, the other dimension is the output.
2118 // - If both are unknown then dimension is unknown
2119 if (c->Value(dim_x) > 1) {
2120 if (!incompatible_shape_error) {
2121 *out = c->UnknownShape();
2122 return OkStatus();
2123 }
2124 dims.push_back(dim_x);
2125 } else if (c->Value(dim_y) > 1) {
2126 if (!incompatible_shape_error) {
2127 *out = c->UnknownShape();
2128 return OkStatus();
2129 }
2130 dims.push_back(dim_y);
2131 } else if (c->Value(dim_x) == 1) {
2132 dims.push_back(dim_y);
2133 } else if (c->Value(dim_y) == 1) {
2134 dims.push_back(dim_x);
2135 } else if (dim_y.SameHandle(dim_x)) {
2136 dims.push_back(dim_x);
2137 } else if (!c->ValueKnown(dim_x) && !c->ValueKnown(dim_y)) {
2138 dims.push_back(c->UnknownDim());
2139 } else {
2140 if (!incompatible_shape_error) {
2141 *out = c->UnknownShape();
2142 return OkStatus();
2143 }
2144 dims.push_back(c->UnknownDim());
2145 }
2146 } else if (c->Value(dim_x) == 1 || c->Value(dim_y) == 1) {
2147 if (c->Value(dim_x) == 1 && !dim_y_is_one) {
2148 // We will broadcast dim_x to dim_y.
2149 dims.push_back(dim_y);
2150 } else {
2151 DCHECK_EQ(c->Value(dim_y), 1);
2152 // We will broadcast dim_y to dim_x.
2153 dims.push_back(dim_x);
2154 }
2155 } else {
2156 DimensionHandle dim;
2157 Status s = c->Merge(dim_x, dim_y, &dim);
2158 if (!s.ok()) {
2159 if (!incompatible_shape_error) {
2160 *out = c->MakeShape({});
2161 return OkStatus();
2162 }
2163 return s;
2164 }
2165 dims.push_back(dim);
2166 }
2167 }
2168
2169 *out = c->MakeShape(dims);
2170 return OkStatus();
2171 }
2172
RandomShape(shape_inference::InferenceContext * c)2173 Status RandomShape(shape_inference::InferenceContext* c) {
2174 shape_inference::ShapeHandle out;
2175 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out));
2176 c->set_output(0, out);
2177 return OkStatus();
2178 }
2179
UnsortedSegmentReductionShapeFn(InferenceContext * c)2180 Status UnsortedSegmentReductionShapeFn(InferenceContext* c) {
2181 ShapeHandle s_data = c->input(0);
2182 ShapeHandle s_segment_ids = c->input(1);
2183 ShapeHandle s_num_segments = c->input(2);
2184 TF_RETURN_IF_ERROR(c->WithRank(s_num_segments, 0, &s_num_segments));
2185
2186 ShapeHandle out;
2187
2188 // Leading dimensions of data must be compatible with dimensions of
2189 // <s_segment_ids>.
2190 if (c->RankKnown(s_segment_ids)) {
2191 TF_RETURN_IF_ERROR(
2192 c->MergePrefix(s_data, s_segment_ids, &s_data, &s_segment_ids));
2193
2194 // Get the value of the num_segments input tensor.
2195 DimensionHandle num_segments_dim;
2196 TF_RETURN_IF_ERROR(c->MakeDimForScalarInput(2, &num_segments_dim));
2197
2198 // Output is {segment_id_rank} + s_data[segment_id_rank:].
2199 ShapeHandle s_data_suffix;
2200 TF_RETURN_IF_ERROR(
2201 c->Subshape(s_data, c->Rank(s_segment_ids), &s_data_suffix));
2202 TF_RETURN_IF_ERROR(
2203 c->Concatenate(c->Vector(num_segments_dim), s_data_suffix, &out));
2204 } else {
2205 out = c->UnknownShape();
2206 }
2207 c->set_output(0, out);
2208 return OkStatus();
2209 }
2210
2211 namespace {
2212
2213 // This SliceHelper processes the output shape of the `slice`
2214 // when the tensor of `sizes` is available.
2215 template <typename T>
SliceHelper(InferenceContext * c,ShapeHandle begin_value,const Tensor * sizes_value,std::vector<DimensionHandle> * dims)2216 Status SliceHelper(InferenceContext* c, ShapeHandle begin_value,
2217 const Tensor* sizes_value,
2218 std::vector<DimensionHandle>* dims) {
2219 auto sizes_vec = sizes_value->vec<T>();
2220 for (int i = 0; i < sizes_value->NumElements(); ++i) {
2221 DimensionHandle dim = c->Dim(c->input(0), i);
2222 if (sizes_vec(i) != -1) {
2223 auto dim_val = c->Value(dim);
2224 if (sizes_vec(i) < 0) {
2225 return errors::InvalidArgument(
2226 "Out of bounds slicing on dimension ", i, " of length ", dim_val,
2227 ": sizes vector cannot be < -1, but was ", sizes_vec(i));
2228 }
2229
2230 dims->emplace_back(c->MakeDim(sizes_vec(i)));
2231 } else {
2232 DimensionHandle result;
2233 TF_RETURN_IF_ERROR(c->Subtract(dim, c->Dim(begin_value, i), &result));
2234 dims->emplace_back(result);
2235 }
2236 }
2237
2238 return OkStatus();
2239 }
2240 } // namespace
2241
SliceShape(InferenceContext * c)2242 Status SliceShape(InferenceContext* c) {
2243 ShapeHandle input = c->input(0);
2244 ShapeHandle begin_shape;
2245 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &begin_shape));
2246 ShapeHandle sizes_shape;
2247 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &sizes_shape));
2248
2249 // Merge to check compatibility of begin and sizes tensors.
2250 TF_RETURN_IF_ERROR(c->Merge(begin_shape, sizes_shape, &begin_shape));
2251
2252 DimensionHandle ndims = c->Dim(begin_shape, 0);
2253 if (c->ValueKnown(ndims)) {
2254 TF_RETURN_IF_ERROR(c->WithRank(input, c->Value(ndims), &input));
2255 }
2256
2257 // NOTE(mrry): Use MakeShapeFromShapeTensor to handle partially-known
2258 // values, even though the `begin` value does not represent a shape.
2259 ShapeHandle begin_value;
2260 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &begin_value));
2261
2262 // We check the tensor value here and will only use
2263 // `MakeShapeFromShapeTensor` when `sizes_value` is null.
2264 // The reason is that `sizes` might contain -1, which can't
2265 // be represented (-1 in the ShapeHandle would mean "unknown").
2266 const Tensor* sizes_value = c->input_tensor(2);
2267
2268 if (sizes_value != nullptr) {
2269 TF_RETURN_IF_ERROR(
2270 c->WithRank(begin_value, sizes_value->NumElements(), &begin_value));
2271 std::vector<DimensionHandle> dims;
2272 // If the begin and sizes tensors are available, then
2273 // we can be precise about the shape of the output.
2274 if (sizes_value->dtype() == DT_INT64) {
2275 TF_RETURN_IF_ERROR(
2276 SliceHelper<int64_t>(c, begin_value, sizes_value, &dims));
2277 } else {
2278 TF_RETURN_IF_ERROR(
2279 SliceHelper<int32>(c, begin_value, sizes_value, &dims));
2280 }
2281 c->set_output(0, c->MakeShape(dims));
2282 return OkStatus();
2283 } else {
2284 // In case `sizes` is not available (`sizes_value` is null),
2285 // we could try to use `MakeShapeFromShapeTensor` here.
2286 // If sizes contain -1, we will simply consider it as `Unknown`.
2287 // This is less than ideal but still an improvement of shape inference.
2288 // The following is an example that returns [None, 1, None] with this
2289 // code path:
2290 // z = tf.zeros((1, 2, 3))
2291 // m = tf.slice(z, [0, 0, 0], [tf.constant(1) + 0, 1, -1])
2292 // m.get_shape().as_list()
2293 ShapeHandle sizes_value;
2294 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(2, &sizes_value));
2295 if (c->RankKnown(sizes_value)) {
2296 TF_RETURN_IF_ERROR(
2297 c->WithRank(begin_value, c->Rank(sizes_value), &begin_value));
2298 std::vector<DimensionHandle> dims;
2299 dims.reserve(c->Rank(sizes_value));
2300 for (int i = 0; i < c->Rank(sizes_value); ++i) {
2301 dims.emplace_back(c->Dim(sizes_value, i));
2302 }
2303 c->set_output(0, c->MakeShape(dims));
2304 return OkStatus();
2305 }
2306 // We might know the rank of the input.
2307 if (c->RankKnown(input)) {
2308 c->set_output(0, c->UnknownShapeOfRank(c->Rank(input)));
2309 return OkStatus();
2310 } else {
2311 return shape_inference::UnknownShape(c);
2312 }
2313 }
2314
2315 return OkStatus();
2316 }
2317
ValidateSparseTensor(InferenceContext * c,ShapeHandle indices_shape,ShapeHandle values_shape,ShapeHandle shape_shape)2318 Status ValidateSparseTensor(InferenceContext* c, ShapeHandle indices_shape,
2319 ShapeHandle values_shape, ShapeHandle shape_shape) {
2320 // Validate ranks.
2321 ShapeHandle unused_shape;
2322 TF_RETURN_IF_ERROR(c->WithRank(indices_shape, 2, &unused_shape));
2323 TF_RETURN_IF_ERROR(c->WithRank(values_shape, 1, &unused_shape));
2324 TF_RETURN_IF_ERROR(c->WithRank(shape_shape, 1, &unused_shape));
2325
2326 // Number of elements in indices and values must match.
2327 DimensionHandle num_index_elements_dim = c->Dim(indices_shape, 0);
2328 if (c->ValueKnown(num_index_elements_dim)) {
2329 DimensionHandle num_values_elements_dim = c->Dim(values_shape, 0);
2330 if (c->ValueKnown(num_values_elements_dim)) {
2331 int64_t num_index_elements = c->Value(num_index_elements_dim);
2332 int64_t num_values_elements = c->Value(num_values_elements_dim);
2333 if (num_index_elements != num_values_elements) {
2334 return errors::InvalidArgument("Number of elements in index (",
2335 num_index_elements, ") and values (",
2336 num_values_elements, ") do not match.");
2337 }
2338 }
2339 }
2340
2341 // Rank embedded in indices must match shape.
2342 DimensionHandle index_rank_dim = c->Dim(indices_shape, 1);
2343 if (c->ValueKnown(index_rank_dim)) {
2344 DimensionHandle shape_rank_dim = c->Dim(shape_shape, 0);
2345 if (c->ValueKnown(shape_rank_dim)) {
2346 int64_t index_rank = c->Value(index_rank_dim);
2347 int32_t shape_rank = c->Value(shape_rank_dim);
2348 if (index_rank != shape_rank) {
2349 return errors::InvalidArgument("Index rank (", index_rank,
2350 ") and shape rank (", shape_rank,
2351 ") do not match.");
2352 }
2353 }
2354 }
2355
2356 return OkStatus();
2357 }
2358
ValidateVariableResourceHandle(InferenceContext * c,std::vector<ShapeAndType> * shape_and_type)2359 Status ValidateVariableResourceHandle(
2360 InferenceContext* c, std::vector<ShapeAndType>* shape_and_type) {
2361 auto* handle_data = c->input_handle_shapes_and_types(0);
2362 if (handle_data == nullptr || handle_data->empty()) {
2363 shape_and_type->emplace_back(c->UnknownShape(), DT_INVALID);
2364 } else {
2365 *shape_and_type = *handle_data;
2366 DataType value_dtype;
2367 TF_RETURN_IF_ERROR(c->GetAttr("dtype", &value_dtype));
2368 if (shape_and_type->at(0).dtype != value_dtype) {
2369 return errors::InvalidArgument(
2370 "Trying to read variable with wrong dtype. "
2371 "Expected ",
2372 DataTypeString(shape_and_type->at(0).dtype), " got ",
2373 DataTypeString(value_dtype));
2374 }
2375 }
2376 return OkStatus();
2377 }
2378
GatherNdShape(InferenceContext * c)2379 Status GatherNdShape(InferenceContext* c) {
2380 ShapeHandle params;
2381 std::vector<ShapeAndType> handle_shape_and_type;
2382 if (c->input_handle_shapes_and_types(0) != nullptr) {
2383 TF_RETURN_IF_ERROR(
2384 ValidateVariableResourceHandle(c, &handle_shape_and_type));
2385 params = handle_shape_and_type[0].shape;
2386 } else {
2387 params = c->input(0);
2388 }
2389 ShapeHandle indices;
2390 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &indices));
2391 DimensionHandle r_dim = c->Dim(indices, -1);
2392
2393 if (!c->RankKnown(params) || !c->ValueKnown(r_dim)) {
2394 c->set_output(0, c->UnknownShape());
2395 return OkStatus();
2396 }
2397
2398 if (c->Value(r_dim) > c->Rank(params)) {
2399 return errors::InvalidArgument(
2400 "indices.shape[-1] must be <= params.rank, but saw indices shape: ",
2401 c->DebugString(indices), " and params shape: ", c->DebugString(params));
2402 }
2403
2404 // Remove r_dim from indices to get output.
2405 ShapeHandle indices_slice;
2406 ShapeHandle params_slice;
2407 TF_RETURN_IF_ERROR(c->Subshape(indices, 0, -1, &indices_slice));
2408 TF_RETURN_IF_ERROR(c->Subshape(params, c->Value(r_dim), ¶ms_slice));
2409 ShapeHandle out;
2410 TF_RETURN_IF_ERROR(c->Concatenate(indices_slice, params_slice, &out));
2411 c->set_output(0, out);
2412 return OkStatus();
2413 }
2414
ScatterNdShapeHelper(InferenceContext * c,ShapeHandle indices_shape,ShapeHandle updates_shape,ShapeHandle input_shape)2415 Status ScatterNdShapeHelper(InferenceContext* c, ShapeHandle indices_shape,
2416 ShapeHandle updates_shape,
2417 ShapeHandle input_shape) {
2418 if (c->Value(c->NumElements(input_shape)) == 0 &&
2419 (c->Value(c->NumElements(indices_shape)) > 0 ||
2420 c->Value(c->NumElements(updates_shape)) > 0)) {
2421 return errors::InvalidArgument(
2422 "Indices and updates specified for empty input");
2423 }
2424
2425 if (c->RankKnown(indices_shape) && c->RankKnown(updates_shape) &&
2426 c->Rank(updates_shape) != 0) {
2427 const int64_t outer_dims = c->Rank(indices_shape) - 1;
2428 const DimensionHandle ixdim = c->Dim(indices_shape, -1);
2429
2430 // We can only do more validation if the last dimension of indices
2431 // is a known value.
2432 if (c->ValueKnown(ixdim)) {
2433 int64_t ix = c->Value(ixdim);
2434 ShapeHandle unused;
2435 ShapeHandle prefix_indices;
2436 TF_RETURN_IF_ERROR(
2437 c->Subshape(indices_shape, 0, outer_dims, &prefix_indices));
2438 ShapeHandle prefix_updates;
2439 TF_RETURN_IF_ERROR(
2440 c->Subshape(updates_shape, 0, outer_dims, &prefix_updates));
2441
2442 Status s = c->Merge(prefix_indices, prefix_updates, &unused);
2443 if (!s.ok()) {
2444 return errors::InvalidArgument(
2445 "Dimensions [0,", outer_dims,
2446 ") of indices[shape=", c->DebugString(indices_shape),
2447 "] = ", c->DebugString(prefix_indices),
2448 " must match dimensions [0,", outer_dims,
2449 ") of updates[shape=", c->DebugString(updates_shape),
2450 "] = ", c->DebugString(prefix_updates), ": ", s.error_message());
2451 }
2452
2453 ShapeHandle suffix_output;
2454 TF_RETURN_IF_ERROR(c->Subshape(input_shape, ix, &suffix_output));
2455 ShapeHandle suffix_updates;
2456 TF_RETURN_IF_ERROR(
2457 c->Subshape(updates_shape, outer_dims, &suffix_updates));
2458 s = c->Merge(suffix_output, suffix_updates, &unused);
2459 if (!s.ok()) {
2460 return errors::InvalidArgument(
2461 "Dimensions [", ix, ",", c->Rank(input_shape),
2462 ") of input[shape=", c->DebugString(input_shape),
2463 "] = ", c->DebugString(suffix_output), " must match dimensions [",
2464 outer_dims, ",", c->Rank(updates_shape),
2465 ") of updates[shape=", c->DebugString(updates_shape),
2466 "] = ", c->DebugString(suffix_updates), ": ", s.error_message());
2467 }
2468 }
2469 }
2470
2471 if (c->input_handle_shapes_and_types(0) == nullptr && c->num_outputs() > 0) {
2472 // This is called for tf.scatter_nd; output is a tensor with this shape.
2473 c->set_output(0, input_shape);
2474 }
2475 return OkStatus();
2476 }
2477
ExplicitShape(InferenceContext * c)2478 Status ExplicitShape(InferenceContext* c) {
2479 PartialTensorShape shape;
2480 TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape));
2481 ShapeHandle output_shape;
2482 TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(shape, &output_shape));
2483 c->set_output(0, output_shape);
2484 return OkStatus();
2485 }
2486
ExplicitShapes(InferenceContext * c)2487 Status ExplicitShapes(InferenceContext* c) {
2488 std::vector<PartialTensorShape> shapes;
2489 TF_RETURN_IF_ERROR(c->GetAttr("shapes", &shapes));
2490 if (shapes.empty()) {
2491 return errors::Internal("shapes attribute is empty");
2492 }
2493 for (int i = 0, end = shapes.size(); i < end; ++i) {
2494 ShapeHandle output_shape;
2495 TF_RETURN_IF_ERROR(
2496 c->MakeShapeFromPartialTensorShape(shapes[i], &output_shape));
2497 c->set_output(i, output_shape);
2498 }
2499 return OkStatus();
2500 }
2501
SparseReduceShapeFn(InferenceContext * c)2502 Status SparseReduceShapeFn(InferenceContext* c) {
2503 // Input 0: input_indices
2504 // Input 1: input_values
2505 // Input 2: input_shape
2506 // Input 3: reduction_axes
2507 // Attr: keep_dims
2508 bool keep_dims = false;
2509 TF_RETURN_IF_ERROR(c->GetAttr("keep_dims", &keep_dims));
2510
2511 const Tensor* shape_tensor = c->input_tensor(2);
2512 const Tensor* axes_tensor = c->input_tensor(3);
2513 if (shape_tensor != nullptr && axes_tensor != nullptr) {
2514 auto shape_vec = shape_tensor->flat<int64_t>();
2515 auto axes_vec = axes_tensor->flat<int32>();
2516
2517 int64_t ndims = shape_vec.size();
2518 absl::flat_hash_set<int64_t> axes;
2519 if (ndims == 0)
2520 return errors::InvalidArgument(
2521 "Number of dims in shape tensor must not be 0");
2522 for (int i = 0; i < axes_vec.size(); i++) {
2523 axes.insert((axes_vec(i) + ndims) % ndims);
2524 }
2525
2526 std::vector<DimensionHandle> dims;
2527 if (keep_dims) {
2528 dims.reserve(ndims);
2529 for (int d = 0; d < ndims; ++d) {
2530 if (axes.find(d) == axes.end()) {
2531 dims.push_back(c->MakeDim(shape_vec(d)));
2532 } else {
2533 dims.push_back(c->MakeDim(1));
2534 }
2535 }
2536 } else {
2537 for (int d = 0; d < ndims; ++d) {
2538 if (axes.find(d) == axes.end()) {
2539 dims.push_back(c->MakeDim(shape_vec(d)));
2540 }
2541 }
2542 }
2543
2544 c->set_output(0, c->MakeShape(dims));
2545 return OkStatus();
2546 }
2547 return UnknownShape(c);
2548 }
2549
QuantizedConv2DShape(InferenceContext * c)2550 Status QuantizedConv2DShape(InferenceContext* c) {
2551 TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c));
2552 ShapeHandle unused;
2553 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
2554 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
2555 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
2556 TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused));
2557 c->set_output(1, c->Scalar());
2558 c->set_output(2, c->Scalar());
2559 return OkStatus();
2560 }
2561
QuantizedAvgPoolShape(InferenceContext * c)2562 Status QuantizedAvgPoolShape(InferenceContext* c) {
2563 TF_RETURN_IF_ERROR(shape_inference::AvgPoolShape(c));
2564 ShapeHandle unused;
2565 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
2566 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
2567 c->set_output(1, c->Scalar());
2568 c->set_output(2, c->Scalar());
2569 return OkStatus();
2570 }
2571
QuantizeV2Shape(InferenceContext * c)2572 Status QuantizeV2Shape(InferenceContext* c) {
2573 int axis = -1;
2574 Status s = c->GetAttr("axis", &axis);
2575 if (!s.ok() && s.code() != error::NOT_FOUND) {
2576 return s;
2577 }
2578 if (axis < -1) {
2579 return errors::InvalidArgument("axis should be at least -1, got ", axis);
2580 }
2581 const int minmax_rank = (axis == -1) ? 0 : 1;
2582 TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c));
2583 ShapeHandle minmax;
2584 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), minmax_rank, &minmax));
2585 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), minmax_rank, &minmax));
2586 if (axis != -1) {
2587 ShapeHandle input;
2588 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), axis + 1, &input));
2589 DimensionHandle depth;
2590 TF_RETURN_IF_ERROR(
2591 c->Merge(c->Dim(minmax, 0), c->Dim(input, axis), &depth));
2592 }
2593 c->set_output(1, minmax);
2594 c->set_output(2, minmax);
2595 return OkStatus();
2596 }
2597
ReduceScatterShape(shape_inference::InferenceContext * c)2598 Status ReduceScatterShape(shape_inference::InferenceContext* c) {
2599 shape_inference::ShapeHandle in = c->input(0);
2600 if (!c->RankKnown(in)) {
2601 // Input shape unknown, so set unknown output shape.
2602 c->set_output(0, in);
2603 return OkStatus();
2604 }
2605
2606 shape_inference::ShapeHandle group_assignment_shape = c->input(1);
2607 if (c->Rank(group_assignment_shape) != 2)
2608 return errors::InvalidArgument(
2609 "ReduceScatter group_assignment should be rank 2");
2610
2611 const Tensor* scatter_dimension = c->input_tensor(2);
2612 if (!scatter_dimension) {
2613 c->set_output(0, c->UnknownShape());
2614 return OkStatus();
2615 }
2616 int64_t scatter_dim;
2617 TF_RETURN_IF_ERROR(c->GetScalarFromTensor(scatter_dimension, &scatter_dim));
2618
2619 std::vector<shape_inference::DimensionHandle> out_dims;
2620 out_dims.reserve(c->Rank(in));
2621 for (int i = 0; i < c->Rank(in); ++i) {
2622 // If the dimension is the scatter_dimension, then divide the dimension
2623 // by the partition size in the group_assignment.
2624 if (i == scatter_dim) {
2625 shape_inference::DimensionHandle dim = c->Dim(in, i);
2626 shape_inference::DimensionHandle out_dim;
2627 TF_RETURN_IF_ERROR(c->Divide(dim, c->Dim(group_assignment_shape, 1),
2628 /*evenly_divisible=*/true, &out_dim));
2629 out_dims.push_back(out_dim);
2630 } else {
2631 out_dims.emplace_back(c->Dim(in, i));
2632 }
2633 }
2634 c->set_output(0, c->MakeShape(out_dims));
2635 return OkStatus();
2636 }
2637
2638 } // namespace shape_inference
2639
2640 } // namespace tensorflow
2641