1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/cudnn_simplify_padding.h"
17
18 #include <algorithm>
19 #include <cstdio>
20 #include <iterator>
21 #include <optional>
22 #include <sstream>
23 #include <vector>
24
25 #include "tensorflow/compiler/xla/service/gpu/cublas_cudnn.h"
26 #include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
27 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
28 #include "tensorflow/compiler/xla/xla_data.pb.h"
29
30 namespace xla::gpu {
31
32 namespace {
33 namespace m = ::xla::match;
34
35 // If exactly one index of `dims` is false, returns that index. If 0 or more
36 // than one index is false, returns nullopt.
FindFalseIndex(absl::Span<const bool> vals)37 std::optional<int64_t> FindFalseIndex(absl::Span<const bool> vals) {
38 std::optional<int64_t> missing_dim;
39 for (int i = 0; i < vals.size(); i++) {
40 if (vals[i]) {
41 continue;
42 }
43 if (missing_dim.has_value()) {
44 VLOG(2) << "Multiple dimensions are missing from conv dnums; can't "
45 "determine which is vect_c dimension";
46 return std::nullopt;
47 }
48 missing_dim = i;
49 }
50 return missing_dim;
51 }
52
53 // Finds the vect_c dimension in the convolution's output.
54 //
55 // The vect_c dimension in dnums is the dimension that's not mentioned in
56 // `dnums`. If there's zero or more than one such dimension, returns nullopt.
FindOutputVectCDim(HloInstruction * conv)57 std::optional<int64_t> FindOutputVectCDim(HloInstruction* conv) {
58 const ConvolutionDimensionNumbers& dnums =
59 conv->convolution_dimension_numbers();
60 int64_t num_dims = conv->shape().tuple_shapes(0).dimensions_size();
61 absl::InlinedVector<bool, 5> seen_dims(num_dims);
62 seen_dims[dnums.output_batch_dimension()] = true;
63 seen_dims[dnums.output_feature_dimension()] = true;
64 for (int64_t d : dnums.output_spatial_dimensions()) {
65 seen_dims[d] = true;
66 }
67 return FindFalseIndex(seen_dims);
68 }
69
70 // Finds the vect_c dimension in the convolution's kernel.
FindKernelVectCDim(HloInstruction * conv)71 std::optional<int64_t> FindKernelVectCDim(HloInstruction* conv) {
72 const ConvolutionDimensionNumbers& dnums =
73 conv->convolution_dimension_numbers();
74 int64_t num_dims = conv->operand(1)->shape().dimensions_size();
75 absl::InlinedVector<bool, 5> seen_dims(num_dims);
76 seen_dims[dnums.kernel_input_feature_dimension()] = true;
77 seen_dims[dnums.kernel_output_feature_dimension()] = true;
78 for (int64_t d : dnums.kernel_spatial_dimensions()) {
79 seen_dims[d] = true;
80 }
81 return FindFalseIndex(seen_dims);
82 }
83
84 // Attempts to count the number of output features at the end of conv that are
85 // guaranteed to be 0.
86 //
87 // This is the same as counting the number of values o at the end of the kernel
88 // for which kernel[i,o,h,w] is 0 for all values i,h,w.
NumTrailingZeroOutputFeatures(HloInstruction * conv)89 std::optional<int64_t> NumTrailingZeroOutputFeatures(HloInstruction* conv) {
90 const ConvolutionDimensionNumbers& dnums =
91 conv->convolution_dimension_numbers();
92 int64_t feature_dim = dnums.kernel_output_feature_dimension();
93 const HloInstruction* weights = conv->operand(1);
94 VLOG(2) << "Computing NumTrailingZeroOutputFeatures of " << conv->ToString()
95 << "\nwith weights " << weights->ToString();
96 if (Match(weights, m::Pad(m::Op(), m::ConstantEffectiveScalar(0)))) {
97 const PaddingConfig::PaddingConfigDimension& padding_config =
98 weights->padding_config().dimensions(feature_dim);
99 // The last N output feature weights are all 0.
100 VLOG(2) << "Success: Weights is a pad; padding on output feature dim is "
101 << padding_config.edge_padding_high();
102 return padding_config.edge_padding_high();
103 } else if (const HloInstruction * pad; Match(
104 weights, m::Reshape(m::Pad(&pad, m::Op(),
105 m::ConstantEffectiveScalar(0))))) {
106 // Check that the reshape merely adds a VECT_C to the kernel input features.
107 // That is, we reshape from [I,O,H,W] (in some order) to [I/k,k,O,H,W] (in
108 // the same order) for some constant k (probably 32). Then check how much
109 // the pad adds to the O dimension.
110 std::optional<int64_t> vect_c_dim = FindKernelVectCDim(conv);
111 if (!vect_c_dim.has_value()) {
112 VLOG(2) << "fail: Can't find vect_c dimension in conv.";
113 return std::nullopt;
114 }
115 if (*vect_c_dim != dnums.kernel_input_feature_dimension() + 1) {
116 VLOG(2) << "fail: vect_c dim is in the wrong place; should be right "
117 "after kernel input feature dims in conv.";
118 return std::nullopt;
119 }
120 absl::InlinedVector<int64_t, 5> expected_pad_dim_sizes(
121 weights->shape().dimensions().begin(),
122 weights->shape().dimensions().end());
123 expected_pad_dim_sizes[dnums.kernel_input_feature_dimension()] *=
124 weights->shape().dimensions(*vect_c_dim);
125 expected_pad_dim_sizes.erase(expected_pad_dim_sizes.begin() + *vect_c_dim);
126 if (pad->shape().dimensions() != expected_pad_dim_sizes) {
127 VLOG(2) << "fail: Reshape doesn't simply merge vect_c dimension into "
128 "input features dim "
129 << weights->ToString() << " but expected dims "
130 << absl::StrJoin(expected_pad_dim_sizes, ",");
131 return std::nullopt;
132 }
133
134 // If the filter dnums are e.g. [I,O,H,W] then after reshape they are
135 // [I/k,k,O,H,W] and the new index of O is greater less than before the
136 // reshape (which we know only adds the I/k and k dims, which we also know
137 // are contiguous). OTOH if the O comes before the I in the original, then
138 // the index of O doesn't change after the reshape.
139 int64_t feature_dim_before_reshape = feature_dim;
140 if (dnums.kernel_output_feature_dimension() >
141 dnums.kernel_input_feature_dimension()) {
142 feature_dim_before_reshape--;
143 }
144 const PaddingConfig::PaddingConfigDimension& padding_config =
145 pad->padding_config().dimensions(feature_dim_before_reshape);
146
147 // The last N output feature weights are all 0.
148 VLOG(2) << "Success: Weights is a reshape of a pad; padding on output "
149 "feature dim is "
150 << padding_config.edge_padding_high();
151 return padding_config.edge_padding_high();
152 } else if (Match(weights, m::Constant())) {
153 // Iterate backwards over `weights` to find the index of the first nonzero
154 // value.
155 //
156 // TODO(jlebar): This may be slow, because it iterates over potentially the
157 // whole constant and does a multi_index -> linear_index conversion for each
158 // element. If necessary we could rewrite this by using linear indices, but
159 // we'd have to be careful of the fact that literals can have arbitrary
160 // layouts, so you can't just iterate over the literal's bytes.
161 const Literal& lit = weights->literal();
162 const auto& dims = weights->shape().dimensions();
163 absl::InlinedVector<int64_t, 5> multi_index;
164 for (int64_t dim : dims) {
165 multi_index.push_back(dim - 1);
166 }
167 while (true) {
168 if (!lit.IsZero(multi_index)) {
169 break;
170 }
171 multi_index[multi_index.size() - 1]--;
172 for (int i = multi_index.size() - 2; i > 0; i--) {
173 if (multi_index[i] == -1) {
174 multi_index[i] = dims[i] - 1;
175 multi_index[i - 1]--;
176 } else {
177 break;
178 }
179 }
180 if (multi_index[0] == -1) {
181 break;
182 }
183 }
184
185 VLOG(2) << "First nonzero index in weights constant is "
186 << absl::StrJoin(multi_index, ",");
187 int64_t first_nonzero_feature = multi_index[feature_dim];
188 // "round up" the first nonzero feature index if it's not *all* zeros.
189 for (int i = 0; i < multi_index.size(); i++) {
190 if (i != feature_dim && multi_index[i] != 0) {
191 first_nonzero_feature++;
192 break;
193 }
194 }
195 int64_t ret = std::max<int64_t>(
196 0, weights->shape().dimensions(feature_dim) - first_nonzero_feature);
197 VLOG(2) << "Success: weights is a constant; num zero trailing output "
198 "features is "
199 << ret;
200 return ret;
201 }
202 return std::nullopt;
203 }
204
TrySimplifyPadding(HloInstruction * instr)205 StatusOr<bool> TrySimplifyPadding(HloInstruction* instr) {
206 // Match one of the following patterns.
207 // conv -> slice -> pad
208 // conv -> reshape -> slice-> pad
209 // conv -> transpose -> reshape -> slice -> pad
210 //
211 // where `pad` (the root of the pattern) is `instr`.
212 HloInstruction* conv;
213 HloInstruction* transpose = nullptr; // optional
214 HloInstruction* reshape = nullptr; // optional
215 HloInstruction* slice;
216 HloInstruction* pad;
217 auto conv_matcher = m::GetTupleElement(
218 m::CustomCall(&conv).WithPredicate([](const HloInstruction* instr) {
219 return instr->custom_call_target() == kCudnnConvForwardCallTarget ||
220 instr->custom_call_target() ==
221 kCudnnConvBiasActivationForwardCallTarget;
222 }),
223 0);
224 auto pad_matcher = m::Pad(m::Op(), m::ConstantEffectiveScalar(0));
225 if (!MatchAndLogIfFailed(instr, "conv-slice-pad",
226 m::Pad(&pad, m::Slice(&slice, conv_matcher),
227 m::ConstantEffectiveScalar(0)),
228 VLOG_IS_ON(3), pad_matcher) &&
229 !MatchAndLogIfFailed(
230 instr, "conv-reshape-slice-pad",
231 m::Pad(&pad, m::Slice(&slice, m::Reshape(&reshape, conv_matcher)),
232 m::ConstantEffectiveScalar(0)),
233 VLOG_IS_ON(3), pad_matcher) &&
234 !MatchAndLogIfFailed(
235 instr, "conv-transpose-reshape-slice-pad",
236 m::Pad(&pad,
237 m::Slice(&slice,
238 m::Reshape(&reshape,
239 m::Transpose(&transpose, conv_matcher))),
240 m::ConstantEffectiveScalar(0)),
241 VLOG_IS_ON(3), pad_matcher)) {
242 return false;
243 }
244
245 VLOG(2) << "Found pattern to attempt to simplify:\n"
246 << "conv: " << conv->ToString() //
247 << "\ntranspose: "
248 << (transpose != nullptr ? transpose->ToString() : "(null)")
249 << "\nreshape: "
250 << (reshape != nullptr ? reshape->ToString() : "(null)")
251 << "\nslice: " << slice->ToString() //
252 << "\npad: " << pad->ToString();
253
254 // Now check that we can merge the slice into the pad, because the slice is
255 // slicing off elements that we know are 0 and the pad is just adding those 0s
256 // back.
257 //
258 // First, we have to check whether any of the output features at the end of
259 // the conv are known to be 0.
260 std::optional<int64_t> num_known_zero_output_features =
261 NumTrailingZeroOutputFeatures(conv);
262 if (!num_known_zero_output_features.has_value() ||
263 *num_known_zero_output_features == 0) {
264 VLOG(2) << "fail: Didn't find any known-zero output features";
265 return false;
266 }
267
268 // We now know that some of the output features of the conv (starting at
269 // known_zero_output_features_start_idx) are zero. Check if the
270 // optional-reshape + optional-transpose + slice + pad combination is setting
271 // all of these features to 0. If so, we can merge the slice into the pad.
272 const auto& dnums = conv->convolution_dimension_numbers();
273 int64_t output_feature_dim;
274 if (reshape == nullptr) {
275 CHECK_EQ(transpose, nullptr);
276 output_feature_dim = dnums.output_feature_dimension();
277 } else {
278 std::optional<int64_t> vect_c_dim_before_transpose =
279 FindOutputVectCDim(conv);
280 if (!vect_c_dim_before_transpose.has_value()) {
281 VLOG(2) << "Couldn't find vect_c output dim in conv.";
282 return false;
283 }
284
285 // If there's no transpose, check that the vect_c dim is immediately after
286 // the feature dim. OTOH if there is a transpose, check that the transpose
287 // moves the vect_c dim immediately after the feature dim.
288 int64_t feature_dim_after_transpose;
289 int64_t vect_c_dim_after_transpose;
290 if (transpose == nullptr) {
291 feature_dim_after_transpose = dnums.output_feature_dimension();
292 vect_c_dim_after_transpose = *vect_c_dim_before_transpose;
293 } else {
294 const auto& transpose_dims = transpose->dimensions();
295 feature_dim_after_transpose = std::distance(
296 transpose->dimensions().begin(),
297 absl::c_find(transpose_dims, dnums.output_feature_dimension()));
298 vect_c_dim_after_transpose = std::distance(
299 transpose->dimensions().begin(),
300 absl::c_find(transpose_dims, *vect_c_dim_before_transpose));
301 }
302 if (vect_c_dim_after_transpose != feature_dim_after_transpose + 1) {
303 VLOG(2) << "fail: after transpose (if present), vect_c dim must appear "
304 "immediately after output feature dim: Computed "
305 "vect_d_dim_after_transpose to be "
306 << vect_c_dim_after_transpose;
307 return false;
308 }
309
310 // Now check that the reshape merges the feature + vect_c dims and
311 // doesn't do anything else.
312 absl::InlinedVector<int64_t, 5> expected_reshape_dim_sizes(
313 reshape->operand(0)->shape().dimensions().begin(),
314 reshape->operand(0)->shape().dimensions().end());
315 expected_reshape_dim_sizes[feature_dim_after_transpose] *=
316 expected_reshape_dim_sizes[vect_c_dim_after_transpose];
317 expected_reshape_dim_sizes.erase(expected_reshape_dim_sizes.begin() +
318 vect_c_dim_after_transpose);
319 if (reshape->shape().dimensions() != expected_reshape_dim_sizes) {
320 VLOG(2) << "fail: Reshape doesn't merge vect_c with feature dimension.";
321 return false;
322 }
323
324 output_feature_dim = feature_dim_after_transpose;
325 }
326
327 // Check that `slice` slices only the output feature dimension.
328 if (!absl::c_all_of(slice->slice_starts(), [](auto v) { return v == 0; }) ||
329 !absl::c_all_of(slice->slice_strides(), [](auto v) { return v == 1; })) {
330 VLOG(2) << "fail: Slice doesn't start at the front or has stride != 1.";
331 return false;
332 }
333
334 // We're only allowed to slice the feature dim.
335 for (int64_t dim = 0; dim < slice->slice_limits().size(); dim++) {
336 if (dim != output_feature_dim &&
337 slice->slice_limits(dim) != slice->shape().dimensions(dim)) {
338 VLOG(2) << "fail: Slice removes something other than the features dim.";
339 return false;
340 }
341 }
342 int64_t num_sliced_from_feature_dim =
343 slice->operand(0)->shape().dimensions(output_feature_dim) -
344 slice->slice_limits(output_feature_dim);
345
346 // If we slice off more than the known-zero output features, then we need to
347 // keep the slice -- it's "doing something".
348 if (num_sliced_from_feature_dim > *num_known_zero_output_features) {
349 VLOG(2) << "fail: Slice removes " << num_sliced_from_feature_dim
350 << " features from the conv, but only "
351 << *num_known_zero_output_features
352 << " features in the conv are known to be zero.";
353 return false;
354 }
355
356 // Check if we can merge the slice into the pad.
357 if (pad->padding_config().dimensions(output_feature_dim).interior_padding() !=
358 0) {
359 VLOG(2)
360 << "fail: Can't merge slice into pad because pad adds interior padding "
361 "in feature dimension.";
362 return false;
363 }
364
365 // Okay! If we got here, it's legal to fold the slice into the pad. We pad
366 // less, because we know that the sliced-off elements are all 0. Ideally, the
367 // pad becomes a nop and gets eliminated by algsimp later.
368 VLOG(1) << "Eliminating " << num_sliced_from_feature_dim
369 << " elements of padding from conv " << conv->name();
370 PaddingConfig new_padding_config = pad->padding_config();
371 PaddingConfig::PaddingConfigDimension* new_pad_feature_dim =
372 new_padding_config.mutable_dimensions(output_feature_dim);
373 // This is safe even if the new edge_padding_high is negative -- negative
374 // padding is allowed.
375 new_pad_feature_dim->set_edge_padding_high(
376 new_pad_feature_dim->edge_padding_high() - num_sliced_from_feature_dim);
377 TF_ASSIGN_OR_RETURN(HloInstruction * new_pad,
378 MakePadHlo(slice->mutable_operand(0),
379 pad->mutable_operand(1), new_padding_config));
380 TF_RETURN_IF_ERROR(pad->parent()->ReplaceInstruction(pad, new_pad));
381 return true;
382 }
383
384 } // anonymous namespace
385
Run(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)386 StatusOr<bool> CudnnSimplifyPadding::Run(
387 HloModule* module,
388 const absl::flat_hash_set<absl::string_view>& execution_threads) {
389 bool changed = false;
390 for (HloComputation* comp :
391 module->MakeNonfusionComputations(execution_threads)) {
392 for (HloInstruction* instr : comp->MakeInstructionPostOrder()) {
393 TF_ASSIGN_OR_RETURN(bool c, TrySimplifyPadding(instr));
394 changed |= c;
395 }
396 }
397 return changed;
398 }
399
400 } // namespace xla::gpu
401