• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/framework/common_shape_fns.h"
16 
17 #include "absl/container/flat_hash_map.h"
18 #include "absl/container/flat_hash_set.h"
19 #include "absl/strings/match.h"
20 #include "absl/strings/str_split.h"
21 #include "absl/strings/string_view.h"
22 #include "tensorflow/core/framework/attr_value.pb.h"
23 #include "tensorflow/core/framework/shape_inference.h"
24 #include "tensorflow/core/lib/core/errors.h"
25 #include "tensorflow/core/lib/gtl/inlined_vector.h"
26 #include "tensorflow/core/util/einsum_op_util.h"
27 
28 namespace tensorflow {
29 
30 namespace shape_inference {
31 
32 // The V2 version computes windowed output size with arbitrary dilation_rate and
33 // explicit padding, while the original version only handles the cases where
34 // dilation_rates equal to 1 and the padding is SAME or VALID.
GetWindowedOutputSizeFromDimsV2(shape_inference::InferenceContext * c,shape_inference::DimensionHandle input_size,shape_inference::DimensionOrConstant filter_size,int64_t dilation_rate,int64_t stride,Padding padding_type,int64_t padding_before,int64_t padding_after,shape_inference::DimensionHandle * output_size)35 Status GetWindowedOutputSizeFromDimsV2(
36     shape_inference::InferenceContext* c,
37     shape_inference::DimensionHandle input_size,
38     shape_inference::DimensionOrConstant filter_size, int64_t dilation_rate,
39     int64_t stride, Padding padding_type, int64_t padding_before,
40     int64_t padding_after, shape_inference::DimensionHandle* output_size) {
41   if (stride <= 0) {
42     return errors::InvalidArgument("Stride must be > 0, but got ", stride);
43   }
44 
45   if (dilation_rate < 1) {
46     return errors::InvalidArgument("Dilation rate must be >= 1, but got ",
47                                    dilation_rate);
48   }
49 
50   // See also the parallel implementation in GetWindowedOutputSizeVerbose.
51   switch (padding_type) {
52     case Padding::VALID:
53       padding_before = padding_after = 0;
54       TF_FALLTHROUGH_INTENDED;
55     case Padding::EXPLICIT:
56       TF_RETURN_IF_ERROR(
57           c->Add(input_size, padding_before + padding_after, &input_size));
58       if (dilation_rate > 1) {
59         DimensionHandle window_size;
60         TF_RETURN_IF_ERROR(
61             c->Subtract(c->MakeDim(filter_size), 1, &window_size));
62         TF_RETURN_IF_ERROR(
63             c->Multiply(window_size, dilation_rate, &window_size));
64         TF_RETURN_IF_ERROR(c->Add(window_size, 1, &window_size));
65         TF_RETURN_IF_ERROR(c->Subtract(input_size, window_size, output_size));
66       } else {
67         TF_RETURN_IF_ERROR(c->Subtract(input_size, filter_size, output_size));
68       }
69       TF_RETURN_IF_ERROR(c->Add(*output_size, stride, output_size));
70       TF_RETURN_IF_ERROR(c->Divide(*output_size, stride,
71                                    /*evenly_divisible=*/false, output_size));
72       break;
73     case Padding::SAME:
74       TF_RETURN_IF_ERROR(c->Add(input_size, stride - 1, output_size));
75       TF_RETURN_IF_ERROR(c->Divide(*output_size, stride,
76                                    /*evenly_divisible=*/false, output_size));
77       break;
78   }
79   return Status::OK();
80 }
81 
GetWindowedOutputSizeFromDims(shape_inference::InferenceContext * c,shape_inference::DimensionHandle input_size,shape_inference::DimensionOrConstant filter_size,int64_t stride,Padding padding_type,shape_inference::DimensionHandle * output_size)82 Status GetWindowedOutputSizeFromDims(
83     shape_inference::InferenceContext* c,
84     shape_inference::DimensionHandle input_size,
85     shape_inference::DimensionOrConstant filter_size, int64_t stride,
86     Padding padding_type, shape_inference::DimensionHandle* output_size) {
87   if (padding_type == Padding::EXPLICIT) {
88     return errors::Internal(
89         "GetWindowedOutputSizeFromDims does not handle EXPLICIT padding; call "
90         "GetWindowedOutputSizeFromDimsV2 instead");
91   }
92   return GetWindowedOutputSizeFromDimsV2(c, input_size, filter_size,
93                                          /*dilation_rate=*/1, stride,
94                                          padding_type,
95                                          // Give dummy values of -1 to
96                                          // padding_before and padding_after,
97                                          // since explicit padding is not used.
98                                          -1, -1, output_size);
99 }
100 
UnchangedShape(shape_inference::InferenceContext * c)101 Status UnchangedShape(shape_inference::InferenceContext* c) {
102   c->set_output(0, c->input(0));
103   auto* handle_data = c->input_handle_shapes_and_types(0);
104   if (handle_data != nullptr) {
105     c->set_output_handle_shapes_and_types(0, *handle_data);
106   }
107   return Status::OK();
108 }
109 
MatMulShape(shape_inference::InferenceContext * c)110 Status MatMulShape(shape_inference::InferenceContext* c) {
111   ShapeHandle a;
112   TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &a));
113 
114   ShapeHandle b;
115   TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &b));
116 
117   bool transpose_a, transpose_b;
118   TF_RETURN_IF_ERROR(c->GetAttr("transpose_a", &transpose_a));
119   TF_RETURN_IF_ERROR(c->GetAttr("transpose_b", &transpose_b));
120   DimensionHandle output_rows = transpose_a ? c->Dim(a, 1) : c->Dim(a, 0);
121   DimensionHandle output_cols = transpose_b ? c->Dim(b, 0) : c->Dim(b, 1);
122 
123   // Validate that the inner shapes are compatible.
124   DimensionHandle inner_a = transpose_a ? c->Dim(a, 0) : c->Dim(a, 1);
125   DimensionHandle inner_b = transpose_b ? c->Dim(b, 1) : c->Dim(b, 0);
126   DimensionHandle merged;
127   TF_RETURN_IF_ERROR(c->Merge(inner_a, inner_b, &merged));
128 
129   c->set_output(0, c->Matrix(output_rows, output_cols));
130   return Status::OK();
131 }
132 
133 namespace {
134 
135 // Validate that an Einsum subscript contains exactly one or zero ellipsis; and
136 // that periods (.) occur only within an ellipses (...).
ValidateEinsumEllipsis(absl::string_view subscript,bool * found_ellipsis)137 Status ValidateEinsumEllipsis(absl::string_view subscript,
138                               bool* found_ellipsis) {
139   const int num_periods = absl::c_count(subscript, '.');
140   if (num_periods != 0 && num_periods != 3) {
141     return errors::InvalidArgument(
142         "Expected at most one ellipsis (...), but found ", num_periods,
143         " periods (.) in the input subscript: ", subscript);
144   }
145   if (num_periods == 3 && !absl::StrContains(subscript, "...")) {
146     return errors::InvalidArgument(
147         "Periods found outside of ellipsis in subscript: ", subscript);
148   }
149   *found_ellipsis = num_periods > 0;
150   return Status::OK();
151 }
152 
153 }  // namespace
154 
EinsumShape(shape_inference::InferenceContext * c)155 Status EinsumShape(shape_inference::InferenceContext* c) {
156   // We assume that the equation has a valid format. Either (x),(y)->(z)
157   // or (x)->(z), where each of (x), (y) and (z) are concatenation of zero or
158   // more latin alphabets and contains at most one ellipsis ('...').
159   string equation;
160   TF_RETURN_IF_ERROR(c->GetAttr("equation", &equation));
161   gtl::InlinedVector<string, 2> input_labels;
162   string output_labels;
163   TF_RETURN_IF_ERROR(
164       ParseEinsumEquation(equation, &input_labels, &output_labels));
165 
166   if (c->num_inputs() == 0 || c->num_inputs() > 2) {
167     return errors::InvalidArgument("Expected either 1 or 2 inputs but got: ",
168                                    c->num_inputs());
169   }
170   const int input_labels_size = input_labels.size();
171   if (c->num_inputs() != input_labels_size) {
172     return errors::InvalidArgument("Expected ", input_labels.size(),
173                                    " inputs for equation ", equation,
174                                    " but got: ", c->num_inputs());
175   }
176 
177   // Validate input subscripts, build the label to dimension mapping and obtain
178   // the broadcast shapes that map to ellipsis.
179   absl::flat_hash_map<char, DimensionHandle> label_to_dimension;
180   gtl::InlinedVector<ShapeHandle, 2> input_bcast_shapes(c->num_inputs());
181   for (int i = 0, end = c->num_inputs(); i < end; ++i) {
182     bool has_ellipsis = false;
183     TF_RETURN_IF_ERROR(ValidateEinsumEllipsis(input_labels[i], &has_ellipsis));
184     ShapeHandle input_shape = c->input(i);
185     // Validate that the input rank is sufficient for the given number of named
186     // labels.
187     if (c->RankKnown(input_shape)) {
188       if (has_ellipsis) {
189         const int num_named_labels =
190             static_cast<int>(input_labels[i].size()) - 3;
191         TF_RETURN_WITH_CONTEXT_IF_ERROR(
192             c->WithRankAtLeast(input_shape, num_named_labels, &input_shape),
193             " for ", i, "th input and equation: ", equation);
194       } else {
195         const int num_named_labels = static_cast<int>(input_labels[i].size());
196         TF_RETURN_WITH_CONTEXT_IF_ERROR(
197             c->WithRank(input_shape, num_named_labels, &input_shape), " for ",
198             i, "th input and equation: ", equation);
199       }
200     }
201 
202     bool seen_ellipsis = false;
203     input_bcast_shapes[i] = c->Scalar();
204     // Run through the input labels; populate label_to_dimension mapping and
205     // compute the broadcast shapes corresponding to the ellipsis (if present).
206     for (int label_idx = 0, end = input_labels[i].size(); label_idx < end;
207          ++label_idx) {
208       const char label = input_labels[i][label_idx];
209       // Calculate the input axis that the current label is referring to. After
210       // the ellipsis, the axis may be found by using negative indices; i.e the
211       // (rank - k)th dimension corresponds to the (num_labels - k)th label.
212       const int64_t axis_before_ellipsis = label_idx;
213       const int64_t axis_after_ellipsis =
214           c->RankKnown(input_shape)
215               ? label_idx + c->Rank(input_shape) - input_labels[i].size()
216               : -1;
217 
218       // Populate the input broadcast shape when we encounter an ellipsis (...).
219       if (label == '.') {
220         if (!c->RankKnown(input_shape)) {
221           input_bcast_shapes[i] = c->UnknownShape();
222         } else {
223           // The broadcast shape runs till the named label right after the
224           // ellipsis, the label with index (label_idx + 3).
225           TF_RETURN_IF_ERROR(c->Subshape(input_shape, axis_before_ellipsis,
226                                          axis_after_ellipsis + 3,
227                                          &input_bcast_shapes[i]));
228         }
229         label_idx += 2;  // Skip the rest of the ellipsis.
230         seen_ellipsis = true;
231         continue;
232       }
233       // Obtain the dimension that the current label corresponds to.
234       int64_t axis = seen_ellipsis ? axis_after_ellipsis : axis_before_ellipsis;
235       DimensionHandle new_dim = c->RankKnown(input_shape)
236                                     ? c->Dim(input_shape, axis)
237                                     : c->UnknownDim();
238       // If we've seen this label before, make sure previous and current
239       // dimensions are compatible.
240       if (label_to_dimension.contains(label)) {
241         DimensionHandle merged;
242         TF_RETURN_IF_ERROR(
243             c->Merge(label_to_dimension[label], new_dim, &merged));
244         label_to_dimension[label] = merged;
245       } else {
246         label_to_dimension[label] = new_dim;
247       }
248     }
249   }
250 
251   // For two inputs, broadcast the two input broadcast shapes to create the
252   // output broadcast shape. For one input, just copy the single broadcast
253   // shape.
254   ShapeHandle output_bcast_shape;
255   if (input_bcast_shapes.size() == 1) {
256     output_bcast_shape = input_bcast_shapes[0];
257   } else if (input_bcast_shapes.size() == 2) {
258     TF_RETURN_IF_ERROR(BroadcastBinaryOpOutputShapeFnHelper(
259         c, input_bcast_shapes[0], input_bcast_shapes[1], true,
260         &output_bcast_shape));
261   }
262 
263   bool output_has_ellipsis = false;
264   TF_RETURN_IF_ERROR(
265       ValidateEinsumEllipsis(output_labels, &output_has_ellipsis));
266   if (output_has_ellipsis) {
267     // If the output subscript has ellipsis and the output broadcast rank is
268     // unknown, then the output shape should have unknown rank.
269     if (!c->RankKnown(output_bcast_shape)) {
270       c->set_output(0, c->UnknownShape());
271       return Status::OK();
272     }
273   } else {
274     // If the output subscripts don't have ellipsis then make sure the output
275     // broadcasting shape is empty.
276     TF_RETURN_WITH_CONTEXT_IF_ERROR(
277         c->WithRankAtMost(output_bcast_shape, 0, &output_bcast_shape),
278         " for einsum equation '", equation,
279         "' without ellipsis (...) in the output subscripts where input(s) have "
280         "non-empty broadcasting shape");
281     output_bcast_shape = c->Scalar();
282   }
283 
284   // Create the output shape from output labels and label_to_dimension mapping.
285   std::vector<DimensionHandle> output_dims;
286   for (int label_idx = 0, end = output_labels.size(); label_idx < end;
287        ++label_idx) {
288     const char label = output_labels[label_idx];
289     // Append the output_bcast_shape when the ellipsis is encountered.
290     if (label == '.') {
291       for (int k = 0; k < c->Rank(output_bcast_shape); ++k) {
292         output_dims.push_back(c->Dim(output_bcast_shape, k));
293       }
294       label_idx += 2;  // Skip the rest of the ellipsis.
295       continue;
296     }
297     auto dimension_it = label_to_dimension.find(label);
298     if (dimension_it == label_to_dimension.end()) {
299       return errors::InvalidArgument(
300           "Einsum output subscripts for equation '", equation, "' has label '",
301           label, "' which is not present in the input subscripts");
302     }
303     output_dims.push_back(dimension_it->second);
304   }
305   c->set_output(0, c->MakeShape(output_dims));
306   return Status::OK();
307 }
308 
BatchMatMulV2Shape(shape_inference::InferenceContext * c)309 Status BatchMatMulV2Shape(shape_inference::InferenceContext* c) {
310   ShapeHandle a_shape;
311   ShapeHandle b_shape;
312   TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &a_shape));
313   TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 2, &b_shape));
314 
315   // Determine output rows and columns.
316   bool adj_x;
317   bool adj_y;
318   TF_RETURN_IF_ERROR(c->GetAttr("adj_x", &adj_x));
319   TF_RETURN_IF_ERROR(c->GetAttr("adj_y", &adj_y));
320   DimensionHandle output_rows = c->Dim(a_shape, adj_x ? -1 : -2);
321   DimensionHandle output_cols = c->Dim(b_shape, adj_y ? -2 : -1);
322 
323   // Inner dimensions should be compatible.
324   DimensionHandle inner_merged;
325   TF_RETURN_IF_ERROR(c->Merge(c->Dim(a_shape, adj_x ? -2 : -1),
326                               c->Dim(b_shape, adj_y ? -1 : -2), &inner_merged));
327 
328   // Batch dimensions should broadcast with each other.
329   ShapeHandle a_batch_shape;
330   ShapeHandle b_batch_shape;
331   ShapeHandle output_batch_shape;
332   TF_RETURN_IF_ERROR(c->Subshape(a_shape, 0, -2, &a_batch_shape));
333   TF_RETURN_IF_ERROR(c->Subshape(b_shape, 0, -2, &b_batch_shape));
334 
335   TF_RETURN_IF_ERROR(BroadcastBinaryOpOutputShapeFnHelper(
336       c, a_batch_shape, b_batch_shape, true, &output_batch_shape));
337 
338   ShapeHandle output_shape;
339   TF_RETURN_IF_ERROR(c->Concatenate(
340       output_batch_shape, c->Matrix(output_rows, output_cols), &output_shape));
341 
342   c->set_output(0, output_shape);
343   return Status::OK();
344 }
345 
BatchMatMulShape(shape_inference::InferenceContext * c)346 Status BatchMatMulShape(shape_inference::InferenceContext* c) {
347   ShapeHandle a_shape;
348   ShapeHandle b_shape;
349   TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &a_shape));
350   TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 2, &b_shape));
351 
352   // Determine output rows and cols.
353   bool adj_x;
354   bool adj_y;
355   TF_RETURN_IF_ERROR(c->GetAttr("adj_x", &adj_x));
356   TF_RETURN_IF_ERROR(c->GetAttr("adj_y", &adj_y));
357   DimensionHandle output_rows = c->Dim(a_shape, adj_x ? -1 : -2);
358   DimensionHandle output_cols = c->Dim(b_shape, adj_y ? -2 : -1);
359 
360   // Batch dims match between inputs.
361   ShapeHandle a_batch_dims;
362   ShapeHandle b_batch_dims;
363   ShapeHandle batch_dims;
364   TF_RETURN_IF_ERROR(c->Subshape(a_shape, 0, -2, &a_batch_dims));
365   TF_RETURN_IF_ERROR(c->Subshape(b_shape, 0, -2, &b_batch_dims));
366   TF_RETURN_IF_ERROR(c->Merge(a_batch_dims, b_batch_dims, &batch_dims));
367 
368   // Assert inner dims match.
369   DimensionHandle unused;
370   TF_RETURN_IF_ERROR(c->Merge(c->Dim(a_shape, adj_x ? -2 : -1),
371                               c->Dim(b_shape, adj_y ? -1 : -2), &unused));
372 
373   ShapeHandle out;
374   TF_RETURN_IF_ERROR(
375       c->Concatenate(batch_dims, c->Matrix(output_rows, output_cols), &out));
376   c->set_output(0, out);
377   return Status::OK();
378 }
379 
380 // --------------------------------------------------------------------------
381 
BiasAddShape(shape_inference::InferenceContext * c)382 Status BiasAddShape(shape_inference::InferenceContext* c) {
383   ShapeHandle input_shape;
384 
385   // Fetch the data_format attribute, which may not exist.
386   string data_format;
387   Status s = c->GetAttr("data_format", &data_format);
388 
389   if (s.ok() && data_format == "NCHW") {
390     TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 3, &input_shape));
391   } else {
392     TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input_shape));
393   }
394 
395   ShapeHandle bias_shape;
396   TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &bias_shape));
397   DimensionHandle bias_dim = c->Dim(bias_shape, 0);
398 
399   // If rank unknown, return unknown shape.
400   if (!c->RankKnown(input_shape)) {
401     c->set_output(0, c->UnknownShape());
402     return Status::OK();
403   }
404 
405   // Output has the same shape as the input, and matches the length of
406   // the bias in its bias dimension.
407   ShapeHandle output_shape;
408   if (s.ok() && data_format == "NCHW") {
409     // Merge the length of bias_shape into the third to last dimension
410     ShapeHandle first;
411     TF_RETURN_IF_ERROR(c->Subshape(input_shape, 0, 1, &first));
412 
413     ShapeHandle last;
414     TF_RETURN_IF_ERROR(c->Subshape(input_shape, 2, &last));
415 
416     DimensionHandle input_bias_dim = c->Dim(input_shape, 1);
417     DimensionHandle merged_bias_dim;
418     TF_RETURN_IF_ERROR(c->Merge(input_bias_dim, bias_dim, &merged_bias_dim));
419     ShapeHandle merged_bias = c->Vector(merged_bias_dim);
420 
421     ShapeHandle temp;
422     TF_RETURN_IF_ERROR(c->Concatenate(first, merged_bias, &temp));
423     TF_RETURN_IF_ERROR(c->Concatenate(temp, last, &output_shape));
424   } else {
425     ShapeHandle all_but_bias;
426     TF_RETURN_IF_ERROR(c->Subshape(input_shape, 0, -1, &all_but_bias));
427 
428     DimensionHandle input_bias_dim = c->Dim(input_shape, -1);
429     DimensionHandle merged_bias_dim;
430     TF_RETURN_IF_ERROR(c->Merge(input_bias_dim, bias_dim, &merged_bias_dim));
431 
432     ShapeHandle merged_bias = c->Vector(merged_bias_dim);
433     TF_RETURN_IF_ERROR(
434         c->Concatenate(all_but_bias, merged_bias, &output_shape));
435   }
436 
437   c->set_output(0, output_shape);
438   return Status::OK();
439 }
440 
BiasAddGradShape(shape_inference::InferenceContext * c)441 Status BiasAddGradShape(shape_inference::InferenceContext* c) {
442   ShapeHandle input_shape;
443   // Fetch the data_format attribute, which may not exist.
444   string data_format;
445   Status s = c->GetAttr("data_format", &data_format);
446 
447   if (s.ok() && data_format == "NCHW") {
448     TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 3, &input_shape));
449     c->set_output(0, c->Vector(c->Dim(input_shape, 1)));
450   } else {
451     TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input_shape));
452     c->set_output(0, c->Vector(c->Dim(input_shape, -1)));
453   }
454 
455   return Status::OK();
456 }
457 
CheckFormatConstraintsOnShape(const TensorFormat tensor_format,const ShapeHandle shape_handle,const string & tensor_name,shape_inference::InferenceContext * c)458 Status CheckFormatConstraintsOnShape(const TensorFormat tensor_format,
459                                      const ShapeHandle shape_handle,
460                                      const string& tensor_name,
461                                      shape_inference::InferenceContext* c) {
462   if (tensor_format == FORMAT_NCHW_VECT_C) {
463     // Check that the vect dim has size 4 or 32.
464     const int num_dims = c->Rank(shape_handle);
465     DimensionHandle vect_dim = c->Dim(
466         shape_handle, GetTensorInnerFeatureDimIndex(num_dims, tensor_format));
467     int64_t vect_dim_val = c->Value(vect_dim);
468     if (vect_dim_val != 4 && vect_dim_val != 32) {
469       return errors::InvalidArgument(
470           "VECT_C dimension must be 4 or 32, but is ", vect_dim_val);
471     }
472   }
473 
474   return Status::OK();
475 }
476 
DatasetIteratorShape(shape_inference::InferenceContext * c)477 Status DatasetIteratorShape(shape_inference::InferenceContext* c) {
478   shape_inference::ShapeHandle unused;
479   TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
480   std::vector<PartialTensorShape> output_shapes;
481   TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes));
482   const int output_shapes_size = output_shapes.size();
483   if (output_shapes_size != c->num_outputs()) {
484     return errors::InvalidArgument(
485         "`output_shapes` must be the same length as `output_types` (",
486         output_shapes.size(), " vs. ", c->num_outputs());
487   }
488   for (size_t i = 0; i < output_shapes.size(); ++i) {
489     shape_inference::ShapeHandle output_shape_handle;
490     TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(
491         output_shapes[i], &output_shape_handle));
492     c->set_output(static_cast<int>(i), output_shape_handle);
493   }
494   return Status::OK();
495 }
496 
MakeShapeFromFormat(TensorFormat format,DimensionOrConstant N,const std::vector<DimensionOrConstant> & spatial,DimensionOrConstant C,ShapeHandle * out,shape_inference::InferenceContext * context)497 Status MakeShapeFromFormat(TensorFormat format, DimensionOrConstant N,
498                            const std::vector<DimensionOrConstant>& spatial,
499                            DimensionOrConstant C, ShapeHandle* out,
500                            shape_inference::InferenceContext* context) {
501   const int num_dims = GetTensorDimsFromSpatialDims(spatial.size(), format);
502   std::vector<DimensionHandle> dims_actual(num_dims);
503   dims_actual[GetTensorBatchDimIndex(num_dims, format)] = context->MakeDim(N);
504   int outer_c_index = GetTensorFeatureDimIndex(num_dims, format);
505   dims_actual[outer_c_index] = context->MakeDim(C);
506   if (format == FORMAT_NCHW_VECT_C) {
507     dims_actual[GetTensorInnerFeatureDimIndex(num_dims, format)] =
508         context->MakeDim(4);
509   } else if (format == FORMAT_NHWC_VECT_W) {
510     dims_actual[GetTensorInnerWidthDimIndex(num_dims, format)] =
511         context->MakeDim(4);
512   }
513   for (int spatial_dim = 0, end = spatial.size(); spatial_dim < end;
514        spatial_dim++) {
515     dims_actual[GetTensorSpatialDimIndex(num_dims, format, spatial_dim)] =
516         context->MakeDim(spatial[spatial_dim]);
517   }
518   *out = context->MakeShape(dims_actual);
519   return Status::OK();
520 }
521 
DimensionsFromShape(ShapeHandle shape,TensorFormat format,DimensionHandle * batch_dim,gtl::MutableArraySlice<DimensionHandle> spatial_dims,DimensionHandle * filter_dim,InferenceContext * context)522 Status DimensionsFromShape(ShapeHandle shape, TensorFormat format,
523                            DimensionHandle* batch_dim,
524                            gtl::MutableArraySlice<DimensionHandle> spatial_dims,
525                            DimensionHandle* filter_dim,
526                            InferenceContext* context) {
527   const int32_t rank =
528       GetTensorDimsFromSpatialDims(spatial_dims.size(), format);
529   // Batch.
530   *batch_dim = context->Dim(shape, GetTensorBatchDimIndex(rank, format));
531   // Spatial.
532   for (int spatial_dim_index = 0, end = spatial_dims.size();
533        spatial_dim_index < end; ++spatial_dim_index) {
534     spatial_dims[spatial_dim_index] = context->Dim(
535         shape, GetTensorSpatialDimIndex(rank, format, spatial_dim_index));
536   }
537   // Channel.
538   *filter_dim = context->Dim(shape, GetTensorFeatureDimIndex(rank, format));
539   if (format == FORMAT_NCHW_VECT_C) {
540     TF_RETURN_IF_ERROR(context->Multiply(
541         *filter_dim,
542         context->Dim(shape, GetTensorInnerFeatureDimIndex(rank, format)),
543         filter_dim));
544   }
545   return Status::OK();
546 }
547 
548 // vect_size must be provided if format is NCHW_VECT_C.
ShapeFromDimensions(DimensionHandle batch_dim,gtl::ArraySlice<DimensionHandle> spatial_dims,DimensionHandle filter_dim,TensorFormat format,absl::optional<DimensionHandle> vect_size,InferenceContext * context,ShapeHandle * shape)549 Status ShapeFromDimensions(DimensionHandle batch_dim,
550                            gtl::ArraySlice<DimensionHandle> spatial_dims,
551                            DimensionHandle filter_dim, TensorFormat format,
552                            absl::optional<DimensionHandle> vect_size,
553                            InferenceContext* context, ShapeHandle* shape) {
554   const int32_t rank =
555       GetTensorDimsFromSpatialDims(spatial_dims.size(), format);
556   std::vector<DimensionHandle> out_dims(rank);
557 
558   // Batch.
559   out_dims[tensorflow::GetTensorBatchDimIndex(rank, format)] = batch_dim;
560   // Spatial.
561   for (int spatial_dim_index = 0, end = spatial_dims.size();
562        spatial_dim_index < end; ++spatial_dim_index) {
563     out_dims[tensorflow::GetTensorSpatialDimIndex(
564         rank, format, spatial_dim_index)] = spatial_dims[spatial_dim_index];
565   }
566   // Channel.
567   if (format == tensorflow::FORMAT_NCHW_VECT_C) {
568     // When format is NCHW_VECT_C, factor the feature map count into the outer
569     // feature count and the inner feature count (4 or 32).
570     CHECK(vect_size.has_value());  // Crash ok.
571     TF_RETURN_IF_ERROR(context->Divide(
572         filter_dim, *vect_size, /*evenly_divisible=*/true,
573         &out_dims[tensorflow::GetTensorFeatureDimIndex(rank, format)]));
574     out_dims[GetTensorInnerFeatureDimIndex(rank, format)] = *vect_size;
575   } else {
576     out_dims[tensorflow::GetTensorFeatureDimIndex(rank, format)] = filter_dim;
577   }
578 
579   *shape = context->MakeShape(out_dims);
580   return tensorflow::Status::OK();
581 }
582 
583 namespace {
584 
Conv2DShapeImpl(shape_inference::InferenceContext * c,bool supports_explicit_padding)585 Status Conv2DShapeImpl(shape_inference::InferenceContext* c,
586                        bool supports_explicit_padding) {
587   string data_format_str, filter_format_str;
588   if (!c->GetAttr("data_format", &data_format_str).ok()) {
589     data_format_str = "NHWC";
590   }
591   if (!c->GetAttr("filter_format", &filter_format_str).ok()) {
592     filter_format_str = "HWIO";
593   }
594 
595   TensorFormat data_format;
596   if (!FormatFromString(data_format_str, &data_format)) {
597     return errors::InvalidArgument("Invalid data format string: ",
598                                    data_format_str);
599   }
600   FilterTensorFormat filter_format;
601   if (!FilterFormatFromString(filter_format_str, &filter_format)) {
602     return errors::InvalidArgument("Invalid filter format string: ",
603                                    filter_format_str);
604   }
605 
606   constexpr int num_spatial_dims = 2;
607   const int rank = GetTensorDimsFromSpatialDims(num_spatial_dims, data_format);
608   ShapeHandle conv_input_shape;
609   TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &conv_input_shape));
610   TF_RETURN_IF_ERROR(CheckFormatConstraintsOnShape(
611       data_format, conv_input_shape, "conv_input", c));
612 
613   // The filter rank should match the input (4 for NCHW, 5 for NCHW_VECT_C).
614   ShapeHandle filter_shape;
615   TF_RETURN_IF_ERROR(c->WithRank(c->input(1), rank, &filter_shape));
616   TF_RETURN_IF_ERROR(
617       CheckFormatConstraintsOnShape(data_format, filter_shape, "filter", c));
618 
619   std::vector<int32> dilations;
620   TF_RETURN_IF_ERROR(c->GetAttr("dilations", &dilations));
621 
622   if (dilations.size() != 4) {
623     return errors::InvalidArgument(
624         "Conv2D requires the dilation attribute to contain 4 values, but got: ",
625         dilations.size());
626   }
627 
628   std::vector<int32> strides;
629   TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
630 
631   // strides.size() should be 4 (NCHW) even if the input is 5 (NCHW_VECT_C).
632   if (strides.size() != 4) {
633     return errors::InvalidArgument("Conv2D on data format ", data_format_str,
634                                    " requires the stride attribute to contain"
635                                    " 4 values, but got: ",
636                                    strides.size());
637   }
638 
639   const int32_t stride_rows = GetTensorDim(strides, data_format, 'H');
640   const int32_t stride_cols = GetTensorDim(strides, data_format, 'W');
641   const int32_t dilation_rows = GetTensorDim(dilations, data_format, 'H');
642   const int32_t dilation_cols = GetTensorDim(dilations, data_format, 'W');
643 
644   DimensionHandle batch_size_dim;
645   DimensionHandle input_depth_dim;
646   gtl::InlinedVector<DimensionHandle, 2> input_spatial_dims(2);
647   TF_RETURN_IF_ERROR(DimensionsFromShape(
648       conv_input_shape, data_format, &batch_size_dim,
649       absl::MakeSpan(input_spatial_dims), &input_depth_dim, c));
650 
651   DimensionHandle output_depth_dim = c->Dim(
652       filter_shape, GetFilterDimIndex<num_spatial_dims>(filter_format, 'O'));
653   DimensionHandle filter_rows_dim = c->Dim(
654       filter_shape, GetFilterDimIndex<num_spatial_dims>(filter_format, 'H'));
655   DimensionHandle filter_cols_dim = c->Dim(
656       filter_shape, GetFilterDimIndex<num_spatial_dims>(filter_format, 'W'));
657   DimensionHandle filter_input_depth_dim;
658   if (filter_format == FORMAT_OIHW_VECT_I) {
659     TF_RETURN_IF_ERROR(c->Multiply(
660         c->Dim(filter_shape,
661                GetFilterDimIndex<num_spatial_dims>(filter_format, 'I')),
662         c->Dim(filter_shape,
663                GetFilterTensorInnerInputChannelsDimIndex(rank, filter_format)),
664         &filter_input_depth_dim));
665   } else {
666     filter_input_depth_dim = c->Dim(
667         filter_shape, GetFilterDimIndex<num_spatial_dims>(filter_format, 'I'));
668   }
669 
670   // Check that the input tensor and the filter tensor agree on the channel
671   // count.
672   if (c->ValueKnown(input_depth_dim) && c->ValueKnown(filter_input_depth_dim)) {
673     int64_t input_depth_value = c->Value(input_depth_dim),
674             filter_input_depth_value = c->Value(filter_input_depth_dim);
675     if (filter_input_depth_value == 0)
676       return errors::InvalidArgument("Depth of filter must not be 0");
677     if (input_depth_value % filter_input_depth_value != 0)
678       return errors::InvalidArgument(
679           "Depth of input (", input_depth_value,
680           ") is not a multiple of input depth of filter (",
681           filter_input_depth_value, ")");
682     if (input_depth_value != filter_input_depth_value) {
683       int64_t num_groups = input_depth_value / filter_input_depth_value;
684       if (c->ValueKnown(output_depth_dim)) {
685         int64_t output_depth_value = c->Value(output_depth_dim);
686         if (num_groups == 0)
687           return errors::InvalidArgument("Number of groups must not be 0");
688         if (output_depth_value % num_groups != 0)
689           return errors::InvalidArgument(
690               "Depth of output (", output_depth_value,
691               ") is not a multiple of the number of groups (", num_groups, ")");
692       }
693     }
694   }
695 
696   Padding padding;
697   TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
698 
699   std::vector<int64> explicit_paddings;
700   if (supports_explicit_padding) {
701     Status s = c->GetAttr("explicit_paddings", &explicit_paddings);
702     // Use the default value, which is an empty list, if the attribute is not
703     // found. Otherwise return the error to the caller.
704     if (!s.ok() && !errors::IsNotFound(s)) {
705       return s;
706     }
707     TF_RETURN_IF_ERROR(CheckValidPadding(padding, explicit_paddings,
708                                          /*num_dims=*/4, data_format));
709   } else {
710     CHECK(padding != Padding::EXPLICIT);  // Crash ok.
711   }
712 
713   DimensionHandle output_rows, output_cols;
714   int64_t pad_rows_before = -1, pad_rows_after = -1;
715   int64_t pad_cols_before = -1, pad_cols_after = -1;
716   if (padding == Padding::EXPLICIT) {
717     GetExplicitPaddingForDim(explicit_paddings, data_format, 'H',
718                              &pad_rows_before, &pad_rows_after);
719     GetExplicitPaddingForDim(explicit_paddings, data_format, 'W',
720                              &pad_cols_before, &pad_cols_after);
721   }
722   TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDimsV2(
723       c, input_spatial_dims[0], filter_rows_dim, dilation_rows, stride_rows,
724       padding, pad_rows_before, pad_rows_after, &output_rows));
725   TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDimsV2(
726       c, input_spatial_dims[1], filter_cols_dim, dilation_cols, stride_cols,
727       padding, pad_cols_before, pad_cols_after, &output_cols));
728 
729   absl::optional<DimensionHandle> vect_size;
730   if (data_format == FORMAT_NCHW_VECT_C) {
731     vect_size.emplace(c->Dim(conv_input_shape,
732                              GetTensorInnerFeatureDimIndex(rank, data_format)));
733   }
734   ShapeHandle output_shape;
735   TF_RETURN_IF_ERROR(ShapeFromDimensions(
736       batch_size_dim, {output_rows, output_cols}, output_depth_dim, data_format,
737       vect_size, c, &output_shape));
738   c->set_output(0, output_shape);
739   return Status::OK();
740 }
741 
742 }  // namespace
743 
744 // Shape function for Conv2D-like operations that support explicit padding.
Conv2DShapeWithExplicitPadding(shape_inference::InferenceContext * c)745 Status Conv2DShapeWithExplicitPadding(shape_inference::InferenceContext* c) {
746   return Conv2DShapeImpl(c, true);
747 }
748 
749 // Shape function for Conv2D-like operations that do not support explicit
750 // padding.
Conv2DShape(shape_inference::InferenceContext * c)751 Status Conv2DShape(shape_inference::InferenceContext* c) {
752   return Conv2DShapeImpl(c, false);
753 }
754 
755 // TODO(mjanusz): Unify all conv/pooling shape functions.
Conv3DShape(shape_inference::InferenceContext * c)756 Status Conv3DShape(shape_inference::InferenceContext* c) {
757   ShapeHandle input_shape;
758   TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 5, &input_shape));
759   ShapeHandle filter_shape;
760   TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 5, &filter_shape));
761 
762   string data_format;
763   Status s = c->GetAttr("data_format", &data_format);
764 
765   std::vector<int32> dilations;
766   TF_RETURN_IF_ERROR(c->GetAttr("dilations", &dilations));
767 
768   if (dilations.size() != 5) {
769     return errors::InvalidArgument(
770         "Conv3D requires the dilation attribute to contain 5 values, but got: ",
771         dilations.size());
772   }
773 
774   std::vector<int32> strides;
775   TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
776   if (strides.size() != 5) {
777     return errors::InvalidArgument(
778         "Conv3D requires the stride attribute to contain 5 values, but got: ",
779         strides.size());
780   }
781 
782   int32_t stride_planes, stride_rows, stride_cols;
783   int32_t dilation_planes, dilation_rows, dilation_cols;
784   if (s.ok() && data_format == "NCDHW") {
785     // Convert input_shape to NDHWC.
786     auto dim = [&](char dimension) {
787       return c->Dim(input_shape, GetTensorDimIndex<3>(FORMAT_NCHW, dimension));
788     };
789     input_shape =
790         c->MakeShape({{dim('N'), dim('0'), dim('1'), dim('2'), dim('C')}});
791     stride_planes = strides[2];
792     stride_rows = strides[3];
793     stride_cols = strides[4];
794     dilation_planes = dilations[2];
795     dilation_cols = dilations[3];
796     dilation_rows = dilations[4];
797   } else {
798     stride_planes = strides[1];
799     stride_rows = strides[2];
800     stride_cols = strides[3];
801     dilation_planes = dilations[1];
802     dilation_cols = dilations[2];
803     dilation_rows = dilations[3];
804   }
805 
806   DimensionHandle batch_size_dim = c->Dim(input_shape, 0);
807   DimensionHandle in_planes_dim = c->Dim(input_shape, 1);
808   DimensionHandle in_rows_dim = c->Dim(input_shape, 2);
809   DimensionHandle in_cols_dim = c->Dim(input_shape, 3);
810   DimensionHandle input_depth_dim = c->Dim(input_shape, 4);
811 
812   DimensionHandle filter_planes_dim = c->Dim(filter_shape, 0);
813   DimensionHandle filter_rows_dim = c->Dim(filter_shape, 1);
814   DimensionHandle filter_cols_dim = c->Dim(filter_shape, 2);
815   DimensionHandle filter_input_depth_dim = c->Dim(filter_shape, 3);
816   DimensionHandle output_depth_dim = c->Dim(filter_shape, 4);
817 
818   // Check that the input tensor and the filter tensor agree on the channel
819   // count.
820   if (c->ValueKnown(input_depth_dim) && c->ValueKnown(filter_input_depth_dim)) {
821     int64_t input_depth_value = c->Value(input_depth_dim),
822             filter_input_depth_value = c->Value(filter_input_depth_dim);
823     if (filter_input_depth_value == 0)
824       return errors::InvalidArgument("Depth of filter must not be 0");
825     if (input_depth_value % filter_input_depth_value != 0)
826       return errors::InvalidArgument(
827           "Depth of input (", input_depth_value,
828           ") is not a multiple of input depth of filter (",
829           filter_input_depth_value, ")");
830     if (input_depth_value != filter_input_depth_value) {
831       int64_t num_groups = input_depth_value / filter_input_depth_value;
832       if (c->ValueKnown(output_depth_dim)) {
833         int64_t output_depth_value = c->Value(output_depth_dim);
834         if (num_groups == 0)
835           return errors::InvalidArgument("Number of groups must not be 0");
836         if (output_depth_value % num_groups != 0)
837           return errors::InvalidArgument(
838               "Depth of output (", output_depth_value,
839               ") is not a multiple of the number of groups (", num_groups, ")");
840       }
841     }
842   }
843 
844   Padding padding;
845   TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
846   DimensionHandle output_planes, output_rows, output_cols;
847 
848   TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDimsV2(
849       c, in_planes_dim, filter_planes_dim, dilation_planes, stride_planes,
850       padding, -1, -1, &output_planes));
851   TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDimsV2(
852       c, in_rows_dim, filter_rows_dim, dilation_rows, stride_rows, padding, -1,
853       -1, &output_rows));
854   TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDimsV2(
855       c, in_cols_dim, filter_cols_dim, dilation_cols, stride_cols, padding, -1,
856       -1, &output_cols));
857 
858   ShapeHandle output_shape;
859   if (data_format == "NCDHW") {
860     output_shape = c->MakeShape({batch_size_dim, output_depth_dim,
861                                  output_planes, output_rows, output_cols});
862   } else {
863     output_shape = c->MakeShape({batch_size_dim, output_planes, output_rows,
864                                  output_cols, output_depth_dim});
865   }
866   c->set_output(0, output_shape);
867   return Status::OK();
868 }
869 
Conv2DBackpropInputShape(shape_inference::InferenceContext * c)870 Status Conv2DBackpropInputShape(shape_inference::InferenceContext* c) {
871   string data_format_str;
872   if (!c->GetAttr("data_format", &data_format_str).ok()) {
873     data_format_str = "NHWC";
874   }
875   TensorFormat data_format;
876   if (!FormatFromString(data_format_str, &data_format)) {
877     return errors::InvalidArgument("Invalid data format string: ",
878                                    data_format_str);
879   }
880 
881   // For the rest of this function, output_grad_* describes out_backprop and
882   // input_grad_* describes in_backprop.
883   ShapeHandle output_grad_shape = c->input(2);
884   TF_RETURN_IF_ERROR(c->WithRank(output_grad_shape, 4, &output_grad_shape));
885   ShapeHandle filter_shape = c->input(1);
886   TF_RETURN_IF_ERROR(c->WithRank(filter_shape, 4, &filter_shape));
887 
888   DimensionHandle batch_size_dim;
889   DimensionHandle output_grad_depth_dim;
890   gtl::InlinedVector<DimensionHandle, 2> output_grad_spatial_dims(2);
891   TF_RETURN_IF_ERROR(DimensionsFromShape(
892       output_grad_shape, data_format, &batch_size_dim,
893       absl::MakeSpan(output_grad_spatial_dims), &output_grad_depth_dim, c));
894   DimensionHandle unused;
895   TF_RETURN_IF_ERROR(
896       c->Merge(output_grad_depth_dim, c->Dim(filter_shape, 3), &unused));
897 
898   ShapeHandle specified_input_grad_shape;
899   TF_RETURN_IF_ERROR(
900       c->MakeShapeFromShapeTensor(0, &specified_input_grad_shape));
901   if (c->Rank(specified_input_grad_shape) == InferenceContext::kUnknownRank) {
902     TF_RETURN_IF_ERROR(c->WithRank(specified_input_grad_shape, 4,
903                                    &specified_input_grad_shape));
904   }
905 
906   // input_grad_depth_dim doesn't equal c->Dim(filter_shape,2) when the number
907   // of groups is larger than 1. If input_sizes is a 4D shape, we collect
908   // input_grad_depth_dim from input_sizes; otherwise we compute it as
909   // c->Dim(filter_shape,2).
910   DimensionHandle input_grad_depth_dim;
911   gtl::InlinedVector<DimensionHandle, 2> specified_input_grad_spatial_dims(2);
912   int specified_input_grad_rank = c->Rank(specified_input_grad_shape);
913   if (specified_input_grad_rank == 4) {
914     DimensionHandle specified_batch_size_dim;
915     TF_RETURN_IF_ERROR(DimensionsFromShape(
916         specified_input_grad_shape, data_format, &specified_batch_size_dim,
917         absl::MakeSpan(specified_input_grad_spatial_dims),
918         &input_grad_depth_dim, c));
919     TF_RETURN_IF_ERROR(
920         c->Merge(specified_batch_size_dim, batch_size_dim, &unused));
921   } else if (specified_input_grad_rank == 2) {
922     specified_input_grad_spatial_dims[0] =
923         c->Dim(specified_input_grad_shape, 0);
924     specified_input_grad_spatial_dims[1] =
925         c->Dim(specified_input_grad_shape, 1);
926     input_grad_depth_dim = c->Dim(filter_shape, 2);
927   } else {
928     return errors::InvalidArgument(
929         "Conv2DBackpropInput requires input_sizes to contain 4 values or 2 "
930         "values, but got: ",
931         specified_input_grad_rank);
932   }
933 
934   ShapeHandle input_grad_shape;
935   TF_RETURN_IF_ERROR(ShapeFromDimensions(
936       batch_size_dim, specified_input_grad_spatial_dims, input_grad_depth_dim,
937       data_format, /*vect_size=*/absl::nullopt, c, &input_grad_shape));
938   c->set_output(0, input_grad_shape);
939   return Status::OK();
940 }
941 
Conv2DBackpropFilterWithBiasShape(shape_inference::InferenceContext * c)942 Status Conv2DBackpropFilterWithBiasShape(shape_inference::InferenceContext* c) {
943   ShapeHandle input_shape;
944   // Fetch the data_format attribute, which may not exist.
945   string data_format;
946   Status s = c->GetAttr("data_format", &data_format);
947 
948   TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_shape));
949   if (s.ok() && data_format == "NCHW") {
950     c->set_output(1, c->Vector(c->Dim(input_shape, -3)));
951   } else {
952     c->set_output(1, c->Vector(c->Dim(input_shape, -1)));
953   }
954   ShapeHandle sh;
955   TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &sh));
956   TF_RETURN_IF_ERROR(c->WithRank(sh, 4, &sh));
957   c->set_output(0, sh);
958   return Status::OK();
959 }
960 
961 namespace {
962 
DepthwiseConv2DNativeShapeImpl(shape_inference::InferenceContext * c,bool supports_explicit_padding)963 Status DepthwiseConv2DNativeShapeImpl(shape_inference::InferenceContext* c,
964                                       bool supports_explicit_padding) {
965   ShapeHandle input_shape;
966   TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_shape));
967   ShapeHandle filter_shape;
968   TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 4, &filter_shape));
969 
970   std::vector<int32> strides;
971   TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
972 
973   if (strides.size() != 4) {
974     return errors::InvalidArgument(
975         "DepthwiseConv2D requires the stride attribute to contain 4 values, "
976         "but got: ",
977         strides.size());
978   }
979 
980   std::vector<int32> dilations;
981   if (!c->GetAttr("dilations", &dilations).ok()) {
982     dilations.resize(4, 1);
983   }
984 
985   if (dilations.size() != 4) {
986     return errors::InvalidArgument(
987         "DepthwiseConv2D requires the dilations attribute to contain 4 values, "
988         "but got: ",
989         dilations.size());
990   }
991 
992   string data_format_str;
993   Status s = c->GetAttr("data_format", &data_format_str);
994   TensorFormat data_format;
995   if (!s.ok() || !FormatFromString(data_format_str, &data_format)) {
996     data_format = FORMAT_NHWC;
997   }
998   int32_t stride_rows;
999   int32_t stride_cols;
1000   int32_t dilation_rows;
1001   int32_t dilation_cols;
1002   if (data_format == FORMAT_NCHW) {
1003     // Canonicalize input shape to NHWC so the shape inference code below can
1004     // process it.
1005     input_shape =
1006         c->MakeShape({{c->Dim(input_shape, 0), c->Dim(input_shape, 2),
1007                        c->Dim(input_shape, 3), c->Dim(input_shape, 1)}});
1008     stride_rows = strides[2];
1009     stride_cols = strides[3];
1010     dilation_rows = dilations[2];
1011     dilation_cols = dilations[3];
1012   } else {
1013     stride_rows = strides[1];
1014     stride_cols = strides[2];
1015     dilation_rows = dilations[1];
1016     dilation_cols = dilations[2];
1017   }
1018 
1019   DimensionHandle batch_size_dim = c->Dim(input_shape, 0);
1020   DimensionHandle in_rows_dim = c->Dim(input_shape, 1);
1021   DimensionHandle in_cols_dim = c->Dim(input_shape, 2);
1022 
1023   DimensionHandle filter_rows_dim = c->Dim(filter_shape, 0);
1024   DimensionHandle filter_cols_dim = c->Dim(filter_shape, 1);
1025   DimensionHandle input_depth = c->Dim(filter_shape, 2);
1026   DimensionHandle depth_multiplier = c->Dim(filter_shape, 3);
1027 
1028   // Check that the input depths are compatible.
1029   TF_RETURN_IF_ERROR(
1030       c->Merge(c->Dim(input_shape, 3), input_depth, &input_depth));
1031 
1032   DimensionHandle output_depth;
1033   TF_RETURN_IF_ERROR(c->Multiply(input_depth, depth_multiplier, &output_depth));
1034 
1035   Padding padding;
1036   TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
1037 
1038   std::vector<int64> explicit_paddings;
1039   if (supports_explicit_padding) {
1040     Status status = c->GetAttr("explicit_paddings", &explicit_paddings);
1041     // Use the default value, which is an empty list, if the attribute is not
1042     // found. Otherwise return the error to the caller.
1043     if (!status.ok() && !errors::IsNotFound(status)) {
1044       return status;
1045     }
1046     TF_RETURN_IF_ERROR(CheckValidPadding(padding, explicit_paddings,
1047                                          /*num_dims=*/4, data_format));
1048   } else {
1049     DCHECK(padding != Padding::EXPLICIT);
1050   }
1051 
1052   // TODO(mrry,shlens): Raise an error if the stride would cause
1053   // information in the input to be ignored. This will require a change
1054   // in the kernel implementation.
1055   DimensionHandle output_rows, output_cols;
1056   int64_t pad_rows_before = -1, pad_rows_after = -1;
1057   int64_t pad_cols_before = -1, pad_cols_after = -1;
1058   if (padding == Padding::EXPLICIT) {
1059     GetExplicitPaddingForDim(explicit_paddings, data_format, 'H',
1060                              &pad_rows_before, &pad_rows_after);
1061     GetExplicitPaddingForDim(explicit_paddings, data_format, 'W',
1062                              &pad_cols_before, &pad_cols_after);
1063   }
1064   TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDimsV2(
1065       c, in_rows_dim, filter_rows_dim, dilation_rows, stride_rows, padding,
1066       pad_rows_before, pad_rows_after, &output_rows));
1067   TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDimsV2(
1068       c, in_cols_dim, filter_cols_dim, dilation_cols, stride_cols, padding,
1069       pad_cols_before, pad_cols_after, &output_cols));
1070 
1071   ShapeHandle output_shape;
1072   if (data_format == FORMAT_NCHW) {
1073     output_shape =
1074         c->MakeShape({batch_size_dim, output_depth, output_rows, output_cols});
1075   } else {
1076     output_shape =
1077         c->MakeShape({batch_size_dim, output_rows, output_cols, output_depth});
1078   }
1079   c->set_output(0, output_shape);
1080   return Status::OK();
1081 }
1082 
1083 };  // namespace
1084 
DepthwiseConv2DNativeShape(shape_inference::InferenceContext * c)1085 Status DepthwiseConv2DNativeShape(shape_inference::InferenceContext* c) {
1086   return DepthwiseConv2DNativeShapeImpl(c, false);
1087 }
1088 
DepthwiseConv2DNativeShapeWithExplicitPadding(shape_inference::InferenceContext * c)1089 Status DepthwiseConv2DNativeShapeWithExplicitPadding(
1090     shape_inference::InferenceContext* c) {
1091   return DepthwiseConv2DNativeShapeImpl(c, true);
1092 }
1093 
AvgPoolShape(shape_inference::InferenceContext * c)1094 Status AvgPoolShape(shape_inference::InferenceContext* c) {
1095   string data_format_str;
1096   TensorFormat data_format;
1097   Status s = c->GetAttr("data_format", &data_format_str);
1098   if (s.ok()) {
1099     FormatFromString(data_format_str, &data_format);
1100   } else {
1101     data_format = FORMAT_NHWC;
1102   }
1103 
1104   const int rank = (data_format == FORMAT_NCHW_VECT_C) ? 5 : 4;
1105   ShapeHandle input_shape;
1106   TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &input_shape));
1107 
1108   TF_RETURN_IF_ERROR(
1109       CheckFormatConstraintsOnShape(data_format, input_shape, "input", c));
1110 
1111   std::vector<int32> strides;
1112   TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
1113   if (strides.size() != 4) {
1114     return errors::InvalidArgument(
1115         "AvgPool requires the stride attribute to contain 4 values, but got: ",
1116         strides.size());
1117   }
1118 
1119   std::vector<int32> kernel_sizes;
1120   TF_RETURN_IF_ERROR(c->GetAttr("ksize", &kernel_sizes));
1121   if (kernel_sizes.size() != 4) {
1122     return errors::InvalidArgument(
1123         "AvgPool requires the ksize attribute to contain 4 values, but got: ",
1124         kernel_sizes.size());
1125   }
1126 
1127   int32_t stride_rows = GetTensorDim(strides, data_format, 'H');
1128   int32_t stride_cols = GetTensorDim(strides, data_format, 'W');
1129   int32_t kernel_rows = GetTensorDim(kernel_sizes, data_format, 'H');
1130   int32_t kernel_cols = GetTensorDim(kernel_sizes, data_format, 'W');
1131 
1132   constexpr int num_spatial_dims = 2;
1133   DimensionHandle batch_size_dim = c->Dim(
1134       input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'N'));
1135   DimensionHandle in_rows_dim = c->Dim(
1136       input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'H'));
1137   DimensionHandle in_cols_dim = c->Dim(
1138       input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'W'));
1139   DimensionHandle depth_dim = c->Dim(
1140       input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'C'));
1141 
1142   Padding padding;
1143   TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
1144 
1145   // TODO(mrry,shlens): Raise an error if the stride would cause
1146   // information in the input to be ignored. This will require a change
1147   // in the kernel implementation.
1148 
1149   DimensionHandle output_rows, output_cols;
1150   TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
1151       c, in_rows_dim, kernel_rows, stride_rows, padding, &output_rows));
1152   TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
1153       c, in_cols_dim, kernel_cols, stride_cols, padding, &output_cols));
1154 
1155   ShapeHandle output_shape;
1156   TF_RETURN_IF_ERROR(MakeShapeFromFormat(data_format, batch_size_dim,
1157                                          {output_rows, output_cols}, depth_dim,
1158                                          &output_shape, c));
1159   c->set_output(0, output_shape);
1160   return Status::OK();
1161 }
1162 
AvgPoolGradShape(shape_inference::InferenceContext * c)1163 Status AvgPoolGradShape(shape_inference::InferenceContext* c) {
1164   ShapeHandle s;
1165   TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s));
1166   TF_RETURN_IF_ERROR(c->WithRank(s, 4, &s));
1167   c->set_output(0, s);
1168   return Status::OK();
1169 }
1170 
FusedBatchNormShape(shape_inference::InferenceContext * c)1171 Status FusedBatchNormShape(shape_inference::InferenceContext* c) {
1172   string data_format_str;
1173   TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format_str));
1174   TensorFormat data_format;
1175   if (!FormatFromString(data_format_str, &data_format)) {
1176     return errors::InvalidArgument("Invalid data format string: ",
1177                                    data_format_str);
1178   }
1179   const int rank =
1180       (data_format_str == "NDHWC" || data_format_str == "NCDHW") ? 5 : 4;
1181   ShapeHandle x;
1182   TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &x));
1183 
1184   bool is_training;
1185   TF_RETURN_IF_ERROR(c->GetAttr("is_training", &is_training));
1186   float exponential_avg_factor;
1187   if (!c->GetAttr("exponential_avg_factor", &exponential_avg_factor).ok()) {
1188     exponential_avg_factor = 1.0f;  // default value
1189   }
1190   int number_inputs = (is_training && exponential_avg_factor == 1.0f) ? 3 : 5;
1191 
1192   int channel_dim_index = GetTensorFeatureDimIndex(rank, data_format);
1193   DimensionHandle channel_dim = c->Dim(x, channel_dim_index);
1194 
1195   // covers scale, offset, and if is_training is false, mean, variance
1196   for (int i = 1; i < number_inputs; ++i) {
1197     ShapeHandle vec;
1198     TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &vec));
1199     TF_RETURN_IF_ERROR(c->Merge(channel_dim, c->Dim(vec, 0), &channel_dim));
1200   }
1201 
1202   ShapeHandle y;
1203   TF_RETURN_IF_ERROR(c->ReplaceDim(x, channel_dim_index, channel_dim, &y));
1204   c->set_output(0, y);
1205   ShapeHandle vector_shape = c->Vector(channel_dim);
1206   c->set_output(1, vector_shape);
1207   c->set_output(2, vector_shape);
1208   c->set_output(3, vector_shape);
1209   c->set_output(4, vector_shape);
1210   return Status::OK();
1211 }
1212 
FusedBatchNormV3Shape(shape_inference::InferenceContext * c)1213 Status FusedBatchNormV3Shape(shape_inference::InferenceContext* c) {
1214   TF_RETURN_IF_ERROR(FusedBatchNormShape(c));
1215   c->set_output(5, c->UnknownShape());
1216   return Status::OK();
1217 }
1218 
FusedBatchNormExShape(shape_inference::InferenceContext * c)1219 Status FusedBatchNormExShape(shape_inference::InferenceContext* c) {
1220   TF_RETURN_IF_ERROR(FusedBatchNormV3Shape(c));
1221 
1222   string data_format_str;
1223   TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format_str));
1224   TensorFormat data_format;
1225   if (!FormatFromString(data_format_str, &data_format)) {
1226     return errors::InvalidArgument("Invalid data format string: ",
1227                                    data_format_str);
1228   }
1229   ShapeHandle x;
1230   TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &x));
1231 
1232   int channel_dim_index = GetTensorFeatureDimIndex(4, data_format);
1233   DimensionHandle channel_dim = c->Dim(x, channel_dim_index);
1234 
1235   // This is a cuDNN implementation constraint.
1236   if (c->ValueKnown(channel_dim) && c->Value(channel_dim) % 4 != 0) {
1237     return errors::InvalidArgument(
1238         "_FusedBatchNormEx channel dimension must be divisible by 4.");
1239   }
1240 
1241   return Status::OK();
1242 }
1243 
FusedBatchNormGradShape(shape_inference::InferenceContext * c)1244 Status FusedBatchNormGradShape(shape_inference::InferenceContext* c) {
1245   string data_format_str;
1246   TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format_str));
1247   TensorFormat data_format;
1248   if (!FormatFromString(data_format_str, &data_format)) {
1249     return errors::InvalidArgument("Invalid data format string: ",
1250                                    data_format_str);
1251   }
1252   const int rank =
1253       (data_format_str == "NDHWC" || data_format_str == "NCDHW") ? 5 : 4;
1254   ShapeHandle y_backprop;
1255   TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &y_backprop));
1256   ShapeHandle x;
1257   TF_RETURN_IF_ERROR(c->WithRank(c->input(1), rank, &x));
1258 
1259   bool is_training;
1260   TF_RETURN_IF_ERROR(c->GetAttr("is_training", &is_training));
1261 
1262   int channel_dim_index = GetTensorFeatureDimIndex(rank, data_format);
1263   DimensionHandle channel_dim = c->Dim(y_backprop, channel_dim_index);
1264   TF_RETURN_IF_ERROR(
1265       c->Merge(channel_dim, c->Dim(x, channel_dim_index), &channel_dim));
1266 
1267   // covers scale, mean (reserve_space_1), variance (reserve_space_2)
1268   for (int i = 2; i < 5; ++i) {
1269     ShapeHandle vec;
1270     TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &vec));
1271     TF_RETURN_IF_ERROR(c->Merge(channel_dim, c->Dim(vec, 0), &channel_dim));
1272   }
1273 
1274   ShapeHandle x_backprop;
1275   TF_RETURN_IF_ERROR(
1276       c->ReplaceDim(y_backprop, channel_dim_index, channel_dim, &x_backprop));
1277   c->set_output(0, x_backprop);
1278   c->set_output(1, c->Vector(channel_dim));
1279   c->set_output(2, c->Vector(channel_dim));
1280   c->set_output(3, c->Vector(0));
1281   c->set_output(4, c->Vector(0));
1282   return Status::OK();
1283 }
1284 
ReadDiagIndex(InferenceContext * c,const Tensor * diag_index_tensor,int32 * lower_diag_index,int32 * upper_diag_index)1285 Status ReadDiagIndex(InferenceContext* c, const Tensor* diag_index_tensor,
1286                      int32* lower_diag_index, int32* upper_diag_index) {
1287   // This function assumes that the shape of diag_index_tensor is fully defined.
1288   if (diag_index_tensor->dims() == 0) {
1289     *lower_diag_index = diag_index_tensor->scalar<int32>()();
1290     *upper_diag_index = *lower_diag_index;
1291   } else {
1292     int32_t num_elements = diag_index_tensor->dim_size(0);
1293     if (num_elements == 1) {
1294       *lower_diag_index = diag_index_tensor->vec<int32>()(0);
1295       *upper_diag_index = *lower_diag_index;
1296     } else if (num_elements == 2) {
1297       *lower_diag_index = diag_index_tensor->vec<int32>()(0);
1298       *upper_diag_index = diag_index_tensor->vec<int32>()(1);
1299     } else {
1300       return errors::InvalidArgument(
1301           "diag_index must be a vector with one or two elements. It has ",
1302           num_elements, " elements.");
1303     }
1304   }
1305   return Status::OK();
1306 }
1307 
MatrixDiagPartV2Shape(shape_inference::InferenceContext * c)1308 Status MatrixDiagPartV2Shape(shape_inference::InferenceContext* c) {
1309   ShapeHandle input_shape, diag_index_shape, unused_shape;
1310   TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input_shape));
1311   TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(1), 1, &diag_index_shape));
1312   TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused_shape));
1313 
1314   const Tensor* diag_index_tensor = c->input_tensor(1);
1315   if (!c->RankKnown(input_shape) || !c->FullyDefined(diag_index_shape) ||
1316       diag_index_tensor == nullptr) {
1317     c->set_output(0, c->UnknownShape());
1318     return Status::OK();
1319   }
1320   int32_t lower_diag_index = 0;
1321   int32_t upper_diag_index = 0;
1322   TF_RETURN_IF_ERROR(ReadDiagIndex(c, diag_index_tensor, &lower_diag_index,
1323                                    &upper_diag_index));
1324   if (lower_diag_index > upper_diag_index) {
1325     return errors::InvalidArgument(
1326         "lower_diag_index is greater than upper_diag_index");
1327   }
1328 
1329   // Validates lower_diag_index and upper_diag_index.
1330   const int32_t input_rank = c->Rank(input_shape);
1331   const int32_t num_rows = c->Value(c->Dim(input_shape, input_rank - 2));
1332   const int32_t num_cols = c->Value(c->Dim(input_shape, input_rank - 1));
1333   int32_t max_diag_len = InferenceContext::kUnknownDim;
1334   if (num_rows != InferenceContext::kUnknownDim &&
1335       num_cols != InferenceContext::kUnknownDim) {
1336     if (lower_diag_index != 0 &&  // For when num_rows or num_cols == 0.
1337         (-num_rows >= lower_diag_index || lower_diag_index >= num_cols)) {
1338       return errors::InvalidArgument("lower_diag_index is out of bound.");
1339     }
1340     if (upper_diag_index != 0 &&  // For when num_rows or num_cols == 0.
1341         (-num_rows >= upper_diag_index || upper_diag_index >= num_cols)) {
1342       return errors::InvalidArgument("upper_diag_index is out of bound.");
1343     }
1344     max_diag_len = std::min(num_rows + std::min(upper_diag_index, 0),
1345                             num_cols - std::max(lower_diag_index, 0));
1346   }
1347 
1348   std::vector<DimensionHandle> dims;
1349   dims.reserve(input_rank - 2);
1350   for (int i = 0; i < input_rank - 2; ++i) {
1351     dims.push_back(c->Dim(input_shape, i));
1352   }
1353   if (lower_diag_index < upper_diag_index) {
1354     dims.push_back(c->MakeDim(upper_diag_index - lower_diag_index + 1));
1355   }
1356   dims.push_back(c->MakeDim(max_diag_len));
1357   c->set_output(0, c->MakeShape(dims));
1358   return Status::OK();
1359 }
1360 
MatrixDiagV2Shape(shape_inference::InferenceContext * c)1361 Status MatrixDiagV2Shape(shape_inference::InferenceContext* c) {
1362   // Checks input ranks.
1363   ShapeHandle input_shape, diag_index_shape, unused_shape;
1364   TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &input_shape));
1365   TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(1), 1, &diag_index_shape));
1366   TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused_shape));
1367   TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused_shape));
1368   TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused_shape));
1369 
1370   // Reads the diagonal indices.
1371   const Tensor* diag_index_tensor = c->input_tensor(1);
1372   if (!c->RankKnown(input_shape) || !c->FullyDefined(diag_index_shape) ||
1373       diag_index_tensor == nullptr) {
1374     c->set_output(0, c->UnknownShape());
1375     return Status::OK();
1376   }
1377   int32_t lower_diag_index = 0;
1378   int32_t upper_diag_index = 0;
1379   TF_RETURN_IF_ERROR(ReadDiagIndex(c, diag_index_tensor, &lower_diag_index,
1380                                    &upper_diag_index));
1381   if (lower_diag_index > upper_diag_index) {
1382     return errors::InvalidArgument(
1383         "lower_diag_index is greater than upper_diag_index");
1384   }
1385 
1386   // Checks if the number of diagonals provided matches what we imply from
1387   // lower_diag_index and upper_diag_index.
1388   const int32_t input_rank = c->Rank(input_shape);
1389   if (lower_diag_index < upper_diag_index) {
1390     const int32_t num_diags = c->Value(c->Dim(input_shape, input_rank - 2));
1391     const int32_t other_dim = c->Value(c->Dim(input_shape, input_rank - 1));
1392 
1393     if (num_diags != (upper_diag_index - lower_diag_index + 1)) {
1394       return errors::InvalidArgument(
1395           "The number of rows of `diagonal` doesn't match the number of "
1396           "diagonals implied from `d_lower` and `d_upper`.\n",
1397           "num_diags = ", num_diags, ", d_lower = ", lower_diag_index,
1398           ", d_upper = ", upper_diag_index, " ", input_rank, " ", other_dim);
1399     }
1400   }
1401 
1402   // Reads num_rows and num_cols.
1403   const Tensor* num_rows_tensor = c->input_tensor(2);
1404   const Tensor* num_cols_tensor = c->input_tensor(3);
1405   int64_t num_rows = -1;
1406   int64_t num_cols = -1;
1407   if (num_rows_tensor != nullptr) {
1408     TF_RETURN_IF_ERROR(c->GetScalarFromTensor(num_rows_tensor, &num_rows));
1409   }
1410   if (num_cols_tensor != nullptr) {
1411     TF_RETURN_IF_ERROR(c->GetScalarFromTensor(num_cols_tensor, &num_cols));
1412   }
1413 
1414   // Infers the missing num_rows or num_cols: If both are missing, assume
1415   // output is square. Otherwise, use the smallest possible value. Also
1416   // validates the provided values.
1417   const int32_t max_diag_len = c->Value(c->Dim(input_shape, input_rank - 1));
1418   const int32_t min_num_rows = max_diag_len - std::min(upper_diag_index, 0);
1419   const int32_t min_num_cols = max_diag_len + std::max(lower_diag_index, 0);
1420   if (num_rows == -1 && num_cols == -1) {  // Special case.
1421     num_rows = std::max(min_num_rows, min_num_cols);
1422     num_cols = num_rows;
1423   }
1424   if (num_rows == -1) {
1425     num_rows = min_num_rows;
1426   } else if (num_rows < min_num_rows) {
1427     return errors::InvalidArgument("num_rows is too small");
1428   }
1429   if (num_cols == -1) {
1430     num_cols = min_num_cols;
1431   } else if (num_cols < min_num_cols) {
1432     return errors::InvalidArgument("num_cols is too small.");
1433   }
1434   // At least one of them must match the minimum length.
1435   if (num_rows != min_num_rows && num_cols != min_num_cols) {
1436     return errors::InvalidArgument(
1437         "num_rows and num_cols are not consistent with lower_diag_index, "
1438         "upper_diag_index, and the length of the given diagonals.\n",
1439         "num_rows = ", num_rows, " != min_num_rows = ", min_num_rows,
1440         ", num_cols = ", num_cols, " != min_num_cols = ", min_num_cols);
1441   }
1442 
1443   // Sets output shape.
1444   ShapeHandle output_shape;
1445   const DimensionHandle output_row_dim = c->MakeDim(num_rows);
1446   const DimensionHandle output_col_dim = c->MakeDim(num_cols);
1447   if (lower_diag_index == upper_diag_index) {
1448     TF_RETURN_IF_ERROR(c->ReplaceDim(input_shape, input_rank - 1,
1449                                      output_row_dim, &output_shape));
1450     TF_RETURN_IF_ERROR(
1451         c->Concatenate(output_shape, c->Vector(output_col_dim), &output_shape));
1452   } else {
1453     TF_RETURN_IF_ERROR(c->ReplaceDim(input_shape, input_rank - 2,
1454                                      output_row_dim, &output_shape));
1455     TF_RETURN_IF_ERROR(c->ReplaceDim(output_shape, input_rank - 1,
1456                                      output_col_dim, &output_shape));
1457   }
1458   c->set_output(0, output_shape);
1459   return Status::OK();
1460 }
1461 
MatrixSetDiagV2Shape(shape_inference::InferenceContext * c)1462 Status MatrixSetDiagV2Shape(shape_inference::InferenceContext* c) {
1463   ShapeHandle input_shape, diag_shape, diag_index_shape;
1464   TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input_shape));
1465   TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &diag_shape));
1466   TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(2), 1, &diag_index_shape));
1467 
1468   int32_t lower_diag_index = 0;
1469   int32_t upper_diag_index = 0;
1470   bool diag_index_known = false;
1471   const Tensor* diag_index_tensor = c->input_tensor(2);
1472   if (diag_index_tensor != nullptr && c->FullyDefined(diag_index_shape)) {
1473     diag_index_known = true;
1474     TF_RETURN_IF_ERROR(ReadDiagIndex(c, diag_index_tensor, &lower_diag_index,
1475                                      &upper_diag_index));
1476     if (lower_diag_index > upper_diag_index) {
1477       return errors::InvalidArgument(
1478           "lower_diag_index is greater than upper_diag_index");
1479     }
1480   }
1481 
1482   // Do more checks when input rank is known.
1483   if (c->RankKnown(input_shape)) {
1484     int32_t input_rank = c->Rank(input_shape);
1485 
1486     // If diag_index is set, we know the exact rank of diagonal.
1487     if (diag_index_known) {
1488       TF_RETURN_IF_ERROR(c->WithRank(
1489           c->input(1),
1490           (lower_diag_index == upper_diag_index) ? input_rank - 1 : input_rank,
1491           &diag_shape));
1492     } else {
1493       TF_RETURN_IF_ERROR(
1494           c->WithRankAtLeast(c->input(1), input_rank - 1, &diag_shape));
1495       TF_RETURN_IF_ERROR(
1496           c->WithRankAtMost(c->input(1), input_rank, &diag_shape));
1497     }
1498 
1499     // Validates lower_diag_index and upper_diag_index.
1500     const int32_t num_rows = c->Value(c->Dim(input_shape, input_rank - 2));
1501     const int32_t num_cols = c->Value(c->Dim(input_shape, input_rank - 1));
1502     if (num_rows != InferenceContext::kUnknownDim &&
1503         num_cols != InferenceContext::kUnknownDim) {
1504       if (lower_diag_index != 0 &&  // For when num_rows or num_cols == 0.
1505           (-num_rows >= lower_diag_index || lower_diag_index >= num_cols)) {
1506         return errors::InvalidArgument("lower_diag_index is out of bound.");
1507       }
1508       if (upper_diag_index != 0 &&  // For when num_rows or num_cols == 0.
1509           (-num_rows >= upper_diag_index || upper_diag_index >= num_cols)) {
1510         return errors::InvalidArgument("upper_diag_index is out of bound.");
1511       }
1512     }
1513   }
1514 
1515   ShapeHandle output_shape = input_shape;
1516   if (c->RankKnown(diag_shape) && !c->FullyDefined(input_shape)) {
1517     // Try to infer parts of shape from diag.
1518     ShapeHandle diag_prefix;
1519     TF_RETURN_IF_ERROR(c->Subshape(
1520         diag_shape, 0, (lower_diag_index == upper_diag_index) ? -1 : -2,
1521         &diag_prefix));
1522 
1523     // The inner matrices can be rectangular, so we can't pinpoint their
1524     // exact height and width by just lower_diag_index, upper_diag_index,
1525     // and the longest length of given diagonals.
1526     TF_RETURN_IF_ERROR(
1527         c->Concatenate(diag_prefix, c->UnknownShapeOfRank(2), &diag_shape));
1528     TF_RETURN_IF_ERROR(c->Merge(input_shape, diag_shape, &output_shape));
1529   }
1530   c->set_output(0, output_shape);
1531   return Status::OK();
1532 }
1533 
MaxPoolShapeImpl(shape_inference::InferenceContext * c,bool supports_explicit_padding)1534 Status MaxPoolShapeImpl(shape_inference::InferenceContext* c,
1535                         bool supports_explicit_padding) {
1536   string data_format_str;
1537   TensorFormat data_format;
1538   Status s = c->GetAttr("data_format", &data_format_str);
1539   if (s.ok()) {
1540     FormatFromString(data_format_str, &data_format);
1541   } else {
1542     data_format = FORMAT_NHWC;
1543   }
1544 
1545   const int rank = (data_format == FORMAT_NCHW_VECT_C) ? 5 : 4;
1546   ShapeHandle input_shape;
1547   TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &input_shape));
1548 
1549   TF_RETURN_IF_ERROR(
1550       CheckFormatConstraintsOnShape(data_format, input_shape, "input", c));
1551 
1552   std::vector<int32> strides;
1553   TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
1554   if (strides.size() != 4) {
1555     return errors::InvalidArgument(
1556         "MaxPool requires the stride attribute to contain 4 values, but got: ",
1557         strides.size());
1558   }
1559 
1560   std::vector<int32> kernel_sizes;
1561   TF_RETURN_IF_ERROR(c->GetAttr("ksize", &kernel_sizes));
1562   if (kernel_sizes.size() != 4) {
1563     return errors::InvalidArgument(
1564         "MaxPool requires the ksize attribute to contain 4 values, but got: ",
1565         kernel_sizes.size());
1566   }
1567 
1568   int32_t stride_depth = GetTensorDim(strides, data_format, 'C');
1569   int32_t stride_rows = GetTensorDim(strides, data_format, 'H');
1570   int32_t stride_cols = GetTensorDim(strides, data_format, 'W');
1571   int32_t kernel_depth = GetTensorDim(kernel_sizes, data_format, 'C');
1572   int32_t kernel_rows = GetTensorDim(kernel_sizes, data_format, 'H');
1573   int32_t kernel_cols = GetTensorDim(kernel_sizes, data_format, 'W');
1574 
1575   constexpr int num_spatial_dims = 2;
1576   DimensionHandle batch_size_dim = c->Dim(
1577       input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'N'));
1578   DimensionHandle in_rows_dim = c->Dim(
1579       input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'H'));
1580   DimensionHandle in_cols_dim = c->Dim(
1581       input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'W'));
1582   DimensionHandle in_depth_dim = c->Dim(
1583       input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'C'));
1584 
1585   Padding padding;
1586   TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
1587 
1588   std::vector<int64> explicit_paddings;
1589   if (supports_explicit_padding) {
1590     Status status = c->GetAttr("explicit_paddings", &explicit_paddings);
1591     // Use the default value, which is an empty list, if the attribute is not
1592     // found. Otherwise return the error to the caller.
1593     if (!status.ok() && !errors::IsNotFound(status)) {
1594       return status;
1595     }
1596     TF_RETURN_IF_ERROR(CheckValidPadding(padding, explicit_paddings,
1597                                          /*num_dims=*/4, data_format));
1598   } else {
1599     DCHECK(padding != Padding::EXPLICIT);
1600   }
1601 
1602   ShapeHandle output_shape;
1603   DimensionHandle output_rows, output_cols, output_depth;
1604   int64_t pad_rows_before = -1, pad_rows_after = -1;
1605   int64_t pad_cols_before = -1, pad_cols_after = -1;
1606   if (padding == Padding::EXPLICIT) {
1607     GetExplicitPaddingForDim(explicit_paddings, data_format, 'H',
1608                              &pad_rows_before, &pad_rows_after);
1609     GetExplicitPaddingForDim(explicit_paddings, data_format, 'W',
1610                              &pad_cols_before, &pad_cols_after);
1611   }
1612   TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDimsV2(
1613       c, in_rows_dim, kernel_rows, /*dilation_rate=*/1, stride_rows, padding,
1614       pad_rows_before, pad_rows_after, &output_rows));
1615   TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDimsV2(
1616       c, in_cols_dim, kernel_cols, /*dilation_rate=*/1, stride_cols, padding,
1617       pad_cols_before, pad_cols_after, &output_cols));
1618   TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDimsV2(
1619       c, in_depth_dim, kernel_depth, /*dilation_rate=*/1, stride_depth, padding,
1620       /*pad_before*/ 0, /*pad_after*/ 0, &output_depth));
1621 
1622   TF_RETURN_IF_ERROR(MakeShapeFromFormat(data_format, batch_size_dim,
1623                                          {output_rows, output_cols},
1624                                          output_depth, &output_shape, c));
1625 
1626   c->set_output(0, output_shape);
1627   return Status::OK();
1628 }
1629 
MaxPoolShape(shape_inference::InferenceContext * c)1630 Status MaxPoolShape(shape_inference::InferenceContext* c) {
1631   return MaxPoolShapeImpl(c, /*supports_explicit_padding=*/false);
1632 }
1633 
MaxPoolGradShape(shape_inference::InferenceContext * c)1634 Status MaxPoolGradShape(shape_inference::InferenceContext* c) {
1635   return UnchangedShapeWithRank(c, 4);
1636 }
1637 
MaxPoolShapeWithExplicitPadding(shape_inference::InferenceContext * c)1638 Status MaxPoolShapeWithExplicitPadding(shape_inference::InferenceContext* c) {
1639   return MaxPoolShapeImpl(c, /*supports_explicit_padding=*/true);
1640 }
1641 
MaxPoolV2Shape(shape_inference::InferenceContext * c,int num_inputs)1642 Status MaxPoolV2Shape(shape_inference::InferenceContext* c, int num_inputs) {
1643   string data_format_str;
1644   TensorFormat data_format;
1645   Status s = c->GetAttr("data_format", &data_format_str);
1646   if (s.ok()) {
1647     FormatFromString(data_format_str, &data_format);
1648   } else {
1649     data_format = FORMAT_NHWC;
1650   }
1651 
1652   const int rank = (data_format == FORMAT_NCHW_VECT_C) ? 5 : 4;
1653   ShapeHandle input_shape;
1654   TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &input_shape));
1655 
1656   TF_RETURN_IF_ERROR(
1657       CheckFormatConstraintsOnShape(data_format, input_shape, "input", c));
1658 
1659   std::vector<int32> kernel_sizes;
1660   std::vector<int32> strides;
1661 
1662   if (c->num_inputs() + 2 == num_inputs) {
1663     TF_RETURN_IF_ERROR(c->GetAttr("ksize", &kernel_sizes));
1664 
1665     TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
1666   } else {
1667     // Verify shape of ksize and strides input.
1668     ShapeHandle size;
1669     DimensionHandle unused;
1670     TF_RETURN_IF_ERROR(c->WithRank(c->input(c->num_inputs() - 2), 1, &size));
1671     TF_RETURN_IF_ERROR(c->WithValue(c->Dim(size, 0), 4, &unused));
1672     TF_RETURN_IF_ERROR(c->WithRank(c->input(c->num_inputs() - 1), 1, &size));
1673     TF_RETURN_IF_ERROR(c->WithValue(c->Dim(size, 0), 4, &unused));
1674 
1675     const Tensor* kernel_sizes_tensor = c->input_tensor(c->num_inputs() - 2);
1676     if (kernel_sizes_tensor == nullptr) {
1677       c->set_output(0, c->UnknownShape());
1678       return Status::OK();
1679     }
1680     kernel_sizes.resize(kernel_sizes_tensor->shape().num_elements());
1681     auto kernel_sizes_vec = kernel_sizes_tensor->flat<int32>();
1682     std::copy_n(&kernel_sizes_vec(0), kernel_sizes.size(),
1683                 kernel_sizes.begin());
1684 
1685     const Tensor* strides_tensor = c->input_tensor(c->num_inputs() - 1);
1686     if (strides_tensor == nullptr) {
1687       c->set_output(0, c->UnknownShape());
1688       return Status::OK();
1689     }
1690     strides.resize(strides_tensor->shape().num_elements());
1691     auto strides_vec = strides_tensor->flat<int32>();
1692     std::copy_n(&strides_vec(0), strides.size(), strides.begin());
1693   }
1694 
1695   if (strides.size() != 4) {
1696     return errors::InvalidArgument(
1697         "MaxPool requires the stride attribute to contain 4 values, but "
1698         "got: ",
1699         strides.size());
1700   }
1701   if (kernel_sizes.size() != 4) {
1702     return errors::InvalidArgument(
1703         "MaxPool requires the ksize attribute to contain 4 values, but got: ",
1704         kernel_sizes.size());
1705   }
1706 
1707   int32_t stride_depth = GetTensorDim(strides, data_format, 'C');
1708   int32_t stride_rows = GetTensorDim(strides, data_format, 'H');
1709   int32_t stride_cols = GetTensorDim(strides, data_format, 'W');
1710   int32_t kernel_depth = GetTensorDim(kernel_sizes, data_format, 'C');
1711   int32_t kernel_rows = GetTensorDim(kernel_sizes, data_format, 'H');
1712   int32_t kernel_cols = GetTensorDim(kernel_sizes, data_format, 'W');
1713 
1714   constexpr int num_spatial_dims = 2;
1715   DimensionHandle batch_size_dim = c->Dim(
1716       input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'N'));
1717   DimensionHandle in_rows_dim = c->Dim(
1718       input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'H'));
1719   DimensionHandle in_cols_dim = c->Dim(
1720       input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'W'));
1721   DimensionHandle in_depth_dim = c->Dim(
1722       input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'C'));
1723 
1724   Padding padding;
1725   TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
1726 
1727   ShapeHandle output_shape;
1728   DimensionHandle output_rows, output_cols, output_depth;
1729   TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
1730       c, in_rows_dim, kernel_rows, stride_rows, padding, &output_rows));
1731   TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
1732       c, in_cols_dim, kernel_cols, stride_cols, padding, &output_cols));
1733   TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
1734       c, in_depth_dim, kernel_depth, stride_depth, padding, &output_depth));
1735 
1736   TF_RETURN_IF_ERROR(MakeShapeFromFormat(data_format, batch_size_dim,
1737                                          {output_rows, output_cols},
1738                                          output_depth, &output_shape, c));
1739 
1740   c->set_output(0, output_shape);
1741   return Status::OK();
1742 }
1743 
Pool3DShape(shape_inference::InferenceContext * c)1744 Status Pool3DShape(shape_inference::InferenceContext* c) {
1745   ShapeHandle input_shape;
1746   TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 5, &input_shape));
1747 
1748   string data_format;
1749   Status s = c->GetAttr("data_format", &data_format);
1750 
1751   std::vector<int32> strides;
1752   TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
1753   if (strides.size() != 5) {
1754     return errors::InvalidArgument(
1755         "Pool3D ops require the stride attribute to contain 5 values, but "
1756         "got: ",
1757         strides.size());
1758   }
1759 
1760   std::vector<int32> kernel_sizes;
1761   TF_RETURN_IF_ERROR(c->GetAttr("ksize", &kernel_sizes));
1762   if (kernel_sizes.size() != 5) {
1763     return errors::InvalidArgument(
1764         "Pool3D requires the ksize attribute to contain 5 values, but got: ",
1765         kernel_sizes.size());
1766   }
1767 
1768   int32_t stride_planes, stride_rows, stride_cols;
1769   int32_t kernel_planes, kernel_rows, kernel_cols;
1770 
1771   if (s.ok() && data_format == "NCDHW") {
1772     // Convert input_shape to NDHWC.
1773     auto dim = [&](char dimension) {
1774       return c->Dim(input_shape, GetTensorDimIndex<3>(FORMAT_NCHW, dimension));
1775     };
1776     input_shape =
1777         c->MakeShape({{dim('N'), dim('0'), dim('1'), dim('2'), dim('C')}});
1778     stride_planes = strides[2];
1779     stride_rows = strides[3];
1780     stride_cols = strides[4];
1781     kernel_planes = kernel_sizes[2];
1782     kernel_rows = kernel_sizes[3];
1783     kernel_cols = kernel_sizes[4];
1784   } else {
1785     stride_planes = strides[1];
1786     stride_rows = strides[2];
1787     stride_cols = strides[3];
1788     kernel_planes = kernel_sizes[1];
1789     kernel_rows = kernel_sizes[2];
1790     kernel_cols = kernel_sizes[3];
1791   }
1792 
1793   DimensionHandle batch_size_dim = c->Dim(input_shape, 0);
1794   DimensionHandle in_planes_dim = c->Dim(input_shape, 1);
1795   DimensionHandle in_rows_dim = c->Dim(input_shape, 2);
1796   DimensionHandle in_cols_dim = c->Dim(input_shape, 3);
1797   DimensionHandle output_depth_dim = c->Dim(input_shape, 4);
1798 
1799   Padding padding;
1800   TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
1801 
1802   // TODO(mrry,shlens): Raise an error if the stride would cause
1803   // information in the input to be ignored. This will require a change
1804   // in the kernel implementation.
1805   DimensionHandle output_planes, output_rows, output_cols;
1806   TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
1807       c, in_planes_dim, kernel_planes, stride_planes, padding, &output_planes));
1808   TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
1809       c, in_rows_dim, kernel_rows, stride_rows, padding, &output_rows));
1810   TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
1811       c, in_cols_dim, kernel_cols, stride_cols, padding, &output_cols));
1812 
1813   ShapeHandle output_shape;
1814   if (data_format == "NCDHW") {
1815     output_shape = c->MakeShape({batch_size_dim, output_depth_dim,
1816                                  output_planes, output_rows, output_cols});
1817   } else {
1818     output_shape = c->MakeShape({batch_size_dim, output_planes, output_rows,
1819                                  output_cols, output_depth_dim});
1820   }
1821 
1822   c->set_output(0, output_shape);
1823   return Status::OK();
1824 }
1825 
MaxPool3DGradShape(shape_inference::InferenceContext * c)1826 Status MaxPool3DGradShape(shape_inference::InferenceContext* c) {
1827   return UnchangedShapeWithRank(c, 5);
1828 }
1829 
AvgPool3DGradShape(shape_inference::InferenceContext * c)1830 Status AvgPool3DGradShape(shape_inference::InferenceContext* c) {
1831   ShapeHandle s;
1832   TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s));
1833   TF_RETURN_IF_ERROR(c->WithRank(s, 5, &s));
1834   c->set_output(0, s);
1835   return Status::OK();
1836 }
1837 
UnknownShape(shape_inference::InferenceContext * c)1838 Status UnknownShape(shape_inference::InferenceContext* c) {
1839   for (int i = 0; i < c->num_outputs(); ++i) {
1840     c->set_output(i, c->UnknownShape());
1841   }
1842   return Status::OK();
1843 }
1844 
1845 template <typename T>
ReductionShapeHelper(const Tensor * reduction_indices_t,const int32_t input_rank,std::set<int64> * true_indices)1846 Status ReductionShapeHelper(const Tensor* reduction_indices_t,
1847                             const int32_t input_rank,
1848                             std::set<int64>* true_indices) {
1849   auto reduction_indices = reduction_indices_t->flat<T>();
1850   for (int i = 0; i < reduction_indices_t->NumElements(); ++i) {
1851     const T reduction_index = reduction_indices(i);
1852     if (reduction_index < -input_rank || reduction_index >= input_rank) {
1853       return errors::InvalidArgument("Invalid reduction dimension ",
1854                                      reduction_index, " for input with ",
1855                                      input_rank, " dimensions.");
1856     }
1857 
1858     auto wrapped_index = reduction_index;
1859     if (wrapped_index < 0) {
1860       wrapped_index += input_rank;
1861     }
1862 
1863     true_indices->insert(wrapped_index);
1864   }
1865   return Status::OK();
1866 }
1867 
ReductionShape(InferenceContext * c)1868 Status ReductionShape(InferenceContext* c) {
1869   ShapeHandle input = c->input(0);
1870 
1871   ShapeHandle indices;
1872   // Older versions of TensorFlow accidentally allowed higher rank tensors like
1873   // [[1,2]] or [[1],[2]] to represent axis=[1,2].
1874   if (c->graph_def_version() < 21) {
1875     indices = c->input(1);
1876   } else {
1877     TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(1), 1, &indices));
1878   }
1879 
1880   bool keep_dims;
1881   TF_RETURN_IF_ERROR(c->GetAttr("keep_dims", &keep_dims));
1882 
1883   const Tensor* reduction_indices_t = c->input_tensor(1);
1884   if (reduction_indices_t == nullptr || !c->RankKnown(input)) {
1885     // If we do not have the reduction values at runtime, or the
1886     // rank of the input, we don't know the output shape.
1887 
1888     if (keep_dims && c->RankKnown(input)) {
1889       // output rank matches input input if <keep_dims>.
1890       c->set_output(0, c->UnknownShapeOfRank(c->Rank(input)));
1891       return Status::OK();
1892     } else {
1893       return shape_inference::UnknownShape(c);
1894     }
1895   }
1896 
1897   const int32_t input_rank = c->Rank(input);
1898   std::set<int64> true_indices;
1899   if (reduction_indices_t->dtype() == DataType::DT_INT32) {
1900     TF_RETURN_IF_ERROR(ReductionShapeHelper<int32>(reduction_indices_t,
1901                                                    input_rank, &true_indices));
1902   } else if (reduction_indices_t->dtype() == DataType::DT_INT64) {
1903     TF_RETURN_IF_ERROR(ReductionShapeHelper<int64>(reduction_indices_t,
1904                                                    input_rank, &true_indices));
1905   } else {
1906     return errors::InvalidArgument(
1907         "reduction_indices can only be int32 or int64");
1908   }
1909 
1910   std::vector<DimensionHandle> dims;
1911   for (int i = 0; i < input_rank; ++i) {
1912     if (true_indices.count(i) > 0) {
1913       if (keep_dims) {
1914         dims.emplace_back(c->MakeDim(1));
1915       }
1916     } else {
1917       dims.emplace_back(c->Dim(input, i));
1918     }
1919   }
1920 
1921   c->set_output(0, c->MakeShape(dims));
1922   return Status::OK();
1923 }
1924 
ConcatShapeHelper(InferenceContext * c,int start_value_index,int end_value_index,int dim_index)1925 Status ConcatShapeHelper(InferenceContext* c, int start_value_index,
1926                          int end_value_index, int dim_index) {
1927   ShapeHandle unused;
1928   TF_RETURN_IF_ERROR(c->WithRank(c->input(dim_index), 0, &unused));
1929   const Tensor* concat_dim_t = c->input_tensor(dim_index);
1930   if (concat_dim_t == nullptr) {
1931     // Return an unknown shape with same rank as inputs, or an unknown rank
1932     // if no input's rank is known.
1933 
1934     // Find rank.
1935     int32_t rank = InferenceContext::kUnknownRank;
1936     for (int i = start_value_index; i < end_value_index; ++i) {
1937       if (rank == InferenceContext::kUnknownRank) rank = c->Rank(c->input(i));
1938       if (rank != InferenceContext::kUnknownRank) {
1939         break;
1940       }
1941     }
1942     if (rank == InferenceContext::kUnknownRank) {
1943       c->set_output(0, c->UnknownShape());
1944       return Status::OK();
1945     } else if (rank == 0) {
1946       return errors::InvalidArgument(
1947           "Can't concatenate scalars (use tf.stack instead)");
1948     } else {
1949       for (int i = start_value_index; i < end_value_index; ++i) {
1950         // Check that all the inputs are of the correct rank.
1951         TF_RETURN_IF_ERROR(c->WithRank(c->input(i), rank, &unused));
1952       }
1953     }
1954     // Build result of <rank> different unknown dims.
1955     std::vector<DimensionHandle> dims;
1956     dims.reserve(rank);
1957     for (int i = 0; i < rank; ++i) dims.push_back(c->UnknownDim());
1958     c->set_output(0, c->MakeShape(dims));
1959     return Status::OK();
1960   }
1961 
1962   // Merge all the non-concat dims, and sum the concat dim to make an output
1963   // shape.
1964   int64_t concat_dim;
1965   if (concat_dim_t->dtype() == DT_INT32) {
1966     concat_dim = static_cast<int64>(concat_dim_t->flat<int32>()(0));
1967   } else {
1968     concat_dim = concat_dim_t->flat<int64>()(0);
1969   }
1970 
1971   // Minimum required number of dimensions.
1972   const int min_rank = concat_dim < 0 ? -concat_dim : concat_dim + 1;
1973 
1974   ShapeHandle output_before;
1975   ShapeHandle output_after;
1976 
1977   ShapeHandle input = c->input(end_value_index - 1);
1978   TF_RETURN_IF_ERROR(c->WithRankAtLeast(input, min_rank, &input));
1979   TF_RETURN_IF_ERROR(c->Subshape(input, 0, concat_dim, &output_before));
1980   DimensionHandle output_middle = c->Dim(input, concat_dim);
1981   if (concat_dim == -1) {
1982     output_after = c->Scalar();  // no dimensions.
1983   } else {
1984     TF_RETURN_IF_ERROR(c->Subshape(input, concat_dim + 1, &output_after));
1985   }
1986 
1987   for (int i = end_value_index - 2; i >= start_value_index; --i) {
1988     ShapeHandle before;
1989     ShapeHandle after;
1990     input = c->input(i);
1991     TF_RETURN_IF_ERROR(c->WithRankAtLeast(input, min_rank, &input));
1992     TF_RETURN_IF_ERROR(c->Subshape(input, 0, concat_dim, &before));
1993     DimensionHandle middle = c->Dim(input, concat_dim);
1994     if (concat_dim == -1) {
1995       after = c->Scalar();
1996     } else {
1997       TF_RETURN_IF_ERROR(c->Subshape(input, concat_dim + 1, &after));
1998     }
1999 
2000     TF_RETURN_IF_ERROR(c->Merge(before, output_before, &output_before));
2001     TF_RETURN_IF_ERROR(c->Add(output_middle, middle, &output_middle));
2002     TF_RETURN_IF_ERROR(c->Merge(after, output_after, &output_after));
2003   }
2004 
2005   ShapeHandle s;
2006   TF_RETURN_IF_ERROR(
2007       c->Concatenate(output_before, c->Vector(output_middle), &s));
2008   TF_RETURN_IF_ERROR(c->Concatenate(s, output_after, &s));
2009   c->set_output(0, s);
2010   return Status::OK();
2011 }
2012 
ConcatShape(InferenceContext * c,int num_inputs_to_concat)2013 Status ConcatShape(InferenceContext* c, int num_inputs_to_concat) {
2014   return ConcatShapeHelper(c, 1 /* start_value_index */,
2015                            1 + num_inputs_to_concat /* end_value_index */,
2016                            0 /* dim_index */);
2017 }
2018 
ConcatV2Shape(InferenceContext * c)2019 Status ConcatV2Shape(InferenceContext* c) {
2020   return ConcatShapeHelper(c, 0 /* start_value_index */,
2021                            c->num_inputs() - 1 /* end_value_index */,
2022                            c->num_inputs() - 1 /* dim_index */);
2023 }
2024 
QuantizedConcatV2Shape(InferenceContext * c,int num_inputs_to_concat)2025 Status QuantizedConcatV2Shape(InferenceContext* c, int num_inputs_to_concat) {
2026   return ConcatShapeHelper(c, 0 /* start_value_index */,
2027                            num_inputs_to_concat /* end_value_index */,
2028                            num_inputs_to_concat /* dim_index */);
2029 }
2030 
BroadcastBinaryOpOutputShapeFnHelper(InferenceContext * c,ShapeHandle shape_x,ShapeHandle shape_y,bool incompatible_shape_error,ShapeHandle * out)2031 Status BroadcastBinaryOpOutputShapeFnHelper(InferenceContext* c,
2032                                             ShapeHandle shape_x,
2033                                             ShapeHandle shape_y,
2034                                             bool incompatible_shape_error,
2035                                             ShapeHandle* out) {
2036   CHECK_NOTNULL(out);
2037   if (!c->RankKnown(shape_x) || !c->RankKnown(shape_y)) {
2038     *out = c->UnknownShape();
2039     return Status::OK();
2040   }
2041   const int32_t rank_x = c->Rank(shape_x);
2042   const int32_t rank_y = c->Rank(shape_y);
2043   const int32_t rank_out = std::max(rank_x, rank_y);
2044 
2045   // To compute the broadcast dimensions, we zip together shape_x and shape_y
2046   // and
2047   // pad with 1 to make them the same length.
2048   std::vector<DimensionHandle> dims;
2049   DimensionHandle dim_one;
2050   if (rank_x != rank_y) dim_one = c->MakeDim(1);
2051   for (int i = 0; i < rank_out; ++i) {
2052     const auto dim_x = i < (rank_out - rank_x)
2053                            ? dim_one
2054                            : c->Dim(shape_x, i - (rank_out - rank_x));
2055     const bool dim_y_is_one = (i < (rank_out - rank_y));
2056     const auto dim_y =
2057         dim_y_is_one ? dim_one : c->Dim(shape_y, i - (rank_out - rank_y));
2058     if (!c->ValueKnown(dim_x) || !c->ValueKnown(dim_y)) {
2059       // One or both dimensions is unknown.
2060       //
2061       // - If either dimension is greater than 1, we assume that the program is
2062       // correct, and the other dimension will be broadcast to match it.
2063       // TODO(cwhipkey): For shape inference, if we eliminate the shape checks
2064       // in C++ op code, we must still assert that the unknown dim is either 1
2065       // or the same as the known dim.
2066       // - If either dimension is 1, the other dimension is the output.
2067       // - If both are unknown then dimension is unknown
2068       if (c->Value(dim_x) > 1) {
2069         if (!incompatible_shape_error) {
2070           *out = c->UnknownShape();
2071           return Status::OK();
2072         }
2073         dims.push_back(dim_x);
2074       } else if (c->Value(dim_y) > 1) {
2075         if (!incompatible_shape_error) {
2076           *out = c->UnknownShape();
2077           return Status::OK();
2078         }
2079         dims.push_back(dim_y);
2080       } else if (c->Value(dim_x) == 1) {
2081         dims.push_back(dim_y);
2082       } else if (c->Value(dim_y) == 1) {
2083         dims.push_back(dim_x);
2084       } else if (dim_y.SameHandle(dim_x)) {
2085         dims.push_back(dim_x);
2086       } else if (!c->ValueKnown(dim_x) && !c->ValueKnown(dim_y)) {
2087         dims.push_back(c->UnknownDim());
2088       } else {
2089         if (!incompatible_shape_error) {
2090           *out = c->UnknownShape();
2091           return Status::OK();
2092         }
2093         dims.push_back(c->UnknownDim());
2094       }
2095     } else if (c->Value(dim_x) == 1 || c->Value(dim_y) == 1) {
2096       if (c->Value(dim_x) == 1 && !dim_y_is_one) {
2097         // We will broadcast dim_x to dim_y.
2098         dims.push_back(dim_y);
2099       } else {
2100         DCHECK_EQ(c->Value(dim_y), 1);
2101         // We will broadcast dim_y to dim_x.
2102         dims.push_back(dim_x);
2103       }
2104     } else {
2105       DimensionHandle dim;
2106       Status s = c->Merge(dim_x, dim_y, &dim);
2107       if (!s.ok()) {
2108         if (!incompatible_shape_error) {
2109           *out = c->MakeShape({});
2110           return Status::OK();
2111         }
2112         return s;
2113       }
2114       dims.push_back(dim);
2115     }
2116   }
2117 
2118   *out = c->MakeShape(dims);
2119   return Status::OK();
2120 }
2121 
RandomShape(shape_inference::InferenceContext * c)2122 Status RandomShape(shape_inference::InferenceContext* c) {
2123   shape_inference::ShapeHandle out;
2124   TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out));
2125   c->set_output(0, out);
2126   return Status::OK();
2127 }
2128 
UnsortedSegmentReductionShapeFn(InferenceContext * c)2129 Status UnsortedSegmentReductionShapeFn(InferenceContext* c) {
2130   ShapeHandle s_data = c->input(0);
2131   ShapeHandle s_segment_ids = c->input(1);
2132   ShapeHandle s_num_segments = c->input(2);
2133   TF_RETURN_IF_ERROR(c->WithRank(s_num_segments, 0, &s_num_segments));
2134 
2135   ShapeHandle out;
2136 
2137   // Leading dimensions of data must be compatible with dimensions of
2138   // <s_segment_ids>.
2139   if (c->RankKnown(s_segment_ids)) {
2140     TF_RETURN_IF_ERROR(
2141         c->MergePrefix(s_data, s_segment_ids, &s_data, &s_segment_ids));
2142 
2143     // Get the value of the num_segments input tensor.
2144     DimensionHandle num_segments_dim;
2145     TF_RETURN_IF_ERROR(c->MakeDimForScalarInput(2, &num_segments_dim));
2146 
2147     // Output is {segment_id_rank} + s_data[segment_id_rank:].
2148     ShapeHandle s_data_suffix;
2149     TF_RETURN_IF_ERROR(
2150         c->Subshape(s_data, c->Rank(s_segment_ids), &s_data_suffix));
2151     TF_RETURN_IF_ERROR(
2152         c->Concatenate(c->Vector(num_segments_dim), s_data_suffix, &out));
2153   } else {
2154     out = c->UnknownShape();
2155   }
2156   c->set_output(0, out);
2157   return Status::OK();
2158 }
2159 
2160 namespace {
2161 
2162 // This SliceHelper processes the output shape of the `slice`
2163 // when the tensor of `sizes` is available.
2164 template <typename T>
SliceHelper(InferenceContext * c,ShapeHandle begin_value,const Tensor * sizes_value,std::vector<DimensionHandle> * dims)2165 Status SliceHelper(InferenceContext* c, ShapeHandle begin_value,
2166                    const Tensor* sizes_value,
2167                    std::vector<DimensionHandle>* dims) {
2168   auto sizes_vec = sizes_value->vec<T>();
2169   for (int i = 0; i < sizes_value->NumElements(); ++i) {
2170     DimensionHandle dim = c->Dim(c->input(0), i);
2171     if (sizes_vec(i) != -1) {
2172       auto dim_val = c->Value(dim);
2173       if (sizes_vec(i) < 0) {
2174         return errors::InvalidArgument(
2175             "Out of bounds slicing on dimension ", i, " of length ", dim_val,
2176             ": sizes vector cannot be < -1, but was ", sizes_vec(i));
2177       }
2178 
2179       dims->emplace_back(c->MakeDim(sizes_vec(i)));
2180     } else {
2181       DimensionHandle result;
2182       TF_RETURN_IF_ERROR(c->Subtract(dim, c->Dim(begin_value, i), &result));
2183       dims->emplace_back(result);
2184     }
2185   }
2186 
2187   return Status::OK();
2188 }
2189 }  // namespace
2190 
SliceShape(InferenceContext * c)2191 Status SliceShape(InferenceContext* c) {
2192   ShapeHandle input = c->input(0);
2193   ShapeHandle begin_shape;
2194   TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &begin_shape));
2195   ShapeHandle sizes_shape;
2196   TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &sizes_shape));
2197 
2198   // Merge to check compatibility of begin and sizes tensors.
2199   TF_RETURN_IF_ERROR(c->Merge(begin_shape, sizes_shape, &begin_shape));
2200 
2201   DimensionHandle ndims = c->Dim(begin_shape, 0);
2202   if (c->ValueKnown(ndims)) {
2203     TF_RETURN_IF_ERROR(c->WithRank(input, c->Value(ndims), &input));
2204   }
2205 
2206   // NOTE(mrry): Use MakeShapeFromShapeTensor to handle partially-known
2207   // values, even though the `begin` value does not represent a shape.
2208   ShapeHandle begin_value;
2209   TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &begin_value));
2210 
2211   // We check the tensor value here and will only use
2212   // `MakeShapeFromShapeTensor` when `sizes_value` is null.
2213   // The reason is that `sizes` might contain -1, which can't
2214   // be represented (-1 in the ShapeHandle would mean "unknown").
2215   const Tensor* sizes_value = c->input_tensor(2);
2216 
2217   if (sizes_value != nullptr) {
2218     TF_RETURN_IF_ERROR(
2219         c->WithRank(begin_value, sizes_value->NumElements(), &begin_value));
2220     std::vector<DimensionHandle> dims;
2221     // If the begin and sizes tensors are available, then
2222     // we can be precise about the shape of the output.
2223     if (sizes_value->dtype() == DT_INT64) {
2224       TF_RETURN_IF_ERROR(
2225           SliceHelper<int64>(c, begin_value, sizes_value, &dims));
2226     } else {
2227       TF_RETURN_IF_ERROR(
2228           SliceHelper<int32>(c, begin_value, sizes_value, &dims));
2229     }
2230     c->set_output(0, c->MakeShape(dims));
2231     return Status::OK();
2232   } else {
2233     // In case `sizes` is not available (`sizes_value` is null),
2234     // we could try to use `MakeShapeFromShapeTensor` here.
2235     // If sizes contain -1, we will simply consider it as `Unknown`.
2236     // This is less than ideal but still an improvement of shape inference.
2237     // The following is an example that returns [None, 1, None] with this
2238     // code path:
2239     //   z = tf.zeros((1, 2, 3))
2240     //   m = tf.slice(z, [0, 0, 0], [tf.constant(1) + 0, 1, -1])
2241     //   m.get_shape().as_list()
2242     ShapeHandle sizes_value;
2243     TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(2, &sizes_value));
2244     if (c->RankKnown(sizes_value)) {
2245       TF_RETURN_IF_ERROR(
2246           c->WithRank(begin_value, c->Rank(sizes_value), &begin_value));
2247       std::vector<DimensionHandle> dims;
2248       dims.reserve(c->Rank(sizes_value));
2249       for (int i = 0; i < c->Rank(sizes_value); ++i) {
2250         dims.emplace_back(c->Dim(sizes_value, i));
2251       }
2252       c->set_output(0, c->MakeShape(dims));
2253       return Status::OK();
2254     }
2255     // We might know the rank of the input.
2256     if (c->RankKnown(input)) {
2257       c->set_output(0, c->UnknownShapeOfRank(c->Rank(input)));
2258       return Status::OK();
2259     } else {
2260       return shape_inference::UnknownShape(c);
2261     }
2262   }
2263 
2264   return Status::OK();
2265 }
2266 
ValidateSparseTensor(InferenceContext * c,ShapeHandle indices_shape,ShapeHandle values_shape,ShapeHandle shape_shape)2267 Status ValidateSparseTensor(InferenceContext* c, ShapeHandle indices_shape,
2268                             ShapeHandle values_shape, ShapeHandle shape_shape) {
2269   // Validate ranks.
2270   ShapeHandle unused_shape;
2271   TF_RETURN_IF_ERROR(c->WithRank(indices_shape, 2, &unused_shape));
2272   TF_RETURN_IF_ERROR(c->WithRank(values_shape, 1, &unused_shape));
2273   TF_RETURN_IF_ERROR(c->WithRank(shape_shape, 1, &unused_shape));
2274 
2275   // Number of elements in indices and values must match.
2276   DimensionHandle num_index_elements_dim = c->Dim(indices_shape, 0);
2277   if (c->ValueKnown(num_index_elements_dim)) {
2278     DimensionHandle num_values_elements_dim = c->Dim(values_shape, 0);
2279     if (c->ValueKnown(num_values_elements_dim)) {
2280       int64_t num_index_elements = c->Value(num_index_elements_dim);
2281       int64_t num_values_elements = c->Value(num_values_elements_dim);
2282       if (num_index_elements != num_values_elements) {
2283         return errors::InvalidArgument("Number of elements in index (",
2284                                        num_index_elements, ") and values (",
2285                                        num_values_elements, ") do not match.");
2286       }
2287     }
2288   }
2289 
2290   // Rank embedded in indices must match shape.
2291   DimensionHandle index_rank_dim = c->Dim(indices_shape, 1);
2292   if (c->ValueKnown(index_rank_dim)) {
2293     DimensionHandle shape_rank_dim = c->Dim(shape_shape, 0);
2294     if (c->ValueKnown(shape_rank_dim)) {
2295       int64_t index_rank = c->Value(index_rank_dim);
2296       int32_t shape_rank = c->Value(shape_rank_dim);
2297       if (index_rank != shape_rank) {
2298         return errors::InvalidArgument("Index rank (", index_rank,
2299                                        ") and shape rank (", shape_rank,
2300                                        ") do not match.");
2301       }
2302     }
2303   }
2304 
2305   return Status::OK();
2306 }
2307 
ValidateVariableResourceHandle(InferenceContext * c,std::vector<ShapeAndType> * shape_and_type)2308 Status ValidateVariableResourceHandle(
2309     InferenceContext* c, std::vector<ShapeAndType>* shape_and_type) {
2310   auto* handle_data = c->input_handle_shapes_and_types(0);
2311   if (handle_data == nullptr || handle_data->empty()) {
2312     shape_and_type->emplace_back(c->UnknownShape(), DT_INVALID);
2313   } else {
2314     *shape_and_type = *handle_data;
2315     DataType value_dtype;
2316     TF_RETURN_IF_ERROR(c->GetAttr("dtype", &value_dtype));
2317     if (shape_and_type->at(0).dtype != value_dtype) {
2318       return errors::InvalidArgument(
2319           "Trying to read variable with wrong dtype. "
2320           "Expected ",
2321           DataTypeString(shape_and_type->at(0).dtype), " got ",
2322           DataTypeString(value_dtype));
2323     }
2324   }
2325   return Status::OK();
2326 }
2327 
GatherNdShape(InferenceContext * c)2328 Status GatherNdShape(InferenceContext* c) {
2329   ShapeHandle params;
2330   std::vector<ShapeAndType> handle_shape_and_type;
2331   if (c->input_handle_shapes_and_types(0) != nullptr) {
2332     TF_RETURN_IF_ERROR(
2333         ValidateVariableResourceHandle(c, &handle_shape_and_type));
2334     params = handle_shape_and_type[0].shape;
2335   } else {
2336     params = c->input(0);
2337   }
2338   ShapeHandle indices;
2339   TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &indices));
2340   DimensionHandle r_dim = c->Dim(indices, -1);
2341 
2342   if (!c->RankKnown(params) || !c->ValueKnown(r_dim)) {
2343     c->set_output(0, c->UnknownShape());
2344     return Status::OK();
2345   }
2346 
2347   if (c->Value(r_dim) > c->Rank(params)) {
2348     return errors::InvalidArgument(
2349         "indices.shape[-1] must be <= params.rank, but saw indices shape: ",
2350         c->DebugString(indices), " and params shape: ", c->DebugString(params));
2351   }
2352 
2353   // Remove r_dim from indices to get output.
2354   ShapeHandle indices_slice;
2355   ShapeHandle params_slice;
2356   TF_RETURN_IF_ERROR(c->Subshape(indices, 0, -1, &indices_slice));
2357   TF_RETURN_IF_ERROR(c->Subshape(params, c->Value(r_dim), &params_slice));
2358   ShapeHandle out;
2359   TF_RETURN_IF_ERROR(c->Concatenate(indices_slice, params_slice, &out));
2360   c->set_output(0, out);
2361   return Status::OK();
2362 }
2363 
ScatterNdShapeHelper(InferenceContext * c,ShapeHandle indices_shape,ShapeHandle updates_shape,ShapeHandle input_shape)2364 Status ScatterNdShapeHelper(InferenceContext* c, ShapeHandle indices_shape,
2365                             ShapeHandle updates_shape,
2366                             ShapeHandle input_shape) {
2367   if (c->Value(c->NumElements(input_shape)) == 0 &&
2368       (c->Value(c->NumElements(indices_shape)) > 0 ||
2369        c->Value(c->NumElements(updates_shape)) > 0)) {
2370     return errors::InvalidArgument(
2371         "Indices and updates specified for empty input");
2372   }
2373 
2374   if (c->RankKnown(indices_shape) && c->RankKnown(updates_shape)) {
2375     const int64_t outer_dims = c->Rank(indices_shape) - 1;
2376     const DimensionHandle ixdim = c->Dim(indices_shape, -1);
2377 
2378     // We can only do more validation if the last dimension of indices
2379     // is a known value.
2380     if (c->ValueKnown(ixdim)) {
2381       int64_t ix = c->Value(ixdim);
2382       ShapeHandle unused;
2383       ShapeHandle prefix_indices;
2384       TF_RETURN_IF_ERROR(
2385           c->Subshape(indices_shape, 0, outer_dims, &prefix_indices));
2386       ShapeHandle prefix_updates;
2387       TF_RETURN_IF_ERROR(
2388           c->Subshape(updates_shape, 0, outer_dims, &prefix_updates));
2389 
2390       Status s = c->Merge(prefix_indices, prefix_updates, &unused);
2391       if (!s.ok()) {
2392         return errors::InvalidArgument(
2393             "Dimensions [0,", outer_dims,
2394             ") of indices[shape=", c->DebugString(indices_shape),
2395             "] = ", c->DebugString(prefix_indices),
2396             " must match dimensions [0,", outer_dims,
2397             ") of updates[shape=", c->DebugString(updates_shape),
2398             "] = ", c->DebugString(prefix_updates), ": ", s.error_message());
2399       }
2400 
2401       ShapeHandle suffix_output;
2402       TF_RETURN_IF_ERROR(c->Subshape(input_shape, ix, &suffix_output));
2403       ShapeHandle suffix_updates;
2404       TF_RETURN_IF_ERROR(
2405           c->Subshape(updates_shape, outer_dims, &suffix_updates));
2406       s = c->Merge(suffix_output, suffix_updates, &unused);
2407       if (!s.ok()) {
2408         return errors::InvalidArgument(
2409             "Dimensions [", ix, ",", c->Rank(input_shape),
2410             ") of input[shape=", c->DebugString(input_shape),
2411             "] = ", c->DebugString(suffix_output), " must match dimensions [",
2412             outer_dims, ",", c->Rank(updates_shape),
2413             ") of updates[shape=", c->DebugString(updates_shape),
2414             "] = ", c->DebugString(suffix_updates), ": ", s.error_message());
2415       }
2416     }
2417   }
2418 
2419   if (c->input_handle_shapes_and_types(0) == nullptr && c->num_outputs() > 0) {
2420     // This is called for tf.scatter_nd; output is a tensor with this shape.
2421     c->set_output(0, input_shape);
2422   }
2423   return Status::OK();
2424 }
2425 
ExplicitShape(InferenceContext * c)2426 Status ExplicitShape(InferenceContext* c) {
2427   PartialTensorShape shape;
2428   TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape));
2429   ShapeHandle output_shape;
2430   TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(shape, &output_shape));
2431   c->set_output(0, output_shape);
2432   return Status::OK();
2433 }
2434 
ExplicitShapes(InferenceContext * c)2435 Status ExplicitShapes(InferenceContext* c) {
2436   std::vector<PartialTensorShape> shapes;
2437   TF_RETURN_IF_ERROR(c->GetAttr("shapes", &shapes));
2438   if (shapes.empty()) {
2439     return errors::Internal("shapes attribute is empty");
2440   }
2441   for (int i = 0, end = shapes.size(); i < end; ++i) {
2442     ShapeHandle output_shape;
2443     TF_RETURN_IF_ERROR(
2444         c->MakeShapeFromPartialTensorShape(shapes[i], &output_shape));
2445     c->set_output(i, output_shape);
2446   }
2447   return Status::OK();
2448 }
2449 
SparseReduceShapeFn(InferenceContext * c)2450 Status SparseReduceShapeFn(InferenceContext* c) {
2451   // Input 0: input_indices
2452   // Input 1: input_values
2453   // Input 2: input_shape
2454   // Input 3: reduction_axes
2455   // Attr: keep_dims
2456   bool keep_dims = false;
2457   TF_RETURN_IF_ERROR(c->GetAttr("keep_dims", &keep_dims));
2458 
2459   const Tensor* shape_tensor = c->input_tensor(2);
2460   const Tensor* axes_tensor = c->input_tensor(3);
2461   if (shape_tensor != nullptr && axes_tensor != nullptr) {
2462     auto shape_vec = shape_tensor->flat<int64>();
2463     auto axes_vec = axes_tensor->flat<int32>();
2464 
2465     int64_t ndims = shape_vec.size();
2466     absl::flat_hash_set<int64> axes;
2467     if (ndims == 0)
2468       return errors::InvalidArgument(
2469           "Number of dims in shape tensor must not be 0");
2470     for (int i = 0; i < axes_vec.size(); i++) {
2471       axes.insert((axes_vec(i) + ndims) % ndims);
2472     }
2473 
2474     std::vector<DimensionHandle> dims;
2475     if (keep_dims) {
2476       dims.reserve(ndims);
2477       for (int d = 0; d < ndims; ++d) {
2478         if (axes.find(d) == axes.end()) {
2479           dims.push_back(c->MakeDim(shape_vec(d)));
2480         } else {
2481           dims.push_back(c->MakeDim(1));
2482         }
2483       }
2484     } else {
2485       for (int d = 0; d < ndims; ++d) {
2486         if (axes.find(d) == axes.end()) {
2487           dims.push_back(c->MakeDim(shape_vec(d)));
2488         }
2489       }
2490     }
2491 
2492     c->set_output(0, c->MakeShape(dims));
2493     return Status::OK();
2494   }
2495   return UnknownShape(c);
2496 }
2497 
QuantizedConv2DShape(InferenceContext * c)2498 Status QuantizedConv2DShape(InferenceContext* c) {
2499   TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c));
2500   ShapeHandle unused;
2501   TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
2502   TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
2503   TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
2504   TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused));
2505   c->set_output(1, c->Scalar());
2506   c->set_output(2, c->Scalar());
2507   return Status::OK();
2508 }
2509 
QuantizedAvgPoolShape(InferenceContext * c)2510 Status QuantizedAvgPoolShape(InferenceContext* c) {
2511   TF_RETURN_IF_ERROR(shape_inference::AvgPoolShape(c));
2512   ShapeHandle unused;
2513   TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
2514   TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
2515   c->set_output(1, c->Scalar());
2516   c->set_output(2, c->Scalar());
2517   return Status::OK();
2518 }
2519 
QuantizeV2Shape(InferenceContext * c)2520 Status QuantizeV2Shape(InferenceContext* c) {
2521   int axis = -1;
2522   Status s = c->GetAttr("axis", &axis);
2523   if (!s.ok() && s.code() != error::NOT_FOUND) {
2524     return s;
2525   }
2526   const int minmax_rank = (axis == -1) ? 0 : 1;
2527   TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c));
2528   ShapeHandle minmax;
2529   TF_RETURN_IF_ERROR(c->WithRank(c->input(1), minmax_rank, &minmax));
2530   TF_RETURN_IF_ERROR(c->WithRank(c->input(2), minmax_rank, &minmax));
2531   if (axis != -1) {
2532     ShapeHandle input;
2533     TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), axis + 1, &input));
2534     DimensionHandle depth;
2535     TF_RETURN_IF_ERROR(
2536         c->Merge(c->Dim(minmax, 0), c->Dim(input, axis), &depth));
2537   }
2538   c->set_output(1, minmax);
2539   c->set_output(2, minmax);
2540   return Status::OK();
2541 }
2542 
2543 }  // namespace shape_inference
2544 
2545 }  // namespace tensorflow
2546