• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/gpu_conv_padding_legalization.h"
17 
18 #include <memory>
19 
20 #include "tensorflow/compiler/xla/literal.h"
21 #include "tensorflow/compiler/xla/literal_util.h"
22 #include "tensorflow/compiler/xla/service/gpu/cublas_cudnn.h"
23 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
24 #include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
25 #include "tensorflow/compiler/xla/service/shape_inference.h"
26 #include "tensorflow/compiler/xla/util.h"
27 #include "tensorflow/compiler/xla/window_util.h"
28 #include "tensorflow/compiler/xla/xla_data.pb.h"
29 
30 namespace xla {
31 namespace gpu {
32 
33 namespace {
IsForwardConvolutionCanonical(const HloInstruction & conv)34 bool IsForwardConvolutionCanonical(const HloInstruction& conv) {
35   CHECK(conv.custom_call_target() == kCudnnConvForwardCallTarget ||
36         conv.custom_call_target() == kCudnnConvBiasActivationForwardCallTarget);
37   return window_util::HasSymmetricPadding(conv.window()) &&
38          !window_util::HasNegativePadding(conv.window()) &&
39          !window_util::HasDilation(conv.window());
40 }
41 
42 // If the (positive and negative) padding on the input operand of a convolution
43 // can't be folded into a cuDNN convolution libcall (e.g. uneven padding and
44 // dilation), returns kPad and/or kSlice instructions that explicitly apply the
45 // padding; otherwise returns the original input operand. When there is both
46 // positive padding (including dilation) and negative padding, we insert both
47 // kPad and kSlice. Modifies 'conv_window' accordingly if any padding was moved
48 // into a kPad or kSlice op.
MaybePaddedAndSlicedInput(Window * conv_window,const ConvolutionDimensionNumbers & conv_dnums,HloInstruction * input)49 HloInstruction* MaybePaddedAndSlicedInput(
50     Window* conv_window, const ConvolutionDimensionNumbers& conv_dnums,
51     HloInstruction* input) {
52   HloComputation* computation = input->parent();
53   if (!window_util::HasSymmetricPadding(*conv_window) ||
54       window_util::HasBaseDilation(*conv_window)) {
55     // If padding is uneven or has dilation, we insert a kPad instruction that
56     // applies positive padding and dilation.
57     //
58     // TODO(phawkins): If conv_window has asymmetric padding, perhaps instead of
59     // moving all the padding into an explicit pad op, we should keep as much
60     // padding inside of cudnn as possible, on the assumption that padding
61     // within cudnn is basically free, whereas a kPad's cost increases as the
62     // amount of padding increases.
63     PaddingConfig padding_config =
64         MakeNoPaddingConfig(input->shape().dimensions_size());
65     for (size_t i = 0; i < conv_dnums.input_spatial_dimensions().size(); ++i) {
66       int64_t dim = conv_dnums.input_spatial_dimensions(i);
67       if (conv_window->dimensions(i).padding_low() > 0) {
68         padding_config.mutable_dimensions(dim)->set_edge_padding_low(
69             conv_window->dimensions(i).padding_low());
70         conv_window->mutable_dimensions(i)->set_padding_low(0);
71       }
72       if (conv_window->dimensions(i).padding_high() > 0) {
73         padding_config.mutable_dimensions(dim)->set_edge_padding_high(
74             conv_window->dimensions(i).padding_high());
75         conv_window->mutable_dimensions(i)->set_padding_high(0);
76       }
77       if (conv_window->dimensions(i).base_dilation() != 1) {
78         padding_config.mutable_dimensions(dim)->set_interior_padding(
79             conv_window->dimensions(i).base_dilation() - 1);
80         conv_window->mutable_dimensions(i)->set_base_dilation(1);
81       }
82     }
83     PrimitiveType element_type = input->shape().element_type();
84     HloInstruction* padding = computation->AddInstruction(
85         HloInstruction::CreateConstant(LiteralUtil::Zero(element_type)));
86     input = MakePadHlo(input, padding, padding_config, &input->metadata())
87                 .ValueOrDie();
88   }
89 
90   if (window_util::HasNegativePadding(*conv_window)) {
91     // If the window has negative padding, insert a kSlice that explicitly
92     // applies negative padding.
93     //
94     // For each dimension, initialize the start index to 0 and the limit index
95     // to the size of that dimension.
96     std::vector<int64_t> start_indices(input->shape().dimensions_size(), 0);
97     std::vector<int64_t> limit_indices(input->shape().dimensions().begin(),
98                                        input->shape().dimensions().end());
99     std::vector<int64_t> strides(input->shape().dimensions_size(), 1);
100     for (size_t i = 0; i < conv_dnums.input_spatial_dimensions().size(); ++i) {
101       int64_t dim = conv_dnums.input_spatial_dimensions(i);
102       // If dimension "dim" has negative padding, increase the start index or
103       // decrement the limit index by the amount of negative padding.
104       if (conv_window->dimensions(i).padding_low() < 0) {
105         start_indices[dim] += -conv_window->dimensions(i).padding_low();
106         conv_window->mutable_dimensions(i)->set_padding_low(0);
107       }
108       if (conv_window->dimensions(i).padding_high() < 0) {
109         limit_indices[dim] -= -conv_window->dimensions(i).padding_high();
110         conv_window->mutable_dimensions(i)->set_padding_high(0);
111       }
112     }
113 
114     input =
115         MakeSliceHlo(input, start_indices, limit_indices, strides).ValueOrDie();
116   }
117 
118   return input;
119 }
120 
121 // If the padding on the kernel operand of a convolution can't be folded into a
122 // cuDNN convolution libcall (e.g. dilation), returns a kPad instruction that
123 // explicitly applies the padding; otherwise returns the original kernel
124 // operand.
MaybePaddedKernel(const Window & conv_window,const ConvolutionDimensionNumbers & conv_dnums,HloInstruction * kernel)125 HloInstruction* MaybePaddedKernel(const Window& conv_window,
126                                   const ConvolutionDimensionNumbers& conv_dnums,
127                                   HloInstruction* kernel) {
128   if (!window_util::HasWindowDilation(conv_window)) {
129     return kernel;
130   }
131 
132   // Compute the shape and padding config of the pad to be inserted.
133   PaddingConfig padding_config;
134   for (size_t i = 0; i < kernel->shape().dimensions_size(); ++i) {
135     padding_config.add_dimensions();
136   }
137   for (size_t i = 0; i < conv_dnums.kernel_spatial_dimensions().size(); ++i) {
138     int64_t dim = conv_dnums.kernel_spatial_dimensions(i);
139     padding_config.mutable_dimensions(dim)->set_interior_padding(
140         conv_window.dimensions(i).window_dilation() - 1);
141   }
142 
143   HloComputation* computation = kernel->parent();
144   PrimitiveType element_type = kernel->shape().element_type();
145   HloInstruction* padding = computation->AddInstruction(
146       HloInstruction::CreateConstant(LiteralUtil::Zero(element_type)));
147   return MakePadHlo(kernel, padding, padding_config, &kernel->metadata())
148       .ValueOrDie();
149 }
150 }  // namespace
151 
CanonicalizeForwardConvolution(HloInstruction * conv)152 bool GpuConvPaddingLegalization::CanonicalizeForwardConvolution(
153     HloInstruction* conv) {
154   if (IsForwardConvolutionCanonical(*conv)) {
155     return false;
156   }
157 
158   // Insert slices and/or pads between the convolution and its input and/or
159   // kernel operand.
160   Window new_conv_window = conv->window();
161   HloInstruction* new_input = MaybePaddedAndSlicedInput(
162       &new_conv_window, conv->convolution_dimension_numbers(),
163       conv->mutable_operand(0));
164   HloInstruction* new_kernel =
165       MaybePaddedKernel(new_conv_window, conv->convolution_dimension_numbers(),
166                         conv->mutable_operand(1));
167 
168   // Remove the window dilation from convolution's window field. These paddings
169   // are made explicit with the pads inserted by MaybePaddedKernel().
170   for (size_t i = 0; i < new_conv_window.dimensions_size(); ++i) {
171     WindowDimension* dim = new_conv_window.mutable_dimensions(i);
172 
173     // The size of the kernel may have changed so update the Window to match.
174     dim->set_size(new_kernel->shape().dimensions(
175         conv->convolution_dimension_numbers().kernel_spatial_dimensions(i)));
176     dim->set_window_dilation(1);
177   }
178 
179   // The conv CustomCall returns a tuple (conv_result, scratch_buffer).  Extract
180   // out the shape of conv_result.
181   VLOG(1) << "Canonicalizing forward conv";
182   std::vector<HloInstruction*> operands(conv->operands().begin(),
183                                         conv->operands().end());
184   operands[0] = new_input;
185   operands[1] = new_kernel;
186   auto new_conv = conv->parent()->AddInstruction(
187       conv->CloneWithNewOperands(conv->shape(), operands));
188   new_conv->set_window(new_conv_window);
189   VLOG(1) << "Replacing:\n  " << conv->ToString() << "\nwith:\n  "
190           << new_conv->ToString();
191   TF_CHECK_OK(conv->parent()->ReplaceInstruction(conv, new_conv));
192   return true;
193 }
194 
195 namespace {
IncreasePaddingLowBy(int64_t delta,WindowDimension * window_dim)196 void IncreasePaddingLowBy(int64_t delta, WindowDimension* window_dim) {
197   window_dim->set_padding_low(window_dim->padding_low() + delta);
198 }
199 
IncreasePaddingHighBy(int64_t delta,WindowDimension * window_dim)200 void IncreasePaddingHighBy(int64_t delta, WindowDimension* window_dim) {
201   window_dim->set_padding_high(window_dim->padding_high() + delta);
202 }
203 }  // namespace
204 
CanonicalizeBackwardFilterConvolution(HloInstruction * backward_conv)205 bool GpuConvPaddingLegalization::CanonicalizeBackwardFilterConvolution(
206     HloInstruction* backward_conv) {
207   CHECK_EQ(backward_conv->custom_call_target(),
208            kCudnnConvBackwardFilterCallTarget);
209   if (window_util::HasSymmetricPadding(backward_conv->window())) {
210     return false;
211   }
212 
213   // A backward filter convolution with uneven padding can be canonicalized to
214   // one with even padding by padding the activations (input) beforehand. For
215   // example,
216   //   BackwardFilterConv(ABCD, xyz, padding_low=1, padding_high=2)
217   // is equivalent to
218   //   ABCD0 = Pad(ABCD, padding_high=1)
219   //   BackwardFilterConv(ABCD0, xyz, padding_low=padding_high=1)
220   // We choose the lesser of padding_low and padding_high as the new padding.
221   HloInstruction* input = backward_conv->mutable_operand(0);
222   Window new_backward_conv_window = backward_conv->window();
223   // input_padding_config is the config of the kPad to be inserted.
224   PaddingConfig input_padding_config =
225       MakeNoPaddingConfig(input->shape().rank());
226   ConvolutionDimensionNumbers backward_conv_dnums =
227       backward_conv->convolution_dimension_numbers();
228   for (size_t i = 0; i < backward_conv->window().dimensions_size(); ++i) {
229     int64_t padding_low = backward_conv->window().dimensions(i).padding_low();
230     int64_t padding_high = backward_conv->window().dimensions(i).padding_high();
231     if (padding_low < 0 || padding_high < 0) {
232       // TODO(b/32744257): The following canonicalization wouldn't remove
233       // negative padding in a backward convolution, and would therefore cause
234       // cuDNN convolution (which doesn't support negative padding) to fail.
235       return false;
236     }
237     // Compute the new, even padding for the backward conv operation.
238     int64_t new_conv_padding = std::min(padding_low, padding_high);
239     int64_t dim = backward_conv_dnums.input_spatial_dimensions(i);
240     input_padding_config.mutable_dimensions(dim)->set_edge_padding_low(
241         padding_low - new_conv_padding);
242     input_padding_config.mutable_dimensions(dim)->set_edge_padding_high(
243         padding_high - new_conv_padding);
244 
245     // Since we move some padding from the backward convolution to the kPad, we
246     // need to accordingly reduce the padding amount of the backward convolution
247     // and its inner forward convolution.
248     auto* new_dim = new_backward_conv_window.mutable_dimensions(i);
249     new_dim->set_padding_low(new_conv_padding);
250     new_dim->set_padding_high(new_conv_padding);
251   }
252 
253   // Create a new backward convolution replacing the old one.
254   HloComputation* computation = backward_conv->parent();
255   HloInstruction* output = backward_conv->mutable_operand(1);
256   HloInstruction* padding =
257       computation->AddInstruction(HloInstruction::CreateConstant(
258           LiteralUtil::Zero(input->shape().element_type())));
259   HloInstruction* padded_input =
260       MakePadHlo(input, padding, input_padding_config).ValueOrDie();
261 
262   // The shape of the backward_conv CustomCall is a tuple (conv_result,
263   // scratch_buffer).  Extract out the shape of conv_result.
264   HloInstruction* new_backward_conv =
265       computation->AddInstruction(backward_conv->CloneWithNewOperands(
266           backward_conv->shape(), {padded_input, output}));
267   new_backward_conv->set_window(new_backward_conv_window);
268 
269   VLOG(1) << "Canonicalizing backward filter conv";
270   VLOG(1) << "Replacing:\n  " << backward_conv->ToString() << "\nwith:\n  "
271           << new_backward_conv->ToString();
272 
273   TF_CHECK_OK(
274       computation->ReplaceInstruction(backward_conv, new_backward_conv));
275   return true;
276 }
277 
CanonicalizeBackwardInputConvolution(HloInstruction * backward_conv)278 bool GpuConvPaddingLegalization::CanonicalizeBackwardInputConvolution(
279     HloInstruction* backward_conv) {
280   if (window_util::HasSymmetricPadding(backward_conv->window())) {
281     return false;
282   }
283 
284   Window new_backward_conv_window = backward_conv->window();
285   ConvolutionDimensionNumbers backward_conv_dnums =
286       backward_conv->convolution_dimension_numbers();
287 
288   // The backward_conv CustomCall returns a tuple (conv_result, scratch_memory).
289   // Get the shape of conv_result.
290   Shape backward_conv_shape = backward_conv->shape().tuple_shapes(0);
291 
292   Shape new_backward_conv_shape = backward_conv_shape;
293   for (size_t i = 0; i < backward_conv->window().dimensions_size(); ++i) {
294     int64_t padding_low = backward_conv->window().dimensions(i).padding_low();
295     int64_t padding_high = backward_conv->window().dimensions(i).padding_high();
296     if (padding_low < 0 || padding_high < 0) {
297       // TODO(b/32744257): The following canonicalization wouldn't remove
298       // negative padding in a backward convolution, and would therefore cause
299       // cuDNN convolution (which doesn't support negative padding) to fail.
300       return false;
301     }
302     // If the backward convolution has uneven padding on the activations, we
303     // move some padding on the larger end to "internal" padding, so that the
304     // backward convolution produces larger activations which get sliced later.
305     //
306     // For example, suppose we have a non-canonical HLO
307     //   [A] = BackwardInputConvolve([a b], [x y z], padding=(low=2,high=1))
308     // where the amount of padding low is larger, we can canonicalize it to
309     //   [B A] = BackwardInputConvolve([a b], [x y z], padding=(low=1,high=1))
310     //   [A] = Slice([B A])
311     if (padding_low > padding_high) {
312       IncreasePaddingLowBy(padding_high - padding_low,
313                            new_backward_conv_window.mutable_dimensions(i));
314     } else if (padding_low < padding_high) {
315       IncreasePaddingHighBy(padding_low - padding_high,
316                             new_backward_conv_window.mutable_dimensions(i));
317     }
318     // Decreasing the padding by X *increases* the size of our output by X.
319     // Note that we have swapped input spatial dimensions with output spatial
320     // dimensions to be compatible with the cuDNN API, so
321     // input_spatial_dimensions(i) gives the i-th spatial dimension of the
322     // output.
323     int64_t dim = backward_conv_dnums.input_spatial_dimensions(i);
324     new_backward_conv_shape.set_dimensions(
325         dim, new_backward_conv_shape.dimensions(dim) +
326                  std::abs(padding_low - padding_high));
327   }
328 
329   // Create a new backward convolution replacing the old one.
330   HloComputation* computation = backward_conv->parent();
331   HloInstruction* output = backward_conv->mutable_operand(0);
332   HloInstruction* filter = backward_conv->mutable_operand(1);
333 
334   HloInstruction* new_backward_conv_call =
335       computation->AddInstruction(backward_conv->CloneWithNewOperands(
336           ShapeUtil::MakeTupleShape(
337               {new_backward_conv_shape, ShapeUtil::MakeShape(U8, {0})}),
338           {output, filter}));
339   new_backward_conv_call->set_window(new_backward_conv_window);
340 
341   // The CustomCall created above returns a tuple (conv_result, scratch_memory).
342   // Extract out the two elements.
343   HloInstruction* new_backward_conv =
344       computation->AddInstruction(HloInstruction::CreateGetTupleElement(
345           new_backward_conv_shape, new_backward_conv_call, 0));
346   HloInstruction* new_backward_conv_scratch =
347       computation->AddInstruction(HloInstruction::CreateGetTupleElement(
348           new_backward_conv_call->shape().tuple_shapes(1),
349           new_backward_conv_call, 1));
350 
351   // Slice the new backward convolution.
352   //
353   // Initialize start_indices and limit_indices as no slicing.
354   std::vector<int64_t> start_indices(
355       new_backward_conv->shape().dimensions_size(), 0LL);
356   std::vector<int64_t> limit_indices(
357       new_backward_conv->shape().dimensions().begin(),
358       new_backward_conv->shape().dimensions().end());
359   std::vector<int64_t> strides(new_backward_conv->shape().dimensions_size(),
360                                1LL);
361   for (size_t i = 0; i < backward_conv->window().dimensions_size(); ++i) {
362     int64_t padding_low = backward_conv->window().dimensions(i).padding_low();
363     int64_t padding_high = backward_conv->window().dimensions(i).padding_high();
364     // Note that we have swapped input spatial dimensions with output spatial
365     // dimensions to be compatible with the cuDNN API, so
366     // input_spatial_dimensions(i) gives the i-th spatial dimension of the
367     // output.
368     int64_t dim = backward_conv_dnums.input_spatial_dimensions(i);
369     if (padding_low > padding_high) {
370       // If the amount of low padding (of the old backward convolution) is
371       // larger, we internally pad the low end of the activations and slice
372       // internal padding out here.
373       start_indices[dim] += padding_low - padding_high;
374     } else if (padding_low < padding_high) {
375       // If the amount of high padding is larger, we slice out the internal
376       // padding on the high end.
377       limit_indices[dim] -= padding_high - padding_low;
378     }
379   }
380 
381   // Replace the old backward convolution with the slice.
382   Shape slice_shape =
383       ShapeInference::InferSliceShape(new_backward_conv->shape(), start_indices,
384                                       limit_indices, strides)
385           .value();
386   CHECK(ShapeUtil::Compatible(slice_shape, backward_conv_shape))
387       << ShapeUtil::HumanString(slice_shape) << " vs "
388       << ShapeUtil::HumanString(backward_conv_shape);
389 
390   HloInstruction* slice = computation->AddInstruction(
391       HloInstruction::CreateSlice(backward_conv_shape, new_backward_conv,
392                                   start_indices, limit_indices, strides));
393   HloInstruction* new_tuple = computation->AddInstruction(
394       HloInstruction::CreateTuple({slice, new_backward_conv_scratch}));
395 
396   VLOG(1) << "Canonicalizing backward input conv";
397   VLOG(1) << "Replacing:\n  " << backward_conv->ToString() << "\nwith:\n  "
398           << new_tuple->ToString();
399 
400   TF_CHECK_OK(computation->ReplaceInstruction(backward_conv, new_tuple));
401   return true;
402 }
403 
RunOnComputation(HloComputation * computation)404 StatusOr<bool> GpuConvPaddingLegalization::RunOnComputation(
405     HloComputation* computation) {
406   bool changed = false;
407   std::vector<HloCustomCallInstruction*> convs;
408   for (auto* instr : computation->instructions()) {
409     if (IsCustomCallToDnnConvolution(*instr)) {
410       convs.push_back(Cast<HloCustomCallInstruction>(instr));
411     }
412   }
413   for (HloCustomCallInstruction* instruction : convs) {
414     TF_ASSIGN_OR_RETURN(auto kind, GetCudnnConvKind(instruction));
415     changed |= [&] {
416       switch (kind) {
417         case CudnnConvKind::kForward:
418         case CudnnConvKind::kForwardActivation:
419           return CanonicalizeForwardConvolution(instruction);
420         case CudnnConvKind::kBackwardInput:
421           return CanonicalizeBackwardInputConvolution(instruction);
422         case CudnnConvKind::kBackwardFilter:
423           return CanonicalizeBackwardFilterConvolution(instruction);
424       }
425     }();
426   }
427   return changed;
428 }
429 
Run(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)430 StatusOr<bool> GpuConvPaddingLegalization::Run(
431     HloModule* module,
432     const absl::flat_hash_set<absl::string_view>& execution_threads) {
433   bool changed = false;
434   for (HloComputation* computation :
435        module->MakeNonfusionComputations(execution_threads)) {
436     TF_ASSIGN_OR_RETURN(bool result, RunOnComputation(computation));
437     changed |= result;
438   }
439   return changed;
440 }
441 
442 }  // namespace gpu
443 }  // namespace xla
444