• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/gpu_conv_padding_legalization.h"
17 
18 #include "absl/memory/memory.h"
19 #include "tensorflow/compiler/xla/literal.h"
20 #include "tensorflow/compiler/xla/literal_util.h"
21 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
22 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
23 #include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
24 #include "tensorflow/compiler/xla/service/shape_inference.h"
25 #include "tensorflow/compiler/xla/util.h"
26 #include "tensorflow/compiler/xla/window_util.h"
27 #include "tensorflow/compiler/xla/xla_data.pb.h"
28 
29 namespace xla {
30 namespace gpu {
31 
32 namespace {
IsForwardConvolutionCanonical(const HloInstruction & conv)33 bool IsForwardConvolutionCanonical(const HloInstruction& conv) {
34   CHECK(conv.custom_call_target() == kCudnnConvForwardCallTarget ||
35         conv.custom_call_target() == kCudnnConvBiasActivationForwardCallTarget);
36   return window_util::HasSymmetricPadding(conv.window()) &&
37          !window_util::HasNegativePadding(conv.window()) &&
38          !window_util::HasDilation(conv.window());
39 }
40 
41 // If the (positive and negative) padding on the input operand of a convolution
42 // can't be folded into a cuDNN convolution libcall (e.g. uneven padding and
43 // dilation), returns kPad and/or kSlice instructions that explicitly apply the
44 // padding; otherwise returns the original input operand. When there is both
45 // positive padding (including dilation) and negative padding, we insert both
46 // kPad and kSlice. Modifies 'conv_window' accordingly if any padding was moved
47 // into a kPad or kSlice op.
MaybePaddedAndSlicedInput(Window * conv_window,const ConvolutionDimensionNumbers & conv_dnums,HloInstruction * input)48 HloInstruction* MaybePaddedAndSlicedInput(
49     Window* conv_window, const ConvolutionDimensionNumbers& conv_dnums,
50     HloInstruction* input) {
51   HloComputation* computation = input->parent();
52   if (!window_util::HasSymmetricPadding(*conv_window) ||
53       window_util::HasBaseDilation(*conv_window)) {
54     // If padding is uneven or has dilation, we insert a kPad instruction that
55     // applies positive padding and dilation.
56     //
57     // TODO(phawkins): If conv_window has asymmetric padding, perhaps instead of
58     // moving all the padding into an explicit pad op, we should keep as much
59     // padding inside of cudnn as possible, on the assumption that padding
60     // within cudnn is basically free, whereas a kPad's cost increases as the
61     // amount of padding increases.
62     PaddingConfig padding_config =
63         MakeNoPaddingConfig(input->shape().dimensions_size());
64     for (size_t i = 0; i < conv_dnums.input_spatial_dimensions().size(); ++i) {
65       int64 dim = conv_dnums.input_spatial_dimensions(i);
66       if (conv_window->dimensions(i).padding_low() > 0) {
67         padding_config.mutable_dimensions(dim)->set_edge_padding_low(
68             conv_window->dimensions(i).padding_low());
69         conv_window->mutable_dimensions(i)->set_padding_low(0);
70       }
71       if (conv_window->dimensions(i).padding_high() > 0) {
72         padding_config.mutable_dimensions(dim)->set_edge_padding_high(
73             conv_window->dimensions(i).padding_high());
74         conv_window->mutable_dimensions(i)->set_padding_high(0);
75       }
76       if (conv_window->dimensions(i).base_dilation() != 1) {
77         padding_config.mutable_dimensions(dim)->set_interior_padding(
78             conv_window->dimensions(i).base_dilation() - 1);
79         conv_window->mutable_dimensions(i)->set_base_dilation(1);
80       }
81     }
82     PrimitiveType element_type = input->shape().element_type();
83     HloInstruction* padding = computation->AddInstruction(
84         HloInstruction::CreateConstant(LiteralUtil::Zero(element_type)));
85     input = MakePadHlo(input, padding, padding_config).ValueOrDie();
86   }
87 
88   if (window_util::HasNegativePadding(*conv_window)) {
89     // If the window has negative padding, insert a kSlice that explicitly
90     // applies negative padding.
91     //
92     // For each dimension, initialize the start index to 0 and the limit index
93     // to the size of that dimension.
94     std::vector<int64> start_indices(input->shape().dimensions_size(), 0);
95     std::vector<int64> limit_indices(input->shape().dimensions().begin(),
96                                      input->shape().dimensions().end());
97     std::vector<int64> strides(input->shape().dimensions_size(), 1);
98     for (size_t i = 0; i < conv_dnums.input_spatial_dimensions().size(); ++i) {
99       int64 dim = conv_dnums.input_spatial_dimensions(i);
100       // If dimension "dim" has negative padding, increase the start index or
101       // decrement the limit index by the amount of negative padding.
102       if (conv_window->dimensions(i).padding_low() < 0) {
103         start_indices[dim] += -conv_window->dimensions(i).padding_low();
104         conv_window->mutable_dimensions(i)->set_padding_low(0);
105       }
106       if (conv_window->dimensions(i).padding_high() < 0) {
107         limit_indices[dim] -= -conv_window->dimensions(i).padding_high();
108         conv_window->mutable_dimensions(i)->set_padding_high(0);
109       }
110     }
111 
112     input =
113         MakeSliceHlo(input, start_indices, limit_indices, strides).ValueOrDie();
114   }
115 
116   return input;
117 }
118 
119 // If the padding on the kernel operand of a convolution can't be folded into a
120 // cuDNN convolution libcall (e.g. dilation), returns a kPad instruction that
121 // explicitly applies the padding; otherwise returns the original kernel
122 // operand.
MaybePaddedKernel(const Window & conv_window,const ConvolutionDimensionNumbers & conv_dnums,HloInstruction * kernel)123 HloInstruction* MaybePaddedKernel(const Window& conv_window,
124                                   const ConvolutionDimensionNumbers& conv_dnums,
125                                   HloInstruction* kernel) {
126   if (!window_util::HasWindowDilation(conv_window)) {
127     return kernel;
128   }
129 
130   // Compute the shape and padding config of the pad to be inserted.
131   PaddingConfig padding_config;
132   for (size_t i = 0; i < kernel->shape().dimensions_size(); ++i) {
133     padding_config.add_dimensions();
134   }
135   for (size_t i = 0; i < conv_dnums.kernel_spatial_dimensions().size(); ++i) {
136     int64 dim = conv_dnums.kernel_spatial_dimensions(i);
137     padding_config.mutable_dimensions(dim)->set_interior_padding(
138         conv_window.dimensions(i).window_dilation() - 1);
139   }
140 
141   HloComputation* computation = kernel->parent();
142   PrimitiveType element_type = kernel->shape().element_type();
143   HloInstruction* padding = computation->AddInstruction(
144       HloInstruction::CreateConstant(LiteralUtil::Zero(element_type)));
145   return MakePadHlo(kernel, padding, padding_config).ValueOrDie();
146 }
147 }  // namespace
148 
CanonicalizeForwardConvolution(HloInstruction * conv)149 bool GpuConvPaddingLegalization::CanonicalizeForwardConvolution(
150     HloInstruction* conv) {
151   if (IsForwardConvolutionCanonical(*conv)) {
152     return false;
153   }
154 
155   // Insert slices and/or pads between the convolution and its input and/or
156   // kernel operand.
157   Window new_conv_window = conv->window();
158   HloInstruction* new_input = MaybePaddedAndSlicedInput(
159       &new_conv_window, conv->convolution_dimension_numbers(),
160       conv->mutable_operand(0));
161   HloInstruction* new_kernel =
162       MaybePaddedKernel(new_conv_window, conv->convolution_dimension_numbers(),
163                         conv->mutable_operand(1));
164 
165   // Remove the window dilation from convolution's window field. These paddings
166   // are made explicit with the pads inserted by MaybePaddedKernel().
167   for (size_t i = 0; i < new_conv_window.dimensions_size(); ++i) {
168     WindowDimension* dim = new_conv_window.mutable_dimensions(i);
169 
170     // The size of the kernel may have changed so update the Window to match.
171     dim->set_size(new_kernel->shape().dimensions(
172         conv->convolution_dimension_numbers().kernel_spatial_dimensions(i)));
173     dim->set_window_dilation(1);
174   }
175 
176   // The conv CustomCall returns a tuple (conv_result, scratch_buffer).  Extract
177   // out the shape of conv_result.
178   VLOG(1) << "Canonicalizing forward conv";
179   std::vector<HloInstruction*> operands(conv->operands().begin(),
180                                         conv->operands().end());
181   operands[0] = new_input;
182   operands[1] = new_kernel;
183   auto new_conv = conv->parent()->AddInstruction(
184       conv->CloneWithNewOperands(conv->shape(), operands));
185   new_conv->set_window(new_conv_window);
186   VLOG(1) << "Replacing:\n  " << conv->ToString() << "\nwith:\n  "
187           << new_conv->ToString();
188   TF_CHECK_OK(conv->parent()->ReplaceInstruction(conv, new_conv));
189   return true;
190 }
191 
192 namespace {
IncreasePaddingLowBy(int64 delta,WindowDimension * window_dim)193 void IncreasePaddingLowBy(int64 delta, WindowDimension* window_dim) {
194   window_dim->set_padding_low(window_dim->padding_low() + delta);
195 }
196 
IncreasePaddingHighBy(int64 delta,WindowDimension * window_dim)197 void IncreasePaddingHighBy(int64 delta, WindowDimension* window_dim) {
198   window_dim->set_padding_high(window_dim->padding_high() + delta);
199 }
200 }  // namespace
201 
CanonicalizeBackwardFilterConvolution(HloInstruction * backward_conv)202 bool GpuConvPaddingLegalization::CanonicalizeBackwardFilterConvolution(
203     HloInstruction* backward_conv) {
204   CHECK_EQ(backward_conv->custom_call_target(),
205            kCudnnConvBackwardFilterCallTarget);
206   if (window_util::HasSymmetricPadding(backward_conv->window())) {
207     return false;
208   }
209 
210   // A backward filter convolution with uneven padding can be canonicalized to
211   // one with even padding by padding the activations (input) beforehand. For
212   // example,
213   //   BackwardFilterConv(ABCD, xyz, padding_low=1, padding_high=2)
214   // is equivalent to
215   //   ABCD0 = Pad(ABCD, padding_high=1)
216   //   BackwardFilterConv(ABCD0, xyz, padding_low=padding_high=1)
217   // We choose the lesser of padding_low and padding_high as the new padding.
218   HloInstruction* input = backward_conv->mutable_operand(0);
219   Window new_backward_conv_window = backward_conv->window();
220   // input_padding_config is the config of the kPad to be inserted.
221   PaddingConfig input_padding_config =
222       MakeNoPaddingConfig(input->shape().rank());
223   ConvolutionDimensionNumbers backward_conv_dnums =
224       backward_conv->convolution_dimension_numbers();
225   for (size_t i = 0; i < backward_conv->window().dimensions_size(); ++i) {
226     int64 padding_low = backward_conv->window().dimensions(i).padding_low();
227     int64 padding_high = backward_conv->window().dimensions(i).padding_high();
228     if (padding_low < 0 || padding_high < 0) {
229       // TODO(b/32744257): The following canonicalization wouldn't remove
230       // negative padding in a backward convolution, and would therefore cause
231       // cuDNN convolution (which doesn't support negative padding) to fail.
232       return false;
233     }
234     // Compute the new, even padding for the backward conv operation.
235     int64 new_conv_padding = std::min(padding_low, padding_high);
236     int64 dim = backward_conv_dnums.input_spatial_dimensions(i);
237     input_padding_config.mutable_dimensions(dim)->set_edge_padding_low(
238         padding_low - new_conv_padding);
239     input_padding_config.mutable_dimensions(dim)->set_edge_padding_high(
240         padding_high - new_conv_padding);
241 
242     // Since we move some padding from the backward convolution to the kPad, we
243     // need to accordingly reduce the padding amount of the backward convolution
244     // and its inner forward convolution.
245     auto* new_dim = new_backward_conv_window.mutable_dimensions(i);
246     new_dim->set_padding_low(new_conv_padding);
247     new_dim->set_padding_high(new_conv_padding);
248   }
249 
250   // Create a new backward convolution replacing the old one.
251   HloComputation* computation = backward_conv->parent();
252   HloInstruction* output = backward_conv->mutable_operand(1);
253   HloInstruction* padding =
254       computation->AddInstruction(HloInstruction::CreateConstant(
255           LiteralUtil::Zero(input->shape().element_type())));
256   HloInstruction* padded_input =
257       MakePadHlo(input, padding, input_padding_config).ValueOrDie();
258 
259   // The shape of the backward_conv CustomCall is a tuple (conv_result,
260   // scratch_buffer).  Extract out the shape of conv_result.
261   HloInstruction* new_backward_conv =
262       computation->AddInstruction(backward_conv->CloneWithNewOperands(
263           backward_conv->shape(), {padded_input, output}));
264   new_backward_conv->set_window(new_backward_conv_window);
265 
266   VLOG(1) << "Canonicalizing backward filter conv";
267   VLOG(1) << "Replacing:\n  " << backward_conv->ToString() << "\nwith:\n  "
268           << new_backward_conv->ToString();
269 
270   TF_CHECK_OK(
271       computation->ReplaceInstruction(backward_conv, new_backward_conv));
272   return true;
273 }
274 
CanonicalizeBackwardInputConvolution(HloInstruction * backward_conv)275 bool GpuConvPaddingLegalization::CanonicalizeBackwardInputConvolution(
276     HloInstruction* backward_conv) {
277   if (window_util::HasSymmetricPadding(backward_conv->window())) {
278     return false;
279   }
280 
281   Window new_backward_conv_window = backward_conv->window();
282   ConvolutionDimensionNumbers backward_conv_dnums =
283       backward_conv->convolution_dimension_numbers();
284 
285   // The backward_conv CustomCall returns a tuple (conv_result, scratch_memory).
286   // Get the shape of conv_result.
287   Shape backward_conv_shape = backward_conv->shape().tuple_shapes(0);
288 
289   Shape new_backward_conv_shape = backward_conv_shape;
290   for (size_t i = 0; i < backward_conv->window().dimensions_size(); ++i) {
291     int64 padding_low = backward_conv->window().dimensions(i).padding_low();
292     int64 padding_high = backward_conv->window().dimensions(i).padding_high();
293     if (padding_low < 0 || padding_high < 0) {
294       // TODO(b/32744257): The following canonicalization wouldn't remove
295       // negative padding in a backward convolution, and would therefore cause
296       // cuDNN convolution (which doesn't support negative padding) to fail.
297       return false;
298     }
299     // If the backward convolution has uneven padding on the activations, we
300     // move some padding on the larger end to "internal" padding, so that the
301     // backward convolution produces larger activations which get sliced later.
302     //
303     // For example, suppose we have a non-canonical HLO
304     //   [A] = BackwardInputConvolve([a b], [x y z], padding=(low=2,high=1))
305     // where the amount of padding low is larger, we can canonicalize it to
306     //   [B A] = BackwardInputConvolve([a b], [x y z], padding=(low=1,high=1))
307     //   [A] = Slice([B A])
308     if (padding_low > padding_high) {
309       IncreasePaddingLowBy(padding_high - padding_low,
310                            new_backward_conv_window.mutable_dimensions(i));
311     } else if (padding_low < padding_high) {
312       IncreasePaddingHighBy(padding_low - padding_high,
313                             new_backward_conv_window.mutable_dimensions(i));
314     }
315     // Decreasing the padding by X *increases* the size of our output by X.
316     // Note that we have swapped input spatial dimensions with output spatial
317     // dimensions to be compatible with the cuDNN API, so
318     // input_spatial_dimensions(i) gives the i-th spatial dimension of the
319     // output.
320     int64 dim = backward_conv_dnums.input_spatial_dimensions(i);
321     new_backward_conv_shape.set_dimensions(
322         dim, new_backward_conv_shape.dimensions(dim) +
323                  std::abs(padding_low - padding_high));
324   }
325 
326   // Create a new backward convolution replacing the old one.
327   HloComputation* computation = backward_conv->parent();
328   HloInstruction* output = backward_conv->mutable_operand(0);
329   HloInstruction* filter = backward_conv->mutable_operand(1);
330 
331   HloInstruction* new_backward_conv_call =
332       computation->AddInstruction(backward_conv->CloneWithNewOperands(
333           ShapeUtil::MakeTupleShape(
334               {new_backward_conv_shape, ShapeUtil::MakeShape(U8, {0})}),
335           {output, filter}));
336   new_backward_conv_call->set_window(new_backward_conv_window);
337 
338   // The CustomCall created above returns a tuple (conv_result, scratch_memory).
339   // Extract out the two elements.
340   HloInstruction* new_backward_conv =
341       computation->AddInstruction(HloInstruction::CreateGetTupleElement(
342           new_backward_conv_shape, new_backward_conv_call, 0));
343   HloInstruction* new_backward_conv_scratch =
344       computation->AddInstruction(HloInstruction::CreateGetTupleElement(
345           new_backward_conv_call->shape().tuple_shapes(1),
346           new_backward_conv_call, 1));
347 
348   // Slice the new backward convolution.
349   //
350   // Initialize start_indices and limit_indices as no slicing.
351   std::vector<int64> start_indices(new_backward_conv->shape().dimensions_size(),
352                                    0LL);
353   std::vector<int64> limit_indices(
354       new_backward_conv->shape().dimensions().begin(),
355       new_backward_conv->shape().dimensions().end());
356   std::vector<int64> strides(new_backward_conv->shape().dimensions_size(), 1LL);
357   for (size_t i = 0; i < backward_conv->window().dimensions_size(); ++i) {
358     int64 padding_low = backward_conv->window().dimensions(i).padding_low();
359     int64 padding_high = backward_conv->window().dimensions(i).padding_high();
360     // Note that we have swapped input spatial dimensions with output spatial
361     // dimensions to be compatible with the cuDNN API, so
362     // input_spatial_dimensions(i) gives the i-th spatial dimension of the
363     // output.
364     int64 dim = backward_conv_dnums.input_spatial_dimensions(i);
365     if (padding_low > padding_high) {
366       // If the amount of low padding (of the old backward convolution) is
367       // larger, we internally pad the low end of the activations and slice
368       // internal padding out here.
369       start_indices[dim] += padding_low - padding_high;
370     } else if (padding_low < padding_high) {
371       // If the amount of high padding is larger, we slice out the internal
372       // padding on the high end.
373       limit_indices[dim] -= padding_high - padding_low;
374     }
375   }
376 
377   // Replace the old backward convolution with the slice.
378   Shape slice_shape =
379       ShapeInference::InferSliceShape(new_backward_conv->shape(), start_indices,
380                                       limit_indices, strides)
381           .ConsumeValueOrDie();
382   CHECK(ShapeUtil::Compatible(slice_shape, backward_conv_shape))
383       << ShapeUtil::HumanString(slice_shape) << " vs "
384       << ShapeUtil::HumanString(backward_conv_shape);
385 
386   HloInstruction* slice = computation->AddInstruction(
387       HloInstruction::CreateSlice(backward_conv_shape, new_backward_conv,
388                                   start_indices, limit_indices, strides));
389   HloInstruction* new_tuple = computation->AddInstruction(
390       HloInstruction::CreateTuple({slice, new_backward_conv_scratch}));
391 
392   VLOG(1) << "Canonicalizing backward input conv";
393   VLOG(1) << "Replacing:\n  " << backward_conv->ToString() << "\nwith:\n  "
394           << new_tuple->ToString();
395 
396   TF_CHECK_OK(computation->ReplaceInstruction(backward_conv, new_tuple));
397   return true;
398 }
399 
RunOnComputation(HloComputation * computation)400 StatusOr<bool> GpuConvPaddingLegalization::RunOnComputation(
401     HloComputation* computation) {
402   bool changed = false;
403   std::vector<HloCustomCallInstruction*> convs;
404   for (auto* instr : computation->instructions()) {
405     if (IsCustomCallToDnnConvolution(*instr)) {
406       convs.push_back(Cast<HloCustomCallInstruction>(instr));
407     }
408   }
409   for (HloCustomCallInstruction* instruction : convs) {
410     TF_ASSIGN_OR_RETURN(auto kind, GetCudnnConvKind(instruction));
411     changed |= [&] {
412       switch (kind) {
413         case CudnnConvKind::kForward:
414         case CudnnConvKind::kForwardActivation:
415           return CanonicalizeForwardConvolution(instruction);
416         case CudnnConvKind::kBackwardInput:
417           return CanonicalizeBackwardInputConvolution(instruction);
418         case CudnnConvKind::kBackwardFilter:
419           return CanonicalizeBackwardFilterConvolution(instruction);
420       }
421     }();
422   }
423   return changed;
424 }
425 
Run(HloModule * module)426 StatusOr<bool> GpuConvPaddingLegalization::Run(HloModule* module) {
427   bool changed = false;
428   for (HloComputation* computation : module->MakeNonfusionComputations()) {
429     TF_ASSIGN_OR_RETURN(bool result, RunOnComputation(computation));
430     changed |= result;
431   }
432   return changed;
433 }
434 
435 }  // namespace gpu
436 }  // namespace xla
437