1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/window_util.h"
17
18 #include <vector>
19
20 #include "absl/algorithm/container.h"
21 #include "absl/strings/str_cat.h"
22 #include "tensorflow/compiler/xla/types.h"
23 #include "tensorflow/compiler/xla/xla_data.pb.h"
24 #include "tensorflow/core/platform/logging.h"
25
26 namespace xla {
27 namespace window_util {
28
MakeWindow(absl::Span<const int64_t> sizes)29 Window MakeWindow(absl::Span<const int64_t> sizes) {
30 Window window;
31 for (int64_t size : sizes) {
32 auto* dimension = window.add_dimensions();
33 dimension->set_size(size);
34 dimension->set_stride(1);
35 dimension->set_base_dilation(1);
36 dimension->set_window_dilation(1);
37 }
38 return window;
39 }
40
MakeWindow(absl::Span<const int64_t> sizes,absl::Span<const int64_t> strides)41 Window MakeWindow(absl::Span<const int64_t> sizes,
42 absl::Span<const int64_t> strides) {
43 Window window;
44 CHECK_EQ(sizes.size(), strides.size());
45 for (auto nb = 0; nb < sizes.size(); ++nb) {
46 auto* dimension = window.add_dimensions();
47 dimension->set_size(sizes[nb]);
48 dimension->set_stride(strides[nb]);
49 dimension->set_base_dilation(1);
50 dimension->set_window_dilation(1);
51 }
52 return window;
53 }
54
MakeSymmetricPadding(absl::Span<const int64_t> sizes)55 PaddingConfig MakeSymmetricPadding(absl::Span<const int64_t> sizes) {
56 PaddingConfig config;
57 for (int64_t size : sizes) {
58 auto* dimension = config.add_dimensions();
59 dimension->set_edge_padding_low(size);
60 dimension->set_edge_padding_high(size);
61 }
62 return config;
63 }
64
ToString(const WindowDimension & dim)65 /* static */ std::string ToString(const WindowDimension& dim) {
66 using absl::StrAppend;
67 using absl::StrCat;
68 std::string str = StrCat("(size=", dim.size());
69 if (dim.stride() != 1) {
70 StrAppend(&str, ",stride=", dim.stride());
71 }
72 if (dim.padding_low() != 0) {
73 StrAppend(&str, ",padding_low=", dim.padding_low());
74 }
75 if (dim.padding_high() != 0) {
76 StrAppend(&str, ",padding_high=", dim.padding_high());
77 }
78 if (dim.base_dilation() != 1) {
79 StrAppend(&str, ",base_dilation=", dim.base_dilation());
80 }
81 if (dim.window_dilation() != 1) {
82 StrAppend(&str, ",window_dilation=", dim.window_dilation());
83 }
84 if (dim.window_reversal()) {
85 StrAppend(&str, ",window_reversal");
86 }
87 StrAppend(&str, ")");
88 return str;
89 }
90
ToString(const Window & window)91 std::string ToString(const Window& window) {
92 using absl::StrAppend;
93 using absl::StrCat;
94
95 std::string str;
96 const auto add_field =
97 [&](const char* heading,
98 std::function<std::string(const WindowDimension&)> format) {
99 StrAppend(&str, heading, "=");
100 const char* prefix = "";
101 for (const auto& window_dimension : window.dimensions()) {
102 StrAppend(&str, prefix, format(window_dimension));
103 prefix = "x";
104 }
105 };
106
107 if (window.dimensions_size() > 0) {
108 add_field("size",
109 [](const WindowDimension& dim) { return StrCat(dim.size()); });
110 }
111 if (HasStride(window)) {
112 add_field(" stride",
113 [](const WindowDimension& dim) { return StrCat(dim.stride()); });
114 }
115 if (HasPadding(window)) {
116 add_field(" pad", [](const WindowDimension& dim) {
117 return StrCat(dim.padding_low(), "_", dim.padding_high());
118 });
119 }
120 if (HasBaseDilation(window)) {
121 add_field(" lhs_dilate", [](const WindowDimension& dim) {
122 return StrCat(dim.base_dilation());
123 });
124 }
125 if (HasWindowDilation(window)) {
126 add_field(" rhs_dilate", [](const WindowDimension& dim) {
127 return StrCat(dim.window_dilation());
128 });
129 }
130 if (HasWindowReversal(window)) {
131 add_field(" rhs_reversal", [](const WindowDimension& dim) {
132 return StrCat(dim.window_reversal() ? 1 : 0);
133 });
134 }
135 return str;
136 }
137
HasStride(const Window & window)138 bool HasStride(const Window& window) {
139 for (const auto& dim : window.dimensions()) {
140 if (dim.stride() != 1) {
141 return true;
142 }
143 }
144 return false;
145 }
146
HasPadding(const Window & window)147 bool HasPadding(const Window& window) {
148 for (const auto& dim : window.dimensions()) {
149 if (dim.padding_low() != 0 || dim.padding_high() != 0) {
150 return true;
151 }
152 }
153 return false;
154 }
155
HasSymmetricPadding(const Window & window)156 bool HasSymmetricPadding(const Window& window) {
157 return absl::c_all_of(window.dimensions(), [](const WindowDimension& dim) {
158 return dim.padding_low() == dim.padding_high();
159 });
160 }
161
HasSymmetricPadding(const PaddingConfig & padding_config)162 bool HasSymmetricPadding(const PaddingConfig& padding_config) {
163 return absl::c_all_of(padding_config.dimensions(),
164 [](const PaddingConfig::PaddingConfigDimension& dim) {
165 return dim.edge_padding_low() ==
166 dim.edge_padding_high();
167 });
168 }
169
HasNegativePadding(const Window & window)170 bool HasNegativePadding(const Window& window) {
171 return absl::c_any_of(window.dimensions(), [](const WindowDimension& dim) {
172 return dim.padding_low() < 0 || dim.padding_high() < 0;
173 });
174 }
175
HasBaseDilation(const Window & window)176 bool HasBaseDilation(const Window& window) {
177 for (const auto& dim : window.dimensions()) {
178 if (dim.base_dilation() != 1) {
179 return true;
180 }
181 }
182 return false;
183 }
184
HasWindowDilation(const Window & window)185 bool HasWindowDilation(const Window& window) {
186 for (const auto& dim : window.dimensions()) {
187 if (dim.window_dilation() != 1) {
188 return true;
189 }
190 }
191 return false;
192 }
193
HasWindowReversal(const Window & window)194 bool HasWindowReversal(const Window& window) {
195 for (const auto& dim : window.dimensions()) {
196 if (dim.window_reversal()) {
197 return true;
198 }
199 }
200 return false;
201 }
202
AllOrNoneReversed(const Window & window)203 bool AllOrNoneReversed(const Window& window) {
204 if (window.dimensions().empty()) {
205 return true;
206 }
207 bool reversed = window.dimensions()[0].window_reversal();
208 return absl::c_all_of(window.dimensions(), [&](const WindowDimension& dim) {
209 return dim.window_reversal() == reversed;
210 });
211 }
212
HasDilation(const Window & window)213 bool HasDilation(const Window& window) {
214 return HasBaseDilation(window) || HasWindowDilation(window);
215 }
216
IsTrivialWindowDimension(const WindowDimension & window_dimension)217 bool IsTrivialWindowDimension(const WindowDimension& window_dimension) {
218 return window_dimension.size() == 1 && window_dimension.stride() == 1 &&
219 window_dimension.padding_low() == 0 &&
220 window_dimension.padding_high() == 0 &&
221 window_dimension.window_dilation() == 1 &&
222 window_dimension.base_dilation() == 1;
223 }
224
HasOverlappingWindow(const Window & window)225 bool HasOverlappingWindow(const Window& window) {
226 for (const auto& dim : window.dimensions()) {
227 if (dim.size() > dim.stride()) {
228 return true;
229 }
230 }
231 return false;
232 }
233
DilatedBound(int64_t bound,int64_t dilation)234 int64_t DilatedBound(int64_t bound, int64_t dilation) {
235 CHECK_GE(bound, 0);
236 CHECK_GE(dilation, 1);
237 if (bound == 0) {
238 return 0;
239 }
240
241 // Suppose the array has three entries 123 and the dilation factor is 4. Then
242 // the dilated array has 9 entries 1xxx2xxx3. Here, each original entry except
243 // the last expands into 4 entries, so that is (bound - 1) * dilation. Then we
244 // add 1 to account for the final input element.
245 return (bound - 1) * dilation + 1;
246 }
247
StridedBound(int64_t bound,int64_t window_size,int64_t stride)248 int64_t StridedBound(int64_t bound, int64_t window_size, int64_t stride) {
249 CHECK_GE(window_size, 0);
250 CHECK_GE(bound, 0);
251 CHECK_GE(stride, 1);
252
253 if (bound == 0 || window_size > bound) {
254 return 0;
255 }
256
257 // Without considering stride, the maximum valid offset is bound -
258 // window_size. Taking stride into account, the valid offsets then have the
259 // form q * stride for q = 0, ..., Q such that q * stride <= bound -
260 // window_size. This implies that Q equals floor(bound - window_size /
261 // stride). There are Q + 1 valid values of q, yielding the formula below.
262 return (bound - window_size) / stride + 1;
263 }
264
265 } // namespace window_util
266 } // namespace xla
267