• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/framework/kernel_shape_util.h"
16 
17 #include "tensorflow/core/lib/core/errors.h"
18 
19 namespace tensorflow {
GetWindowedOutputSizeVerboseV2(int64 input_size,int64 filter_size,int64 dilation_rate,int64 stride,Padding padding_type,int64 * output_size,int64 * padding_before,int64 * padding_after)20 Status GetWindowedOutputSizeVerboseV2(int64 input_size, int64 filter_size,
21                                       int64 dilation_rate, int64 stride,
22                                       Padding padding_type, int64* output_size,
23                                       int64* padding_before,
24                                       int64* padding_after) {
25   if (stride <= 0) {
26     return errors::InvalidArgument("Stride must be > 0, but got ", stride);
27   }
28   if (dilation_rate < 1) {
29     return errors::InvalidArgument("Dilation rate must be >= 1, but got ",
30                                    dilation_rate);
31   }
32 
33   // See also the parallel implementation in GetWindowedOutputSizeFromDimsV2.
34   int64 effective_filter_size = (filter_size - 1) * dilation_rate + 1;
35   switch (padding_type) {
36     case Padding::VALID:
37       *output_size = (input_size - effective_filter_size + stride) / stride;
38       *padding_before = *padding_after = 0;
39       break;
40     case Padding::EXPLICIT:
41       *output_size = (input_size + *padding_before + *padding_after -
42                       effective_filter_size + stride) /
43                      stride;
44       break;
45     case Padding::SAME:
46       *output_size = (input_size + stride - 1) / stride;
47       const int64 padding_needed =
48           std::max(int64{0}, (*output_size - 1) * stride +
49                                  effective_filter_size - input_size);
50       // For odd values of total padding, add more padding at the 'right'
51       // side of the given dimension.
52       *padding_before = padding_needed / 2;
53       *padding_after = padding_needed - *padding_before;
54       break;
55   }
56   if (*output_size < 0) {
57     return errors::InvalidArgument(
58         "Computed output size would be negative: ", *output_size,
59         " [input_size: ", input_size,
60         ", effective_filter_size: ", effective_filter_size,
61         ", stride: ", stride, "]");
62   }
63   return Status::OK();
64 }
65 
GetWindowedOutputSizeVerbose(int64 input_size,int64 filter_size,int64 stride,Padding padding_type,int64 * output_size,int64 * padding_before,int64 * padding_after)66 Status GetWindowedOutputSizeVerbose(int64 input_size, int64 filter_size,
67                                     int64 stride, Padding padding_type,
68                                     int64* output_size, int64* padding_before,
69                                     int64* padding_after) {
70   return GetWindowedOutputSizeVerboseV2(input_size, filter_size,
71                                         /*dilation_rate=*/1, stride,
72                                         padding_type, output_size,
73                                         padding_before, padding_after);
74 }
75 
GetWindowedOutputSize(int64 input_size,int64 filter_size,int64 stride,Padding padding_type,int64 * output_size,int64 * padding_size)76 Status GetWindowedOutputSize(int64 input_size, int64 filter_size, int64 stride,
77                              Padding padding_type, int64* output_size,
78                              int64* padding_size) {
79   if (padding_type == Padding::EXPLICIT) {
80     return errors::Internal(
81         "GetWindowedOutputSize does not handle EXPLICIT padding; call "
82         "GetWindowedOutputSizeVerbose instead");
83   }
84   int64 padding_after_unused;
85   return GetWindowedOutputSizeVerbose(input_size, filter_size, stride,
86                                       padding_type, output_size, padding_size,
87                                       &padding_after_unused);
88 }
89 
GetWindowedOutputSizeV2(int64 input_size,int64 filter_size,int64 dilation_rate,int64 stride,Padding padding_type,int64 * output_size,int64 * padding_size)90 Status GetWindowedOutputSizeV2(int64 input_size, int64 filter_size,
91                                int64 dilation_rate, int64 stride,
92                                Padding padding_type, int64* output_size,
93                                int64* padding_size) {
94   if (padding_type == Padding::EXPLICIT) {
95     return errors::Internal(
96         "GetWindowedOutputSizeV2 does not handle EXPLICIT padding; call "
97         "GetWindowedOutputSizeVerboseV2 instead");
98   }
99   int64 padding_after_unused;
100   return GetWindowedOutputSizeVerboseV2(input_size, filter_size, dilation_rate,
101                                         stride, padding_type, output_size,
102                                         padding_size, &padding_after_unused);
103 }
104 
Get3dOutputSize(const std::array<int64,3> & input,const std::array<int64,3> & window,const std::array<int64,3> & strides,Padding padding_type,std::array<int64,3> * output_ptr,std::array<int64,3> * padding_ptr)105 Status Get3dOutputSize(const std::array<int64, 3>& input,
106                        const std::array<int64, 3>& window,
107                        const std::array<int64, 3>& strides,
108                        Padding padding_type, std::array<int64, 3>* output_ptr,
109                        std::array<int64, 3>* padding_ptr) {
110   for (size_t i = 0; i < input.size(); ++i) {
111     TF_RETURN_IF_ERROR(GetWindowedOutputSize(input[i], window[i], strides[i],
112                                              padding_type, &(*output_ptr)[i],
113                                              &(*padding_ptr)[i]));
114   }
115   return Status::OK();
116 }
117 
Get3dOutputSizeV2(const std::array<int64,3> & input,const std::array<int64,3> & window,const std::array<int64,3> & dilations,const std::array<int64,3> & strides,Padding padding_type,std::array<int64,3> * output_ptr,std::array<int64,3> * padding_ptr)118 Status Get3dOutputSizeV2(const std::array<int64, 3>& input,
119                          const std::array<int64, 3>& window,
120                          const std::array<int64, 3>& dilations,
121                          const std::array<int64, 3>& strides,
122                          Padding padding_type, std::array<int64, 3>* output_ptr,
123                          std::array<int64, 3>* padding_ptr) {
124   for (size_t i = 0; i < input.size(); ++i) {
125     TF_RETURN_IF_ERROR(GetWindowedOutputSizeV2(
126         input[i], window[i], dilations[i], strides[i], padding_type,
127         &(*output_ptr)[i], &(*padding_ptr)[i]));
128   }
129   return Status::OK();
130 }
131 }  // namespace tensorflow
132