1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/window_util.h"
17
18 #include <vector>
19
20 #include "absl/algorithm/container.h"
21 #include "absl/strings/str_cat.h"
22 #include "tensorflow/compiler/xla/types.h"
23 #include "tensorflow/compiler/xla/xla_data.pb.h"
24 #include "tensorflow/core/platform/logging.h"
25
26 namespace xla {
27 namespace window_util {
28
MakeWindow(absl::Span<const int64> sizes)29 Window MakeWindow(absl::Span<const int64> sizes) {
30 Window window;
31 for (int64 size : sizes) {
32 auto* dimension = window.add_dimensions();
33 dimension->set_size(size);
34 dimension->set_stride(1);
35 dimension->set_base_dilation(1);
36 dimension->set_window_dilation(1);
37 }
38 return window;
39 }
40
MakeSymmetricPadding(absl::Span<const int64> sizes)41 PaddingConfig MakeSymmetricPadding(absl::Span<const int64> sizes) {
42 PaddingConfig config;
43 for (int64 size : sizes) {
44 auto* dimension = config.add_dimensions();
45 dimension->set_edge_padding_low(size);
46 dimension->set_edge_padding_high(size);
47 }
48 return config;
49 }
50
ToString(const WindowDimension & dim)51 /* static */ string ToString(const WindowDimension& dim) {
52 using absl::StrAppend;
53 using absl::StrCat;
54 string str = StrCat("(size=", dim.size());
55 if (dim.stride() != 1) {
56 StrAppend(&str, ",stride=", dim.stride());
57 }
58 if (dim.padding_low() != 0) {
59 StrAppend(&str, ",padding_low=", dim.padding_low());
60 }
61 if (dim.padding_high() != 0) {
62 StrAppend(&str, ",padding_high=", dim.padding_high());
63 }
64 if (dim.base_dilation() != 1) {
65 StrAppend(&str, ",base_dilation=", dim.base_dilation());
66 }
67 if (dim.window_dilation() != 1) {
68 StrAppend(&str, ",window_dilation=", dim.window_dilation());
69 }
70 if (dim.window_reversal()) {
71 StrAppend(&str, ",window_reversal");
72 }
73 StrAppend(&str, ")");
74 return str;
75 }
76
ToString(const Window & window)77 string ToString(const Window& window) {
78 using absl::StrAppend;
79 using absl::StrCat;
80
81 string str;
82 const auto add_field =
83 [&](const char* heading,
84 std::function<string(const WindowDimension&)> format) {
85 StrAppend(&str, heading, "=");
86 const char* prefix = "";
87 for (const auto& window_dimension : window.dimensions()) {
88 StrAppend(&str, prefix, format(window_dimension));
89 prefix = "x";
90 }
91 };
92
93 add_field("size",
94 [](const WindowDimension& dim) { return StrCat(dim.size()); });
95 if (HasStride(window)) {
96 add_field(" stride",
97 [](const WindowDimension& dim) { return StrCat(dim.stride()); });
98 }
99 if (HasPadding(window)) {
100 add_field(" pad", [](const WindowDimension& dim) {
101 return StrCat(dim.padding_low(), "_", dim.padding_high());
102 });
103 }
104 if (HasBaseDilation(window)) {
105 add_field(" lhs_dilate", [](const WindowDimension& dim) {
106 return StrCat(dim.base_dilation());
107 });
108 }
109 if (HasWindowDilation(window)) {
110 add_field(" rhs_dilate", [](const WindowDimension& dim) {
111 return StrCat(dim.window_dilation());
112 });
113 }
114 if (HasWindowReversal(window)) {
115 add_field(" rhs_reversal", [](const WindowDimension& dim) {
116 return StrCat(dim.window_reversal() ? 1 : 0);
117 });
118 }
119 return str;
120 }
121
HasStride(const Window & window)122 bool HasStride(const Window& window) {
123 for (const auto& dim : window.dimensions()) {
124 if (dim.stride() != 1) {
125 return true;
126 }
127 }
128 return false;
129 }
130
HasPadding(const Window & window)131 bool HasPadding(const Window& window) {
132 for (const auto& dim : window.dimensions()) {
133 if (dim.padding_low() != 0 || dim.padding_high() != 0) {
134 return true;
135 }
136 }
137 return false;
138 }
139
HasSymmetricPadding(const Window & window)140 bool HasSymmetricPadding(const Window& window) {
141 return absl::c_all_of(window.dimensions(), [](const WindowDimension& dim) {
142 return dim.padding_low() == dim.padding_high();
143 });
144 }
145
HasSymmetricPadding(const PaddingConfig & padding_config)146 bool HasSymmetricPadding(const PaddingConfig& padding_config) {
147 return absl::c_all_of(padding_config.dimensions(),
148 [](const PaddingConfig::PaddingConfigDimension& dim) {
149 return dim.edge_padding_low() ==
150 dim.edge_padding_high();
151 });
152 }
153
HasNegativePadding(const Window & window)154 bool HasNegativePadding(const Window& window) {
155 return absl::c_any_of(window.dimensions(), [](const WindowDimension& dim) {
156 return dim.padding_low() < 0 || dim.padding_high() < 0;
157 });
158 }
159
HasBaseDilation(const Window & window)160 bool HasBaseDilation(const Window& window) {
161 for (const auto& dim : window.dimensions()) {
162 if (dim.base_dilation() != 1) {
163 return true;
164 }
165 }
166 return false;
167 }
168
HasWindowDilation(const Window & window)169 bool HasWindowDilation(const Window& window) {
170 for (const auto& dim : window.dimensions()) {
171 if (dim.window_dilation() != 1) {
172 return true;
173 }
174 }
175 return false;
176 }
177
HasWindowReversal(const Window & window)178 bool HasWindowReversal(const Window& window) {
179 for (const auto& dim : window.dimensions()) {
180 if (dim.window_reversal()) {
181 return true;
182 }
183 }
184 return false;
185 }
186
AllOrNoneReversed(const Window & window)187 bool AllOrNoneReversed(const Window& window) {
188 if (window.dimensions().empty()) {
189 return true;
190 }
191 bool reversed = window.dimensions()[0].window_reversal();
192 return absl::c_all_of(window.dimensions(), [&](const WindowDimension& dim) {
193 return dim.window_reversal() == reversed;
194 });
195 }
196
HasDilation(const Window & window)197 bool HasDilation(const Window& window) {
198 return HasBaseDilation(window) || HasWindowDilation(window);
199 }
200
IsInactiveWindowDimension(const Window & window,int64 logical_dim)201 bool IsInactiveWindowDimension(const Window& window, int64 logical_dim) {
202 const WindowDimension& window_dim = window.dimensions(logical_dim);
203 return window_dim.size() == 1 && window_dim.stride() == 1 &&
204 window_dim.padding_low() == 0 && window_dim.padding_high() == 0;
205 }
206
IsTrivialWindowDimension(const WindowDimension & window_dimension)207 bool IsTrivialWindowDimension(const WindowDimension& window_dimension) {
208 return window_dimension.size() == 1 && window_dimension.stride() == 1 &&
209 window_dimension.padding_low() == 0 &&
210 window_dimension.padding_high() == 0 &&
211 window_dimension.window_dilation() == 1 &&
212 window_dimension.base_dilation() == 1;
213 }
214
DilatedBound(int64 bound,int64 dilation)215 int64 DilatedBound(int64 bound, int64 dilation) {
216 CHECK_GE(bound, 0);
217 CHECK_GE(dilation, 1);
218 if (bound == 0) {
219 return 0;
220 }
221
222 // Suppose the array has three entries 123 and the dilation factor is 4. Then
223 // the dilated array has 9 entries 1xxx2xxx3. Here, each original entry except
224 // the last expands into 4 entries, so that is (bound - 1) * dilation. Then we
225 // add 1 to account for the final input element.
226 return (bound - 1) * dilation + 1;
227 }
228
StridedBound(int64 bound,int64 window_size,int64 stride)229 int64 StridedBound(int64 bound, int64 window_size, int64 stride) {
230 CHECK_GE(window_size, 0);
231 CHECK_GE(bound, 0);
232 CHECK_GE(stride, 1);
233
234 if (bound == 0 || window_size > bound) {
235 return 0;
236 }
237
238 // Without considering stride, the maximum valid offset is bound -
239 // window_size. Taking stride into account, the valid offsets then have the
240 // form q * stride for q = 0, ..., Q such that q * stride <= bound -
241 // window_size. This implies that Q equals floor(bound - window_size /
242 // stride). There are Q + 1 valid values of q, yielding the formula below.
243 return (bound - window_size) / stride + 1;
244 }
245
246 } // namespace window_util
247 } // namespace xla
248