• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/window_util.h"
17 
18 #include <vector>
19 
20 #include "absl/algorithm/container.h"
21 #include "absl/strings/str_cat.h"
22 #include "tensorflow/compiler/xla/types.h"
23 #include "tensorflow/compiler/xla/xla_data.pb.h"
24 #include "tensorflow/core/platform/logging.h"
25 
26 namespace xla {
27 namespace window_util {
28 
MakeWindow(absl::Span<const int64> sizes)29 Window MakeWindow(absl::Span<const int64> sizes) {
30   Window window;
31   for (int64 size : sizes) {
32     auto* dimension = window.add_dimensions();
33     dimension->set_size(size);
34     dimension->set_stride(1);
35     dimension->set_base_dilation(1);
36     dimension->set_window_dilation(1);
37   }
38   return window;
39 }
40 
MakeWindow(absl::Span<const int64> sizes,absl::Span<const int64> strides)41 Window MakeWindow(absl::Span<const int64> sizes,
42                   absl::Span<const int64> strides) {
43   Window window;
44   CHECK_EQ(sizes.size(), strides.size());
45   for (auto nb = 0; nb < sizes.size(); ++nb) {
46     auto* dimension = window.add_dimensions();
47     dimension->set_size(sizes[nb]);
48     dimension->set_stride(strides[nb]);
49     dimension->set_base_dilation(1);
50     dimension->set_window_dilation(1);
51   }
52   return window;
53 }
54 
MakeSymmetricPadding(absl::Span<const int64> sizes)55 PaddingConfig MakeSymmetricPadding(absl::Span<const int64> sizes) {
56   PaddingConfig config;
57   for (int64 size : sizes) {
58     auto* dimension = config.add_dimensions();
59     dimension->set_edge_padding_low(size);
60     dimension->set_edge_padding_high(size);
61   }
62   return config;
63 }
64 
ToString(const WindowDimension & dim)65 /* static */ string ToString(const WindowDimension& dim) {
66   using absl::StrAppend;
67   using absl::StrCat;
68   string str = StrCat("(size=", dim.size());
69   if (dim.stride() != 1) {
70     StrAppend(&str, ",stride=", dim.stride());
71   }
72   if (dim.padding_low() != 0) {
73     StrAppend(&str, ",padding_low=", dim.padding_low());
74   }
75   if (dim.padding_high() != 0) {
76     StrAppend(&str, ",padding_high=", dim.padding_high());
77   }
78   if (dim.base_dilation() != 1) {
79     StrAppend(&str, ",base_dilation=", dim.base_dilation());
80   }
81   if (dim.window_dilation() != 1) {
82     StrAppend(&str, ",window_dilation=", dim.window_dilation());
83   }
84   if (dim.window_reversal()) {
85     StrAppend(&str, ",window_reversal");
86   }
87   StrAppend(&str, ")");
88   return str;
89 }
90 
ToString(const Window & window)91 string ToString(const Window& window) {
92   using absl::StrAppend;
93   using absl::StrCat;
94 
95   string str;
96   const auto add_field =
97       [&](const char* heading,
98           std::function<string(const WindowDimension&)> format) {
99         StrAppend(&str, heading, "=");
100         const char* prefix = "";
101         for (const auto& window_dimension : window.dimensions()) {
102           StrAppend(&str, prefix, format(window_dimension));
103           prefix = "x";
104         }
105       };
106 
107   if (window.dimensions_size() > 0) {
108     add_field("size",
109               [](const WindowDimension& dim) { return StrCat(dim.size()); });
110   }
111   if (HasStride(window)) {
112     add_field(" stride",
113               [](const WindowDimension& dim) { return StrCat(dim.stride()); });
114   }
115   if (HasPadding(window)) {
116     add_field(" pad", [](const WindowDimension& dim) {
117       return StrCat(dim.padding_low(), "_", dim.padding_high());
118     });
119   }
120   if (HasBaseDilation(window)) {
121     add_field(" lhs_dilate", [](const WindowDimension& dim) {
122       return StrCat(dim.base_dilation());
123     });
124   }
125   if (HasWindowDilation(window)) {
126     add_field(" rhs_dilate", [](const WindowDimension& dim) {
127       return StrCat(dim.window_dilation());
128     });
129   }
130   if (HasWindowReversal(window)) {
131     add_field(" rhs_reversal", [](const WindowDimension& dim) {
132       return StrCat(dim.window_reversal() ? 1 : 0);
133     });
134   }
135   return str;
136 }
137 
HasStride(const Window & window)138 bool HasStride(const Window& window) {
139   for (const auto& dim : window.dimensions()) {
140     if (dim.stride() != 1) {
141       return true;
142     }
143   }
144   return false;
145 }
146 
HasPadding(const Window & window)147 bool HasPadding(const Window& window) {
148   for (const auto& dim : window.dimensions()) {
149     if (dim.padding_low() != 0 || dim.padding_high() != 0) {
150       return true;
151     }
152   }
153   return false;
154 }
155 
HasSymmetricPadding(const Window & window)156 bool HasSymmetricPadding(const Window& window) {
157   return absl::c_all_of(window.dimensions(), [](const WindowDimension& dim) {
158     return dim.padding_low() == dim.padding_high();
159   });
160 }
161 
HasSymmetricPadding(const PaddingConfig & padding_config)162 bool HasSymmetricPadding(const PaddingConfig& padding_config) {
163   return absl::c_all_of(padding_config.dimensions(),
164                         [](const PaddingConfig::PaddingConfigDimension& dim) {
165                           return dim.edge_padding_low() ==
166                                  dim.edge_padding_high();
167                         });
168 }
169 
HasNegativePadding(const Window & window)170 bool HasNegativePadding(const Window& window) {
171   return absl::c_any_of(window.dimensions(), [](const WindowDimension& dim) {
172     return dim.padding_low() < 0 || dim.padding_high() < 0;
173   });
174 }
175 
HasBaseDilation(const Window & window)176 bool HasBaseDilation(const Window& window) {
177   for (const auto& dim : window.dimensions()) {
178     if (dim.base_dilation() != 1) {
179       return true;
180     }
181   }
182   return false;
183 }
184 
HasWindowDilation(const Window & window)185 bool HasWindowDilation(const Window& window) {
186   for (const auto& dim : window.dimensions()) {
187     if (dim.window_dilation() != 1) {
188       return true;
189     }
190   }
191   return false;
192 }
193 
HasWindowReversal(const Window & window)194 bool HasWindowReversal(const Window& window) {
195   for (const auto& dim : window.dimensions()) {
196     if (dim.window_reversal()) {
197       return true;
198     }
199   }
200   return false;
201 }
202 
AllOrNoneReversed(const Window & window)203 bool AllOrNoneReversed(const Window& window) {
204   if (window.dimensions().empty()) {
205     return true;
206   }
207   bool reversed = window.dimensions()[0].window_reversal();
208   return absl::c_all_of(window.dimensions(), [&](const WindowDimension& dim) {
209     return dim.window_reversal() == reversed;
210   });
211 }
212 
HasDilation(const Window & window)213 bool HasDilation(const Window& window) {
214   return HasBaseDilation(window) || HasWindowDilation(window);
215 }
216 
IsTrivialWindowDimension(const WindowDimension & window_dimension)217 bool IsTrivialWindowDimension(const WindowDimension& window_dimension) {
218   return window_dimension.size() == 1 && window_dimension.stride() == 1 &&
219          window_dimension.padding_low() == 0 &&
220          window_dimension.padding_high() == 0 &&
221          window_dimension.window_dilation() == 1 &&
222          window_dimension.base_dilation() == 1;
223 }
224 
HasOverlappingWindow(const Window & window)225 bool HasOverlappingWindow(const Window& window) {
226   for (const auto& dim : window.dimensions()) {
227     if (dim.size() > dim.stride()) {
228       return true;
229     }
230   }
231   return false;
232 }
233 
DilatedBound(int64 bound,int64 dilation)234 int64 DilatedBound(int64 bound, int64 dilation) {
235   CHECK_GE(bound, 0);
236   CHECK_GE(dilation, 1);
237   if (bound == 0) {
238     return 0;
239   }
240 
241   // Suppose the array has three entries 123 and the dilation factor is 4. Then
242   // the dilated array has 9 entries 1xxx2xxx3. Here, each original entry except
243   // the last expands into 4 entries, so that is (bound - 1) * dilation. Then we
244   // add 1 to account for the final input element.
245   return (bound - 1) * dilation + 1;
246 }
247 
StridedBound(int64 bound,int64 window_size,int64 stride)248 int64 StridedBound(int64 bound, int64 window_size, int64 stride) {
249   CHECK_GE(window_size, 0);
250   CHECK_GE(bound, 0);
251   CHECK_GE(stride, 1);
252 
253   if (bound == 0 || window_size > bound) {
254     return 0;
255   }
256 
257   // Without considering stride, the maximum valid offset is bound -
258   // window_size. Taking stride into account, the valid offsets then have the
259   // form q * stride for q = 0, ..., Q such that q * stride <= bound -
260   // window_size. This implies that Q equals floor(bound - window_size /
261   // stride). There are Q + 1 valid values of q, yielding the formula below.
262   return (bound - window_size) / stride + 1;
263 }
264 
265 }  // namespace window_util
266 }  // namespace xla
267