• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/cudnn_simplify_padding.h"
17 
18 #include <algorithm>
19 #include <cstdio>
20 #include <iterator>
21 #include <optional>
22 #include <sstream>
23 #include <vector>
24 
25 #include "tensorflow/compiler/xla/service/gpu/cublas_cudnn.h"
26 #include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
27 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
28 #include "tensorflow/compiler/xla/xla_data.pb.h"
29 
30 namespace xla::gpu {
31 
32 namespace {
33 namespace m = ::xla::match;
34 
35 // If exactly one index of `dims` is false, returns that index.  If 0 or more
36 // than one index is false, returns nullopt.
FindFalseIndex(absl::Span<const bool> vals)37 std::optional<int64_t> FindFalseIndex(absl::Span<const bool> vals) {
38   std::optional<int64_t> missing_dim;
39   for (int i = 0; i < vals.size(); i++) {
40     if (vals[i]) {
41       continue;
42     }
43     if (missing_dim.has_value()) {
44       VLOG(2) << "Multiple dimensions are missing from conv dnums; can't "
45                  "determine which is vect_c dimension";
46       return std::nullopt;
47     }
48     missing_dim = i;
49   }
50   return missing_dim;
51 }
52 
53 // Finds the vect_c dimension in the convolution's output.
54 //
55 // The vect_c dimension in dnums is the dimension that's not mentioned in
56 // `dnums`.  If there's zero or more than one such dimension, returns nullopt.
FindOutputVectCDim(HloInstruction * conv)57 std::optional<int64_t> FindOutputVectCDim(HloInstruction* conv) {
58   const ConvolutionDimensionNumbers& dnums =
59       conv->convolution_dimension_numbers();
60   int64_t num_dims = conv->shape().tuple_shapes(0).dimensions_size();
61   absl::InlinedVector<bool, 5> seen_dims(num_dims);
62   seen_dims[dnums.output_batch_dimension()] = true;
63   seen_dims[dnums.output_feature_dimension()] = true;
64   for (int64_t d : dnums.output_spatial_dimensions()) {
65     seen_dims[d] = true;
66   }
67   return FindFalseIndex(seen_dims);
68 }
69 
70 // Finds the vect_c dimension in the convolution's kernel.
FindKernelVectCDim(HloInstruction * conv)71 std::optional<int64_t> FindKernelVectCDim(HloInstruction* conv) {
72   const ConvolutionDimensionNumbers& dnums =
73       conv->convolution_dimension_numbers();
74   int64_t num_dims = conv->operand(1)->shape().dimensions_size();
75   absl::InlinedVector<bool, 5> seen_dims(num_dims);
76   seen_dims[dnums.kernel_input_feature_dimension()] = true;
77   seen_dims[dnums.kernel_output_feature_dimension()] = true;
78   for (int64_t d : dnums.kernel_spatial_dimensions()) {
79     seen_dims[d] = true;
80   }
81   return FindFalseIndex(seen_dims);
82 }
83 
84 // Attempts to count the number of output features at the end of conv that are
85 // guaranteed to be 0.
86 //
87 // This is the same as counting the number of values o at the end of the kernel
88 // for which kernel[i,o,h,w] is 0 for all values i,h,w.
NumTrailingZeroOutputFeatures(HloInstruction * conv)89 std::optional<int64_t> NumTrailingZeroOutputFeatures(HloInstruction* conv) {
90   const ConvolutionDimensionNumbers& dnums =
91       conv->convolution_dimension_numbers();
92   int64_t feature_dim = dnums.kernel_output_feature_dimension();
93   const HloInstruction* weights = conv->operand(1);
94   VLOG(2) << "Computing NumTrailingZeroOutputFeatures of " << conv->ToString()
95           << "\nwith weights " << weights->ToString();
96   if (Match(weights, m::Pad(m::Op(), m::ConstantEffectiveScalar(0)))) {
97     const PaddingConfig::PaddingConfigDimension& padding_config =
98         weights->padding_config().dimensions(feature_dim);
99     // The last N output feature weights are all 0.
100     VLOG(2) << "Success: Weights is a pad; padding on output feature dim is "
101             << padding_config.edge_padding_high();
102     return padding_config.edge_padding_high();
103   } else if (const HloInstruction * pad; Match(
104                  weights, m::Reshape(m::Pad(&pad, m::Op(),
105                                             m::ConstantEffectiveScalar(0))))) {
106     // Check that the reshape merely adds a VECT_C to the kernel input features.
107     // That is, we reshape from [I,O,H,W] (in some order) to [I/k,k,O,H,W] (in
108     // the same order) for some constant k (probably 32).  Then check how much
109     // the pad adds to the O dimension.
110     std::optional<int64_t> vect_c_dim = FindKernelVectCDim(conv);
111     if (!vect_c_dim.has_value()) {
112       VLOG(2) << "fail: Can't find vect_c dimension in conv.";
113       return std::nullopt;
114     }
115     if (*vect_c_dim != dnums.kernel_input_feature_dimension() + 1) {
116       VLOG(2) << "fail: vect_c dim is in the wrong place; should be right "
117                  "after kernel input feature dims in conv.";
118       return std::nullopt;
119     }
120     absl::InlinedVector<int64_t, 5> expected_pad_dim_sizes(
121         weights->shape().dimensions().begin(),
122         weights->shape().dimensions().end());
123     expected_pad_dim_sizes[dnums.kernel_input_feature_dimension()] *=
124         weights->shape().dimensions(*vect_c_dim);
125     expected_pad_dim_sizes.erase(expected_pad_dim_sizes.begin() + *vect_c_dim);
126     if (pad->shape().dimensions() != expected_pad_dim_sizes) {
127       VLOG(2) << "fail: Reshape doesn't simply merge vect_c dimension into "
128                  "input features dim "
129               << weights->ToString() << " but expected dims "
130               << absl::StrJoin(expected_pad_dim_sizes, ",");
131       return std::nullopt;
132     }
133 
134     // If the filter dnums are e.g. [I,O,H,W] then after reshape they are
135     // [I/k,k,O,H,W] and the new index of O is greater less than before the
136     // reshape (which we know only adds the I/k and k dims, which we also know
137     // are contiguous).  OTOH if the O comes before the I in the original, then
138     // the index of O doesn't change after the reshape.
139     int64_t feature_dim_before_reshape = feature_dim;
140     if (dnums.kernel_output_feature_dimension() >
141         dnums.kernel_input_feature_dimension()) {
142       feature_dim_before_reshape--;
143     }
144     const PaddingConfig::PaddingConfigDimension& padding_config =
145         pad->padding_config().dimensions(feature_dim_before_reshape);
146 
147     // The last N output feature weights are all 0.
148     VLOG(2) << "Success: Weights is a reshape of a pad; padding on output "
149                "feature dim is "
150             << padding_config.edge_padding_high();
151     return padding_config.edge_padding_high();
152   } else if (Match(weights, m::Constant())) {
153     // Iterate backwards over `weights` to find the index of the first nonzero
154     // value.
155     //
156     // TODO(jlebar): This may be slow, because it iterates over potentially the
157     // whole constant and does a multi_index -> linear_index conversion for each
158     // element. If necessary we could rewrite this by using linear indices, but
159     // we'd have to be careful of the fact that literals can have arbitrary
160     // layouts, so you can't just iterate over the literal's bytes.
161     const Literal& lit = weights->literal();
162     const auto& dims = weights->shape().dimensions();
163     absl::InlinedVector<int64_t, 5> multi_index;
164     for (int64_t dim : dims) {
165       multi_index.push_back(dim - 1);
166     }
167     while (true) {
168       if (!lit.IsZero(multi_index)) {
169         break;
170       }
171       multi_index[multi_index.size() - 1]--;
172       for (int i = multi_index.size() - 2; i > 0; i--) {
173         if (multi_index[i] == -1) {
174           multi_index[i] = dims[i] - 1;
175           multi_index[i - 1]--;
176         } else {
177           break;
178         }
179       }
180       if (multi_index[0] == -1) {
181         break;
182       }
183     }
184 
185     VLOG(2) << "First nonzero index in weights constant is "
186             << absl::StrJoin(multi_index, ",");
187     int64_t first_nonzero_feature = multi_index[feature_dim];
188     // "round up" the first nonzero feature index if it's not *all* zeros.
189     for (int i = 0; i < multi_index.size(); i++) {
190       if (i != feature_dim && multi_index[i] != 0) {
191         first_nonzero_feature++;
192         break;
193       }
194     }
195     int64_t ret = std::max<int64_t>(
196         0, weights->shape().dimensions(feature_dim) - first_nonzero_feature);
197     VLOG(2) << "Success: weights is a constant; num zero trailing output "
198                "features is "
199             << ret;
200     return ret;
201   }
202   return std::nullopt;
203 }
204 
TrySimplifyPadding(HloInstruction * instr)205 StatusOr<bool> TrySimplifyPadding(HloInstruction* instr) {
206   // Match one of the following patterns.
207   //   conv -> slice -> pad
208   //   conv -> reshape -> slice-> pad
209   //   conv -> transpose -> reshape -> slice -> pad
210   //
211   // where `pad` (the root of the pattern) is `instr`.
212   HloInstruction* conv;
213   HloInstruction* transpose = nullptr;  // optional
214   HloInstruction* reshape = nullptr;    // optional
215   HloInstruction* slice;
216   HloInstruction* pad;
217   auto conv_matcher = m::GetTupleElement(
218       m::CustomCall(&conv).WithPredicate([](const HloInstruction* instr) {
219         return instr->custom_call_target() == kCudnnConvForwardCallTarget ||
220                instr->custom_call_target() ==
221                    kCudnnConvBiasActivationForwardCallTarget;
222       }),
223       0);
224   auto pad_matcher = m::Pad(m::Op(), m::ConstantEffectiveScalar(0));
225   if (!MatchAndLogIfFailed(instr, "conv-slice-pad",
226                            m::Pad(&pad, m::Slice(&slice, conv_matcher),
227                                   m::ConstantEffectiveScalar(0)),
228                            VLOG_IS_ON(3), pad_matcher) &&
229       !MatchAndLogIfFailed(
230           instr, "conv-reshape-slice-pad",
231           m::Pad(&pad, m::Slice(&slice, m::Reshape(&reshape, conv_matcher)),
232                  m::ConstantEffectiveScalar(0)),
233           VLOG_IS_ON(3), pad_matcher) &&
234       !MatchAndLogIfFailed(
235           instr, "conv-transpose-reshape-slice-pad",
236           m::Pad(&pad,
237                  m::Slice(&slice,
238                           m::Reshape(&reshape,
239                                      m::Transpose(&transpose, conv_matcher))),
240                  m::ConstantEffectiveScalar(0)),
241           VLOG_IS_ON(3), pad_matcher)) {
242     return false;
243   }
244 
245   VLOG(2) << "Found pattern to attempt to simplify:\n"
246           << "conv: " << conv->ToString()  //
247           << "\ntranspose: "
248           << (transpose != nullptr ? transpose->ToString() : "(null)")
249           << "\nreshape: "
250           << (reshape != nullptr ? reshape->ToString() : "(null)")
251           << "\nslice: " << slice->ToString()  //
252           << "\npad: " << pad->ToString();
253 
254   // Now check that we can merge the slice into the pad, because the slice is
255   // slicing off elements that we know are 0 and the pad is just adding those 0s
256   // back.
257   //
258   // First, we have to check whether any of the output features at the end of
259   // the conv are known to be 0.
260   std::optional<int64_t> num_known_zero_output_features =
261       NumTrailingZeroOutputFeatures(conv);
262   if (!num_known_zero_output_features.has_value() ||
263       *num_known_zero_output_features == 0) {
264     VLOG(2) << "fail: Didn't find any known-zero output features";
265     return false;
266   }
267 
268   // We now know that some of the output features of the conv (starting at
269   // known_zero_output_features_start_idx) are zero.  Check if the
270   // optional-reshape + optional-transpose + slice + pad combination is setting
271   // all of these features to 0.  If so, we can merge the slice into the pad.
272   const auto& dnums = conv->convolution_dimension_numbers();
273   int64_t output_feature_dim;
274   if (reshape == nullptr) {
275     CHECK_EQ(transpose, nullptr);
276     output_feature_dim = dnums.output_feature_dimension();
277   } else {
278     std::optional<int64_t> vect_c_dim_before_transpose =
279         FindOutputVectCDim(conv);
280     if (!vect_c_dim_before_transpose.has_value()) {
281       VLOG(2) << "Couldn't find vect_c output dim in conv.";
282       return false;
283     }
284 
285     // If there's no transpose, check that the vect_c dim is immediately after
286     // the feature dim.  OTOH if there is a transpose, check that the transpose
287     // moves the vect_c dim immediately after the feature dim.
288     int64_t feature_dim_after_transpose;
289     int64_t vect_c_dim_after_transpose;
290     if (transpose == nullptr) {
291       feature_dim_after_transpose = dnums.output_feature_dimension();
292       vect_c_dim_after_transpose = *vect_c_dim_before_transpose;
293     } else {
294       const auto& transpose_dims = transpose->dimensions();
295       feature_dim_after_transpose = std::distance(
296           transpose->dimensions().begin(),
297           absl::c_find(transpose_dims, dnums.output_feature_dimension()));
298       vect_c_dim_after_transpose = std::distance(
299           transpose->dimensions().begin(),
300           absl::c_find(transpose_dims, *vect_c_dim_before_transpose));
301     }
302     if (vect_c_dim_after_transpose != feature_dim_after_transpose + 1) {
303       VLOG(2) << "fail: after transpose (if present), vect_c dim must appear "
304                  "immediately after output feature dim: Computed "
305                  "vect_d_dim_after_transpose to be "
306               << vect_c_dim_after_transpose;
307       return false;
308     }
309 
310     // Now check that the reshape merges the feature + vect_c dims and
311     // doesn't do anything else.
312     absl::InlinedVector<int64_t, 5> expected_reshape_dim_sizes(
313         reshape->operand(0)->shape().dimensions().begin(),
314         reshape->operand(0)->shape().dimensions().end());
315     expected_reshape_dim_sizes[feature_dim_after_transpose] *=
316         expected_reshape_dim_sizes[vect_c_dim_after_transpose];
317     expected_reshape_dim_sizes.erase(expected_reshape_dim_sizes.begin() +
318                                      vect_c_dim_after_transpose);
319     if (reshape->shape().dimensions() != expected_reshape_dim_sizes) {
320       VLOG(2) << "fail: Reshape doesn't merge vect_c with feature dimension.";
321       return false;
322     }
323 
324     output_feature_dim = feature_dim_after_transpose;
325   }
326 
327   // Check that `slice` slices only the output feature dimension.
328   if (!absl::c_all_of(slice->slice_starts(), [](auto v) { return v == 0; }) ||
329       !absl::c_all_of(slice->slice_strides(), [](auto v) { return v == 1; })) {
330     VLOG(2) << "fail: Slice doesn't start at the front or has stride != 1.";
331     return false;
332   }
333 
334   // We're only allowed to slice the feature dim.
335   for (int64_t dim = 0; dim < slice->slice_limits().size(); dim++) {
336     if (dim != output_feature_dim &&
337         slice->slice_limits(dim) != slice->shape().dimensions(dim)) {
338       VLOG(2) << "fail: Slice removes something other than the features dim.";
339       return false;
340     }
341   }
342   int64_t num_sliced_from_feature_dim =
343       slice->operand(0)->shape().dimensions(output_feature_dim) -
344       slice->slice_limits(output_feature_dim);
345 
346   // If we slice off more than the known-zero output features, then we need to
347   // keep the slice -- it's "doing something".
348   if (num_sliced_from_feature_dim > *num_known_zero_output_features) {
349     VLOG(2) << "fail: Slice removes " << num_sliced_from_feature_dim
350             << " features from the conv, but only "
351             << *num_known_zero_output_features
352             << " features in the conv are known to be zero.";
353     return false;
354   }
355 
356   // Check if we can merge the slice into the pad.
357   if (pad->padding_config().dimensions(output_feature_dim).interior_padding() !=
358       0) {
359     VLOG(2)
360         << "fail: Can't merge slice into pad because pad adds interior padding "
361            "in feature dimension.";
362     return false;
363   }
364 
365   // Okay!  If we got here, it's legal to fold the slice into the pad.  We pad
366   // less, because we know that the sliced-off elements are all 0.  Ideally, the
367   // pad becomes a nop and gets eliminated by algsimp later.
368   VLOG(1) << "Eliminating " << num_sliced_from_feature_dim
369           << " elements of padding from conv " << conv->name();
370   PaddingConfig new_padding_config = pad->padding_config();
371   PaddingConfig::PaddingConfigDimension* new_pad_feature_dim =
372       new_padding_config.mutable_dimensions(output_feature_dim);
373   // This is safe even if the new edge_padding_high is negative -- negative
374   // padding is allowed.
375   new_pad_feature_dim->set_edge_padding_high(
376       new_pad_feature_dim->edge_padding_high() - num_sliced_from_feature_dim);
377   TF_ASSIGN_OR_RETURN(HloInstruction * new_pad,
378                       MakePadHlo(slice->mutable_operand(0),
379                                  pad->mutable_operand(1), new_padding_config));
380   TF_RETURN_IF_ERROR(pad->parent()->ReplaceInstruction(pad, new_pad));
381   return true;
382 }
383 
384 }  // anonymous namespace
385 
Run(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)386 StatusOr<bool> CudnnSimplifyPadding::Run(
387     HloModule* module,
388     const absl::flat_hash_set<absl::string_view>& execution_threads) {
389   bool changed = false;
390   for (HloComputation* comp :
391        module->MakeNonfusionComputations(execution_threads)) {
392     for (HloInstruction* instr : comp->MakeInstructionPostOrder()) {
393       TF_ASSIGN_OR_RETURN(bool c, TrySimplifyPadding(instr));
394       changed |= c;
395     }
396   }
397   return changed;
398 }
399 
400 }  // namespace xla::gpu
401