• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/framework/common_shape_fns.h"
16 #include "tensorflow/core/framework/attr_value.pb.h"
17 #include "tensorflow/core/lib/core/errors.h"
18 
19 namespace tensorflow {
20 
GetWindowedOutputSizeVerboseV2(int64 input_size,int64 filter_size,int64 dilation_rate,int64 stride,Padding padding_type,int64 * output_size,int64 * padding_before,int64 * padding_after)21 Status GetWindowedOutputSizeVerboseV2(int64 input_size, int64 filter_size,
22                                       int64 dilation_rate, int64 stride,
23                                       Padding padding_type, int64* output_size,
24                                       int64* padding_before,
25                                       int64* padding_after) {
26   if (stride <= 0) {
27     return errors::InvalidArgument("Stride must be > 0, but got ", stride);
28   }
29   if (dilation_rate < 1) {
30     return errors::InvalidArgument("Dilation rate must be >= 1, but got ",
31                                    dilation_rate);
32   }
33 
34   // See also the parallel implementation in GetWindowedOutputSizeFromDimsV2.
35   int64 effective_filter_size = (filter_size - 1) * dilation_rate + 1;
36   switch (padding_type) {
37     case Padding::VALID:
38       *output_size = (input_size - effective_filter_size + stride) / stride;
39       *padding_before = *padding_after = 0;
40       break;
41     case Padding::EXPLICIT:
42       *output_size = (input_size + *padding_before + *padding_after -
43                       effective_filter_size + stride) /
44                      stride;
45       break;
46     case Padding::SAME:
47       *output_size = (input_size + stride - 1) / stride;
48       const int64 padding_needed =
49           std::max(int64{0}, (*output_size - 1) * stride +
50                                  effective_filter_size - input_size);
51       // For odd values of total padding, add more padding at the 'right'
52       // side of the given dimension.
53       *padding_before = padding_needed / 2;
54       *padding_after = padding_needed - *padding_before;
55       break;
56   }
57   if (*output_size < 0) {
58     return errors::InvalidArgument(
59         "Computed output size would be negative: ", *output_size,
60         " [input_size: ", input_size,
61         ", effective_filter_size: ", effective_filter_size,
62         ", stride: ", stride, "]");
63   }
64   return Status::OK();
65 }
66 
GetWindowedOutputSizeVerbose(int64 input_size,int64 filter_size,int64 stride,Padding padding_type,int64 * output_size,int64 * padding_before,int64 * padding_after)67 Status GetWindowedOutputSizeVerbose(int64 input_size, int64 filter_size,
68                                     int64 stride, Padding padding_type,
69                                     int64* output_size, int64* padding_before,
70                                     int64* padding_after) {
71   return GetWindowedOutputSizeVerboseV2(input_size, filter_size,
72                                         /*dilation_rate=*/1, stride,
73                                         padding_type, output_size,
74                                         padding_before, padding_after);
75 }
76 
GetWindowedOutputSize(int64 input_size,int64 filter_size,int64 stride,Padding padding_type,int64 * output_size,int64 * padding_size)77 Status GetWindowedOutputSize(int64 input_size, int64 filter_size, int64 stride,
78                              Padding padding_type, int64* output_size,
79                              int64* padding_size) {
80   if (padding_type == Padding::EXPLICIT) {
81     return errors::Internal(
82         "GetWindowedOutputSize does not handle EXPLICIT padding; call "
83         "GetWindowedOutputSizeVerbose instead");
84   }
85   int64 padding_after_unused;
86   return GetWindowedOutputSizeVerbose(input_size, filter_size, stride,
87                                       padding_type, output_size, padding_size,
88                                       &padding_after_unused);
89 }
90 
GetWindowedOutputSizeV2(int64 input_size,int64 filter_size,int64 dilation_rate,int64 stride,Padding padding_type,int64 * output_size,int64 * padding_size)91 Status GetWindowedOutputSizeV2(int64 input_size, int64 filter_size,
92                                int64 dilation_rate, int64 stride,
93                                Padding padding_type, int64* output_size,
94                                int64* padding_size) {
95   if (padding_type == Padding::EXPLICIT) {
96     return errors::Internal(
97         "GetWindowedOutputSizeV2 does not handle EXPLICIT padding; call "
98         "GetWindowedOutputSizeVerboseV2 instead");
99   }
100   int64 padding_after_unused;
101   return GetWindowedOutputSizeVerboseV2(input_size, filter_size, dilation_rate,
102                                         stride, padding_type, output_size,
103                                         padding_size, &padding_after_unused);
104 }
105 
Get3dOutputSize(const std::array<int64,3> & input,const std::array<int64,3> & window,const std::array<int64,3> & strides,Padding padding_type,std::array<int64,3> * output_ptr,std::array<int64,3> * padding_ptr)106 Status Get3dOutputSize(const std::array<int64, 3>& input,
107                        const std::array<int64, 3>& window,
108                        const std::array<int64, 3>& strides,
109                        Padding padding_type, std::array<int64, 3>* output_ptr,
110                        std::array<int64, 3>* padding_ptr) {
111   for (size_t i = 0; i < input.size(); ++i) {
112     TF_RETURN_IF_ERROR(GetWindowedOutputSize(input[i], window[i], strides[i],
113                                              padding_type, &(*output_ptr)[i],
114                                              &(*padding_ptr)[i]));
115   }
116   return Status::OK();
117 }
118 
Get3dOutputSizeV2(const std::array<int64,3> & input,const std::array<int64,3> & window,const std::array<int64,3> & dilations,const std::array<int64,3> & strides,Padding padding_type,std::array<int64,3> * output_ptr,std::array<int64,3> * padding_ptr)119 Status Get3dOutputSizeV2(const std::array<int64, 3>& input,
120                          const std::array<int64, 3>& window,
121                          const std::array<int64, 3>& dilations,
122                          const std::array<int64, 3>& strides,
123                          Padding padding_type, std::array<int64, 3>* output_ptr,
124                          std::array<int64, 3>* padding_ptr) {
125   for (size_t i = 0; i < input.size(); ++i) {
126     TF_RETURN_IF_ERROR(GetWindowedOutputSizeV2(
127         input[i], window[i], dilations[i], strides[i], padding_type,
128         &(*output_ptr)[i], &(*padding_ptr)[i]));
129   }
130   return Status::OK();
131 }
132 
133 namespace shape_inference {
134 
135 // The V2 version computes windowed output size with arbitrary dilation_rate,
136 // while the original version only handles the cases where dilation_rates equal
137 // to 1.
GetWindowedOutputSizeFromDimsV2(shape_inference::InferenceContext * c,shape_inference::DimensionHandle input_size,shape_inference::DimensionOrConstant filter_size,int64 dilation_rate,int64 stride,Padding padding_type,int64 padding_before,int64 padding_after,shape_inference::DimensionHandle * output_size)138 Status GetWindowedOutputSizeFromDimsV2(
139     shape_inference::InferenceContext* c,
140     shape_inference::DimensionHandle input_size,
141     shape_inference::DimensionOrConstant filter_size, int64 dilation_rate,
142     int64 stride, Padding padding_type, int64 padding_before,
143     int64 padding_after, shape_inference::DimensionHandle* output_size) {
144   if (stride <= 0) {
145     return errors::InvalidArgument("Stride must be > 0, but got ", stride);
146   }
147 
148   if (dilation_rate < 1) {
149     return errors::InvalidArgument("Dilation rate must be >= 1, but got ",
150                                    dilation_rate);
151   }
152 
153   // See also the parallel implementation in GetWindowedOutputSizeVerbose.
154   switch (padding_type) {
155     case Padding::VALID:
156       padding_before = padding_after = 0;
157       TF_FALLTHROUGH_INTENDED;
158     case Padding::EXPLICIT:
159       TF_RETURN_IF_ERROR(
160           c->Add(input_size, padding_before + padding_after, &input_size));
161       if (dilation_rate > 1) {
162         DimensionHandle window_size;
163         TF_RETURN_IF_ERROR(
164             c->Subtract(c->MakeDim(filter_size), 1, &window_size));
165         TF_RETURN_IF_ERROR(
166             c->Multiply(window_size, dilation_rate, &window_size));
167         TF_RETURN_IF_ERROR(c->Add(window_size, 1, &window_size));
168         TF_RETURN_IF_ERROR(c->Subtract(input_size, window_size, output_size));
169       } else {
170         TF_RETURN_IF_ERROR(c->Subtract(input_size, filter_size, output_size));
171       }
172       TF_RETURN_IF_ERROR(c->Add(*output_size, stride, output_size));
173       TF_RETURN_IF_ERROR(c->Divide(*output_size, stride,
174                                    /*evenly_divisible=*/false, output_size));
175       break;
176     case Padding::SAME:
177       TF_RETURN_IF_ERROR(c->Add(input_size, stride - 1, output_size));
178       TF_RETURN_IF_ERROR(c->Divide(*output_size, stride,
179                                    /*evenly_divisible=*/false, output_size));
180       break;
181   }
182   return Status::OK();
183 }
184 
GetWindowedOutputSizeFromDims(shape_inference::InferenceContext * c,shape_inference::DimensionHandle input_size,shape_inference::DimensionOrConstant filter_size,int64 stride,Padding padding_type,shape_inference::DimensionHandle * output_size)185 Status GetWindowedOutputSizeFromDims(
186     shape_inference::InferenceContext* c,
187     shape_inference::DimensionHandle input_size,
188     shape_inference::DimensionOrConstant filter_size, int64 stride,
189     Padding padding_type, shape_inference::DimensionHandle* output_size) {
190   if (padding_type == Padding::EXPLICIT) {
191     return errors::Internal(
192         "GetWindowedOutputSizeFromDims does not handle EXPLICIT padding; call "
193         "GetWindowedOutputSizeFromDimsV2 instead");
194   }
195   return GetWindowedOutputSizeFromDimsV2(c, input_size, filter_size,
196                                          /*dilation_rate=*/1, stride,
197                                          padding_type,
198                                          // Give dummy values of -1 to
199                                          // padding_before and padding_after,
200                                          // since explicit padding is not used.
201                                          -1, -1, output_size);
202 }
203 
UnchangedShape(shape_inference::InferenceContext * c)204 Status UnchangedShape(shape_inference::InferenceContext* c) {
205   c->set_output(0, c->input(0));
206   auto* handle_data = c->input_handle_shapes_and_types(0);
207   if (handle_data != nullptr) {
208     c->set_output_handle_shapes_and_types(0, *handle_data);
209   }
210   return Status::OK();
211 }
212 
MatMulShape(shape_inference::InferenceContext * c)213 Status MatMulShape(shape_inference::InferenceContext* c) {
214   ShapeHandle a;
215   TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &a));
216 
217   ShapeHandle b;
218   TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &b));
219 
220   bool transpose_a, transpose_b;
221   TF_RETURN_IF_ERROR(c->GetAttr("transpose_a", &transpose_a));
222   TF_RETURN_IF_ERROR(c->GetAttr("transpose_b", &transpose_b));
223   DimensionHandle output_rows = transpose_a ? c->Dim(a, 1) : c->Dim(a, 0);
224   DimensionHandle output_cols = transpose_b ? c->Dim(b, 0) : c->Dim(b, 1);
225 
226   // Validate that the inner shapes are compatible.
227   DimensionHandle inner_a = transpose_a ? c->Dim(a, 0) : c->Dim(a, 1);
228   DimensionHandle inner_b = transpose_b ? c->Dim(b, 1) : c->Dim(b, 0);
229   DimensionHandle merged;
230   TF_RETURN_IF_ERROR(c->Merge(inner_a, inner_b, &merged));
231 
232   c->set_output(0, c->Matrix(output_rows, output_cols));
233   return Status::OK();
234 }
235 
BiasAddShape(shape_inference::InferenceContext * c)236 Status BiasAddShape(shape_inference::InferenceContext* c) {
237   ShapeHandle input_shape;
238 
239   // Fetch the data_format attribute, which may not exist.
240   string data_format;
241   Status s = c->GetAttr("data_format", &data_format);
242 
243   if (s.ok() && data_format == "NCHW") {
244     TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 3, &input_shape));
245   } else {
246     TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input_shape));
247   }
248 
249   ShapeHandle bias_shape;
250   TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &bias_shape));
251   DimensionHandle bias_dim = c->Dim(bias_shape, 0);
252 
253   // If rank unknown, return unknown shape.
254   if (!c->RankKnown(input_shape)) {
255     c->set_output(0, c->UnknownShape());
256     return Status::OK();
257   }
258 
259   // Output has the same shape as the input, and matches the length of
260   // the bias in its bias dimension.
261   ShapeHandle output_shape;
262   if (s.ok() && data_format == "NCHW") {
263     // Merge the length of bias_shape into the third to last dimension
264     ShapeHandle first;
265     TF_RETURN_IF_ERROR(c->Subshape(input_shape, 0, 1, &first));
266 
267     ShapeHandle last;
268     TF_RETURN_IF_ERROR(c->Subshape(input_shape, 2, &last));
269 
270     DimensionHandle input_bias_dim = c->Dim(input_shape, 1);
271     DimensionHandle merged_bias_dim;
272     TF_RETURN_IF_ERROR(c->Merge(input_bias_dim, bias_dim, &merged_bias_dim));
273     ShapeHandle merged_bias = c->Vector(merged_bias_dim);
274 
275     ShapeHandle temp;
276     TF_RETURN_IF_ERROR(c->Concatenate(first, merged_bias, &temp));
277     TF_RETURN_IF_ERROR(c->Concatenate(temp, last, &output_shape));
278   } else {
279     ShapeHandle all_but_bias;
280     TF_RETURN_IF_ERROR(c->Subshape(input_shape, 0, -1, &all_but_bias));
281 
282     DimensionHandle input_bias_dim = c->Dim(input_shape, -1);
283     DimensionHandle merged_bias_dim;
284     TF_RETURN_IF_ERROR(c->Merge(input_bias_dim, bias_dim, &merged_bias_dim));
285 
286     ShapeHandle merged_bias = c->Vector(merged_bias_dim);
287     TF_RETURN_IF_ERROR(
288         c->Concatenate(all_but_bias, merged_bias, &output_shape));
289   }
290 
291   c->set_output(0, output_shape);
292   return Status::OK();
293 }
294 
BiasAddGradShape(shape_inference::InferenceContext * c)295 Status BiasAddGradShape(shape_inference::InferenceContext* c) {
296   ShapeHandle input_shape;
297   // Fetch the data_format attribute, which may not exist.
298   string data_format;
299   Status s = c->GetAttr("data_format", &data_format);
300 
301   if (s.ok() && data_format == "NCHW") {
302     TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 3, &input_shape));
303     c->set_output(0, c->Vector(c->Dim(input_shape, 1)));
304   } else {
305     TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input_shape));
306     c->set_output(0, c->Vector(c->Dim(input_shape, -1)));
307   }
308 
309   return Status::OK();
310 }
311 
CheckFormatConstraintsOnShape(const TensorFormat tensor_format,const ShapeHandle shape_handle,const string & tensor_name,shape_inference::InferenceContext * c)312 Status CheckFormatConstraintsOnShape(const TensorFormat tensor_format,
313                                      const ShapeHandle shape_handle,
314                                      const string& tensor_name,
315                                      shape_inference::InferenceContext* c) {
316   if (tensor_format == FORMAT_NCHW_VECT_C) {
317     // Check that the vect dim has size 4.
318     const int num_dims = c->Rank(shape_handle);
319     DimensionHandle vect_dim = c->Dim(
320         shape_handle, GetTensorInnerFeatureDimIndex(num_dims, tensor_format));
321     DimensionHandle unused_vect_dim;
322     TF_RETURN_IF_ERROR(c->WithValue(vect_dim, 4, &unused_vect_dim));
323   }
324 
325   return Status::OK();
326 }
327 
MakeShapeFromFormat(TensorFormat format,DimensionOrConstant N,const std::vector<DimensionOrConstant> & spatial,DimensionOrConstant C,ShapeHandle * out,shape_inference::InferenceContext * context)328 Status MakeShapeFromFormat(TensorFormat format, DimensionOrConstant N,
329                            const std::vector<DimensionOrConstant>& spatial,
330                            DimensionOrConstant C, ShapeHandle* out,
331                            shape_inference::InferenceContext* context) {
332   const int num_dims = GetTensorDimsFromSpatialDims(spatial.size(), format);
333   std::vector<DimensionHandle> dims_actual(num_dims);
334   dims_actual[GetTensorBatchDimIndex(num_dims, format)] = context->MakeDim(N);
335   int outer_c_index = GetTensorFeatureDimIndex(num_dims, format);
336   dims_actual[outer_c_index] = context->MakeDim(C);
337   if (format == FORMAT_NCHW_VECT_C) {
338     dims_actual[GetTensorInnerFeatureDimIndex(num_dims, format)] =
339         context->MakeDim(4);
340   } else if (format == FORMAT_NHWC_VECT_W) {
341     dims_actual[GetTensorInnerWidthDimIndex(num_dims, format)] =
342         context->MakeDim(4);
343   }
344   for (int spatial_dim = 0; spatial_dim < spatial.size(); spatial_dim++) {
345     dims_actual[GetTensorSpatialDimIndex(num_dims, format, spatial_dim)] =
346         context->MakeDim(spatial[spatial_dim]);
347   }
348   *out = context->MakeShape(dims_actual);
349   return Status::OK();
350 }
351 
DimensionsFromShape(ShapeHandle shape,TensorFormat format,DimensionHandle * batch_dim,gtl::MutableArraySlice<DimensionHandle> spatial_dims,DimensionHandle * filter_dim,InferenceContext * context)352 Status DimensionsFromShape(ShapeHandle shape, TensorFormat format,
353                            DimensionHandle* batch_dim,
354                            gtl::MutableArraySlice<DimensionHandle> spatial_dims,
355                            DimensionHandle* filter_dim,
356                            InferenceContext* context) {
357   const int32 rank = GetTensorDimsFromSpatialDims(spatial_dims.size(), format);
358   // Batch.
359   *batch_dim = context->Dim(shape, GetTensorBatchDimIndex(rank, format));
360   // Spatial.
361   for (int spatial_dim_index = 0; spatial_dim_index < spatial_dims.size();
362        ++spatial_dim_index) {
363     spatial_dims[spatial_dim_index] = context->Dim(
364         shape, GetTensorSpatialDimIndex(rank, format, spatial_dim_index));
365   }
366   // Channel.
367   *filter_dim = context->Dim(shape, GetTensorFeatureDimIndex(rank, format));
368   if (format == FORMAT_NCHW_VECT_C) {
369     TF_RETURN_IF_ERROR(context->Multiply(
370         *filter_dim,
371         context->Dim(shape, GetTensorInnerFeatureDimIndex(rank, format)),
372         filter_dim));
373   }
374   return Status::OK();
375 }
376 
ShapeFromDimensions(DimensionHandle batch_dim,gtl::ArraySlice<DimensionHandle> spatial_dims,DimensionHandle filter_dim,TensorFormat format,InferenceContext * context,ShapeHandle * shape)377 Status ShapeFromDimensions(DimensionHandle batch_dim,
378                            gtl::ArraySlice<DimensionHandle> spatial_dims,
379                            DimensionHandle filter_dim, TensorFormat format,
380                            InferenceContext* context, ShapeHandle* shape) {
381   const int32 rank = GetTensorDimsFromSpatialDims(spatial_dims.size(), format);
382   std::vector<DimensionHandle> out_dims(rank);
383 
384   // Batch.
385   out_dims[tensorflow::GetTensorBatchDimIndex(rank, format)] = batch_dim;
386   // Spatial.
387   for (int spatial_dim_index = 0; spatial_dim_index < spatial_dims.size();
388        ++spatial_dim_index) {
389     out_dims[tensorflow::GetTensorSpatialDimIndex(
390         rank, format, spatial_dim_index)] = spatial_dims[spatial_dim_index];
391   }
392   // Channel.
393   if (format == tensorflow::FORMAT_NCHW_VECT_C) {
394     // When format is NCHW_VECT_C, factor the feature map count
395     // into the outer feature count and the inner feature count (=4).
396     TF_RETURN_IF_ERROR(context->Divide(
397         filter_dim, 4, /*evenly_divisible=*/true,
398         &out_dims[tensorflow::GetTensorFeatureDimIndex(rank, format)]));
399     out_dims[GetTensorInnerFeatureDimIndex(rank, format)] = context->MakeDim(4);
400   } else {
401     out_dims[tensorflow::GetTensorFeatureDimIndex(rank, format)] = filter_dim;
402   }
403 
404   *shape = context->MakeShape(out_dims);
405   return tensorflow::Status::OK();
406 }
407 
408 namespace {
409 
Conv2DShapeImpl(shape_inference::InferenceContext * c,bool supports_explicit_padding)410 Status Conv2DShapeImpl(shape_inference::InferenceContext* c,
411                        bool supports_explicit_padding) {
412   string data_format_str, filter_format_str;
413   if (!c->GetAttr("data_format", &data_format_str).ok()) {
414     data_format_str = "NHWC";
415   }
416   if (!c->GetAttr("filter_format", &filter_format_str).ok()) {
417     filter_format_str = "HWIO";
418   }
419 
420   TensorFormat data_format;
421   if (!FormatFromString(data_format_str, &data_format)) {
422     return errors::InvalidArgument("Invalid data format string: ",
423                                    data_format_str);
424   }
425   FilterTensorFormat filter_format;
426   if (!FilterFormatFromString(filter_format_str, &filter_format)) {
427     return errors::InvalidArgument("Invalid filter format string: ",
428                                    filter_format_str);
429   }
430 
431   constexpr int num_spatial_dims = 2;
432   const int rank = GetTensorDimsFromSpatialDims(num_spatial_dims, data_format);
433   ShapeHandle conv_input_shape;
434   TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &conv_input_shape));
435   TF_RETURN_IF_ERROR(CheckFormatConstraintsOnShape(
436       data_format, conv_input_shape, "conv_input", c));
437 
438   // The filter rank should match the input (4 for NCHW, 5 for NCHW_VECT_C).
439   ShapeHandle filter_shape;
440   TF_RETURN_IF_ERROR(c->WithRank(c->input(1), rank, &filter_shape));
441   TF_RETURN_IF_ERROR(
442       CheckFormatConstraintsOnShape(data_format, filter_shape, "filter", c));
443 
444   std::vector<int32> dilations;
445   TF_RETURN_IF_ERROR(c->GetAttr("dilations", &dilations));
446 
447   if (dilations.size() != 4) {
448     return errors::InvalidArgument(
449         "Conv2D requires the dilation attribute to contain 4 values, but got: ",
450         dilations.size());
451   }
452 
453   std::vector<int32> strides;
454   TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
455 
456   // strides.size() should be 4 (NCHW) even if the input is 5 (NCHW_VECT_C).
457   if (strides.size() != 4) {
458     return errors::InvalidArgument("Conv2D on data format ", data_format_str,
459                                    " requires the stride attribute to contain"
460                                    " 4 values, but got: ",
461                                    strides.size());
462   }
463 
464   const int32 stride_rows = GetTensorDim(strides, data_format, 'H');
465   const int32 stride_cols = GetTensorDim(strides, data_format, 'W');
466   const int32 dilation_rows = GetTensorDim(dilations, data_format, 'H');
467   const int32 dilation_cols = GetTensorDim(dilations, data_format, 'W');
468 
469   DimensionHandle batch_size_dim;
470   DimensionHandle input_depth_dim;
471   gtl::InlinedVector<DimensionHandle, 2> input_spatial_dims(2);
472   TF_RETURN_IF_ERROR(DimensionsFromShape(
473       conv_input_shape, data_format, &batch_size_dim,
474       absl::MakeSpan(input_spatial_dims), &input_depth_dim, c));
475 
476   DimensionHandle output_depth_dim = c->Dim(
477       filter_shape, GetFilterDimIndex<num_spatial_dims>(filter_format, 'O'));
478   DimensionHandle filter_rows_dim = c->Dim(
479       filter_shape, GetFilterDimIndex<num_spatial_dims>(filter_format, 'H'));
480   DimensionHandle filter_cols_dim = c->Dim(
481       filter_shape, GetFilterDimIndex<num_spatial_dims>(filter_format, 'W'));
482   DimensionHandle filter_input_depth_dim;
483   if (filter_format == FORMAT_OIHW_VECT_I) {
484     TF_RETURN_IF_ERROR(c->Multiply(
485         c->Dim(filter_shape,
486                GetFilterDimIndex<num_spatial_dims>(filter_format, 'I')),
487         c->Dim(filter_shape,
488                GetFilterTensorInnerInputChannelsDimIndex(rank, filter_format)),
489         &filter_input_depth_dim));
490   } else {
491     filter_input_depth_dim = c->Dim(
492         filter_shape, GetFilterDimIndex<num_spatial_dims>(filter_format, 'I'));
493   }
494 
495   // Check that the input tensor and the filter tensor agree on the input
496   // channel count.
497   DimensionHandle unused;
498   TF_RETURN_IF_ERROR(
499       c->Merge(input_depth_dim, filter_input_depth_dim, &unused));
500 
501   Padding padding;
502   TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
503 
504   std::vector<int64> explicit_paddings;
505   if (supports_explicit_padding) {
506     Status s = c->GetAttr("explicit_paddings", &explicit_paddings);
507     // Use the default value, which is an empty list, if the attribute is not
508     // found. Otherwise return the error to the caller.
509     if (!s.ok() && !errors::IsNotFound(s)) {
510       return s;
511     }
512     TF_RETURN_IF_ERROR(CheckValidPadding(padding, explicit_paddings,
513                                          /*num_dims=*/4, data_format));
514   } else {
515     DCHECK(padding != Padding::EXPLICIT);
516   }
517 
518   DimensionHandle output_rows, output_cols;
519   int64 pad_rows_before = -1, pad_rows_after = -1;
520   int64 pad_cols_before = -1, pad_cols_after = -1;
521   if (padding == Padding::EXPLICIT) {
522     GetExplicitPaddingForDim(explicit_paddings, data_format, 'H',
523                              &pad_rows_before, &pad_rows_after);
524     GetExplicitPaddingForDim(explicit_paddings, data_format, 'W',
525                              &pad_cols_before, &pad_cols_after);
526   }
527   TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDimsV2(
528       c, input_spatial_dims[0], filter_rows_dim, dilation_rows, stride_rows,
529       padding, pad_rows_before, pad_rows_after, &output_rows));
530   TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDimsV2(
531       c, input_spatial_dims[1], filter_cols_dim, dilation_cols, stride_cols,
532       padding, pad_cols_before, pad_cols_after, &output_cols));
533 
534   ShapeHandle output_shape;
535   TF_RETURN_IF_ERROR(
536       ShapeFromDimensions(batch_size_dim, {output_rows, output_cols},
537                           output_depth_dim, data_format, c, &output_shape));
538   c->set_output(0, output_shape);
539   return Status::OK();
540 }
541 
542 }  // namespace
543 
544 // Shape function for Conv2D-like operations that support explicit padding.
Conv2DShapeWithExplicitPadding(shape_inference::InferenceContext * c)545 Status Conv2DShapeWithExplicitPadding(shape_inference::InferenceContext* c) {
546   return Conv2DShapeImpl(c, true);
547 }
548 
549 // Shape function for Conv2D-like operations that do not support explicit
550 // padding.
Conv2DShape(shape_inference::InferenceContext * c)551 Status Conv2DShape(shape_inference::InferenceContext* c) {
552   return Conv2DShapeImpl(c, false);
553 }
554 
555 // TODO(mjanusz): Unify all conv/pooling shape functions.
Conv3DShape(shape_inference::InferenceContext * c)556 Status Conv3DShape(shape_inference::InferenceContext* c) {
557   ShapeHandle input_shape;
558   TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 5, &input_shape));
559   ShapeHandle filter_shape;
560   TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 5, &filter_shape));
561 
562   string data_format;
563   Status s = c->GetAttr("data_format", &data_format);
564 
565   std::vector<int32> dilations;
566   TF_RETURN_IF_ERROR(c->GetAttr("dilations", &dilations));
567 
568   if (dilations.size() != 5) {
569     return errors::InvalidArgument(
570         "Conv3D requires the dilation attribute to contain 5 values, but got: ",
571         dilations.size());
572   }
573 
574   std::vector<int32> strides;
575   TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
576   if (strides.size() != 5) {
577     return errors::InvalidArgument(
578         "Conv3D requires the stride attribute to contain 5 values, but got: ",
579         strides.size());
580   }
581 
582   int32 stride_planes, stride_rows, stride_cols;
583   int32 dilation_planes, dilation_rows, dilation_cols;
584   if (s.ok() && data_format == "NCDHW") {
585     // Convert input_shape to NDHWC.
586     auto dim = [&](char dimension) {
587       return c->Dim(input_shape, GetTensorDimIndex<3>(FORMAT_NCHW, dimension));
588     };
589     input_shape =
590         c->MakeShape({{dim('N'), dim('0'), dim('1'), dim('2'), dim('C')}});
591     stride_planes = strides[2];
592     stride_rows = strides[3];
593     stride_cols = strides[4];
594     dilation_planes = dilations[2];
595     dilation_cols = dilations[3];
596     dilation_rows = dilations[4];
597   } else {
598     stride_planes = strides[1];
599     stride_rows = strides[2];
600     stride_cols = strides[3];
601     dilation_planes = dilations[1];
602     dilation_cols = dilations[2];
603     dilation_rows = dilations[3];
604   }
605 
606   DimensionHandle batch_size_dim = c->Dim(input_shape, 0);
607   DimensionHandle in_planes_dim = c->Dim(input_shape, 1);
608   DimensionHandle in_rows_dim = c->Dim(input_shape, 2);
609   DimensionHandle in_cols_dim = c->Dim(input_shape, 3);
610 
611   DimensionHandle filter_planes_dim = c->Dim(filter_shape, 0);
612   DimensionHandle filter_rows_dim = c->Dim(filter_shape, 1);
613   DimensionHandle filter_cols_dim = c->Dim(filter_shape, 2);
614   DimensionHandle output_depth_dim = c->Dim(filter_shape, 4);
615 
616   DimensionHandle unused;
617   TF_RETURN_IF_ERROR(
618       c->Merge(c->Dim(input_shape, 4), c->Dim(filter_shape, 3), &unused));
619 
620   Padding padding;
621   TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
622   DimensionHandle output_planes, output_rows, output_cols;
623 
624   TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDimsV2(
625       c, in_planes_dim, filter_planes_dim, dilation_planes, stride_planes,
626       padding, -1, -1, &output_planes));
627   TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDimsV2(
628       c, in_rows_dim, filter_rows_dim, dilation_rows, stride_rows, padding, -1,
629       -1, &output_rows));
630   TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDimsV2(
631       c, in_cols_dim, filter_cols_dim, dilation_cols, stride_cols, padding, -1,
632       -1, &output_cols));
633 
634   ShapeHandle output_shape;
635   if (data_format == "NCDHW") {
636     output_shape = c->MakeShape({batch_size_dim, output_depth_dim,
637                                  output_planes, output_rows, output_cols});
638   } else {
639     output_shape = c->MakeShape({batch_size_dim, output_planes, output_rows,
640                                  output_cols, output_depth_dim});
641   }
642   c->set_output(0, output_shape);
643   return Status::OK();
644 }
645 
DepthwiseConv2DNativeShape(shape_inference::InferenceContext * c)646 Status DepthwiseConv2DNativeShape(shape_inference::InferenceContext* c) {
647   ShapeHandle input_shape;
648   TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_shape));
649   ShapeHandle filter_shape;
650   TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 4, &filter_shape));
651 
652   std::vector<int32> strides;
653   TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
654 
655   if (strides.size() != 4) {
656     return errors::InvalidArgument(
657         "DepthwiseConv2D requires the stride attribute to contain 4 values, "
658         "but got: ",
659         strides.size());
660   }
661 
662   string data_format;
663   Status s = c->GetAttr("data_format", &data_format);
664   int32 stride_rows;
665   int32 stride_cols;
666   if (s.ok() && data_format == "NCHW") {
667     // Canonicalize input shape to NHWC so the shape inference code below can
668     // process it.
669     input_shape =
670         c->MakeShape({{c->Dim(input_shape, 0), c->Dim(input_shape, 2),
671                        c->Dim(input_shape, 3), c->Dim(input_shape, 1)}});
672     stride_rows = strides[2];
673     stride_cols = strides[3];
674   } else {
675     stride_rows = strides[1];
676     stride_cols = strides[2];
677   }
678 
679   DimensionHandle batch_size_dim = c->Dim(input_shape, 0);
680   DimensionHandle in_rows_dim = c->Dim(input_shape, 1);
681   DimensionHandle in_cols_dim = c->Dim(input_shape, 2);
682 
683   DimensionHandle filter_rows_dim = c->Dim(filter_shape, 0);
684   DimensionHandle filter_cols_dim = c->Dim(filter_shape, 1);
685   DimensionHandle input_depth = c->Dim(filter_shape, 2);
686   DimensionHandle depth_multiplier = c->Dim(filter_shape, 3);
687 
688   // Check that the input depths are compatible.
689   TF_RETURN_IF_ERROR(
690       c->Merge(c->Dim(input_shape, 3), input_depth, &input_depth));
691 
692   DimensionHandle output_depth;
693   TF_RETURN_IF_ERROR(c->Multiply(input_depth, depth_multiplier, &output_depth));
694 
695   Padding padding;
696   TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
697 
698   // TODO(mrry,shlens): Raise an error if the stride would cause
699   // information in the input to be ignored. This will require a change
700   // in the kernel implementation.
701   DimensionHandle output_rows, output_cols;
702 
703   TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
704       c, in_rows_dim, filter_rows_dim, stride_rows, padding, &output_rows));
705   TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
706       c, in_cols_dim, filter_cols_dim, stride_cols, padding, &output_cols));
707 
708   ShapeHandle output_shape;
709   if (data_format == "NCHW") {
710     output_shape =
711         c->MakeShape({batch_size_dim, output_depth, output_rows, output_cols});
712   } else {
713     output_shape =
714         c->MakeShape({batch_size_dim, output_rows, output_cols, output_depth});
715   }
716   c->set_output(0, output_shape);
717   return Status::OK();
718 }
719 
AvgPoolShape(shape_inference::InferenceContext * c)720 Status AvgPoolShape(shape_inference::InferenceContext* c) {
721   string data_format_str;
722   TensorFormat data_format;
723   Status s = c->GetAttr("data_format", &data_format_str);
724   if (s.ok()) {
725     FormatFromString(data_format_str, &data_format);
726   } else {
727     data_format = FORMAT_NHWC;
728   }
729 
730   const int rank = (data_format == FORMAT_NCHW_VECT_C) ? 5 : 4;
731   ShapeHandle input_shape;
732   TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &input_shape));
733 
734   TF_RETURN_IF_ERROR(
735       CheckFormatConstraintsOnShape(data_format, input_shape, "input", c));
736 
737   std::vector<int32> strides;
738   TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
739   if (strides.size() != 4) {
740     return errors::InvalidArgument(
741         "AvgPool requires the stride attribute to contain 4 values, but got: ",
742         strides.size());
743   }
744 
745   std::vector<int32> kernel_sizes;
746   TF_RETURN_IF_ERROR(c->GetAttr("ksize", &kernel_sizes));
747   if (kernel_sizes.size() != 4) {
748     return errors::InvalidArgument(
749         "AvgPool requires the ksize attribute to contain 4 values, but got: ",
750         kernel_sizes.size());
751   }
752 
753   int32 stride_rows = GetTensorDim(strides, data_format, 'H');
754   int32 stride_cols = GetTensorDim(strides, data_format, 'W');
755   int32 kernel_rows = GetTensorDim(kernel_sizes, data_format, 'H');
756   int32 kernel_cols = GetTensorDim(kernel_sizes, data_format, 'W');
757 
758   constexpr int num_spatial_dims = 2;
759   DimensionHandle batch_size_dim = c->Dim(
760       input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'N'));
761   DimensionHandle in_rows_dim = c->Dim(
762       input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'H'));
763   DimensionHandle in_cols_dim = c->Dim(
764       input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'W'));
765   DimensionHandle depth_dim = c->Dim(
766       input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'C'));
767 
768   Padding padding;
769   TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
770 
771   // TODO(mrry,shlens): Raise an error if the stride would cause
772   // information in the input to be ignored. This will require a change
773   // in the kernel implementation.
774 
775   DimensionHandle output_rows, output_cols;
776   TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
777       c, in_rows_dim, kernel_rows, stride_rows, padding, &output_rows));
778   TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
779       c, in_cols_dim, kernel_cols, stride_cols, padding, &output_cols));
780 
781   ShapeHandle output_shape;
782   TF_RETURN_IF_ERROR(MakeShapeFromFormat(data_format, batch_size_dim,
783                                          {output_rows, output_cols}, depth_dim,
784                                          &output_shape, c));
785   c->set_output(0, output_shape);
786   return Status::OK();
787 }
788 
FusedBatchNormShape(shape_inference::InferenceContext * c)789 Status FusedBatchNormShape(shape_inference::InferenceContext* c) {
790   ShapeHandle x;
791   TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &x));
792 
793   bool is_training;
794   TF_RETURN_IF_ERROR(c->GetAttr("is_training", &is_training));
795   int number_inputs = (is_training) ? 3 : 5;
796   string data_format_str;
797   TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format_str));
798   TensorFormat data_format;
799   if (!FormatFromString(data_format_str, &data_format)) {
800     return errors::InvalidArgument("Invalid data format string: ",
801                                    data_format_str);
802   }
803   int channel_dim_index = GetTensorFeatureDimIndex(4, data_format);
804   DimensionHandle channel_dim = c->Dim(x, channel_dim_index);
805 
806   // covers scale, offset, and if is_training is false, mean, variance
807   for (int i = 1; i < number_inputs; ++i) {
808     ShapeHandle vec;
809     TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &vec));
810     TF_RETURN_IF_ERROR(c->Merge(channel_dim, c->Dim(vec, 0), &channel_dim));
811   }
812 
813   ShapeHandle y;
814   TF_RETURN_IF_ERROR(c->ReplaceDim(x, channel_dim_index, channel_dim, &y));
815   c->set_output(0, y);
816   ShapeHandle vector_shape = c->Vector(channel_dim);
817   c->set_output(1, vector_shape);
818   c->set_output(2, vector_shape);
819   c->set_output(3, vector_shape);
820   c->set_output(4, vector_shape);
821   return Status::OK();
822 }
823 
FusedBatchNormGradShape(shape_inference::InferenceContext * c)824 Status FusedBatchNormGradShape(shape_inference::InferenceContext* c) {
825   ShapeHandle y_backprop;
826   TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &y_backprop));
827   ShapeHandle x;
828   TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 4, &x));
829 
830   bool is_training;
831   TF_RETURN_IF_ERROR(c->GetAttr("is_training", &is_training));
832   string data_format_str;
833   TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format_str));
834   TensorFormat data_format;
835   if (!FormatFromString(data_format_str, &data_format)) {
836     return errors::InvalidArgument("Invalid data format string: ",
837                                    data_format_str);
838   }
839   int channel_dim_index = GetTensorFeatureDimIndex(4, data_format);
840   DimensionHandle channel_dim = c->Dim(y_backprop, channel_dim_index);
841   TF_RETURN_IF_ERROR(
842       c->Merge(channel_dim, c->Dim(x, channel_dim_index), &channel_dim));
843 
844   // covers scale, mean (reserve_space_1), variance (reserve_space_2)
845   for (int i = 2; i < 5; ++i) {
846     ShapeHandle vec;
847     TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &vec));
848     TF_RETURN_IF_ERROR(c->Merge(channel_dim, c->Dim(vec, 0), &channel_dim));
849   }
850 
851   ShapeHandle x_backprop;
852   TF_RETURN_IF_ERROR(
853       c->ReplaceDim(y_backprop, channel_dim_index, channel_dim, &x_backprop));
854   c->set_output(0, x_backprop);
855   c->set_output(1, c->Vector(channel_dim));
856   c->set_output(2, c->Vector(channel_dim));
857   // Set the correct shapes for reserve_spaces
858   // so that gradients can be performed when
859   // the op is in a symbolic condition.
860   if (is_training) {
861     c->set_output(3, c->Vector(0));
862     c->set_output(4, c->Vector(0));
863   } else {
864     c->set_output(3, c->Vector(channel_dim));
865     c->set_output(4, c->Vector(channel_dim));
866   }
867   return Status::OK();
868 }
869 
MaxPoolShape(shape_inference::InferenceContext * c)870 Status MaxPoolShape(shape_inference::InferenceContext* c) {
871   string data_format_str;
872   TensorFormat data_format;
873   Status s = c->GetAttr("data_format", &data_format_str);
874   if (s.ok()) {
875     FormatFromString(data_format_str, &data_format);
876   } else {
877     data_format = FORMAT_NHWC;
878   }
879 
880   const int rank = (data_format == FORMAT_NCHW_VECT_C) ? 5 : 4;
881   ShapeHandle input_shape;
882   TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &input_shape));
883 
884   TF_RETURN_IF_ERROR(
885       CheckFormatConstraintsOnShape(data_format, input_shape, "input", c));
886 
887   std::vector<int32> strides;
888   TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
889   if (strides.size() != 4) {
890     return errors::InvalidArgument(
891         "MaxPool requires the stride attribute to contain 4 values, but got: ",
892         strides.size());
893   }
894 
895   std::vector<int32> kernel_sizes;
896   TF_RETURN_IF_ERROR(c->GetAttr("ksize", &kernel_sizes));
897   if (kernel_sizes.size() != 4) {
898     return errors::InvalidArgument(
899         "MaxPool requires the ksize attribute to contain 4 values, but got: ",
900         kernel_sizes.size());
901   }
902 
903   int32 stride_depth = GetTensorDim(strides, data_format, 'C');
904   int32 stride_rows = GetTensorDim(strides, data_format, 'H');
905   int32 stride_cols = GetTensorDim(strides, data_format, 'W');
906   int32 kernel_depth = GetTensorDim(kernel_sizes, data_format, 'C');
907   int32 kernel_rows = GetTensorDim(kernel_sizes, data_format, 'H');
908   int32 kernel_cols = GetTensorDim(kernel_sizes, data_format, 'W');
909 
910   constexpr int num_spatial_dims = 2;
911   DimensionHandle batch_size_dim = c->Dim(
912       input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'N'));
913   DimensionHandle in_rows_dim = c->Dim(
914       input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'H'));
915   DimensionHandle in_cols_dim = c->Dim(
916       input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'W'));
917   DimensionHandle in_depth_dim = c->Dim(
918       input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'C'));
919 
920   Padding padding;
921   TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
922 
923   ShapeHandle output_shape;
924   DimensionHandle output_rows, output_cols, output_depth;
925   TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
926       c, in_rows_dim, kernel_rows, stride_rows, padding, &output_rows));
927   TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
928       c, in_cols_dim, kernel_cols, stride_cols, padding, &output_cols));
929   TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
930       c, in_depth_dim, kernel_depth, stride_depth, padding, &output_depth));
931 
932   TF_RETURN_IF_ERROR(MakeShapeFromFormat(data_format, batch_size_dim,
933                                          {output_rows, output_cols},
934                                          output_depth, &output_shape, c));
935 
936   c->set_output(0, output_shape);
937   return Status::OK();
938 }
939 
MaxPoolV2Shape(shape_inference::InferenceContext * c,int num_inputs)940 Status MaxPoolV2Shape(shape_inference::InferenceContext* c, int num_inputs) {
941   string data_format_str;
942   TensorFormat data_format;
943   Status s = c->GetAttr("data_format", &data_format_str);
944   if (s.ok()) {
945     FormatFromString(data_format_str, &data_format);
946   } else {
947     data_format = FORMAT_NHWC;
948   }
949 
950   const int rank = (data_format == FORMAT_NCHW_VECT_C) ? 5 : 4;
951   ShapeHandle input_shape;
952   TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &input_shape));
953 
954   TF_RETURN_IF_ERROR(
955       CheckFormatConstraintsOnShape(data_format, input_shape, "input", c));
956 
957   std::vector<int32> kernel_sizes;
958   std::vector<int32> strides;
959 
960   if (c->num_inputs() + 2 == num_inputs) {
961     TF_RETURN_IF_ERROR(c->GetAttr("ksize", &kernel_sizes));
962 
963     TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
964   } else {
965     // Verify shape of ksize and strides input.
966     ShapeHandle size;
967     DimensionHandle unused;
968     TF_RETURN_IF_ERROR(c->WithRank(c->input(c->num_inputs() - 2), 1, &size));
969     TF_RETURN_IF_ERROR(c->WithValue(c->Dim(size, 0), 4, &unused));
970     TF_RETURN_IF_ERROR(c->WithRank(c->input(c->num_inputs() - 1), 1, &size));
971     TF_RETURN_IF_ERROR(c->WithValue(c->Dim(size, 0), 4, &unused));
972 
973     const Tensor* kernel_sizes_tensor = c->input_tensor(c->num_inputs() - 2);
974     if (kernel_sizes_tensor == nullptr) {
975       c->set_output(0, c->UnknownShape());
976       return Status::OK();
977     }
978     kernel_sizes.resize(kernel_sizes_tensor->shape().num_elements());
979     auto kernel_sizes_vec = kernel_sizes_tensor->flat<int32>();
980     std::copy_n(&kernel_sizes_vec(0), kernel_sizes.size(),
981                 kernel_sizes.begin());
982 
983     const Tensor* strides_tensor = c->input_tensor(c->num_inputs() - 1);
984     if (strides_tensor == nullptr) {
985       c->set_output(0, c->UnknownShape());
986       return Status::OK();
987     }
988     strides.resize(strides_tensor->shape().num_elements());
989     auto strides_vec = strides_tensor->flat<int32>();
990     std::copy_n(&strides_vec(0), strides.size(), strides.begin());
991   }
992 
993   if (strides.size() != 4) {
994     return errors::InvalidArgument(
995         "MaxPool requires the stride attribute to contain 4 values, but "
996         "got: ",
997         strides.size());
998   }
999   if (kernel_sizes.size() != 4) {
1000     return errors::InvalidArgument(
1001         "MaxPool requires the ksize attribute to contain 4 values, but got: ",
1002         kernel_sizes.size());
1003   }
1004 
1005   int32 stride_depth = GetTensorDim(strides, data_format, 'C');
1006   int32 stride_rows = GetTensorDim(strides, data_format, 'H');
1007   int32 stride_cols = GetTensorDim(strides, data_format, 'W');
1008   int32 kernel_depth = GetTensorDim(kernel_sizes, data_format, 'C');
1009   int32 kernel_rows = GetTensorDim(kernel_sizes, data_format, 'H');
1010   int32 kernel_cols = GetTensorDim(kernel_sizes, data_format, 'W');
1011 
1012   constexpr int num_spatial_dims = 2;
1013   DimensionHandle batch_size_dim = c->Dim(
1014       input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'N'));
1015   DimensionHandle in_rows_dim = c->Dim(
1016       input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'H'));
1017   DimensionHandle in_cols_dim = c->Dim(
1018       input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'W'));
1019   DimensionHandle in_depth_dim = c->Dim(
1020       input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'C'));
1021 
1022   Padding padding;
1023   TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
1024 
1025   ShapeHandle output_shape;
1026   DimensionHandle output_rows, output_cols, output_depth;
1027   TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
1028       c, in_rows_dim, kernel_rows, stride_rows, padding, &output_rows));
1029   TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
1030       c, in_cols_dim, kernel_cols, stride_cols, padding, &output_cols));
1031   TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
1032       c, in_depth_dim, kernel_depth, stride_depth, padding, &output_depth));
1033 
1034   TF_RETURN_IF_ERROR(MakeShapeFromFormat(data_format, batch_size_dim,
1035                                          {output_rows, output_cols},
1036                                          output_depth, &output_shape, c));
1037 
1038   c->set_output(0, output_shape);
1039   return Status::OK();
1040 }
1041 
Pool3DShape(shape_inference::InferenceContext * c)1042 Status Pool3DShape(shape_inference::InferenceContext* c) {
1043   ShapeHandle input_shape;
1044   TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 5, &input_shape));
1045 
1046   string data_format;
1047   Status s = c->GetAttr("data_format", &data_format);
1048 
1049   std::vector<int32> strides;
1050   TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
1051   if (strides.size() != 5) {
1052     return errors::InvalidArgument(
1053         "Pool3D ops require the stride attribute to contain 5 values, but "
1054         "got: ",
1055         strides.size());
1056   }
1057 
1058   std::vector<int32> kernel_sizes;
1059   TF_RETURN_IF_ERROR(c->GetAttr("ksize", &kernel_sizes));
1060   if (kernel_sizes.size() != 5) {
1061     return errors::InvalidArgument(
1062         "Pool3D requires the ksize attribute to contain 5 values, but got: ",
1063         kernel_sizes.size());
1064   }
1065 
1066   int32 stride_planes, stride_rows, stride_cols;
1067   int32 kernel_planes, kernel_rows, kernel_cols;
1068 
1069   if (s.ok() && data_format == "NCDHW") {
1070     // Convert input_shape to NDHWC.
1071     auto dim = [&](char dimension) {
1072       return c->Dim(input_shape, GetTensorDimIndex<3>(FORMAT_NCHW, dimension));
1073     };
1074     input_shape =
1075         c->MakeShape({{dim('N'), dim('0'), dim('1'), dim('2'), dim('C')}});
1076     stride_planes = strides[2];
1077     stride_rows = strides[3];
1078     stride_cols = strides[4];
1079     kernel_planes = kernel_sizes[2];
1080     kernel_rows = kernel_sizes[3];
1081     kernel_cols = kernel_sizes[4];
1082   } else {
1083     stride_planes = strides[1];
1084     stride_rows = strides[2];
1085     stride_cols = strides[3];
1086     kernel_planes = kernel_sizes[1];
1087     kernel_rows = kernel_sizes[2];
1088     kernel_cols = kernel_sizes[3];
1089   }
1090 
1091   DimensionHandle batch_size_dim = c->Dim(input_shape, 0);
1092   DimensionHandle in_planes_dim = c->Dim(input_shape, 1);
1093   DimensionHandle in_rows_dim = c->Dim(input_shape, 2);
1094   DimensionHandle in_cols_dim = c->Dim(input_shape, 3);
1095   DimensionHandle output_depth_dim = c->Dim(input_shape, 4);
1096 
1097   Padding padding;
1098   TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
1099 
1100   // TODO(mrry,shlens): Raise an error if the stride would cause
1101   // information in the input to be ignored. This will require a change
1102   // in the kernel implementation.
1103   DimensionHandle output_planes, output_rows, output_cols;
1104   TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
1105       c, in_planes_dim, kernel_planes, stride_planes, padding, &output_planes));
1106   TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
1107       c, in_rows_dim, kernel_rows, stride_rows, padding, &output_rows));
1108   TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
1109       c, in_cols_dim, kernel_cols, stride_cols, padding, &output_cols));
1110 
1111   ShapeHandle output_shape;
1112   if (data_format == "NCDHW") {
1113     output_shape = c->MakeShape({batch_size_dim, output_depth_dim,
1114                                  output_planes, output_rows, output_cols});
1115   } else {
1116     output_shape = c->MakeShape({batch_size_dim, output_planes, output_rows,
1117                                  output_cols, output_depth_dim});
1118   }
1119 
1120   c->set_output(0, output_shape);
1121   return Status::OK();
1122 }
1123 
UnknownShape(shape_inference::InferenceContext * c)1124 Status UnknownShape(shape_inference::InferenceContext* c) {
1125   for (int i = 0; i < c->num_outputs(); ++i) {
1126     c->set_output(i, c->UnknownShape());
1127   }
1128   return Status::OK();
1129 }
1130 
1131 template <typename T>
ReductionShapeHelper(const Tensor * reduction_indices_t,const int32 input_rank,std::set<int64> * true_indices)1132 Status ReductionShapeHelper(const Tensor* reduction_indices_t,
1133                             const int32 input_rank,
1134                             std::set<int64>* true_indices) {
1135   auto reduction_indices = reduction_indices_t->flat<T>();
1136   for (int i = 0; i < reduction_indices_t->NumElements(); ++i) {
1137     const T reduction_index = reduction_indices(i);
1138     if (reduction_index < -input_rank || reduction_index >= input_rank) {
1139       return errors::InvalidArgument("Invalid reduction dimension ",
1140                                      reduction_index, " for input with ",
1141                                      input_rank, " dimensions.");
1142     }
1143 
1144     auto wrapped_index = reduction_index;
1145     if (wrapped_index < 0) {
1146       wrapped_index += input_rank;
1147     }
1148 
1149     true_indices->insert(wrapped_index);
1150   }
1151   return Status::OK();
1152 }
1153 
ReductionShape(InferenceContext * c)1154 Status ReductionShape(InferenceContext* c) {
1155   ShapeHandle input = c->input(0);
1156 
1157   ShapeHandle indices;
1158   // Older versions of TensorFlow accidentally allowed higher rank tensors like
1159   // [[1,2]] or [[1],[2]] to represent axis=[1,2].
1160   if (c->graph_def_version() < 21) {
1161     indices = c->input(1);
1162   } else {
1163     TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(1), 1, &indices));
1164   }
1165 
1166   bool keep_dims;
1167   TF_RETURN_IF_ERROR(c->GetAttr("keep_dims", &keep_dims));
1168 
1169   const Tensor* reduction_indices_t = c->input_tensor(1);
1170   if (reduction_indices_t == nullptr || !c->RankKnown(input)) {
1171     // If we do not have the reduction values at runtime, or the
1172     // rank of the input, we don't know the output shape.
1173 
1174     if (keep_dims && c->RankKnown(input)) {
1175       // output rank matches input input if <keep_dims>.
1176       c->set_output(0, c->UnknownShapeOfRank(c->Rank(input)));
1177       return Status::OK();
1178     } else {
1179       return shape_inference::UnknownShape(c);
1180     }
1181   }
1182 
1183   const int32 input_rank = c->Rank(input);
1184   std::set<int64> true_indices;
1185   if (reduction_indices_t->dtype() == DataType::DT_INT32) {
1186     TF_RETURN_IF_ERROR(ReductionShapeHelper<int32>(reduction_indices_t,
1187                                                    input_rank, &true_indices));
1188   } else if (reduction_indices_t->dtype() == DataType::DT_INT64) {
1189     TF_RETURN_IF_ERROR(ReductionShapeHelper<int64>(reduction_indices_t,
1190                                                    input_rank, &true_indices));
1191   } else {
1192     return errors::InvalidArgument(
1193         "reduction_indices can only be int32 or int64");
1194   }
1195 
1196   std::vector<DimensionHandle> dims;
1197   for (int i = 0; i < input_rank; ++i) {
1198     if (true_indices.count(i) > 0) {
1199       if (keep_dims) {
1200         dims.emplace_back(c->MakeDim(1));
1201       }
1202     } else {
1203       dims.emplace_back(c->Dim(input, i));
1204     }
1205   }
1206 
1207   c->set_output(0, c->MakeShape(dims));
1208   return Status::OK();
1209 }
1210 
ConcatShapeHelper(InferenceContext * c,int start_value_index,int end_value_index,int dim_index)1211 Status ConcatShapeHelper(InferenceContext* c, int start_value_index,
1212                          int end_value_index, int dim_index) {
1213   ShapeHandle unused;
1214   TF_RETURN_IF_ERROR(c->WithRank(c->input(dim_index), 0, &unused));
1215   const Tensor* concat_dim_t = c->input_tensor(dim_index);
1216   if (concat_dim_t == nullptr) {
1217     // Return an unknown shape with same rank as inputs, or an unknown rank
1218     // if no input's rank is known.
1219 
1220     // Find rank.
1221     int32 rank = InferenceContext::kUnknownRank;
1222     for (int i = start_value_index; i < end_value_index; ++i) {
1223       if (rank == InferenceContext::kUnknownRank) rank = c->Rank(c->input(i));
1224       if (rank != InferenceContext::kUnknownRank) {
1225         break;
1226       }
1227     }
1228     if (rank == InferenceContext::kUnknownRank) {
1229       c->set_output(0, c->UnknownShape());
1230       return Status::OK();
1231     } else if (rank == 0) {
1232       return errors::InvalidArgument(
1233           "Can't concatenate scalars (use tf.stack instead)");
1234     } else {
1235       for (int i = start_value_index; i < end_value_index; ++i) {
1236         // Check that all the inputs are of the correct rank.
1237         TF_RETURN_IF_ERROR(c->WithRank(c->input(i), rank, &unused));
1238       }
1239     }
1240     // Build result of <rank> different unknown dims.
1241     std::vector<DimensionHandle> dims;
1242     dims.reserve(rank);
1243     for (int i = 0; i < rank; ++i) dims.push_back(c->UnknownDim());
1244     c->set_output(0, c->MakeShape(dims));
1245     return Status::OK();
1246   }
1247 
1248   // Merge all the non-concat dims, and sum the concat dim to make an output
1249   // shape.
1250   const int32 concat_dim = concat_dim_t->scalar<int32>()();
1251 
1252   // Minimum required number of dimensions.
1253   const int min_rank = concat_dim < 0 ? -concat_dim : concat_dim + 1;
1254 
1255   ShapeHandle output_before;
1256   ShapeHandle output_after;
1257 
1258   ShapeHandle input = c->input(end_value_index - 1);
1259   TF_RETURN_IF_ERROR(c->WithRankAtLeast(input, min_rank, &input));
1260   TF_RETURN_IF_ERROR(c->Subshape(input, 0, concat_dim, &output_before));
1261   DimensionHandle output_middle = c->Dim(input, concat_dim);
1262   if (concat_dim == -1) {
1263     output_after = c->Scalar();  // no dimensions.
1264   } else {
1265     TF_RETURN_IF_ERROR(c->Subshape(input, concat_dim + 1, &output_after));
1266   }
1267 
1268   for (int i = end_value_index - 2; i >= start_value_index; --i) {
1269     ShapeHandle before;
1270     ShapeHandle after;
1271     input = c->input(i);
1272     TF_RETURN_IF_ERROR(c->WithRankAtLeast(input, min_rank, &input));
1273     TF_RETURN_IF_ERROR(c->Subshape(input, 0, concat_dim, &before));
1274     DimensionHandle middle = c->Dim(input, concat_dim);
1275     if (concat_dim == -1) {
1276       after = c->Scalar();
1277     } else {
1278       TF_RETURN_IF_ERROR(c->Subshape(input, concat_dim + 1, &after));
1279     }
1280 
1281     TF_RETURN_IF_ERROR(c->Merge(before, output_before, &output_before));
1282     TF_RETURN_IF_ERROR(c->Add(output_middle, middle, &output_middle));
1283     TF_RETURN_IF_ERROR(c->Merge(after, output_after, &output_after));
1284   }
1285 
1286   ShapeHandle s;
1287   TF_RETURN_IF_ERROR(
1288       c->Concatenate(output_before, c->Vector(output_middle), &s));
1289   TF_RETURN_IF_ERROR(c->Concatenate(s, output_after, &s));
1290   c->set_output(0, s);
1291   return Status::OK();
1292 }
1293 
ConcatShape(InferenceContext * c,int num_inputs_to_concat)1294 Status ConcatShape(InferenceContext* c, int num_inputs_to_concat) {
1295   return ConcatShapeHelper(c, 1 /* start_value_index */,
1296                            1 + num_inputs_to_concat /* end_value_index */,
1297                            0 /* dim_index */);
1298 }
1299 
ConcatV2Shape(InferenceContext * c)1300 Status ConcatV2Shape(InferenceContext* c) {
1301   return ConcatShapeHelper(c, 0 /* start_value_index */,
1302                            c->num_inputs() - 1 /* end_value_index */,
1303                            c->num_inputs() - 1 /* dim_index */);
1304 }
1305 
QuantizedConcatV2Shape(InferenceContext * c,int num_inputs_to_concat)1306 Status QuantizedConcatV2Shape(InferenceContext* c, int num_inputs_to_concat) {
1307   return ConcatShapeHelper(c, 0 /* start_value_index */,
1308                            num_inputs_to_concat /* end_value_index */,
1309                            num_inputs_to_concat /* dim_index */);
1310 }
1311 
BroadcastBinaryOpOutputShapeFnHelper(InferenceContext * c,ShapeHandle shape_x,ShapeHandle shape_y,ShapeHandle * out)1312 Status BroadcastBinaryOpOutputShapeFnHelper(InferenceContext* c,
1313                                             ShapeHandle shape_x,
1314                                             ShapeHandle shape_y,
1315                                             ShapeHandle* out) {
1316   CHECK_NOTNULL(out);
1317   if (!c->RankKnown(shape_x) || !c->RankKnown(shape_y)) {
1318     *out = c->UnknownShape();
1319     return Status::OK();
1320   }
1321   const int32 rank_x = c->Rank(shape_x);
1322   const int32 rank_y = c->Rank(shape_y);
1323   const int32 rank_out = std::max(rank_x, rank_y);
1324 
1325   // To compute the broadcast dimensions, we zip together shape_x and shape_y
1326   // and
1327   // pad with 1 to make them the same length.
1328   std::vector<DimensionHandle> dims;
1329   DimensionHandle dim_one;
1330   if (rank_x != rank_y) dim_one = c->MakeDim(1);
1331   for (int i = 0; i < rank_out; ++i) {
1332     const auto dim_x = i < (rank_out - rank_x)
1333                            ? dim_one
1334                            : c->Dim(shape_x, i - (rank_out - rank_x));
1335     const bool dim_y_is_one = (i < (rank_out - rank_y));
1336     const auto dim_y =
1337         dim_y_is_one ? dim_one : c->Dim(shape_y, i - (rank_out - rank_y));
1338     if (!c->ValueKnown(dim_x) || !c->ValueKnown(dim_y)) {
1339       // One or both dimensions is unknown.
1340       //
1341       // - If either dimension is greater than 1, we assume that the program is
1342       // correct, and the other dimension will be broadcast to match it.
1343       // TODO(cwhipkey): For shape inference, if we eliminate the shape checks
1344       // in C++ op code, we must still assert that the unknown dim is either 1
1345       // or the same as the known dim.
1346       // - If either dimension is 1, the other dimension is the output.
1347       if (c->Value(dim_x) > 1) {
1348         dims.push_back(dim_x);
1349       } else if (c->Value(dim_y) > 1) {
1350         dims.push_back(dim_y);
1351       } else if (c->Value(dim_x) == 1) {
1352         dims.push_back(dim_y);
1353       } else if (c->Value(dim_y) == 1) {
1354         dims.push_back(dim_x);
1355       } else if (dim_y.SameHandle(dim_x)) {
1356         dims.push_back(dim_x);
1357       } else {
1358         dims.push_back(c->UnknownDim());
1359       }
1360     } else if (c->Value(dim_x) == 1 || c->Value(dim_y) == 1) {
1361       if (c->Value(dim_x) == 1 && !dim_y_is_one) {
1362         // We will broadcast dim_x to dim_y.
1363         dims.push_back(dim_y);
1364       } else {
1365         DCHECK_EQ(c->Value(dim_y), 1);
1366         // We will broadcast dim_y to dim_x.
1367         dims.push_back(dim_x);
1368       }
1369     } else {
1370       DimensionHandle dim;
1371       TF_RETURN_IF_ERROR(c->Merge(dim_x, dim_y, &dim));
1372       dims.push_back(dim);
1373     }
1374   }
1375 
1376   *out = c->MakeShape(dims);
1377   return Status::OK();
1378 }
1379 
RandomShape(shape_inference::InferenceContext * c)1380 Status RandomShape(shape_inference::InferenceContext* c) {
1381   shape_inference::ShapeHandle out;
1382   TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out));
1383   c->set_output(0, out);
1384   return Status::OK();
1385 }
1386 
1387 namespace {
1388 
1389 // This SliceHelper processes the output shape of the `slice`
1390 // when the tensor of `sizes` is available.
1391 template <typename T>
SliceHelper(InferenceContext * c,ShapeHandle begin_value,const Tensor * sizes_value,std::vector<DimensionHandle> * dims)1392 Status SliceHelper(InferenceContext* c, ShapeHandle begin_value,
1393                    const Tensor* sizes_value,
1394                    std::vector<DimensionHandle>* dims) {
1395   auto sizes_vec = sizes_value->vec<T>();
1396   for (int i = 0; i < sizes_value->NumElements(); ++i) {
1397     DimensionHandle dim = c->Dim(c->input(0), i);
1398     if (sizes_vec(i) != -1) {
1399       auto dim_val = c->Value(dim);
1400       if (sizes_vec(i) < 0) {
1401         return errors::InvalidArgument(
1402             "Out of bounds slicing on dimension ", i, " of length ", dim_val,
1403             ": sizes vector cannot be < -1, but was ", sizes_vec(i));
1404       }
1405 
1406       dims->emplace_back(c->MakeDim(sizes_vec(i)));
1407     } else {
1408       DimensionHandle result;
1409       TF_RETURN_IF_ERROR(c->Subtract(dim, c->Dim(begin_value, i), &result));
1410       dims->emplace_back(result);
1411     }
1412   }
1413 
1414   return Status::OK();
1415 }
1416 }  // namespace
1417 
SliceShape(InferenceContext * c)1418 Status SliceShape(InferenceContext* c) {
1419   ShapeHandle input = c->input(0);
1420   ShapeHandle begin_shape;
1421   TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &begin_shape));
1422   ShapeHandle sizes_shape;
1423   TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &sizes_shape));
1424 
1425   // Merge to check compatibility of begin and sizes tensors.
1426   TF_RETURN_IF_ERROR(c->Merge(begin_shape, sizes_shape, &begin_shape));
1427 
1428   DimensionHandle ndims = c->Dim(begin_shape, 0);
1429   if (c->ValueKnown(ndims)) {
1430     TF_RETURN_IF_ERROR(c->WithRank(input, c->Value(ndims), &input));
1431   }
1432 
1433   // NOTE(mrry): Use MakeShapeFromShapeTensor to handle partially-known
1434   // values, even though the `begin` value does not represent a shape.
1435   ShapeHandle begin_value;
1436   TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &begin_value));
1437 
1438   // We check the tensor value here and will only use
1439   // `MakeShapeFromShapeTensor` when `sizes_value` is null.
1440   // The reason is that `sizes` might contain -1, which can't
1441   // be represented (-1 in the ShapeHandle would mean "unknown").
1442   const Tensor* sizes_value = c->input_tensor(2);
1443 
1444   if (sizes_value != nullptr) {
1445     TF_RETURN_IF_ERROR(
1446         c->WithRank(begin_value, sizes_value->NumElements(), &begin_value));
1447     std::vector<DimensionHandle> dims;
1448     // If the begin and sizes tensors are available, then
1449     // we can be precise about the shape of the output.
1450     if (sizes_value->dtype() == DT_INT64) {
1451       TF_RETURN_IF_ERROR(
1452           SliceHelper<int64>(c, begin_value, sizes_value, &dims));
1453     } else {
1454       TF_RETURN_IF_ERROR(
1455           SliceHelper<int32>(c, begin_value, sizes_value, &dims));
1456     }
1457     c->set_output(0, c->MakeShape(dims));
1458     return Status::OK();
1459   } else {
1460     // In case `sizes` is not available (`sizes_value` is null),
1461     // we could try to use `MakeShapeFromShapeTensor` here.
1462     // If sizes contain -1, we will simply consider it as `Unknown`.
1463     // This is less than ideal but still an improvement of shape inference.
1464     // The following is an example that returns [None, 1, None] with this
1465     // code path:
1466     //   z = tf.zeros((1, 2, 3))
1467     //   m = tf.slice(z, [0, 0, 0], [tf.constant(1) + 0, 1, -1])
1468     //   m.get_shape().as_list()
1469     ShapeHandle sizes_value;
1470     TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(2, &sizes_value));
1471     if (c->RankKnown(sizes_value)) {
1472       TF_RETURN_IF_ERROR(
1473           c->WithRank(begin_value, c->Rank(sizes_value), &begin_value));
1474       std::vector<DimensionHandle> dims;
1475       dims.reserve(c->Rank(sizes_value));
1476       for (int i = 0; i < c->Rank(sizes_value); ++i) {
1477         dims.emplace_back(c->Dim(sizes_value, i));
1478       }
1479       c->set_output(0, c->MakeShape(dims));
1480       return Status::OK();
1481     }
1482     // We might know the rank of the input.
1483     if (c->RankKnown(input)) {
1484       c->set_output(0, c->UnknownShapeOfRank(c->Rank(input)));
1485       return Status::OK();
1486     } else {
1487       return shape_inference::UnknownShape(c);
1488     }
1489   }
1490 
1491   return Status::OK();
1492 }
1493 
ValidateSparseTensor(InferenceContext * c,ShapeHandle indices_shape,ShapeHandle values_shape,ShapeHandle shape_shape)1494 Status ValidateSparseTensor(InferenceContext* c, ShapeHandle indices_shape,
1495                             ShapeHandle values_shape, ShapeHandle shape_shape) {
1496   // Validate ranks.
1497   ShapeHandle unused_shape;
1498   TF_RETURN_IF_ERROR(c->WithRank(indices_shape, 2, &unused_shape));
1499   TF_RETURN_IF_ERROR(c->WithRank(values_shape, 1, &unused_shape));
1500   TF_RETURN_IF_ERROR(c->WithRank(shape_shape, 1, &unused_shape));
1501 
1502   // Number of elements in indices and values must match.
1503   DimensionHandle num_index_elements_dim = c->Dim(indices_shape, 0);
1504   if (c->ValueKnown(num_index_elements_dim)) {
1505     DimensionHandle num_values_elements_dim = c->Dim(values_shape, 0);
1506     if (c->ValueKnown(num_values_elements_dim)) {
1507       int64 num_index_elements = c->Value(num_index_elements_dim);
1508       int64 num_values_elements = c->Value(num_values_elements_dim);
1509       if (num_index_elements != num_values_elements) {
1510         return errors::InvalidArgument("Number of elements in index (",
1511                                        num_index_elements, ") and values (",
1512                                        num_values_elements, ") do not match.");
1513       }
1514     }
1515   }
1516 
1517   // Rank embedded in indices must match shape.
1518   DimensionHandle index_rank_dim = c->Dim(indices_shape, 1);
1519   if (c->ValueKnown(index_rank_dim)) {
1520     DimensionHandle shape_rank_dim = c->Dim(shape_shape, 0);
1521     if (c->ValueKnown(shape_rank_dim)) {
1522       int64 index_rank = c->Value(index_rank_dim);
1523       int32 shape_rank = c->Value(shape_rank_dim);
1524       if (index_rank != shape_rank) {
1525         return errors::InvalidArgument("Index rank (", index_rank,
1526                                        ") and shape rank (", shape_rank,
1527                                        ") do not match.");
1528       }
1529     }
1530   }
1531 
1532   return Status::OK();
1533 }
1534 
ScatterNdUpdateShape(InferenceContext * c)1535 Status ScatterNdUpdateShape(InferenceContext* c) {
1536   ShapeHandle input_shape = c->input(0);
1537   if (c->input_handle_shapes_and_types(0) != nullptr) {
1538     // This is called for tf.scatter_nd_update; input is a Variable handle.
1539     const auto& shape_and_type = *(c->input_handle_shapes_and_types(0));
1540     if (shape_and_type.size() == 1) {
1541       input_shape = shape_and_type[0].shape;
1542     }
1543   }
1544   ShapeHandle indices_shape;
1545   TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &indices_shape));
1546   ShapeHandle updates_shape;
1547   TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(2), 1, &updates_shape));
1548 
1549   if (c->Value(c->NumElements(input_shape)) == 0 &&
1550       (c->Value(c->NumElements(indices_shape)) > 0 ||
1551        c->Value(c->NumElements(updates_shape)) > 0)) {
1552     return errors::InvalidArgument(
1553         "Indices and updates specified for empty output shape");
1554   }
1555 
1556   if (c->RankKnown(indices_shape) && c->RankKnown(updates_shape)) {
1557     const int64 num_outer_dims = c->Rank(indices_shape) - 1;
1558     const DimensionHandle index_size = c->Dim(indices_shape, -1);
1559 
1560     // We can only do more validation if the last dimension of indices
1561     // is a known value.
1562     if (c->ValueKnown(index_size)) {
1563       const int64 ix = c->Value(index_size);
1564       ShapeHandle unused;
1565       ShapeHandle prefix_indices;
1566       TF_RETURN_IF_ERROR(
1567           c->Subshape(indices_shape, 0, num_outer_dims, &prefix_indices));
1568       ShapeHandle prefix_updates;
1569       TF_RETURN_IF_ERROR(
1570           c->Subshape(updates_shape, 0, num_outer_dims, &prefix_updates));
1571 
1572       Status s = c->Merge(prefix_indices, prefix_updates, &unused);
1573       if (!s.ok()) {
1574         return errors::InvalidArgument(
1575             "The outer ", num_outer_dims,
1576             " dimensions of indices.shape=", c->DebugString(indices_shape),
1577             " must match the outer ", num_outer_dims,
1578             " dimensions of updates.shape=", c->DebugString(updates_shape),
1579             ": ", s.error_message());
1580       }
1581 
1582       ShapeHandle input_suffix;
1583       TF_RETURN_IF_ERROR(c->Subshape(input_shape, ix, &input_suffix));
1584       ShapeHandle suffix_updates;
1585       TF_RETURN_IF_ERROR(
1586           c->Subshape(updates_shape, num_outer_dims, &suffix_updates));
1587       s = c->Merge(input_suffix, suffix_updates, &unused);
1588       if (!s.ok()) {
1589         return errors::InvalidArgument(
1590             "The inner ", c->Rank(input_shape) - ix,
1591             " dimensions of input.shape=", c->DebugString(input_shape),
1592             " must match the inner ", c->Rank(updates_shape) - num_outer_dims,
1593             " dimensions of updates.shape=", c->DebugString(updates_shape),
1594             ": ", s.error_message());
1595       }
1596     }
1597   }
1598 
1599   if (c->input_handle_shapes_and_types(0) == nullptr && c->num_outputs() > 0) {
1600     // This is called for tf.scatter_nd; output is a tensor with this shape.
1601     c->set_output(0, input_shape);
1602   }
1603   return Status::OK();
1604 }
1605 
ExplicitShape(InferenceContext * c)1606 Status ExplicitShape(InferenceContext* c) {
1607   PartialTensorShape shape;
1608   TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape));
1609   ShapeHandle output_shape;
1610   TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(shape, &output_shape));
1611   c->set_output(0, output_shape);
1612   return Status::OK();
1613 }
1614 
ExplicitShapes(InferenceContext * c)1615 Status ExplicitShapes(InferenceContext* c) {
1616   std::vector<PartialTensorShape> shapes;
1617   TF_RETURN_IF_ERROR(c->GetAttr("shapes", &shapes));
1618   if (shapes.empty()) {
1619     return errors::Internal("shapes attribute is empty");
1620   }
1621   for (int i = 0; i < shapes.size(); ++i) {
1622     ShapeHandle output_shape;
1623     TF_RETURN_IF_ERROR(
1624         c->MakeShapeFromPartialTensorShape(shapes[i], &output_shape));
1625     c->set_output(i, output_shape);
1626   }
1627   return Status::OK();
1628 }
1629 
SparseReduceShapeFn(InferenceContext * c)1630 Status SparseReduceShapeFn(InferenceContext* c) {
1631   // Input 0: input_indices
1632   // Input 1: input_values
1633   // Input 2: input_shape
1634   // Input 3: reduction_axes
1635   // Attr: keep_dims
1636   bool keep_dims = false;
1637   TF_RETURN_IF_ERROR(c->GetAttr("keep_dims", &keep_dims));
1638 
1639   const Tensor* shape_tensor = c->input_tensor(2);
1640   const Tensor* axes_tensor = c->input_tensor(3);
1641   if (shape_tensor != nullptr && axes_tensor != nullptr) {
1642     auto shape_vec = shape_tensor->flat<int64>();
1643     auto axes_vec = axes_tensor->flat<int32>();
1644 
1645     int64 ndims = shape_vec.size();
1646     std::unordered_set<int64> axes;
1647     for (int i = 0; i < axes_vec.size(); i++) {
1648       axes.insert((axes_vec(i) + ndims) % ndims);
1649     }
1650 
1651     std::vector<DimensionHandle> dims;
1652     if (keep_dims) {
1653       dims.reserve(ndims);
1654       for (int d = 0; d < ndims; ++d) {
1655         if (axes.find(d) == axes.end()) {
1656           dims.push_back(c->MakeDim(shape_vec(d)));
1657         } else {
1658           dims.push_back(c->MakeDim(1));
1659         }
1660       }
1661     } else {
1662       for (int d = 0; d < ndims; ++d) {
1663         if (axes.find(d) == axes.end()) {
1664           dims.push_back(c->MakeDim(shape_vec(d)));
1665         }
1666       }
1667     }
1668 
1669     c->set_output(0, c->MakeShape(dims));
1670     return Status::OK();
1671   }
1672   return UnknownShape(c);
1673 }
1674 
1675 }  // namespace shape_inference
1676 
1677 }  // namespace tensorflow
1678