• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/framework/common_shape_fns.h"
16 
17 #include "absl/container/flat_hash_map.h"
18 #include "absl/container/flat_hash_set.h"
19 #include "absl/strings/match.h"
20 #include "absl/strings/str_split.h"
21 #include "absl/strings/string_view.h"
22 #include "tensorflow/core/framework/attr_value.pb.h"
23 #include "tensorflow/core/framework/shape_inference.h"
24 #include "tensorflow/core/lib/core/errors.h"
25 #include "tensorflow/core/lib/gtl/inlined_vector.h"
26 #include "tensorflow/core/util/einsum_op_util.h"
27 
28 namespace tensorflow {
29 
30 namespace shape_inference {
31 
32 // The V2 version computes windowed output size with arbitrary dilation_rate and
33 // explicit padding, while the original version only handles the cases where
34 // dilation_rates equal to 1 and the padding is SAME or VALID.
GetWindowedOutputSizeFromDimsV2(shape_inference::InferenceContext * c,shape_inference::DimensionHandle input_size,shape_inference::DimensionOrConstant filter_size,int64_t dilation_rate,int64_t stride,Padding padding_type,int64_t padding_before,int64_t padding_after,shape_inference::DimensionHandle * output_size)35 Status GetWindowedOutputSizeFromDimsV2(
36     shape_inference::InferenceContext* c,
37     shape_inference::DimensionHandle input_size,
38     shape_inference::DimensionOrConstant filter_size, int64_t dilation_rate,
39     int64_t stride, Padding padding_type, int64_t padding_before,
40     int64_t padding_after, shape_inference::DimensionHandle* output_size) {
41   if (stride <= 0) {
42     return errors::InvalidArgument("Stride must be > 0, but got ", stride);
43   }
44 
45   if (dilation_rate < 1) {
46     return errors::InvalidArgument("Dilation rate must be >= 1, but got ",
47                                    dilation_rate);
48   }
49 
50   // See also the parallel implementation in GetWindowedOutputSizeVerbose.
51   switch (padding_type) {
52     case Padding::VALID:
53       padding_before = padding_after = 0;
54       TF_FALLTHROUGH_INTENDED;
55     case Padding::EXPLICIT:
56       TF_RETURN_IF_ERROR(
57           c->Add(input_size, padding_before + padding_after, &input_size));
58       if (dilation_rate > 1) {
59         DimensionHandle window_size;
60         TF_RETURN_IF_ERROR(
61             c->Subtract(c->MakeDim(filter_size), 1, &window_size));
62         TF_RETURN_IF_ERROR(
63             c->Multiply(window_size, dilation_rate, &window_size));
64         TF_RETURN_IF_ERROR(c->Add(window_size, 1, &window_size));
65         TF_RETURN_IF_ERROR(c->Subtract(input_size, window_size, output_size));
66       } else {
67         TF_RETURN_IF_ERROR(c->Subtract(input_size, filter_size, output_size));
68       }
69       TF_RETURN_IF_ERROR(c->Add(*output_size, stride, output_size));
70       TF_RETURN_IF_ERROR(c->Divide(*output_size, stride,
71                                    /*evenly_divisible=*/false, output_size));
72       break;
73     case Padding::SAME:
74       TF_RETURN_IF_ERROR(c->Add(input_size, stride - 1, output_size));
75       TF_RETURN_IF_ERROR(c->Divide(*output_size, stride,
76                                    /*evenly_divisible=*/false, output_size));
77       break;
78   }
79   return OkStatus();
80 }
81 
GetWindowedOutputSizeFromDims(shape_inference::InferenceContext * c,shape_inference::DimensionHandle input_size,shape_inference::DimensionOrConstant filter_size,int64_t stride,Padding padding_type,shape_inference::DimensionHandle * output_size)82 Status GetWindowedOutputSizeFromDims(
83     shape_inference::InferenceContext* c,
84     shape_inference::DimensionHandle input_size,
85     shape_inference::DimensionOrConstant filter_size, int64_t stride,
86     Padding padding_type, shape_inference::DimensionHandle* output_size) {
87   if (padding_type == Padding::EXPLICIT) {
88     return errors::Internal(
89         "GetWindowedOutputSizeFromDims does not handle EXPLICIT padding; call "
90         "GetWindowedOutputSizeFromDimsV2 instead");
91   }
92   return GetWindowedOutputSizeFromDimsV2(c, input_size, filter_size,
93                                          /*dilation_rate=*/1, stride,
94                                          padding_type,
95                                          // Give dummy values of -1 to
96                                          // padding_before and padding_after,
97                                          // since explicit padding is not used.
98                                          -1, -1, output_size);
99 }
100 
UnchangedShape(shape_inference::InferenceContext * c)101 Status UnchangedShape(shape_inference::InferenceContext* c) {
102   c->set_output(0, c->input(0));
103   auto* handle_data = c->input_handle_shapes_and_types(0);
104   if (handle_data != nullptr) {
105     c->set_output_handle_shapes_and_types(0, *handle_data);
106   }
107   return OkStatus();
108 }
109 
MatMulShape(shape_inference::InferenceContext * c)110 Status MatMulShape(shape_inference::InferenceContext* c) {
111   ShapeHandle a;
112   TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &a));
113 
114   ShapeHandle b;
115   TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &b));
116 
117   bool transpose_a, transpose_b;
118   TF_RETURN_IF_ERROR(c->GetAttr("transpose_a", &transpose_a));
119   TF_RETURN_IF_ERROR(c->GetAttr("transpose_b", &transpose_b));
120   DimensionHandle output_rows = transpose_a ? c->Dim(a, 1) : c->Dim(a, 0);
121   DimensionHandle output_cols = transpose_b ? c->Dim(b, 0) : c->Dim(b, 1);
122 
123   // Validate that the inner shapes are compatible.
124   DimensionHandle inner_a = transpose_a ? c->Dim(a, 0) : c->Dim(a, 1);
125   DimensionHandle inner_b = transpose_b ? c->Dim(b, 1) : c->Dim(b, 0);
126   DimensionHandle merged;
127   TF_RETURN_IF_ERROR(c->Merge(inner_a, inner_b, &merged));
128 
129   c->set_output(0, c->Matrix(output_rows, output_cols));
130   return OkStatus();
131 }
132 
133 namespace {
134 
135 // Validate that an Einsum subscript contains exactly one or zero ellipsis; and
136 // that periods (.) occur only within an ellipses (...).
ValidateEinsumEllipsis(absl::string_view subscript,bool * found_ellipsis)137 Status ValidateEinsumEllipsis(absl::string_view subscript,
138                               bool* found_ellipsis) {
139   const int num_periods = absl::c_count(subscript, '.');
140   if (num_periods != 0 && num_periods != 3) {
141     return errors::InvalidArgument(
142         "Expected at most one ellipsis (...), but found ", num_periods,
143         " periods (.) in the input subscript: ", subscript);
144   }
145   if (num_periods == 3 && !absl::StrContains(subscript, "...")) {
146     return errors::InvalidArgument(
147         "Periods found outside of ellipsis in subscript: ", subscript);
148   }
149   *found_ellipsis = num_periods > 0;
150   return OkStatus();
151 }
152 
153 }  // namespace
154 
EinsumShape(shape_inference::InferenceContext * c)155 Status EinsumShape(shape_inference::InferenceContext* c) {
156   // We assume that the equation has a valid format. Either (x),(y)->(z)
157   // or (x)->(z), where each of (x), (y) and (z) are concatenation of zero or
158   // more latin alphabets and contains at most one ellipsis ('...').
159   string equation;
160   TF_RETURN_IF_ERROR(c->GetAttr("equation", &equation));
161   gtl::InlinedVector<string, 2> input_labels;
162   string output_labels;
163   TF_RETURN_IF_ERROR(
164       ValidateEinsumEquation(equation, &input_labels, &output_labels));
165 
166   if (c->num_inputs() == 0 || c->num_inputs() > 2) {
167     return errors::InvalidArgument("Expected either 1 or 2 inputs but got: ",
168                                    c->num_inputs());
169   }
170   const int input_labels_size = input_labels.size();
171   if (c->num_inputs() != input_labels_size) {
172     return errors::InvalidArgument("Expected ", input_labels.size(),
173                                    " inputs for equation ", equation,
174                                    " but got: ", c->num_inputs());
175   }
176 
177   // Validate input subscripts, build the label to dimension mapping and obtain
178   // the broadcast shapes that map to ellipsis.
179   absl::flat_hash_map<char, DimensionHandle> label_to_dimension;
180   gtl::InlinedVector<ShapeHandle, 2> input_bcast_shapes(c->num_inputs());
181   for (int i = 0, end = c->num_inputs(); i < end; ++i) {
182     bool has_ellipsis = false;
183     TF_RETURN_IF_ERROR(ValidateEinsumEllipsis(input_labels[i], &has_ellipsis));
184     ShapeHandle input_shape = c->input(i);
185     // Validate that the input rank is sufficient for the given number of named
186     // labels.
187     if (c->RankKnown(input_shape)) {
188       if (has_ellipsis) {
189         const int num_named_labels =
190             static_cast<int>(input_labels[i].size()) - 3;
191         TF_RETURN_WITH_CONTEXT_IF_ERROR(
192             c->WithRankAtLeast(input_shape, num_named_labels, &input_shape),
193             " for ", i, "th input and equation: ", equation);
194       } else {
195         const int num_named_labels = static_cast<int>(input_labels[i].size());
196         TF_RETURN_WITH_CONTEXT_IF_ERROR(
197             c->WithRank(input_shape, num_named_labels, &input_shape), " for ",
198             i, "th input and equation: ", equation);
199       }
200     }
201 
202     bool seen_ellipsis = false;
203     input_bcast_shapes[i] = c->Scalar();
204     // Run through the input labels; populate label_to_dimension mapping and
205     // compute the broadcast shapes corresponding to the ellipsis (if present).
206     for (int label_idx = 0, end = input_labels[i].size(); label_idx < end;
207          ++label_idx) {
208       const char label = input_labels[i][label_idx];
209       // Calculate the input axis that the current label is referring to. After
210       // the ellipsis, the axis may be found by using negative indices; i.e the
211       // (rank - k)th dimension corresponds to the (num_labels - k)th label.
212       const int64_t axis_before_ellipsis = label_idx;
213       const int64_t axis_after_ellipsis =
214           c->RankKnown(input_shape)
215               ? label_idx + c->Rank(input_shape) - input_labels[i].size()
216               : -1;
217 
218       // Populate the input broadcast shape when we encounter an ellipsis (...).
219       if (label == '.') {
220         if (!c->RankKnown(input_shape)) {
221           input_bcast_shapes[i] = c->UnknownShape();
222         } else {
223           // The broadcast shape runs till the named label right after the
224           // ellipsis, the label with index (label_idx + 3).
225           TF_RETURN_IF_ERROR(c->Subshape(input_shape, axis_before_ellipsis,
226                                          axis_after_ellipsis + 3,
227                                          &input_bcast_shapes[i]));
228         }
229         label_idx += 2;  // Skip the rest of the ellipsis.
230         seen_ellipsis = true;
231         continue;
232       }
233       // Obtain the dimension that the current label corresponds to.
234       int64_t axis = seen_ellipsis ? axis_after_ellipsis : axis_before_ellipsis;
235       DimensionHandle new_dim = c->RankKnown(input_shape)
236                                     ? c->Dim(input_shape, axis)
237                                     : c->UnknownDim();
238       // If we've seen this label before, make sure previous and current
239       // dimensions are compatible.
240       if (label_to_dimension.contains(label)) {
241         DimensionHandle merged;
242         TF_RETURN_IF_ERROR(
243             c->Merge(label_to_dimension[label], new_dim, &merged));
244         label_to_dimension[label] = merged;
245       } else {
246         label_to_dimension[label] = new_dim;
247       }
248     }
249   }
250 
251   // For two inputs, broadcast the two input broadcast shapes to create the
252   // output broadcast shape. For one input, just copy the single broadcast
253   // shape.
254   ShapeHandle output_bcast_shape;
255   if (input_bcast_shapes.size() == 1) {
256     output_bcast_shape = input_bcast_shapes[0];
257   } else if (input_bcast_shapes.size() == 2) {
258     TF_RETURN_IF_ERROR(BroadcastBinaryOpOutputShapeFnHelper(
259         c, input_bcast_shapes[0], input_bcast_shapes[1], true,
260         &output_bcast_shape));
261   }
262 
263   bool output_has_ellipsis = false;
264   TF_RETURN_IF_ERROR(
265       ValidateEinsumEllipsis(output_labels, &output_has_ellipsis));
266   if (output_has_ellipsis) {
267     // If the output subscript has ellipsis and the output broadcast rank is
268     // unknown, then the output shape should have unknown rank.
269     if (!c->RankKnown(output_bcast_shape)) {
270       c->set_output(0, c->UnknownShape());
271       return OkStatus();
272     }
273   } else {
274     // If the output subscripts don't have ellipsis then make sure the output
275     // broadcasting shape is empty.
276     TF_RETURN_WITH_CONTEXT_IF_ERROR(
277         c->WithRankAtMost(output_bcast_shape, 0, &output_bcast_shape),
278         " for einsum equation '", equation,
279         "' without ellipsis (...) in the output subscripts where input(s) have "
280         "non-empty broadcasting shape");
281     output_bcast_shape = c->Scalar();
282   }
283 
284   // Create the output shape from output labels and label_to_dimension mapping.
285   std::vector<DimensionHandle> output_dims;
286   for (int label_idx = 0, end = output_labels.size(); label_idx < end;
287        ++label_idx) {
288     const char label = output_labels[label_idx];
289     // Append the output_bcast_shape when the ellipsis is encountered.
290     if (label == '.') {
291       for (int k = 0; k < c->Rank(output_bcast_shape); ++k) {
292         output_dims.push_back(c->Dim(output_bcast_shape, k));
293       }
294       label_idx += 2;  // Skip the rest of the ellipsis.
295       continue;
296     }
297     auto dimension_it = label_to_dimension.find(label);
298     if (dimension_it == label_to_dimension.end()) {
299       return errors::InvalidArgument(
300           "Einsum output subscripts for equation '", equation, "' has label '",
301           label, "' which is not present in the input subscripts");
302     }
303     output_dims.push_back(dimension_it->second);
304   }
305   c->set_output(0, c->MakeShape(output_dims));
306   return OkStatus();
307 }
308 
BatchMatMulV2Shape(shape_inference::InferenceContext * c)309 Status BatchMatMulV2Shape(shape_inference::InferenceContext* c) {
310   ShapeHandle a_shape;
311   ShapeHandle b_shape;
312   TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &a_shape));
313   TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 2, &b_shape));
314 
315   // Determine output rows and columns.
316   bool adj_x;
317   bool adj_y;
318   TF_RETURN_IF_ERROR(c->GetAttr("adj_x", &adj_x));
319   TF_RETURN_IF_ERROR(c->GetAttr("adj_y", &adj_y));
320   DimensionHandle output_rows = c->Dim(a_shape, adj_x ? -1 : -2);
321   DimensionHandle output_cols = c->Dim(b_shape, adj_y ? -2 : -1);
322 
323   // Inner dimensions should be compatible.
324   DimensionHandle inner_merged;
325   TF_RETURN_IF_ERROR(c->Merge(c->Dim(a_shape, adj_x ? -2 : -1),
326                               c->Dim(b_shape, adj_y ? -1 : -2), &inner_merged));
327 
328   // Batch dimensions should broadcast with each other.
329   ShapeHandle a_batch_shape;
330   ShapeHandle b_batch_shape;
331   ShapeHandle output_batch_shape;
332   TF_RETURN_IF_ERROR(c->Subshape(a_shape, 0, -2, &a_batch_shape));
333   TF_RETURN_IF_ERROR(c->Subshape(b_shape, 0, -2, &b_batch_shape));
334 
335   TF_RETURN_IF_ERROR(BroadcastBinaryOpOutputShapeFnHelper(
336       c, a_batch_shape, b_batch_shape, true, &output_batch_shape));
337 
338   ShapeHandle output_shape;
339   TF_RETURN_IF_ERROR(c->Concatenate(
340       output_batch_shape, c->Matrix(output_rows, output_cols), &output_shape));
341 
342   c->set_output(0, output_shape);
343   return OkStatus();
344 }
345 
BatchMatMulShape(shape_inference::InferenceContext * c)346 Status BatchMatMulShape(shape_inference::InferenceContext* c) {
347   ShapeHandle a_shape;
348   ShapeHandle b_shape;
349   TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &a_shape));
350   TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 2, &b_shape));
351 
352   // Determine output rows and cols.
353   bool adj_x;
354   bool adj_y;
355   TF_RETURN_IF_ERROR(c->GetAttr("adj_x", &adj_x));
356   TF_RETURN_IF_ERROR(c->GetAttr("adj_y", &adj_y));
357   DimensionHandle output_rows = c->Dim(a_shape, adj_x ? -1 : -2);
358   DimensionHandle output_cols = c->Dim(b_shape, adj_y ? -2 : -1);
359 
360   // Batch dims match between inputs.
361   ShapeHandle a_batch_dims;
362   ShapeHandle b_batch_dims;
363   ShapeHandle batch_dims;
364   TF_RETURN_IF_ERROR(c->Subshape(a_shape, 0, -2, &a_batch_dims));
365   TF_RETURN_IF_ERROR(c->Subshape(b_shape, 0, -2, &b_batch_dims));
366   TF_RETURN_IF_ERROR(c->Merge(a_batch_dims, b_batch_dims, &batch_dims));
367 
368   // Assert inner dims match.
369   DimensionHandle unused;
370   TF_RETURN_IF_ERROR(c->Merge(c->Dim(a_shape, adj_x ? -2 : -1),
371                               c->Dim(b_shape, adj_y ? -1 : -2), &unused));
372 
373   ShapeHandle out;
374   TF_RETURN_IF_ERROR(
375       c->Concatenate(batch_dims, c->Matrix(output_rows, output_cols), &out));
376   c->set_output(0, out);
377   return OkStatus();
378 }
379 
380 // --------------------------------------------------------------------------
381 
BiasAddShape(shape_inference::InferenceContext * c)382 Status BiasAddShape(shape_inference::InferenceContext* c) {
383   ShapeHandle input_shape;
384 
385   // Fetch the data_format attribute, which may not exist.
386   string data_format;
387   Status s = c->GetAttr("data_format", &data_format);
388 
389   if (s.ok() && data_format == "NCHW") {
390     TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 3, &input_shape));
391   } else {
392     TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input_shape));
393   }
394 
395   ShapeHandle bias_shape;
396   TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &bias_shape));
397   DimensionHandle bias_dim = c->Dim(bias_shape, 0);
398 
399   // If rank unknown, return unknown shape.
400   if (!c->RankKnown(input_shape)) {
401     c->set_output(0, c->UnknownShape());
402     return OkStatus();
403   }
404 
405   // Output has the same shape as the input, and matches the length of
406   // the bias in its bias dimension.
407   ShapeHandle output_shape;
408   if (s.ok() && data_format == "NCHW") {
409     // Merge the length of bias_shape into the third to last dimension
410     ShapeHandle first;
411     TF_RETURN_IF_ERROR(c->Subshape(input_shape, 0, 1, &first));
412 
413     ShapeHandle last;
414     TF_RETURN_IF_ERROR(c->Subshape(input_shape, 2, &last));
415 
416     DimensionHandle input_bias_dim = c->Dim(input_shape, 1);
417     DimensionHandle merged_bias_dim;
418     TF_RETURN_IF_ERROR(c->Merge(input_bias_dim, bias_dim, &merged_bias_dim));
419     ShapeHandle merged_bias = c->Vector(merged_bias_dim);
420 
421     ShapeHandle temp;
422     TF_RETURN_IF_ERROR(c->Concatenate(first, merged_bias, &temp));
423     TF_RETURN_IF_ERROR(c->Concatenate(temp, last, &output_shape));
424   } else {
425     ShapeHandle all_but_bias;
426     TF_RETURN_IF_ERROR(c->Subshape(input_shape, 0, -1, &all_but_bias));
427 
428     DimensionHandle input_bias_dim = c->Dim(input_shape, -1);
429     DimensionHandle merged_bias_dim;
430     TF_RETURN_IF_ERROR(c->Merge(input_bias_dim, bias_dim, &merged_bias_dim));
431 
432     ShapeHandle merged_bias = c->Vector(merged_bias_dim);
433     TF_RETURN_IF_ERROR(
434         c->Concatenate(all_but_bias, merged_bias, &output_shape));
435   }
436 
437   c->set_output(0, output_shape);
438   return OkStatus();
439 }
440 
BiasAddGradShape(shape_inference::InferenceContext * c)441 Status BiasAddGradShape(shape_inference::InferenceContext* c) {
442   ShapeHandle input_shape;
443   // Fetch the data_format attribute, which may not exist.
444   string data_format;
445   Status s = c->GetAttr("data_format", &data_format);
446 
447   if (s.ok() && data_format == "NCHW") {
448     TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 3, &input_shape));
449     c->set_output(0, c->Vector(c->Dim(input_shape, 1)));
450   } else {
451     TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input_shape));
452     c->set_output(0, c->Vector(c->Dim(input_shape, -1)));
453   }
454 
455   return OkStatus();
456 }
457 
CheckFormatConstraintsOnShape(const TensorFormat tensor_format,const ShapeHandle shape_handle,const string & tensor_name,shape_inference::InferenceContext * c)458 Status CheckFormatConstraintsOnShape(const TensorFormat tensor_format,
459                                      const ShapeHandle shape_handle,
460                                      const string& tensor_name,
461                                      shape_inference::InferenceContext* c) {
462   if (tensor_format == FORMAT_NCHW_VECT_C) {
463     // Check that the vect dim has size 4 or 32.
464     const int num_dims = c->Rank(shape_handle);
465     DimensionHandle vect_dim = c->Dim(
466         shape_handle, GetTensorInnerFeatureDimIndex(num_dims, tensor_format));
467     int64_t vect_dim_val = c->Value(vect_dim);
468     if (vect_dim_val != 4 && vect_dim_val != 32) {
469       return errors::InvalidArgument(
470           "VECT_C dimension must be 4 or 32, but is ", vect_dim_val);
471     }
472   }
473 
474   return OkStatus();
475 }
476 
DatasetIteratorShape(shape_inference::InferenceContext * c)477 Status DatasetIteratorShape(shape_inference::InferenceContext* c) {
478   shape_inference::ShapeHandle unused;
479   TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
480   std::vector<PartialTensorShape> output_shapes;
481   TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes));
482   const int output_shapes_size = output_shapes.size();
483   if (output_shapes_size != c->num_outputs()) {
484     return errors::InvalidArgument(
485         "`output_shapes` must be the same length as `output_types` (",
486         output_shapes.size(), " vs. ", c->num_outputs());
487   }
488   for (size_t i = 0; i < output_shapes.size(); ++i) {
489     shape_inference::ShapeHandle output_shape_handle;
490     TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(
491         output_shapes[i], &output_shape_handle));
492     c->set_output(static_cast<int>(i), output_shape_handle);
493   }
494   return OkStatus();
495 }
496 
MakeShapeFromFormat(TensorFormat format,DimensionOrConstant N,const std::vector<DimensionOrConstant> & spatial,DimensionOrConstant C,ShapeHandle * out,shape_inference::InferenceContext * context)497 Status MakeShapeFromFormat(TensorFormat format, DimensionOrConstant N,
498                            const std::vector<DimensionOrConstant>& spatial,
499                            DimensionOrConstant C, ShapeHandle* out,
500                            shape_inference::InferenceContext* context) {
501   const int num_dims = GetTensorDimsFromSpatialDims(spatial.size(), format);
502   std::vector<DimensionHandle> dims_actual(num_dims);
503   dims_actual[GetTensorBatchDimIndex(num_dims, format)] = context->MakeDim(N);
504   int outer_c_index = GetTensorFeatureDimIndex(num_dims, format);
505   dims_actual[outer_c_index] = context->MakeDim(C);
506   if (format == FORMAT_NCHW_VECT_C) {
507     dims_actual[GetTensorInnerFeatureDimIndex(num_dims, format)] =
508         context->MakeDim(4);
509   } else if (format == FORMAT_NHWC_VECT_W) {
510     dims_actual[GetTensorInnerWidthDimIndex(num_dims, format)] =
511         context->MakeDim(4);
512   }
513   for (int spatial_dim = 0, end = spatial.size(); spatial_dim < end;
514        spatial_dim++) {
515     dims_actual[GetTensorSpatialDimIndex(num_dims, format, spatial_dim)] =
516         context->MakeDim(spatial[spatial_dim]);
517   }
518   *out = context->MakeShape(dims_actual);
519   return OkStatus();
520 }
521 
DimensionsFromShape(ShapeHandle shape,TensorFormat format,DimensionHandle * batch_dim,gtl::MutableArraySlice<DimensionHandle> spatial_dims,DimensionHandle * filter_dim,InferenceContext * context)522 Status DimensionsFromShape(ShapeHandle shape, TensorFormat format,
523                            DimensionHandle* batch_dim,
524                            gtl::MutableArraySlice<DimensionHandle> spatial_dims,
525                            DimensionHandle* filter_dim,
526                            InferenceContext* context) {
527   const int32_t rank =
528       GetTensorDimsFromSpatialDims(spatial_dims.size(), format);
529   // Batch.
530   *batch_dim = context->Dim(shape, GetTensorBatchDimIndex(rank, format));
531   // Spatial.
532   for (int spatial_dim_index = 0, end = spatial_dims.size();
533        spatial_dim_index < end; ++spatial_dim_index) {
534     spatial_dims[spatial_dim_index] = context->Dim(
535         shape, GetTensorSpatialDimIndex(rank, format, spatial_dim_index));
536   }
537   // Channel.
538   *filter_dim = context->Dim(shape, GetTensorFeatureDimIndex(rank, format));
539   if (format == FORMAT_NCHW_VECT_C) {
540     TF_RETURN_IF_ERROR(context->Multiply(
541         *filter_dim,
542         context->Dim(shape, GetTensorInnerFeatureDimIndex(rank, format)),
543         filter_dim));
544   }
545   return OkStatus();
546 }
547 
548 // vect_size must be provided if format is NCHW_VECT_C.
ShapeFromDimensions(DimensionHandle batch_dim,gtl::ArraySlice<DimensionHandle> spatial_dims,DimensionHandle filter_dim,TensorFormat format,absl::optional<DimensionHandle> vect_size,InferenceContext * context,ShapeHandle * shape)549 Status ShapeFromDimensions(DimensionHandle batch_dim,
550                            gtl::ArraySlice<DimensionHandle> spatial_dims,
551                            DimensionHandle filter_dim, TensorFormat format,
552                            absl::optional<DimensionHandle> vect_size,
553                            InferenceContext* context, ShapeHandle* shape) {
554   const int32_t rank =
555       GetTensorDimsFromSpatialDims(spatial_dims.size(), format);
556   std::vector<DimensionHandle> out_dims(rank);
557 
558   // Batch.
559   out_dims[tensorflow::GetTensorBatchDimIndex(rank, format)] = batch_dim;
560   // Spatial.
561   for (int spatial_dim_index = 0, end = spatial_dims.size();
562        spatial_dim_index < end; ++spatial_dim_index) {
563     out_dims[tensorflow::GetTensorSpatialDimIndex(
564         rank, format, spatial_dim_index)] = spatial_dims[spatial_dim_index];
565   }
566   // Channel.
567   if (format == tensorflow::FORMAT_NCHW_VECT_C) {
568     // When format is NCHW_VECT_C, factor the feature map count into the outer
569     // feature count and the inner feature count (4 or 32).
570     CHECK(vect_size.has_value());  // Crash ok.
571     TF_RETURN_IF_ERROR(context->Divide(
572         filter_dim, *vect_size, /*evenly_divisible=*/true,
573         &out_dims[tensorflow::GetTensorFeatureDimIndex(rank, format)]));
574     out_dims[GetTensorInnerFeatureDimIndex(rank, format)] = *vect_size;
575   } else {
576     out_dims[tensorflow::GetTensorFeatureDimIndex(rank, format)] = filter_dim;
577   }
578 
579   *shape = context->MakeShape(out_dims);
580   return OkStatus();
581 }
582 
583 namespace {
584 
Conv2DShapeImpl(shape_inference::InferenceContext * c,bool supports_explicit_padding)585 Status Conv2DShapeImpl(shape_inference::InferenceContext* c,
586                        bool supports_explicit_padding) {
587   string data_format_str, filter_format_str;
588   if (!c->GetAttr("data_format", &data_format_str).ok()) {
589     data_format_str = "NHWC";
590   }
591   if (!c->GetAttr("filter_format", &filter_format_str).ok()) {
592     filter_format_str = "HWIO";
593   }
594 
595   TensorFormat data_format;
596   if (!FormatFromString(data_format_str, &data_format)) {
597     return errors::InvalidArgument("Invalid data format string: ",
598                                    data_format_str);
599   }
600   FilterTensorFormat filter_format;
601   if (!FilterFormatFromString(filter_format_str, &filter_format)) {
602     return errors::InvalidArgument("Invalid filter format string: ",
603                                    filter_format_str);
604   }
605 
606   constexpr int num_spatial_dims = 2;
607   const int rank = GetTensorDimsFromSpatialDims(num_spatial_dims, data_format);
608   ShapeHandle conv_input_shape;
609   TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &conv_input_shape));
610   TF_RETURN_IF_ERROR(CheckFormatConstraintsOnShape(
611       data_format, conv_input_shape, "conv_input", c));
612 
613   // The filter rank should match the input (4 for NCHW, 5 for NCHW_VECT_C).
614   ShapeHandle filter_shape;
615   TF_RETURN_IF_ERROR(c->WithRank(c->input(1), rank, &filter_shape));
616   TF_RETURN_IF_ERROR(
617       CheckFormatConstraintsOnShape(data_format, filter_shape, "filter", c));
618 
619   std::vector<int32> dilations;
620   TF_RETURN_IF_ERROR(c->GetAttr("dilations", &dilations));
621 
622   if (dilations.size() != 4) {
623     return errors::InvalidArgument(
624         "Conv2D requires the dilation attribute to contain 4 values, but got: ",
625         dilations.size());
626   }
627 
628   std::vector<int32> strides;
629   TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
630 
631   // strides.size() should be 4 (NCHW) even if the input is 5 (NCHW_VECT_C).
632   if (strides.size() != 4) {
633     return errors::InvalidArgument("Conv2D on data format ", data_format_str,
634                                    " requires the stride attribute to contain"
635                                    " 4 values, but got: ",
636                                    strides.size());
637   }
638 
639   const int32_t stride_rows = GetTensorDim(strides, data_format, 'H');
640   const int32_t stride_cols = GetTensorDim(strides, data_format, 'W');
641   const int32_t dilation_rows = GetTensorDim(dilations, data_format, 'H');
642   const int32_t dilation_cols = GetTensorDim(dilations, data_format, 'W');
643 
644   DimensionHandle batch_size_dim;
645   DimensionHandle input_depth_dim;
646   gtl::InlinedVector<DimensionHandle, 2> input_spatial_dims(2);
647   TF_RETURN_IF_ERROR(DimensionsFromShape(
648       conv_input_shape, data_format, &batch_size_dim,
649       absl::MakeSpan(input_spatial_dims), &input_depth_dim, c));
650 
651   DimensionHandle output_depth_dim = c->Dim(
652       filter_shape, GetFilterDimIndex<num_spatial_dims>(filter_format, 'O'));
653   DimensionHandle filter_rows_dim = c->Dim(
654       filter_shape, GetFilterDimIndex<num_spatial_dims>(filter_format, 'H'));
655   DimensionHandle filter_cols_dim = c->Dim(
656       filter_shape, GetFilterDimIndex<num_spatial_dims>(filter_format, 'W'));
657   DimensionHandle filter_input_depth_dim;
658   if (filter_format == FORMAT_OIHW_VECT_I) {
659     TF_RETURN_IF_ERROR(c->Multiply(
660         c->Dim(filter_shape,
661                GetFilterDimIndex<num_spatial_dims>(filter_format, 'I')),
662         c->Dim(filter_shape,
663                GetFilterTensorInnerInputChannelsDimIndex(rank, filter_format)),
664         &filter_input_depth_dim));
665   } else {
666     filter_input_depth_dim = c->Dim(
667         filter_shape, GetFilterDimIndex<num_spatial_dims>(filter_format, 'I'));
668   }
669 
670   // Check that the input tensor and the filter tensor agree on the channel
671   // count.
672   if (c->ValueKnown(input_depth_dim) && c->ValueKnown(filter_input_depth_dim)) {
673     int64_t input_depth_value = c->Value(input_depth_dim),
674             filter_input_depth_value = c->Value(filter_input_depth_dim);
675     if (filter_input_depth_value == 0)
676       return errors::InvalidArgument("Depth of filter must not be 0");
677     if (input_depth_value % filter_input_depth_value != 0)
678       return errors::InvalidArgument(
679           "Depth of input (", input_depth_value,
680           ") is not a multiple of input depth of filter (",
681           filter_input_depth_value, ")");
682     if (input_depth_value != filter_input_depth_value) {
683       int64_t num_groups = input_depth_value / filter_input_depth_value;
684       if (c->ValueKnown(output_depth_dim)) {
685         int64_t output_depth_value = c->Value(output_depth_dim);
686         if (num_groups == 0)
687           return errors::InvalidArgument("Number of groups must not be 0");
688         if (output_depth_value % num_groups != 0)
689           return errors::InvalidArgument(
690               "Depth of output (", output_depth_value,
691               ") is not a multiple of the number of groups (", num_groups, ")");
692       }
693     }
694   }
695 
696   Padding padding;
697   TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
698   std::vector<int64_t> explicit_paddings;
699   if (supports_explicit_padding) {
700     Status s = c->GetAttr("explicit_paddings", &explicit_paddings);
701     // Use the default value, which is an empty list, if the attribute is not
702     // found. Otherwise return the error to the caller.
703     if (!s.ok() && !errors::IsNotFound(s)) {
704       return s;
705     }
706     TF_RETURN_IF_ERROR(CheckValidPadding(padding, explicit_paddings,
707                                          /*num_dims=*/4, data_format));
708   } else {
709     if (padding == Padding::EXPLICIT) {
710       return errors::InvalidArgument(
711           "Expected non-explicit padding but got explicit padding");
712     }
713     std::vector<int64_t> p_list;
714     // `padding_list` attribute is used by Fused int8 convolutions to support
715     // explicit paddings.
716     Status s_p_list = c->GetAttr("padding_list", &p_list);
717     if (!s_p_list.ok() && !errors::IsNotFound(s_p_list)) {
718       return s_p_list;
719     }
720     if (s_p_list.ok() && !p_list.empty()) {
721       padding = Padding::EXPLICIT;
722       explicit_paddings = p_list;
723       TF_RETURN_IF_ERROR(CheckValidPadding(padding, explicit_paddings,
724                                            /*num_dims=*/4, data_format));
725     }
726   }
727 
728   DimensionHandle output_rows, output_cols;
729   int64_t pad_rows_before = -1, pad_rows_after = -1;
730   int64_t pad_cols_before = -1, pad_cols_after = -1;
731   if (padding == Padding::EXPLICIT) {
732     GetExplicitPaddingForDim(explicit_paddings, data_format, 'H',
733                              &pad_rows_before, &pad_rows_after);
734     GetExplicitPaddingForDim(explicit_paddings, data_format, 'W',
735                              &pad_cols_before, &pad_cols_after);
736   }
737   TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDimsV2(
738       c, input_spatial_dims[0], filter_rows_dim, dilation_rows, stride_rows,
739       padding, pad_rows_before, pad_rows_after, &output_rows));
740   TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDimsV2(
741       c, input_spatial_dims[1], filter_cols_dim, dilation_cols, stride_cols,
742       padding, pad_cols_before, pad_cols_after, &output_cols));
743 
744   absl::optional<DimensionHandle> vect_size;
745   if (data_format == FORMAT_NCHW_VECT_C) {
746     vect_size.emplace(c->Dim(conv_input_shape,
747                              GetTensorInnerFeatureDimIndex(rank, data_format)));
748   }
749   ShapeHandle output_shape;
750   TF_RETURN_IF_ERROR(ShapeFromDimensions(
751       batch_size_dim, {output_rows, output_cols}, output_depth_dim, data_format,
752       vect_size, c, &output_shape));
753   c->set_output(0, output_shape);
754   return OkStatus();
755 }
756 
757 }  // namespace
758 
759 // Shape function for Conv2D-like operations that support explicit padding.
Conv2DShapeWithExplicitPadding(shape_inference::InferenceContext * c)760 Status Conv2DShapeWithExplicitPadding(shape_inference::InferenceContext* c) {
761   return Conv2DShapeImpl(c, true);
762 }
763 
764 // Shape function for Conv2D-like operations that do not support explicit
765 // padding.
Conv2DShape(shape_inference::InferenceContext * c)766 Status Conv2DShape(shape_inference::InferenceContext* c) {
767   return Conv2DShapeImpl(c, false);
768 }
769 
770 // TODO(mjanusz): Unify all conv/pooling shape functions.
Conv3DShape(shape_inference::InferenceContext * c)771 Status Conv3DShape(shape_inference::InferenceContext* c) {
772   ShapeHandle input_shape;
773   TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 5, &input_shape));
774   ShapeHandle filter_shape;
775   TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 5, &filter_shape));
776 
777   string data_format;
778   Status s = c->GetAttr("data_format", &data_format);
779 
780   std::vector<int32> dilations;
781   TF_RETURN_IF_ERROR(c->GetAttr("dilations", &dilations));
782 
783   if (dilations.size() != 5) {
784     return errors::InvalidArgument(
785         "Conv3D requires the dilation attribute to contain 5 values, but got: ",
786         dilations.size());
787   }
788 
789   std::vector<int32> strides;
790   TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
791   if (strides.size() != 5) {
792     return errors::InvalidArgument(
793         "Conv3D requires the stride attribute to contain 5 values, but got: ",
794         strides.size());
795   }
796 
797   int32_t stride_planes, stride_rows, stride_cols;
798   int32_t dilation_planes, dilation_rows, dilation_cols;
799   if (s.ok() && data_format == "NCDHW") {
800     // Convert input_shape to NDHWC.
801     auto dim = [&](char dimension) {
802       return c->Dim(input_shape, GetTensorDimIndex<3>(FORMAT_NCHW, dimension));
803     };
804     input_shape =
805         c->MakeShape({{dim('N'), dim('0'), dim('1'), dim('2'), dim('C')}});
806     stride_planes = strides[2];
807     stride_rows = strides[3];
808     stride_cols = strides[4];
809     dilation_planes = dilations[2];
810     dilation_cols = dilations[3];
811     dilation_rows = dilations[4];
812   } else {
813     stride_planes = strides[1];
814     stride_rows = strides[2];
815     stride_cols = strides[3];
816     dilation_planes = dilations[1];
817     dilation_cols = dilations[2];
818     dilation_rows = dilations[3];
819   }
820 
821   DimensionHandle batch_size_dim = c->Dim(input_shape, 0);
822   DimensionHandle in_planes_dim = c->Dim(input_shape, 1);
823   DimensionHandle in_rows_dim = c->Dim(input_shape, 2);
824   DimensionHandle in_cols_dim = c->Dim(input_shape, 3);
825   DimensionHandle input_depth_dim = c->Dim(input_shape, 4);
826 
827   DimensionHandle filter_planes_dim = c->Dim(filter_shape, 0);
828   DimensionHandle filter_rows_dim = c->Dim(filter_shape, 1);
829   DimensionHandle filter_cols_dim = c->Dim(filter_shape, 2);
830   DimensionHandle filter_input_depth_dim = c->Dim(filter_shape, 3);
831   DimensionHandle output_depth_dim = c->Dim(filter_shape, 4);
832 
833   // Check that the input tensor and the filter tensor agree on the channel
834   // count.
835   if (c->ValueKnown(input_depth_dim) && c->ValueKnown(filter_input_depth_dim)) {
836     int64_t input_depth_value = c->Value(input_depth_dim),
837             filter_input_depth_value = c->Value(filter_input_depth_dim);
838     if (filter_input_depth_value == 0)
839       return errors::InvalidArgument("Depth of filter must not be 0");
840     if (input_depth_value % filter_input_depth_value != 0)
841       return errors::InvalidArgument(
842           "Depth of input (", input_depth_value,
843           ") is not a multiple of input depth of filter (",
844           filter_input_depth_value, ")");
845     if (input_depth_value != filter_input_depth_value) {
846       int64_t num_groups = input_depth_value / filter_input_depth_value;
847       if (c->ValueKnown(output_depth_dim)) {
848         int64_t output_depth_value = c->Value(output_depth_dim);
849         if (num_groups == 0)
850           return errors::InvalidArgument("Number of groups must not be 0");
851         if (output_depth_value % num_groups != 0)
852           return errors::InvalidArgument(
853               "Depth of output (", output_depth_value,
854               ") is not a multiple of the number of groups (", num_groups, ")");
855       }
856     }
857   }
858 
859   Padding padding;
860   TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
861   DimensionHandle output_planes, output_rows, output_cols;
862 
863   TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDimsV2(
864       c, in_planes_dim, filter_planes_dim, dilation_planes, stride_planes,
865       padding, -1, -1, &output_planes));
866   TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDimsV2(
867       c, in_rows_dim, filter_rows_dim, dilation_rows, stride_rows, padding, -1,
868       -1, &output_rows));
869   TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDimsV2(
870       c, in_cols_dim, filter_cols_dim, dilation_cols, stride_cols, padding, -1,
871       -1, &output_cols));
872 
873   ShapeHandle output_shape;
874   if (data_format == "NCDHW") {
875     output_shape = c->MakeShape({batch_size_dim, output_depth_dim,
876                                  output_planes, output_rows, output_cols});
877   } else {
878     output_shape = c->MakeShape({batch_size_dim, output_planes, output_rows,
879                                  output_cols, output_depth_dim});
880   }
881   c->set_output(0, output_shape);
882   return OkStatus();
883 }
884 
Conv2DBackpropInputShape(shape_inference::InferenceContext * c)885 Status Conv2DBackpropInputShape(shape_inference::InferenceContext* c) {
886   string data_format_str;
887   if (!c->GetAttr("data_format", &data_format_str).ok()) {
888     data_format_str = "NHWC";
889   }
890   TensorFormat data_format;
891   if (!FormatFromString(data_format_str, &data_format)) {
892     return errors::InvalidArgument("Invalid data format string: ",
893                                    data_format_str);
894   }
895 
896   // For the rest of this function, output_grad_* describes out_backprop and
897   // input_grad_* describes in_backprop.
898   ShapeHandle output_grad_shape = c->input(2);
899   TF_RETURN_IF_ERROR(c->WithRank(output_grad_shape, 4, &output_grad_shape));
900   ShapeHandle filter_shape = c->input(1);
901   TF_RETURN_IF_ERROR(c->WithRank(filter_shape, 4, &filter_shape));
902 
903   DimensionHandle batch_size_dim;
904   DimensionHandle output_grad_depth_dim;
905   gtl::InlinedVector<DimensionHandle, 2> output_grad_spatial_dims(2);
906   TF_RETURN_IF_ERROR(DimensionsFromShape(
907       output_grad_shape, data_format, &batch_size_dim,
908       absl::MakeSpan(output_grad_spatial_dims), &output_grad_depth_dim, c));
909   DimensionHandle unused;
910   TF_RETURN_IF_ERROR(
911       c->Merge(output_grad_depth_dim, c->Dim(filter_shape, 3), &unused));
912 
913   ShapeHandle specified_input_grad_shape;
914   TF_RETURN_IF_ERROR(
915       c->MakeShapeFromShapeTensor(0, &specified_input_grad_shape));
916   if (c->Rank(specified_input_grad_shape) == InferenceContext::kUnknownRank) {
917     TF_RETURN_IF_ERROR(c->WithRank(specified_input_grad_shape, 4,
918                                    &specified_input_grad_shape));
919   }
920 
921   // input_grad_depth_dim doesn't equal c->Dim(filter_shape,2) when the number
922   // of groups is larger than 1. If input_sizes is a 4D shape, we collect
923   // input_grad_depth_dim from input_sizes; otherwise we compute it as
924   // c->Dim(filter_shape,2).
925   DimensionHandle input_grad_depth_dim;
926   gtl::InlinedVector<DimensionHandle, 2> specified_input_grad_spatial_dims(2);
927   int specified_input_grad_rank = c->Rank(specified_input_grad_shape);
928   if (specified_input_grad_rank == 4) {
929     DimensionHandle specified_batch_size_dim;
930     TF_RETURN_IF_ERROR(DimensionsFromShape(
931         specified_input_grad_shape, data_format, &specified_batch_size_dim,
932         absl::MakeSpan(specified_input_grad_spatial_dims),
933         &input_grad_depth_dim, c));
934     TF_RETURN_IF_ERROR(
935         c->Merge(specified_batch_size_dim, batch_size_dim, &unused));
936   } else if (specified_input_grad_rank == 2) {
937     specified_input_grad_spatial_dims[0] =
938         c->Dim(specified_input_grad_shape, 0);
939     specified_input_grad_spatial_dims[1] =
940         c->Dim(specified_input_grad_shape, 1);
941     input_grad_depth_dim = c->Dim(filter_shape, 2);
942   } else {
943     return errors::InvalidArgument(
944         "Conv2DBackpropInput requires input_sizes to contain 4 values or 2 "
945         "values, but got: ",
946         specified_input_grad_rank);
947   }
948 
949   ShapeHandle input_grad_shape;
950   TF_RETURN_IF_ERROR(ShapeFromDimensions(
951       batch_size_dim, specified_input_grad_spatial_dims, input_grad_depth_dim,
952       data_format, /*vect_size=*/absl::nullopt, c, &input_grad_shape));
953   c->set_output(0, input_grad_shape);
954   return OkStatus();
955 }
956 
Conv2DBackpropFilterWithBiasShape(shape_inference::InferenceContext * c)957 Status Conv2DBackpropFilterWithBiasShape(shape_inference::InferenceContext* c) {
958   ShapeHandle input_shape;
959   // Fetch the data_format attribute, which may not exist.
960   string data_format;
961   Status s = c->GetAttr("data_format", &data_format);
962 
963   TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_shape));
964   if (s.ok() && data_format == "NCHW") {
965     c->set_output(1, c->Vector(c->Dim(input_shape, -3)));
966   } else {
967     c->set_output(1, c->Vector(c->Dim(input_shape, -1)));
968   }
969   ShapeHandle sh;
970   TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &sh));
971   TF_RETURN_IF_ERROR(c->WithRank(sh, 4, &sh));
972   c->set_output(0, sh);
973   return OkStatus();
974 }
975 
976 namespace {
977 
DepthwiseConv2DNativeShapeImpl(shape_inference::InferenceContext * c,bool supports_explicit_padding)978 Status DepthwiseConv2DNativeShapeImpl(shape_inference::InferenceContext* c,
979                                       bool supports_explicit_padding) {
980   ShapeHandle input_shape;
981   TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_shape));
982   ShapeHandle filter_shape;
983   TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 4, &filter_shape));
984 
985   std::vector<int32> strides;
986   TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
987 
988   if (strides.size() != 4) {
989     return errors::InvalidArgument(
990         "DepthwiseConv2D requires the stride attribute to contain 4 values, "
991         "but got: ",
992         strides.size());
993   }
994 
995   std::vector<int32> dilations;
996   if (!c->GetAttr("dilations", &dilations).ok()) {
997     dilations.resize(4, 1);
998   }
999 
1000   if (dilations.size() != 4) {
1001     return errors::InvalidArgument(
1002         "DepthwiseConv2D requires the dilations attribute to contain 4 values, "
1003         "but got: ",
1004         dilations.size());
1005   }
1006 
1007   string data_format_str;
1008   Status s = c->GetAttr("data_format", &data_format_str);
1009   TensorFormat data_format;
1010   if (!s.ok() || !FormatFromString(data_format_str, &data_format)) {
1011     data_format = FORMAT_NHWC;
1012   }
1013   int32_t stride_rows;
1014   int32_t stride_cols;
1015   int32_t dilation_rows;
1016   int32_t dilation_cols;
1017   if (data_format == FORMAT_NCHW) {
1018     // Canonicalize input shape to NHWC so the shape inference code below can
1019     // process it.
1020     input_shape =
1021         c->MakeShape({{c->Dim(input_shape, 0), c->Dim(input_shape, 2),
1022                        c->Dim(input_shape, 3), c->Dim(input_shape, 1)}});
1023     stride_rows = strides[2];
1024     stride_cols = strides[3];
1025     dilation_rows = dilations[2];
1026     dilation_cols = dilations[3];
1027   } else {
1028     stride_rows = strides[1];
1029     stride_cols = strides[2];
1030     dilation_rows = dilations[1];
1031     dilation_cols = dilations[2];
1032   }
1033 
1034   DimensionHandle batch_size_dim = c->Dim(input_shape, 0);
1035   DimensionHandle in_rows_dim = c->Dim(input_shape, 1);
1036   DimensionHandle in_cols_dim = c->Dim(input_shape, 2);
1037 
1038   DimensionHandle filter_rows_dim = c->Dim(filter_shape, 0);
1039   DimensionHandle filter_cols_dim = c->Dim(filter_shape, 1);
1040   DimensionHandle input_depth = c->Dim(filter_shape, 2);
1041   DimensionHandle depth_multiplier = c->Dim(filter_shape, 3);
1042 
1043   // Check that the input depths are compatible.
1044   TF_RETURN_IF_ERROR(
1045       c->Merge(c->Dim(input_shape, 3), input_depth, &input_depth));
1046 
1047   DimensionHandle output_depth;
1048   TF_RETURN_IF_ERROR(c->Multiply(input_depth, depth_multiplier, &output_depth));
1049 
1050   Padding padding;
1051   TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
1052 
1053   std::vector<int64_t> explicit_paddings;
1054   if (supports_explicit_padding) {
1055     Status status = c->GetAttr("explicit_paddings", &explicit_paddings);
1056     // Use the default value, which is an empty list, if the attribute is not
1057     // found. Otherwise return the error to the caller.
1058     if (!status.ok() && !errors::IsNotFound(status)) {
1059       return status;
1060     }
1061     TF_RETURN_IF_ERROR(CheckValidPadding(padding, explicit_paddings,
1062                                          /*num_dims=*/4, data_format));
1063   } else {
1064     DCHECK(padding != Padding::EXPLICIT);
1065   }
1066 
1067   // TODO(mrry,shlens): Raise an error if the stride would cause
1068   // information in the input to be ignored. This will require a change
1069   // in the kernel implementation.
1070   DimensionHandle output_rows, output_cols;
1071   int64_t pad_rows_before = -1, pad_rows_after = -1;
1072   int64_t pad_cols_before = -1, pad_cols_after = -1;
1073   if (padding == Padding::EXPLICIT) {
1074     GetExplicitPaddingForDim(explicit_paddings, data_format, 'H',
1075                              &pad_rows_before, &pad_rows_after);
1076     GetExplicitPaddingForDim(explicit_paddings, data_format, 'W',
1077                              &pad_cols_before, &pad_cols_after);
1078   }
1079   TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDimsV2(
1080       c, in_rows_dim, filter_rows_dim, dilation_rows, stride_rows, padding,
1081       pad_rows_before, pad_rows_after, &output_rows));
1082   TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDimsV2(
1083       c, in_cols_dim, filter_cols_dim, dilation_cols, stride_cols, padding,
1084       pad_cols_before, pad_cols_after, &output_cols));
1085 
1086   ShapeHandle output_shape;
1087   if (data_format == FORMAT_NCHW) {
1088     output_shape =
1089         c->MakeShape({batch_size_dim, output_depth, output_rows, output_cols});
1090   } else {
1091     output_shape =
1092         c->MakeShape({batch_size_dim, output_rows, output_cols, output_depth});
1093   }
1094   c->set_output(0, output_shape);
1095   return OkStatus();
1096 }
1097 
1098 };  // namespace
1099 
DepthwiseConv2DNativeShape(shape_inference::InferenceContext * c)1100 Status DepthwiseConv2DNativeShape(shape_inference::InferenceContext* c) {
1101   return DepthwiseConv2DNativeShapeImpl(c, false);
1102 }
1103 
DepthwiseConv2DNativeShapeWithExplicitPadding(shape_inference::InferenceContext * c)1104 Status DepthwiseConv2DNativeShapeWithExplicitPadding(
1105     shape_inference::InferenceContext* c) {
1106   return DepthwiseConv2DNativeShapeImpl(c, true);
1107 }
1108 
AvgPoolShape(shape_inference::InferenceContext * c)1109 Status AvgPoolShape(shape_inference::InferenceContext* c) {
1110   string data_format_str;
1111   TensorFormat data_format;
1112   Status s = c->GetAttr("data_format", &data_format_str);
1113   if (s.ok()) {
1114     FormatFromString(data_format_str, &data_format);
1115   } else {
1116     data_format = FORMAT_NHWC;
1117   }
1118 
1119   const int rank = (data_format == FORMAT_NCHW_VECT_C) ? 5 : 4;
1120   ShapeHandle input_shape;
1121   TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &input_shape));
1122 
1123   TF_RETURN_IF_ERROR(
1124       CheckFormatConstraintsOnShape(data_format, input_shape, "input", c));
1125 
1126   std::vector<int32> strides;
1127   TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
1128   if (strides.size() != 4) {
1129     return errors::InvalidArgument(
1130         "AvgPool requires the stride attribute to contain 4 values, but got: ",
1131         strides.size());
1132   }
1133 
1134   std::vector<int32> kernel_sizes;
1135   TF_RETURN_IF_ERROR(c->GetAttr("ksize", &kernel_sizes));
1136   if (kernel_sizes.size() != 4) {
1137     return errors::InvalidArgument(
1138         "AvgPool requires the ksize attribute to contain 4 values, but got: ",
1139         kernel_sizes.size());
1140   }
1141 
1142   int32_t stride_rows = GetTensorDim(strides, data_format, 'H');
1143   int32_t stride_cols = GetTensorDim(strides, data_format, 'W');
1144   int32_t kernel_rows = GetTensorDim(kernel_sizes, data_format, 'H');
1145   int32_t kernel_cols = GetTensorDim(kernel_sizes, data_format, 'W');
1146 
1147   constexpr int num_spatial_dims = 2;
1148   DimensionHandle batch_size_dim = c->Dim(
1149       input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'N'));
1150   DimensionHandle in_rows_dim = c->Dim(
1151       input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'H'));
1152   DimensionHandle in_cols_dim = c->Dim(
1153       input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'W'));
1154   DimensionHandle depth_dim = c->Dim(
1155       input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'C'));
1156 
1157   Padding padding;
1158   TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
1159 
1160   // TODO(mrry,shlens): Raise an error if the stride would cause
1161   // information in the input to be ignored. This will require a change
1162   // in the kernel implementation.
1163 
1164   DimensionHandle output_rows, output_cols;
1165   TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
1166       c, in_rows_dim, kernel_rows, stride_rows, padding, &output_rows));
1167   TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
1168       c, in_cols_dim, kernel_cols, stride_cols, padding, &output_cols));
1169 
1170   ShapeHandle output_shape;
1171   TF_RETURN_IF_ERROR(MakeShapeFromFormat(data_format, batch_size_dim,
1172                                          {output_rows, output_cols}, depth_dim,
1173                                          &output_shape, c));
1174   c->set_output(0, output_shape);
1175   return OkStatus();
1176 }
1177 
AvgPoolGradShape(shape_inference::InferenceContext * c)1178 Status AvgPoolGradShape(shape_inference::InferenceContext* c) {
1179   ShapeHandle s;
1180   TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s));
1181   TF_RETURN_IF_ERROR(c->WithRank(s, 4, &s));
1182   c->set_output(0, s);
1183   return OkStatus();
1184 }
1185 
FusedBatchNormShape(shape_inference::InferenceContext * c)1186 Status FusedBatchNormShape(shape_inference::InferenceContext* c) {
1187   string data_format_str;
1188   TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format_str));
1189   TensorFormat data_format;
1190   if (!FormatFromString(data_format_str, &data_format)) {
1191     return errors::InvalidArgument("Invalid data format string: ",
1192                                    data_format_str);
1193   }
1194   const int rank =
1195       (data_format_str == "NDHWC" || data_format_str == "NCDHW") ? 5 : 4;
1196   ShapeHandle x;
1197   TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &x));
1198 
1199   bool is_training;
1200   TF_RETURN_IF_ERROR(c->GetAttr("is_training", &is_training));
1201   float exponential_avg_factor;
1202   if (!c->GetAttr("exponential_avg_factor", &exponential_avg_factor).ok()) {
1203     exponential_avg_factor = 1.0f;  // default value
1204   }
1205   int number_inputs = (is_training && exponential_avg_factor == 1.0f) ? 3 : 5;
1206 
1207   int channel_dim_index = GetTensorFeatureDimIndex(rank, data_format);
1208   DimensionHandle channel_dim = c->Dim(x, channel_dim_index);
1209 
1210   // covers scale, offset, and if is_training is false, mean, variance
1211   for (int i = 1; i < number_inputs; ++i) {
1212     ShapeHandle vec;
1213     TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &vec));
1214     TF_RETURN_IF_ERROR(c->Merge(channel_dim, c->Dim(vec, 0), &channel_dim));
1215   }
1216 
1217   ShapeHandle y;
1218   TF_RETURN_IF_ERROR(c->ReplaceDim(x, channel_dim_index, channel_dim, &y));
1219   c->set_output(0, y);
1220   ShapeHandle vector_shape = c->Vector(channel_dim);
1221   c->set_output(1, vector_shape);
1222   c->set_output(2, vector_shape);
1223   c->set_output(3, vector_shape);
1224   c->set_output(4, vector_shape);
1225   return OkStatus();
1226 }
1227 
FusedBatchNormV3Shape(shape_inference::InferenceContext * c)1228 Status FusedBatchNormV3Shape(shape_inference::InferenceContext* c) {
1229   TF_RETURN_IF_ERROR(FusedBatchNormShape(c));
1230   c->set_output(5, c->UnknownShape());
1231   return OkStatus();
1232 }
1233 
FusedBatchNormExShape(shape_inference::InferenceContext * c)1234 Status FusedBatchNormExShape(shape_inference::InferenceContext* c) {
1235   TF_RETURN_IF_ERROR(FusedBatchNormV3Shape(c));
1236 
1237   string data_format_str;
1238   TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format_str));
1239   TensorFormat data_format;
1240   if (!FormatFromString(data_format_str, &data_format)) {
1241     return errors::InvalidArgument("Invalid data format string: ",
1242                                    data_format_str);
1243   }
1244   ShapeHandle x;
1245   TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &x));
1246 
1247   int channel_dim_index = GetTensorFeatureDimIndex(4, data_format);
1248   DimensionHandle channel_dim = c->Dim(x, channel_dim_index);
1249 
1250   // This is a cuDNN implementation constraint.
1251   if (c->ValueKnown(channel_dim) && c->Value(channel_dim) % 4 != 0) {
1252     return errors::InvalidArgument(
1253         "_FusedBatchNormEx channel dimension must be divisible by 4.");
1254   }
1255 
1256   return OkStatus();
1257 }
1258 
FusedBatchNormGradShape(shape_inference::InferenceContext * c)1259 Status FusedBatchNormGradShape(shape_inference::InferenceContext* c) {
1260   string data_format_str;
1261   TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format_str));
1262   TensorFormat data_format;
1263   if (!FormatFromString(data_format_str, &data_format)) {
1264     return errors::InvalidArgument("Invalid data format string: ",
1265                                    data_format_str);
1266   }
1267   const int rank =
1268       (data_format_str == "NDHWC" || data_format_str == "NCDHW") ? 5 : 4;
1269   ShapeHandle y_backprop;
1270   TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &y_backprop));
1271   ShapeHandle x;
1272   TF_RETURN_IF_ERROR(c->WithRank(c->input(1), rank, &x));
1273 
1274   bool is_training;
1275   TF_RETURN_IF_ERROR(c->GetAttr("is_training", &is_training));
1276 
1277   int channel_dim_index = GetTensorFeatureDimIndex(rank, data_format);
1278   DimensionHandle channel_dim = c->Dim(y_backprop, channel_dim_index);
1279   TF_RETURN_IF_ERROR(
1280       c->Merge(channel_dim, c->Dim(x, channel_dim_index), &channel_dim));
1281 
1282   // covers scale, mean (reserve_space_1), variance (reserve_space_2)
1283   for (int i = 2; i < 5; ++i) {
1284     ShapeHandle vec;
1285     TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &vec));
1286     TF_RETURN_IF_ERROR(c->Merge(channel_dim, c->Dim(vec, 0), &channel_dim));
1287   }
1288 
1289   ShapeHandle x_backprop;
1290   TF_RETURN_IF_ERROR(
1291       c->ReplaceDim(y_backprop, channel_dim_index, channel_dim, &x_backprop));
1292   c->set_output(0, x_backprop);
1293   c->set_output(1, c->Vector(channel_dim));
1294   c->set_output(2, c->Vector(channel_dim));
1295   c->set_output(3, c->Vector(0));
1296   c->set_output(4, c->Vector(0));
1297   return OkStatus();
1298 }
1299 
FusedBatchNormGradExShape(shape_inference::InferenceContext * c)1300 Status FusedBatchNormGradExShape(shape_inference::InferenceContext* c) {
1301   TF_RETURN_IF_ERROR(FusedBatchNormGradShape(c));
1302 
1303   int num_side_inputs;
1304   TF_RETURN_IF_ERROR(c->GetAttr("num_side_inputs", &num_side_inputs));
1305   if (num_side_inputs == 0) {
1306     return OkStatus();
1307   }
1308 
1309   string data_format_str;
1310   TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format_str));
1311   TensorFormat data_format;
1312   if (!FormatFromString(data_format_str, &data_format)) {
1313     return errors::InvalidArgument("Invalid data format string: ",
1314                                    data_format_str);
1315   }
1316   const int rank =
1317       (data_format_str == "NDHWC" || data_format_str == "NCDHW") ? 5 : 4;
1318   ShapeHandle y_backprop;
1319   TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &y_backprop));
1320   ShapeHandle x;
1321   TF_RETURN_IF_ERROR(c->WithRank(c->input(1), rank, &x));
1322 
1323   int channel_dim_index = GetTensorFeatureDimIndex(rank, data_format);
1324   DimensionHandle channel_dim = c->Dim(y_backprop, channel_dim_index);
1325   TF_RETURN_IF_ERROR(
1326       c->Merge(channel_dim, c->Dim(x, channel_dim_index), &channel_dim));
1327 
1328   ShapeHandle side_input_backprop;
1329   TF_RETURN_IF_ERROR(c->ReplaceDim(y_backprop, channel_dim_index, channel_dim,
1330                                    &side_input_backprop));
1331 
1332   c->set_output(5, side_input_backprop);
1333   return OkStatus();
1334 }
1335 
ReadDiagIndex(InferenceContext * c,const Tensor * diag_index_tensor,int32 * lower_diag_index,int32 * upper_diag_index)1336 Status ReadDiagIndex(InferenceContext* c, const Tensor* diag_index_tensor,
1337                      int32* lower_diag_index, int32* upper_diag_index) {
1338   // This function assumes that the shape of diag_index_tensor is fully defined.
1339   if (diag_index_tensor->dims() == 0) {
1340     *lower_diag_index = diag_index_tensor->scalar<int32>()();
1341     *upper_diag_index = *lower_diag_index;
1342   } else {
1343     int32_t num_elements = diag_index_tensor->dim_size(0);
1344     if (num_elements == 1) {
1345       *lower_diag_index = diag_index_tensor->vec<int32>()(0);
1346       *upper_diag_index = *lower_diag_index;
1347     } else if (num_elements == 2) {
1348       *lower_diag_index = diag_index_tensor->vec<int32>()(0);
1349       *upper_diag_index = diag_index_tensor->vec<int32>()(1);
1350     } else {
1351       return errors::InvalidArgument(
1352           "diag_index must be a vector with one or two elements. It has ",
1353           num_elements, " elements.");
1354     }
1355   }
1356   return OkStatus();
1357 }
1358 
MatrixDiagPartV2Shape(shape_inference::InferenceContext * c)1359 Status MatrixDiagPartV2Shape(shape_inference::InferenceContext* c) {
1360   ShapeHandle input_shape, diag_index_shape, unused_shape;
1361   TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input_shape));
1362   TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(1), 1, &diag_index_shape));
1363   TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused_shape));
1364 
1365   const Tensor* diag_index_tensor = c->input_tensor(1);
1366   if (!c->RankKnown(input_shape) || !c->FullyDefined(diag_index_shape) ||
1367       diag_index_tensor == nullptr) {
1368     c->set_output(0, c->UnknownShape());
1369     return OkStatus();
1370   }
1371   int32_t lower_diag_index = 0;
1372   int32_t upper_diag_index = 0;
1373   TF_RETURN_IF_ERROR(ReadDiagIndex(c, diag_index_tensor, &lower_diag_index,
1374                                    &upper_diag_index));
1375   if (lower_diag_index > upper_diag_index) {
1376     return errors::InvalidArgument(
1377         "lower_diag_index is greater than upper_diag_index");
1378   }
1379 
1380   // Validates lower_diag_index and upper_diag_index.
1381   const int32_t input_rank = c->Rank(input_shape);
1382   const int32_t num_rows = c->Value(c->Dim(input_shape, input_rank - 2));
1383   const int32_t num_cols = c->Value(c->Dim(input_shape, input_rank - 1));
1384   int32_t max_diag_len = InferenceContext::kUnknownDim;
1385   if (num_rows != InferenceContext::kUnknownDim &&
1386       num_cols != InferenceContext::kUnknownDim) {
1387     if (lower_diag_index != 0 &&  // For when num_rows or num_cols == 0.
1388         (-num_rows >= lower_diag_index || lower_diag_index >= num_cols)) {
1389       return errors::InvalidArgument("lower_diag_index is out of bound.");
1390     }
1391     if (upper_diag_index != 0 &&  // For when num_rows or num_cols == 0.
1392         (-num_rows >= upper_diag_index || upper_diag_index >= num_cols)) {
1393       return errors::InvalidArgument("upper_diag_index is out of bound.");
1394     }
1395     max_diag_len = std::min(num_rows + std::min(upper_diag_index, 0),
1396                             num_cols - std::max(lower_diag_index, 0));
1397   }
1398 
1399   std::vector<DimensionHandle> dims;
1400   dims.reserve(input_rank - 2);
1401   for (int i = 0; i < input_rank - 2; ++i) {
1402     dims.push_back(c->Dim(input_shape, i));
1403   }
1404   if (lower_diag_index < upper_diag_index) {
1405     dims.push_back(c->MakeDim(upper_diag_index - lower_diag_index + 1));
1406   }
1407   dims.push_back(c->MakeDim(max_diag_len));
1408   c->set_output(0, c->MakeShape(dims));
1409   return OkStatus();
1410 }
1411 
MatrixDiagV2Shape(shape_inference::InferenceContext * c)1412 Status MatrixDiagV2Shape(shape_inference::InferenceContext* c) {
1413   // Checks input ranks.
1414   ShapeHandle input_shape, diag_index_shape, unused_shape;
1415   TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &input_shape));
1416   TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(1), 1, &diag_index_shape));
1417   TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused_shape));
1418   TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused_shape));
1419   TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused_shape));
1420 
1421   // Reads the diagonal indices.
1422   const Tensor* diag_index_tensor = c->input_tensor(1);
1423   if (!c->RankKnown(input_shape) || !c->FullyDefined(diag_index_shape) ||
1424       diag_index_tensor == nullptr) {
1425     c->set_output(0, c->UnknownShape());
1426     return OkStatus();
1427   }
1428   int32_t lower_diag_index = 0;
1429   int32_t upper_diag_index = 0;
1430   TF_RETURN_IF_ERROR(ReadDiagIndex(c, diag_index_tensor, &lower_diag_index,
1431                                    &upper_diag_index));
1432   if (lower_diag_index > upper_diag_index) {
1433     return errors::InvalidArgument(
1434         "lower_diag_index is greater than upper_diag_index");
1435   }
1436 
1437   // Checks if the number of diagonals provided matches what we imply from
1438   // lower_diag_index and upper_diag_index.
1439   const int32_t input_rank = c->Rank(input_shape);
1440   if (lower_diag_index < upper_diag_index) {
1441     const int32_t num_diags = c->Value(c->Dim(input_shape, input_rank - 2));
1442     const int32_t other_dim = c->Value(c->Dim(input_shape, input_rank - 1));
1443 
1444     if (num_diags != (upper_diag_index - lower_diag_index + 1)) {
1445       return errors::InvalidArgument(
1446           "The number of rows of `diagonal` doesn't match the number of "
1447           "diagonals implied from `d_lower` and `d_upper`.\n",
1448           "num_diags = ", num_diags, ", d_lower = ", lower_diag_index,
1449           ", d_upper = ", upper_diag_index, " ", input_rank, " ", other_dim);
1450     }
1451   }
1452 
1453   // Reads num_rows and num_cols.
1454   const Tensor* num_rows_tensor = c->input_tensor(2);
1455   const Tensor* num_cols_tensor = c->input_tensor(3);
1456   int64_t num_rows = -1;
1457   int64_t num_cols = -1;
1458   if (num_rows_tensor != nullptr) {
1459     TF_RETURN_IF_ERROR(c->GetScalarFromTensor(num_rows_tensor, &num_rows));
1460   }
1461   if (num_cols_tensor != nullptr) {
1462     TF_RETURN_IF_ERROR(c->GetScalarFromTensor(num_cols_tensor, &num_cols));
1463   }
1464 
1465   // Infers the missing num_rows or num_cols: If both are missing, assume
1466   // output is square. Otherwise, use the smallest possible value. Also
1467   // validates the provided values.
1468   const int32_t max_diag_len = c->Value(c->Dim(input_shape, input_rank - 1));
1469   const int32_t min_num_rows = max_diag_len - std::min(upper_diag_index, 0);
1470   const int32_t min_num_cols = max_diag_len + std::max(lower_diag_index, 0);
1471   if (num_rows == -1 && num_cols == -1) {  // Special case.
1472     num_rows = std::max(min_num_rows, min_num_cols);
1473     num_cols = num_rows;
1474   }
1475   if (num_rows == -1) {
1476     num_rows = min_num_rows;
1477   } else if (num_rows < min_num_rows) {
1478     return errors::InvalidArgument("num_rows is too small");
1479   }
1480   if (num_cols == -1) {
1481     num_cols = min_num_cols;
1482   } else if (num_cols < min_num_cols) {
1483     return errors::InvalidArgument("num_cols is too small.");
1484   }
1485   // At least one of them must match the minimum length.
1486   if (num_rows != min_num_rows && num_cols != min_num_cols) {
1487     return errors::InvalidArgument(
1488         "num_rows and num_cols are not consistent with lower_diag_index, "
1489         "upper_diag_index, and the length of the given diagonals.\n",
1490         "num_rows = ", num_rows, " != min_num_rows = ", min_num_rows,
1491         ", num_cols = ", num_cols, " != min_num_cols = ", min_num_cols);
1492   }
1493 
1494   // Sets output shape.
1495   ShapeHandle output_shape;
1496   const DimensionHandle output_row_dim = c->MakeDim(num_rows);
1497   const DimensionHandle output_col_dim = c->MakeDim(num_cols);
1498   if (lower_diag_index == upper_diag_index) {
1499     TF_RETURN_IF_ERROR(c->ReplaceDim(input_shape, input_rank - 1,
1500                                      output_row_dim, &output_shape));
1501     TF_RETURN_IF_ERROR(
1502         c->Concatenate(output_shape, c->Vector(output_col_dim), &output_shape));
1503   } else {
1504     TF_RETURN_IF_ERROR(c->ReplaceDim(input_shape, input_rank - 2,
1505                                      output_row_dim, &output_shape));
1506     TF_RETURN_IF_ERROR(c->ReplaceDim(output_shape, input_rank - 1,
1507                                      output_col_dim, &output_shape));
1508   }
1509   c->set_output(0, output_shape);
1510   return OkStatus();
1511 }
1512 
MatrixSetDiagV2Shape(shape_inference::InferenceContext * c)1513 Status MatrixSetDiagV2Shape(shape_inference::InferenceContext* c) {
1514   ShapeHandle input_shape, diag_shape, diag_index_shape;
1515   TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input_shape));
1516   TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &diag_shape));
1517   TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(2), 1, &diag_index_shape));
1518 
1519   int32_t lower_diag_index = 0;
1520   int32_t upper_diag_index = 0;
1521   bool diag_index_known = false;
1522   const Tensor* diag_index_tensor = c->input_tensor(2);
1523   if (diag_index_tensor != nullptr && c->FullyDefined(diag_index_shape)) {
1524     diag_index_known = true;
1525     TF_RETURN_IF_ERROR(ReadDiagIndex(c, diag_index_tensor, &lower_diag_index,
1526                                      &upper_diag_index));
1527     if (lower_diag_index > upper_diag_index) {
1528       return errors::InvalidArgument(
1529           "lower_diag_index is greater than upper_diag_index");
1530     }
1531   }
1532 
1533   // Do more checks when input rank is known.
1534   if (c->RankKnown(input_shape)) {
1535     int32_t input_rank = c->Rank(input_shape);
1536 
1537     // If diag_index is set, we know the exact rank of diagonal.
1538     if (diag_index_known) {
1539       TF_RETURN_IF_ERROR(c->WithRank(
1540           c->input(1),
1541           (lower_diag_index == upper_diag_index) ? input_rank - 1 : input_rank,
1542           &diag_shape));
1543     } else {
1544       TF_RETURN_IF_ERROR(
1545           c->WithRankAtLeast(c->input(1), input_rank - 1, &diag_shape));
1546       TF_RETURN_IF_ERROR(
1547           c->WithRankAtMost(c->input(1), input_rank, &diag_shape));
1548     }
1549 
1550     // Validates lower_diag_index and upper_diag_index.
1551     const int32_t num_rows = c->Value(c->Dim(input_shape, input_rank - 2));
1552     const int32_t num_cols = c->Value(c->Dim(input_shape, input_rank - 1));
1553     if (num_rows != InferenceContext::kUnknownDim &&
1554         num_cols != InferenceContext::kUnknownDim) {
1555       if (lower_diag_index != 0 &&  // For when num_rows or num_cols == 0.
1556           (-num_rows >= lower_diag_index || lower_diag_index >= num_cols)) {
1557         return errors::InvalidArgument("lower_diag_index is out of bound.");
1558       }
1559       if (upper_diag_index != 0 &&  // For when num_rows or num_cols == 0.
1560           (-num_rows >= upper_diag_index || upper_diag_index >= num_cols)) {
1561         return errors::InvalidArgument("upper_diag_index is out of bound.");
1562       }
1563     }
1564   }
1565 
1566   ShapeHandle output_shape = input_shape;
1567   if (c->RankKnown(diag_shape) && !c->FullyDefined(input_shape)) {
1568     // Try to infer parts of shape from diag.
1569     ShapeHandle diag_prefix;
1570     TF_RETURN_IF_ERROR(c->Subshape(
1571         diag_shape, 0, (lower_diag_index == upper_diag_index) ? -1 : -2,
1572         &diag_prefix));
1573 
1574     // The inner matrices can be rectangular, so we can't pinpoint their
1575     // exact height and width by just lower_diag_index, upper_diag_index,
1576     // and the longest length of given diagonals.
1577     TF_RETURN_IF_ERROR(
1578         c->Concatenate(diag_prefix, c->UnknownShapeOfRank(2), &diag_shape));
1579     TF_RETURN_IF_ERROR(c->Merge(input_shape, diag_shape, &output_shape));
1580   }
1581   c->set_output(0, output_shape);
1582   return OkStatus();
1583 }
1584 
MaxPoolShapeImpl(shape_inference::InferenceContext * c,bool supports_explicit_padding)1585 Status MaxPoolShapeImpl(shape_inference::InferenceContext* c,
1586                         bool supports_explicit_padding) {
1587   string data_format_str;
1588   TensorFormat data_format;
1589   Status s = c->GetAttr("data_format", &data_format_str);
1590   if (s.ok()) {
1591     FormatFromString(data_format_str, &data_format);
1592   } else {
1593     data_format = FORMAT_NHWC;
1594   }
1595 
1596   const int rank = (data_format == FORMAT_NCHW_VECT_C) ? 5 : 4;
1597   ShapeHandle input_shape;
1598   TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &input_shape));
1599 
1600   TF_RETURN_IF_ERROR(
1601       CheckFormatConstraintsOnShape(data_format, input_shape, "input", c));
1602 
1603   std::vector<int32> strides;
1604   TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
1605   if (strides.size() != 4) {
1606     return errors::InvalidArgument(
1607         "MaxPool requires the stride attribute to contain 4 values, but got: ",
1608         strides.size());
1609   }
1610 
1611   std::vector<int32> kernel_sizes;
1612   TF_RETURN_IF_ERROR(c->GetAttr("ksize", &kernel_sizes));
1613   if (kernel_sizes.size() != 4) {
1614     return errors::InvalidArgument(
1615         "MaxPool requires the ksize attribute to contain 4 values, but got: ",
1616         kernel_sizes.size());
1617   }
1618 
1619   int32_t stride_depth = GetTensorDim(strides, data_format, 'C');
1620   int32_t stride_rows = GetTensorDim(strides, data_format, 'H');
1621   int32_t stride_cols = GetTensorDim(strides, data_format, 'W');
1622   int32_t kernel_depth = GetTensorDim(kernel_sizes, data_format, 'C');
1623   int32_t kernel_rows = GetTensorDim(kernel_sizes, data_format, 'H');
1624   int32_t kernel_cols = GetTensorDim(kernel_sizes, data_format, 'W');
1625 
1626   constexpr int num_spatial_dims = 2;
1627   DimensionHandle batch_size_dim = c->Dim(
1628       input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'N'));
1629   DimensionHandle in_rows_dim = c->Dim(
1630       input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'H'));
1631   DimensionHandle in_cols_dim = c->Dim(
1632       input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'W'));
1633   DimensionHandle in_depth_dim = c->Dim(
1634       input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'C'));
1635 
1636   Padding padding;
1637   TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
1638 
1639   std::vector<int64_t> explicit_paddings;
1640   if (supports_explicit_padding) {
1641     Status status = c->GetAttr("explicit_paddings", &explicit_paddings);
1642     // Use the default value, which is an empty list, if the attribute is not
1643     // found. Otherwise return the error to the caller.
1644     if (!status.ok() && !errors::IsNotFound(status)) {
1645       return status;
1646     }
1647     TF_RETURN_IF_ERROR(CheckValidPadding(padding, explicit_paddings,
1648                                          /*num_dims=*/4, data_format));
1649   } else {
1650     DCHECK(padding != Padding::EXPLICIT);
1651   }
1652 
1653   ShapeHandle output_shape;
1654   DimensionHandle output_rows, output_cols, output_depth;
1655   int64_t pad_rows_before = -1, pad_rows_after = -1;
1656   int64_t pad_cols_before = -1, pad_cols_after = -1;
1657   if (padding == Padding::EXPLICIT) {
1658     GetExplicitPaddingForDim(explicit_paddings, data_format, 'H',
1659                              &pad_rows_before, &pad_rows_after);
1660     GetExplicitPaddingForDim(explicit_paddings, data_format, 'W',
1661                              &pad_cols_before, &pad_cols_after);
1662   }
1663   TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDimsV2(
1664       c, in_rows_dim, kernel_rows, /*dilation_rate=*/1, stride_rows, padding,
1665       pad_rows_before, pad_rows_after, &output_rows));
1666   TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDimsV2(
1667       c, in_cols_dim, kernel_cols, /*dilation_rate=*/1, stride_cols, padding,
1668       pad_cols_before, pad_cols_after, &output_cols));
1669   TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDimsV2(
1670       c, in_depth_dim, kernel_depth, /*dilation_rate=*/1, stride_depth, padding,
1671       /*pad_before*/ 0, /*pad_after*/ 0, &output_depth));
1672 
1673   TF_RETURN_IF_ERROR(MakeShapeFromFormat(data_format, batch_size_dim,
1674                                          {output_rows, output_cols},
1675                                          output_depth, &output_shape, c));
1676 
1677   c->set_output(0, output_shape);
1678   return OkStatus();
1679 }
1680 
MaxPoolShape(shape_inference::InferenceContext * c)1681 Status MaxPoolShape(shape_inference::InferenceContext* c) {
1682   return MaxPoolShapeImpl(c, /*supports_explicit_padding=*/false);
1683 }
1684 
MaxPoolGradShape(shape_inference::InferenceContext * c)1685 Status MaxPoolGradShape(shape_inference::InferenceContext* c) {
1686   return UnchangedShapeWithRank(c, 4);
1687 }
1688 
MaxPoolShapeWithExplicitPadding(shape_inference::InferenceContext * c)1689 Status MaxPoolShapeWithExplicitPadding(shape_inference::InferenceContext* c) {
1690   return MaxPoolShapeImpl(c, /*supports_explicit_padding=*/true);
1691 }
1692 
MaxPoolV2Shape(shape_inference::InferenceContext * c,int num_inputs)1693 Status MaxPoolV2Shape(shape_inference::InferenceContext* c, int num_inputs) {
1694   string data_format_str;
1695   TensorFormat data_format;
1696   Status s = c->GetAttr("data_format", &data_format_str);
1697   if (s.ok()) {
1698     FormatFromString(data_format_str, &data_format);
1699   } else {
1700     data_format = FORMAT_NHWC;
1701   }
1702 
1703   const int rank = (data_format == FORMAT_NCHW_VECT_C) ? 5 : 4;
1704   ShapeHandle input_shape;
1705   TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &input_shape));
1706 
1707   TF_RETURN_IF_ERROR(
1708       CheckFormatConstraintsOnShape(data_format, input_shape, "input", c));
1709 
1710   std::vector<int32> kernel_sizes;
1711   std::vector<int32> strides;
1712 
1713   if (c->num_inputs() + 2 == num_inputs) {
1714     TF_RETURN_IF_ERROR(c->GetAttr("ksize", &kernel_sizes));
1715 
1716     TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
1717   } else {
1718     // Verify shape of ksize and strides input.
1719     ShapeHandle size;
1720     DimensionHandle unused;
1721     TF_RETURN_IF_ERROR(c->WithRank(c->input(c->num_inputs() - 2), 1, &size));
1722     TF_RETURN_IF_ERROR(c->WithValue(c->Dim(size, 0), 4, &unused));
1723     TF_RETURN_IF_ERROR(c->WithRank(c->input(c->num_inputs() - 1), 1, &size));
1724     TF_RETURN_IF_ERROR(c->WithValue(c->Dim(size, 0), 4, &unused));
1725 
1726     const Tensor* kernel_sizes_tensor = c->input_tensor(c->num_inputs() - 2);
1727     if (kernel_sizes_tensor == nullptr) {
1728       c->set_output(0, c->UnknownShape());
1729       return OkStatus();
1730     }
1731     kernel_sizes.resize(kernel_sizes_tensor->shape().num_elements());
1732     auto kernel_sizes_vec = kernel_sizes_tensor->flat<int32>();
1733     std::copy_n(&kernel_sizes_vec(0), kernel_sizes.size(),
1734                 kernel_sizes.begin());
1735 
1736     const Tensor* strides_tensor = c->input_tensor(c->num_inputs() - 1);
1737     if (strides_tensor == nullptr) {
1738       c->set_output(0, c->UnknownShape());
1739       return OkStatus();
1740     }
1741     strides.resize(strides_tensor->shape().num_elements());
1742     auto strides_vec = strides_tensor->flat<int32>();
1743     std::copy_n(&strides_vec(0), strides.size(), strides.begin());
1744   }
1745 
1746   if (strides.size() != 4) {
1747     return errors::InvalidArgument(
1748         "MaxPool requires the stride attribute to contain 4 values, but "
1749         "got: ",
1750         strides.size());
1751   }
1752   if (kernel_sizes.size() != 4) {
1753     return errors::InvalidArgument(
1754         "MaxPool requires the ksize attribute to contain 4 values, but got: ",
1755         kernel_sizes.size());
1756   }
1757 
1758   int32_t stride_depth = GetTensorDim(strides, data_format, 'C');
1759   int32_t stride_rows = GetTensorDim(strides, data_format, 'H');
1760   int32_t stride_cols = GetTensorDim(strides, data_format, 'W');
1761   int32_t kernel_depth = GetTensorDim(kernel_sizes, data_format, 'C');
1762   int32_t kernel_rows = GetTensorDim(kernel_sizes, data_format, 'H');
1763   int32_t kernel_cols = GetTensorDim(kernel_sizes, data_format, 'W');
1764 
1765   constexpr int num_spatial_dims = 2;
1766   DimensionHandle batch_size_dim = c->Dim(
1767       input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'N'));
1768   DimensionHandle in_rows_dim = c->Dim(
1769       input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'H'));
1770   DimensionHandle in_cols_dim = c->Dim(
1771       input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'W'));
1772   DimensionHandle in_depth_dim = c->Dim(
1773       input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'C'));
1774 
1775   Padding padding;
1776   TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
1777 
1778   ShapeHandle output_shape;
1779   DimensionHandle output_rows, output_cols, output_depth;
1780   TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
1781       c, in_rows_dim, kernel_rows, stride_rows, padding, &output_rows));
1782   TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
1783       c, in_cols_dim, kernel_cols, stride_cols, padding, &output_cols));
1784   TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
1785       c, in_depth_dim, kernel_depth, stride_depth, padding, &output_depth));
1786 
1787   TF_RETURN_IF_ERROR(MakeShapeFromFormat(data_format, batch_size_dim,
1788                                          {output_rows, output_cols},
1789                                          output_depth, &output_shape, c));
1790 
1791   c->set_output(0, output_shape);
1792   return OkStatus();
1793 }
1794 
Pool3DShape(shape_inference::InferenceContext * c)1795 Status Pool3DShape(shape_inference::InferenceContext* c) {
1796   ShapeHandle input_shape;
1797   TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 5, &input_shape));
1798 
1799   string data_format;
1800   Status s = c->GetAttr("data_format", &data_format);
1801 
1802   std::vector<int32> strides;
1803   TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
1804   if (strides.size() != 5) {
1805     return errors::InvalidArgument(
1806         "Pool3D ops require the stride attribute to contain 5 values, but "
1807         "got: ",
1808         strides.size());
1809   }
1810 
1811   std::vector<int32> kernel_sizes;
1812   TF_RETURN_IF_ERROR(c->GetAttr("ksize", &kernel_sizes));
1813   if (kernel_sizes.size() != 5) {
1814     return errors::InvalidArgument(
1815         "Pool3D requires the ksize attribute to contain 5 values, but got: ",
1816         kernel_sizes.size());
1817   }
1818 
1819   int32_t stride_planes, stride_rows, stride_cols;
1820   int32_t kernel_planes, kernel_rows, kernel_cols;
1821 
1822   if (s.ok() && data_format == "NCDHW") {
1823     // Convert input_shape to NDHWC.
1824     auto dim = [&](char dimension) {
1825       return c->Dim(input_shape, GetTensorDimIndex<3>(FORMAT_NCHW, dimension));
1826     };
1827     input_shape =
1828         c->MakeShape({{dim('N'), dim('0'), dim('1'), dim('2'), dim('C')}});
1829     stride_planes = strides[2];
1830     stride_rows = strides[3];
1831     stride_cols = strides[4];
1832     kernel_planes = kernel_sizes[2];
1833     kernel_rows = kernel_sizes[3];
1834     kernel_cols = kernel_sizes[4];
1835   } else {
1836     stride_planes = strides[1];
1837     stride_rows = strides[2];
1838     stride_cols = strides[3];
1839     kernel_planes = kernel_sizes[1];
1840     kernel_rows = kernel_sizes[2];
1841     kernel_cols = kernel_sizes[3];
1842   }
1843 
1844   DimensionHandle batch_size_dim = c->Dim(input_shape, 0);
1845   DimensionHandle in_planes_dim = c->Dim(input_shape, 1);
1846   DimensionHandle in_rows_dim = c->Dim(input_shape, 2);
1847   DimensionHandle in_cols_dim = c->Dim(input_shape, 3);
1848   DimensionHandle output_depth_dim = c->Dim(input_shape, 4);
1849 
1850   Padding padding;
1851   TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
1852 
1853   // TODO(mrry,shlens): Raise an error if the stride would cause
1854   // information in the input to be ignored. This will require a change
1855   // in the kernel implementation.
1856   DimensionHandle output_planes, output_rows, output_cols;
1857   TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
1858       c, in_planes_dim, kernel_planes, stride_planes, padding, &output_planes));
1859   TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
1860       c, in_rows_dim, kernel_rows, stride_rows, padding, &output_rows));
1861   TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
1862       c, in_cols_dim, kernel_cols, stride_cols, padding, &output_cols));
1863 
1864   ShapeHandle output_shape;
1865   if (data_format == "NCDHW") {
1866     output_shape = c->MakeShape({batch_size_dim, output_depth_dim,
1867                                  output_planes, output_rows, output_cols});
1868   } else {
1869     output_shape = c->MakeShape({batch_size_dim, output_planes, output_rows,
1870                                  output_cols, output_depth_dim});
1871   }
1872 
1873   c->set_output(0, output_shape);
1874   return OkStatus();
1875 }
1876 
MaxPool3DGradShape(shape_inference::InferenceContext * c)1877 Status MaxPool3DGradShape(shape_inference::InferenceContext* c) {
1878   return UnchangedShapeWithRank(c, 5);
1879 }
1880 
AvgPool3DGradShape(shape_inference::InferenceContext * c)1881 Status AvgPool3DGradShape(shape_inference::InferenceContext* c) {
1882   ShapeHandle s;
1883   TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s));
1884   TF_RETURN_IF_ERROR(c->WithRank(s, 5, &s));
1885   c->set_output(0, s);
1886   return OkStatus();
1887 }
1888 
UnknownShape(shape_inference::InferenceContext * c)1889 Status UnknownShape(shape_inference::InferenceContext* c) {
1890   for (int i = 0; i < c->num_outputs(); ++i) {
1891     c->set_output(i, c->UnknownShape());
1892   }
1893   return OkStatus();
1894 }
1895 
1896 template <typename T>
ReductionShapeHelper(const Tensor * reduction_indices_t,const int32_t input_rank,std::set<int64_t> * true_indices)1897 Status ReductionShapeHelper(const Tensor* reduction_indices_t,
1898                             const int32_t input_rank,
1899                             std::set<int64_t>* true_indices) {
1900   auto reduction_indices = reduction_indices_t->flat<T>();
1901   for (int i = 0; i < reduction_indices_t->NumElements(); ++i) {
1902     const T reduction_index = reduction_indices(i);
1903     if (reduction_index < -input_rank || reduction_index >= input_rank) {
1904       return errors::InvalidArgument("Invalid reduction dimension ",
1905                                      reduction_index, " for input with ",
1906                                      input_rank, " dimensions.");
1907     }
1908 
1909     auto wrapped_index = reduction_index;
1910     if (wrapped_index < 0) {
1911       wrapped_index += input_rank;
1912     }
1913 
1914     true_indices->insert(wrapped_index);
1915   }
1916   return OkStatus();
1917 }
1918 
ReductionShape(InferenceContext * c)1919 Status ReductionShape(InferenceContext* c) {
1920   ShapeHandle input = c->input(0);
1921 
1922   ShapeHandle indices;
1923   // Older versions of TensorFlow accidentally allowed higher rank tensors like
1924   // [[1,2]] or [[1],[2]] to represent axis=[1,2].
1925   if (c->graph_def_version() < 21) {
1926     indices = c->input(1);
1927   } else {
1928     TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(1), 1, &indices));
1929   }
1930 
1931   bool keep_dims;
1932   TF_RETURN_IF_ERROR(c->GetAttr("keep_dims", &keep_dims));
1933 
1934   const Tensor* reduction_indices_t = c->input_tensor(1);
1935   if (reduction_indices_t == nullptr || !c->RankKnown(input)) {
1936     // If we do not have the reduction values at runtime, or the
1937     // rank of the input, we don't know the output shape.
1938 
1939     if (keep_dims && c->RankKnown(input)) {
1940       // output rank matches input input if <keep_dims>.
1941       c->set_output(0, c->UnknownShapeOfRank(c->Rank(input)));
1942       return OkStatus();
1943     } else {
1944       return shape_inference::UnknownShape(c);
1945     }
1946   }
1947 
1948   const int32_t input_rank = c->Rank(input);
1949   std::set<int64_t> true_indices;
1950   if (reduction_indices_t->dtype() == DataType::DT_INT32) {
1951     TF_RETURN_IF_ERROR(ReductionShapeHelper<int32>(reduction_indices_t,
1952                                                    input_rank, &true_indices));
1953   } else if (reduction_indices_t->dtype() == DataType::DT_INT64) {
1954     TF_RETURN_IF_ERROR(ReductionShapeHelper<int64_t>(
1955         reduction_indices_t, input_rank, &true_indices));
1956   } else {
1957     return errors::InvalidArgument(
1958         "reduction_indices can only be int32 or int64");
1959   }
1960 
1961   std::vector<DimensionHandle> dims;
1962   for (int i = 0; i < input_rank; ++i) {
1963     if (true_indices.count(i) > 0) {
1964       if (keep_dims) {
1965         dims.emplace_back(c->MakeDim(1));
1966       }
1967     } else {
1968       dims.emplace_back(c->Dim(input, i));
1969     }
1970   }
1971 
1972   c->set_output(0, c->MakeShape(dims));
1973   return OkStatus();
1974 }
1975 
ConcatShapeHelper(InferenceContext * c,int start_value_index,int end_value_index,int dim_index)1976 Status ConcatShapeHelper(InferenceContext* c, int start_value_index,
1977                          int end_value_index, int dim_index) {
1978   ShapeHandle unused;
1979   TF_RETURN_IF_ERROR(c->WithRank(c->input(dim_index), 0, &unused));
1980   const Tensor* concat_dim_t = c->input_tensor(dim_index);
1981   if (concat_dim_t == nullptr) {
1982     // Return an unknown shape with same rank as inputs, or an unknown rank
1983     // if no input's rank is known.
1984 
1985     // Find rank.
1986     int32_t rank = InferenceContext::kUnknownRank;
1987     for (int i = start_value_index; i < end_value_index; ++i) {
1988       if (rank == InferenceContext::kUnknownRank) rank = c->Rank(c->input(i));
1989       if (rank != InferenceContext::kUnknownRank) {
1990         break;
1991       }
1992     }
1993     if (rank == InferenceContext::kUnknownRank) {
1994       c->set_output(0, c->UnknownShape());
1995       return OkStatus();
1996     } else if (rank == 0) {
1997       return errors::InvalidArgument(
1998           "Can't concatenate scalars (use tf.stack instead)");
1999     } else {
2000       for (int i = start_value_index; i < end_value_index; ++i) {
2001         // Check that all the inputs are of the correct rank.
2002         TF_RETURN_IF_ERROR(c->WithRank(c->input(i), rank, &unused));
2003       }
2004     }
2005     // Build result of <rank> different unknown dims.
2006     std::vector<DimensionHandle> dims;
2007     dims.reserve(rank);
2008     for (int i = 0; i < rank; ++i) dims.push_back(c->UnknownDim());
2009     c->set_output(0, c->MakeShape(dims));
2010     return OkStatus();
2011   }
2012 
2013   // Merge all the non-concat dims, and sum the concat dim to make an output
2014   // shape.
2015   int64_t concat_dim;
2016   if (concat_dim_t->dtype() == DT_INT32) {
2017     concat_dim = static_cast<int64_t>(concat_dim_t->flat<int32>()(0));
2018   } else {
2019     concat_dim = concat_dim_t->flat<int64_t>()(0);
2020   }
2021 
2022   // Minimum required number of dimensions.
2023   const int64 min_rank = concat_dim < 0 ? -concat_dim : concat_dim + 1;
2024 
2025   ShapeHandle output_before;
2026   ShapeHandle output_after;
2027 
2028   ShapeHandle input = c->input(end_value_index - 1);
2029   TF_RETURN_IF_ERROR(c->WithRankAtLeast(input, min_rank, &input));
2030   TF_RETURN_IF_ERROR(c->Subshape(input, 0, concat_dim, &output_before));
2031   DimensionHandle output_middle = c->Dim(input, concat_dim);
2032   if (concat_dim == -1) {
2033     output_after = c->Scalar();  // no dimensions.
2034   } else {
2035     TF_RETURN_IF_ERROR(c->Subshape(input, concat_dim + 1, &output_after));
2036   }
2037 
2038   for (int i = end_value_index - 2; i >= start_value_index; --i) {
2039     ShapeHandle before;
2040     ShapeHandle after;
2041     input = c->input(i);
2042     TF_RETURN_IF_ERROR(c->WithRankAtLeast(input, min_rank, &input));
2043     TF_RETURN_IF_ERROR(c->Subshape(input, 0, concat_dim, &before));
2044     DimensionHandle middle = c->Dim(input, concat_dim);
2045     if (concat_dim == -1) {
2046       after = c->Scalar();
2047     } else {
2048       TF_RETURN_IF_ERROR(c->Subshape(input, concat_dim + 1, &after));
2049     }
2050 
2051     TF_RETURN_IF_ERROR(c->Merge(before, output_before, &output_before));
2052     TF_RETURN_IF_ERROR(c->Add(output_middle, middle, &output_middle));
2053     TF_RETURN_IF_ERROR(c->Merge(after, output_after, &output_after));
2054   }
2055 
2056   ShapeHandle s;
2057   TF_RETURN_IF_ERROR(
2058       c->Concatenate(output_before, c->Vector(output_middle), &s));
2059   TF_RETURN_IF_ERROR(c->Concatenate(s, output_after, &s));
2060   c->set_output(0, s);
2061   return OkStatus();
2062 }
2063 
ConcatShape(InferenceContext * c,int num_inputs_to_concat)2064 Status ConcatShape(InferenceContext* c, int num_inputs_to_concat) {
2065   return ConcatShapeHelper(c, 1 /* start_value_index */,
2066                            1 + num_inputs_to_concat /* end_value_index */,
2067                            0 /* dim_index */);
2068 }
2069 
ConcatV2Shape(InferenceContext * c)2070 Status ConcatV2Shape(InferenceContext* c) {
2071   return ConcatShapeHelper(c, 0 /* start_value_index */,
2072                            c->num_inputs() - 1 /* end_value_index */,
2073                            c->num_inputs() - 1 /* dim_index */);
2074 }
2075 
QuantizedConcatV2Shape(InferenceContext * c,int num_inputs_to_concat)2076 Status QuantizedConcatV2Shape(InferenceContext* c, int num_inputs_to_concat) {
2077   return ConcatShapeHelper(c, 0 /* start_value_index */,
2078                            num_inputs_to_concat /* end_value_index */,
2079                            num_inputs_to_concat /* dim_index */);
2080 }
2081 
BroadcastBinaryOpOutputShapeFnHelper(InferenceContext * c,ShapeHandle shape_x,ShapeHandle shape_y,bool incompatible_shape_error,ShapeHandle * out)2082 Status BroadcastBinaryOpOutputShapeFnHelper(InferenceContext* c,
2083                                             ShapeHandle shape_x,
2084                                             ShapeHandle shape_y,
2085                                             bool incompatible_shape_error,
2086                                             ShapeHandle* out) {
2087   CHECK_NOTNULL(out);
2088   if (!c->RankKnown(shape_x) || !c->RankKnown(shape_y)) {
2089     *out = c->UnknownShape();
2090     return OkStatus();
2091   }
2092   const int32_t rank_x = c->Rank(shape_x);
2093   const int32_t rank_y = c->Rank(shape_y);
2094   const int32_t rank_out = std::max(rank_x, rank_y);
2095 
2096   // To compute the broadcast dimensions, we zip together shape_x and shape_y
2097   // and
2098   // pad with 1 to make them the same length.
2099   std::vector<DimensionHandle> dims;
2100   DimensionHandle dim_one;
2101   if (rank_x != rank_y) dim_one = c->MakeDim(1);
2102   for (int i = 0; i < rank_out; ++i) {
2103     const auto dim_x = i < (rank_out - rank_x)
2104                            ? dim_one
2105                            : c->Dim(shape_x, i - (rank_out - rank_x));
2106     const bool dim_y_is_one = (i < (rank_out - rank_y));
2107     const auto dim_y =
2108         dim_y_is_one ? dim_one : c->Dim(shape_y, i - (rank_out - rank_y));
2109     if (!c->ValueKnown(dim_x) || !c->ValueKnown(dim_y)) {
2110       // One or both dimensions is unknown.
2111       //
2112       // - If either dimension is greater than 1, we assume that the program is
2113       // correct, and the other dimension will be broadcast to match it.
2114       // TODO(cwhipkey): For shape inference, if we eliminate the shape checks
2115       // in C++ op code, we must still assert that the unknown dim is either 1
2116       // or the same as the known dim.
2117       // - If either dimension is 1, the other dimension is the output.
2118       // - If both are unknown then dimension is unknown
2119       if (c->Value(dim_x) > 1) {
2120         if (!incompatible_shape_error) {
2121           *out = c->UnknownShape();
2122           return OkStatus();
2123         }
2124         dims.push_back(dim_x);
2125       } else if (c->Value(dim_y) > 1) {
2126         if (!incompatible_shape_error) {
2127           *out = c->UnknownShape();
2128           return OkStatus();
2129         }
2130         dims.push_back(dim_y);
2131       } else if (c->Value(dim_x) == 1) {
2132         dims.push_back(dim_y);
2133       } else if (c->Value(dim_y) == 1) {
2134         dims.push_back(dim_x);
2135       } else if (dim_y.SameHandle(dim_x)) {
2136         dims.push_back(dim_x);
2137       } else if (!c->ValueKnown(dim_x) && !c->ValueKnown(dim_y)) {
2138         dims.push_back(c->UnknownDim());
2139       } else {
2140         if (!incompatible_shape_error) {
2141           *out = c->UnknownShape();
2142           return OkStatus();
2143         }
2144         dims.push_back(c->UnknownDim());
2145       }
2146     } else if (c->Value(dim_x) == 1 || c->Value(dim_y) == 1) {
2147       if (c->Value(dim_x) == 1 && !dim_y_is_one) {
2148         // We will broadcast dim_x to dim_y.
2149         dims.push_back(dim_y);
2150       } else {
2151         DCHECK_EQ(c->Value(dim_y), 1);
2152         // We will broadcast dim_y to dim_x.
2153         dims.push_back(dim_x);
2154       }
2155     } else {
2156       DimensionHandle dim;
2157       Status s = c->Merge(dim_x, dim_y, &dim);
2158       if (!s.ok()) {
2159         if (!incompatible_shape_error) {
2160           *out = c->MakeShape({});
2161           return OkStatus();
2162         }
2163         return s;
2164       }
2165       dims.push_back(dim);
2166     }
2167   }
2168 
2169   *out = c->MakeShape(dims);
2170   return OkStatus();
2171 }
2172 
RandomShape(shape_inference::InferenceContext * c)2173 Status RandomShape(shape_inference::InferenceContext* c) {
2174   shape_inference::ShapeHandle out;
2175   TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out));
2176   c->set_output(0, out);
2177   return OkStatus();
2178 }
2179 
UnsortedSegmentReductionShapeFn(InferenceContext * c)2180 Status UnsortedSegmentReductionShapeFn(InferenceContext* c) {
2181   ShapeHandle s_data = c->input(0);
2182   ShapeHandle s_segment_ids = c->input(1);
2183   ShapeHandle s_num_segments = c->input(2);
2184   TF_RETURN_IF_ERROR(c->WithRank(s_num_segments, 0, &s_num_segments));
2185 
2186   ShapeHandle out;
2187 
2188   // Leading dimensions of data must be compatible with dimensions of
2189   // <s_segment_ids>.
2190   if (c->RankKnown(s_segment_ids)) {
2191     TF_RETURN_IF_ERROR(
2192         c->MergePrefix(s_data, s_segment_ids, &s_data, &s_segment_ids));
2193 
2194     // Get the value of the num_segments input tensor.
2195     DimensionHandle num_segments_dim;
2196     TF_RETURN_IF_ERROR(c->MakeDimForScalarInput(2, &num_segments_dim));
2197 
2198     // Output is {segment_id_rank} + s_data[segment_id_rank:].
2199     ShapeHandle s_data_suffix;
2200     TF_RETURN_IF_ERROR(
2201         c->Subshape(s_data, c->Rank(s_segment_ids), &s_data_suffix));
2202     TF_RETURN_IF_ERROR(
2203         c->Concatenate(c->Vector(num_segments_dim), s_data_suffix, &out));
2204   } else {
2205     out = c->UnknownShape();
2206   }
2207   c->set_output(0, out);
2208   return OkStatus();
2209 }
2210 
2211 namespace {
2212 
2213 // This SliceHelper processes the output shape of the `slice`
2214 // when the tensor of `sizes` is available.
2215 template <typename T>
SliceHelper(InferenceContext * c,ShapeHandle begin_value,const Tensor * sizes_value,std::vector<DimensionHandle> * dims)2216 Status SliceHelper(InferenceContext* c, ShapeHandle begin_value,
2217                    const Tensor* sizes_value,
2218                    std::vector<DimensionHandle>* dims) {
2219   auto sizes_vec = sizes_value->vec<T>();
2220   for (int i = 0; i < sizes_value->NumElements(); ++i) {
2221     DimensionHandle dim = c->Dim(c->input(0), i);
2222     if (sizes_vec(i) != -1) {
2223       auto dim_val = c->Value(dim);
2224       if (sizes_vec(i) < 0) {
2225         return errors::InvalidArgument(
2226             "Out of bounds slicing on dimension ", i, " of length ", dim_val,
2227             ": sizes vector cannot be < -1, but was ", sizes_vec(i));
2228       }
2229 
2230       dims->emplace_back(c->MakeDim(sizes_vec(i)));
2231     } else {
2232       DimensionHandle result;
2233       TF_RETURN_IF_ERROR(c->Subtract(dim, c->Dim(begin_value, i), &result));
2234       dims->emplace_back(result);
2235     }
2236   }
2237 
2238   return OkStatus();
2239 }
2240 }  // namespace
2241 
SliceShape(InferenceContext * c)2242 Status SliceShape(InferenceContext* c) {
2243   ShapeHandle input = c->input(0);
2244   ShapeHandle begin_shape;
2245   TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &begin_shape));
2246   ShapeHandle sizes_shape;
2247   TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &sizes_shape));
2248 
2249   // Merge to check compatibility of begin and sizes tensors.
2250   TF_RETURN_IF_ERROR(c->Merge(begin_shape, sizes_shape, &begin_shape));
2251 
2252   DimensionHandle ndims = c->Dim(begin_shape, 0);
2253   if (c->ValueKnown(ndims)) {
2254     TF_RETURN_IF_ERROR(c->WithRank(input, c->Value(ndims), &input));
2255   }
2256 
2257   // NOTE(mrry): Use MakeShapeFromShapeTensor to handle partially-known
2258   // values, even though the `begin` value does not represent a shape.
2259   ShapeHandle begin_value;
2260   TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &begin_value));
2261 
2262   // We check the tensor value here and will only use
2263   // `MakeShapeFromShapeTensor` when `sizes_value` is null.
2264   // The reason is that `sizes` might contain -1, which can't
2265   // be represented (-1 in the ShapeHandle would mean "unknown").
2266   const Tensor* sizes_value = c->input_tensor(2);
2267 
2268   if (sizes_value != nullptr) {
2269     TF_RETURN_IF_ERROR(
2270         c->WithRank(begin_value, sizes_value->NumElements(), &begin_value));
2271     std::vector<DimensionHandle> dims;
2272     // If the begin and sizes tensors are available, then
2273     // we can be precise about the shape of the output.
2274     if (sizes_value->dtype() == DT_INT64) {
2275       TF_RETURN_IF_ERROR(
2276           SliceHelper<int64_t>(c, begin_value, sizes_value, &dims));
2277     } else {
2278       TF_RETURN_IF_ERROR(
2279           SliceHelper<int32>(c, begin_value, sizes_value, &dims));
2280     }
2281     c->set_output(0, c->MakeShape(dims));
2282     return OkStatus();
2283   } else {
2284     // In case `sizes` is not available (`sizes_value` is null),
2285     // we could try to use `MakeShapeFromShapeTensor` here.
2286     // If sizes contain -1, we will simply consider it as `Unknown`.
2287     // This is less than ideal but still an improvement of shape inference.
2288     // The following is an example that returns [None, 1, None] with this
2289     // code path:
2290     //   z = tf.zeros((1, 2, 3))
2291     //   m = tf.slice(z, [0, 0, 0], [tf.constant(1) + 0, 1, -1])
2292     //   m.get_shape().as_list()
2293     ShapeHandle sizes_value;
2294     TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(2, &sizes_value));
2295     if (c->RankKnown(sizes_value)) {
2296       TF_RETURN_IF_ERROR(
2297           c->WithRank(begin_value, c->Rank(sizes_value), &begin_value));
2298       std::vector<DimensionHandle> dims;
2299       dims.reserve(c->Rank(sizes_value));
2300       for (int i = 0; i < c->Rank(sizes_value); ++i) {
2301         dims.emplace_back(c->Dim(sizes_value, i));
2302       }
2303       c->set_output(0, c->MakeShape(dims));
2304       return OkStatus();
2305     }
2306     // We might know the rank of the input.
2307     if (c->RankKnown(input)) {
2308       c->set_output(0, c->UnknownShapeOfRank(c->Rank(input)));
2309       return OkStatus();
2310     } else {
2311       return shape_inference::UnknownShape(c);
2312     }
2313   }
2314 
2315   return OkStatus();
2316 }
2317 
ValidateSparseTensor(InferenceContext * c,ShapeHandle indices_shape,ShapeHandle values_shape,ShapeHandle shape_shape)2318 Status ValidateSparseTensor(InferenceContext* c, ShapeHandle indices_shape,
2319                             ShapeHandle values_shape, ShapeHandle shape_shape) {
2320   // Validate ranks.
2321   ShapeHandle unused_shape;
2322   TF_RETURN_IF_ERROR(c->WithRank(indices_shape, 2, &unused_shape));
2323   TF_RETURN_IF_ERROR(c->WithRank(values_shape, 1, &unused_shape));
2324   TF_RETURN_IF_ERROR(c->WithRank(shape_shape, 1, &unused_shape));
2325 
2326   // Number of elements in indices and values must match.
2327   DimensionHandle num_index_elements_dim = c->Dim(indices_shape, 0);
2328   if (c->ValueKnown(num_index_elements_dim)) {
2329     DimensionHandle num_values_elements_dim = c->Dim(values_shape, 0);
2330     if (c->ValueKnown(num_values_elements_dim)) {
2331       int64_t num_index_elements = c->Value(num_index_elements_dim);
2332       int64_t num_values_elements = c->Value(num_values_elements_dim);
2333       if (num_index_elements != num_values_elements) {
2334         return errors::InvalidArgument("Number of elements in index (",
2335                                        num_index_elements, ") and values (",
2336                                        num_values_elements, ") do not match.");
2337       }
2338     }
2339   }
2340 
2341   // Rank embedded in indices must match shape.
2342   DimensionHandle index_rank_dim = c->Dim(indices_shape, 1);
2343   if (c->ValueKnown(index_rank_dim)) {
2344     DimensionHandle shape_rank_dim = c->Dim(shape_shape, 0);
2345     if (c->ValueKnown(shape_rank_dim)) {
2346       int64_t index_rank = c->Value(index_rank_dim);
2347       int32_t shape_rank = c->Value(shape_rank_dim);
2348       if (index_rank != shape_rank) {
2349         return errors::InvalidArgument("Index rank (", index_rank,
2350                                        ") and shape rank (", shape_rank,
2351                                        ") do not match.");
2352       }
2353     }
2354   }
2355 
2356   return OkStatus();
2357 }
2358 
ValidateVariableResourceHandle(InferenceContext * c,std::vector<ShapeAndType> * shape_and_type)2359 Status ValidateVariableResourceHandle(
2360     InferenceContext* c, std::vector<ShapeAndType>* shape_and_type) {
2361   auto* handle_data = c->input_handle_shapes_and_types(0);
2362   if (handle_data == nullptr || handle_data->empty()) {
2363     shape_and_type->emplace_back(c->UnknownShape(), DT_INVALID);
2364   } else {
2365     *shape_and_type = *handle_data;
2366     DataType value_dtype;
2367     TF_RETURN_IF_ERROR(c->GetAttr("dtype", &value_dtype));
2368     if (shape_and_type->at(0).dtype != value_dtype) {
2369       return errors::InvalidArgument(
2370           "Trying to read variable with wrong dtype. "
2371           "Expected ",
2372           DataTypeString(shape_and_type->at(0).dtype), " got ",
2373           DataTypeString(value_dtype));
2374     }
2375   }
2376   return OkStatus();
2377 }
2378 
GatherNdShape(InferenceContext * c)2379 Status GatherNdShape(InferenceContext* c) {
2380   ShapeHandle params;
2381   std::vector<ShapeAndType> handle_shape_and_type;
2382   if (c->input_handle_shapes_and_types(0) != nullptr) {
2383     TF_RETURN_IF_ERROR(
2384         ValidateVariableResourceHandle(c, &handle_shape_and_type));
2385     params = handle_shape_and_type[0].shape;
2386   } else {
2387     params = c->input(0);
2388   }
2389   ShapeHandle indices;
2390   TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &indices));
2391   DimensionHandle r_dim = c->Dim(indices, -1);
2392 
2393   if (!c->RankKnown(params) || !c->ValueKnown(r_dim)) {
2394     c->set_output(0, c->UnknownShape());
2395     return OkStatus();
2396   }
2397 
2398   if (c->Value(r_dim) > c->Rank(params)) {
2399     return errors::InvalidArgument(
2400         "indices.shape[-1] must be <= params.rank, but saw indices shape: ",
2401         c->DebugString(indices), " and params shape: ", c->DebugString(params));
2402   }
2403 
2404   // Remove r_dim from indices to get output.
2405   ShapeHandle indices_slice;
2406   ShapeHandle params_slice;
2407   TF_RETURN_IF_ERROR(c->Subshape(indices, 0, -1, &indices_slice));
2408   TF_RETURN_IF_ERROR(c->Subshape(params, c->Value(r_dim), &params_slice));
2409   ShapeHandle out;
2410   TF_RETURN_IF_ERROR(c->Concatenate(indices_slice, params_slice, &out));
2411   c->set_output(0, out);
2412   return OkStatus();
2413 }
2414 
ScatterNdShapeHelper(InferenceContext * c,ShapeHandle indices_shape,ShapeHandle updates_shape,ShapeHandle input_shape)2415 Status ScatterNdShapeHelper(InferenceContext* c, ShapeHandle indices_shape,
2416                             ShapeHandle updates_shape,
2417                             ShapeHandle input_shape) {
2418   if (c->Value(c->NumElements(input_shape)) == 0 &&
2419       (c->Value(c->NumElements(indices_shape)) > 0 ||
2420        c->Value(c->NumElements(updates_shape)) > 0)) {
2421     return errors::InvalidArgument(
2422         "Indices and updates specified for empty input");
2423   }
2424 
2425   if (c->RankKnown(indices_shape) && c->RankKnown(updates_shape) &&
2426       c->Rank(updates_shape) != 0) {
2427     const int64_t outer_dims = c->Rank(indices_shape) - 1;
2428     const DimensionHandle ixdim = c->Dim(indices_shape, -1);
2429 
2430     // We can only do more validation if the last dimension of indices
2431     // is a known value.
2432     if (c->ValueKnown(ixdim)) {
2433       int64_t ix = c->Value(ixdim);
2434       ShapeHandle unused;
2435       ShapeHandle prefix_indices;
2436       TF_RETURN_IF_ERROR(
2437           c->Subshape(indices_shape, 0, outer_dims, &prefix_indices));
2438       ShapeHandle prefix_updates;
2439       TF_RETURN_IF_ERROR(
2440           c->Subshape(updates_shape, 0, outer_dims, &prefix_updates));
2441 
2442       Status s = c->Merge(prefix_indices, prefix_updates, &unused);
2443       if (!s.ok()) {
2444         return errors::InvalidArgument(
2445             "Dimensions [0,", outer_dims,
2446             ") of indices[shape=", c->DebugString(indices_shape),
2447             "] = ", c->DebugString(prefix_indices),
2448             " must match dimensions [0,", outer_dims,
2449             ") of updates[shape=", c->DebugString(updates_shape),
2450             "] = ", c->DebugString(prefix_updates), ": ", s.error_message());
2451       }
2452 
2453       ShapeHandle suffix_output;
2454       TF_RETURN_IF_ERROR(c->Subshape(input_shape, ix, &suffix_output));
2455       ShapeHandle suffix_updates;
2456       TF_RETURN_IF_ERROR(
2457           c->Subshape(updates_shape, outer_dims, &suffix_updates));
2458       s = c->Merge(suffix_output, suffix_updates, &unused);
2459       if (!s.ok()) {
2460         return errors::InvalidArgument(
2461             "Dimensions [", ix, ",", c->Rank(input_shape),
2462             ") of input[shape=", c->DebugString(input_shape),
2463             "] = ", c->DebugString(suffix_output), " must match dimensions [",
2464             outer_dims, ",", c->Rank(updates_shape),
2465             ") of updates[shape=", c->DebugString(updates_shape),
2466             "] = ", c->DebugString(suffix_updates), ": ", s.error_message());
2467       }
2468     }
2469   }
2470 
2471   if (c->input_handle_shapes_and_types(0) == nullptr && c->num_outputs() > 0) {
2472     // This is called for tf.scatter_nd; output is a tensor with this shape.
2473     c->set_output(0, input_shape);
2474   }
2475   return OkStatus();
2476 }
2477 
ExplicitShape(InferenceContext * c)2478 Status ExplicitShape(InferenceContext* c) {
2479   PartialTensorShape shape;
2480   TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape));
2481   ShapeHandle output_shape;
2482   TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(shape, &output_shape));
2483   c->set_output(0, output_shape);
2484   return OkStatus();
2485 }
2486 
ExplicitShapes(InferenceContext * c)2487 Status ExplicitShapes(InferenceContext* c) {
2488   std::vector<PartialTensorShape> shapes;
2489   TF_RETURN_IF_ERROR(c->GetAttr("shapes", &shapes));
2490   if (shapes.empty()) {
2491     return errors::Internal("shapes attribute is empty");
2492   }
2493   for (int i = 0, end = shapes.size(); i < end; ++i) {
2494     ShapeHandle output_shape;
2495     TF_RETURN_IF_ERROR(
2496         c->MakeShapeFromPartialTensorShape(shapes[i], &output_shape));
2497     c->set_output(i, output_shape);
2498   }
2499   return OkStatus();
2500 }
2501 
SparseReduceShapeFn(InferenceContext * c)2502 Status SparseReduceShapeFn(InferenceContext* c) {
2503   // Input 0: input_indices
2504   // Input 1: input_values
2505   // Input 2: input_shape
2506   // Input 3: reduction_axes
2507   // Attr: keep_dims
2508   bool keep_dims = false;
2509   TF_RETURN_IF_ERROR(c->GetAttr("keep_dims", &keep_dims));
2510 
2511   const Tensor* shape_tensor = c->input_tensor(2);
2512   const Tensor* axes_tensor = c->input_tensor(3);
2513   if (shape_tensor != nullptr && axes_tensor != nullptr) {
2514     auto shape_vec = shape_tensor->flat<int64_t>();
2515     auto axes_vec = axes_tensor->flat<int32>();
2516 
2517     int64_t ndims = shape_vec.size();
2518     absl::flat_hash_set<int64_t> axes;
2519     if (ndims == 0)
2520       return errors::InvalidArgument(
2521           "Number of dims in shape tensor must not be 0");
2522     for (int i = 0; i < axes_vec.size(); i++) {
2523       axes.insert((axes_vec(i) + ndims) % ndims);
2524     }
2525 
2526     std::vector<DimensionHandle> dims;
2527     if (keep_dims) {
2528       dims.reserve(ndims);
2529       for (int d = 0; d < ndims; ++d) {
2530         if (axes.find(d) == axes.end()) {
2531           dims.push_back(c->MakeDim(shape_vec(d)));
2532         } else {
2533           dims.push_back(c->MakeDim(1));
2534         }
2535       }
2536     } else {
2537       for (int d = 0; d < ndims; ++d) {
2538         if (axes.find(d) == axes.end()) {
2539           dims.push_back(c->MakeDim(shape_vec(d)));
2540         }
2541       }
2542     }
2543 
2544     c->set_output(0, c->MakeShape(dims));
2545     return OkStatus();
2546   }
2547   return UnknownShape(c);
2548 }
2549 
QuantizedConv2DShape(InferenceContext * c)2550 Status QuantizedConv2DShape(InferenceContext* c) {
2551   TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c));
2552   ShapeHandle unused;
2553   TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
2554   TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
2555   TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
2556   TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused));
2557   c->set_output(1, c->Scalar());
2558   c->set_output(2, c->Scalar());
2559   return OkStatus();
2560 }
2561 
QuantizedAvgPoolShape(InferenceContext * c)2562 Status QuantizedAvgPoolShape(InferenceContext* c) {
2563   TF_RETURN_IF_ERROR(shape_inference::AvgPoolShape(c));
2564   ShapeHandle unused;
2565   TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
2566   TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
2567   c->set_output(1, c->Scalar());
2568   c->set_output(2, c->Scalar());
2569   return OkStatus();
2570 }
2571 
QuantizeV2Shape(InferenceContext * c)2572 Status QuantizeV2Shape(InferenceContext* c) {
2573   int axis = -1;
2574   Status s = c->GetAttr("axis", &axis);
2575   if (!s.ok() && s.code() != error::NOT_FOUND) {
2576     return s;
2577   }
2578   if (axis < -1) {
2579     return errors::InvalidArgument("axis should be at least -1, got ", axis);
2580   }
2581   const int minmax_rank = (axis == -1) ? 0 : 1;
2582   TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c));
2583   ShapeHandle minmax;
2584   TF_RETURN_IF_ERROR(c->WithRank(c->input(1), minmax_rank, &minmax));
2585   TF_RETURN_IF_ERROR(c->WithRank(c->input(2), minmax_rank, &minmax));
2586   if (axis != -1) {
2587     ShapeHandle input;
2588     TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), axis + 1, &input));
2589     DimensionHandle depth;
2590     TF_RETURN_IF_ERROR(
2591         c->Merge(c->Dim(minmax, 0), c->Dim(input, axis), &depth));
2592   }
2593   c->set_output(1, minmax);
2594   c->set_output(2, minmax);
2595   return OkStatus();
2596 }
2597 
ReduceScatterShape(shape_inference::InferenceContext * c)2598 Status ReduceScatterShape(shape_inference::InferenceContext* c) {
2599   shape_inference::ShapeHandle in = c->input(0);
2600   if (!c->RankKnown(in)) {
2601     // Input shape unknown, so set unknown output shape.
2602     c->set_output(0, in);
2603     return OkStatus();
2604   }
2605 
2606   shape_inference::ShapeHandle group_assignment_shape = c->input(1);
2607   if (c->Rank(group_assignment_shape) != 2)
2608     return errors::InvalidArgument(
2609         "ReduceScatter group_assignment should be rank 2");
2610 
2611   const Tensor* scatter_dimension = c->input_tensor(2);
2612   if (!scatter_dimension) {
2613     c->set_output(0, c->UnknownShape());
2614     return OkStatus();
2615   }
2616   int64_t scatter_dim;
2617   TF_RETURN_IF_ERROR(c->GetScalarFromTensor(scatter_dimension, &scatter_dim));
2618 
2619   std::vector<shape_inference::DimensionHandle> out_dims;
2620   out_dims.reserve(c->Rank(in));
2621   for (int i = 0; i < c->Rank(in); ++i) {
2622     // If the dimension is the scatter_dimension, then divide the dimension
2623     // by the partition size in the group_assignment.
2624     if (i == scatter_dim) {
2625       shape_inference::DimensionHandle dim = c->Dim(in, i);
2626       shape_inference::DimensionHandle out_dim;
2627       TF_RETURN_IF_ERROR(c->Divide(dim, c->Dim(group_assignment_shape, 1),
2628                                    /*evenly_divisible=*/true, &out_dim));
2629       out_dims.push_back(out_dim);
2630     } else {
2631       out_dims.emplace_back(c->Dim(in, i));
2632     }
2633   }
2634   c->set_output(0, c->MakeShape(out_dims));
2635   return OkStatus();
2636 }
2637 
2638 }  // namespace shape_inference
2639 
2640 }  // namespace tensorflow
2641