1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/gpu_conv_padding_legalization.h"
17
18 #include "absl/memory/memory.h"
19 #include "tensorflow/compiler/xla/literal.h"
20 #include "tensorflow/compiler/xla/literal_util.h"
21 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
22 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
23 #include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
24 #include "tensorflow/compiler/xla/service/shape_inference.h"
25 #include "tensorflow/compiler/xla/util.h"
26 #include "tensorflow/compiler/xla/window_util.h"
27 #include "tensorflow/compiler/xla/xla_data.pb.h"
28
29 namespace xla {
30 namespace gpu {
31
32 namespace {
IsForwardConvolutionCanonical(const HloInstruction & conv)33 bool IsForwardConvolutionCanonical(const HloInstruction& conv) {
34 CHECK(conv.custom_call_target() == kCudnnConvForwardCallTarget ||
35 conv.custom_call_target() == kCudnnConvBiasActivationForwardCallTarget);
36 return window_util::HasSymmetricPadding(conv.window()) &&
37 !window_util::HasNegativePadding(conv.window()) &&
38 !window_util::HasDilation(conv.window());
39 }
40
41 // If the (positive and negative) padding on the input operand of a convolution
42 // can't be folded into a cuDNN convolution libcall (e.g. uneven padding and
43 // dilation), returns kPad and/or kSlice instructions that explicitly apply the
44 // padding; otherwise returns the original input operand. When there is both
45 // positive padding (including dilation) and negative padding, we insert both
46 // kPad and kSlice. Modifies 'conv_window' accordingly if any padding was moved
47 // into a kPad or kSlice op.
MaybePaddedAndSlicedInput(Window * conv_window,const ConvolutionDimensionNumbers & conv_dnums,HloInstruction * input)48 HloInstruction* MaybePaddedAndSlicedInput(
49 Window* conv_window, const ConvolutionDimensionNumbers& conv_dnums,
50 HloInstruction* input) {
51 HloComputation* computation = input->parent();
52 if (!window_util::HasSymmetricPadding(*conv_window) ||
53 window_util::HasBaseDilation(*conv_window)) {
54 // If padding is uneven or has dilation, we insert a kPad instruction that
55 // applies positive padding and dilation.
56 //
57 // TODO(phawkins): If conv_window has asymmetric padding, perhaps instead of
58 // moving all the padding into an explicit pad op, we should keep as much
59 // padding inside of cudnn as possible, on the assumption that padding
60 // within cudnn is basically free, whereas a kPad's cost increases as the
61 // amount of padding increases.
62 PaddingConfig padding_config =
63 MakeNoPaddingConfig(input->shape().dimensions_size());
64 for (size_t i = 0; i < conv_dnums.input_spatial_dimensions().size(); ++i) {
65 int64 dim = conv_dnums.input_spatial_dimensions(i);
66 if (conv_window->dimensions(i).padding_low() > 0) {
67 padding_config.mutable_dimensions(dim)->set_edge_padding_low(
68 conv_window->dimensions(i).padding_low());
69 conv_window->mutable_dimensions(i)->set_padding_low(0);
70 }
71 if (conv_window->dimensions(i).padding_high() > 0) {
72 padding_config.mutable_dimensions(dim)->set_edge_padding_high(
73 conv_window->dimensions(i).padding_high());
74 conv_window->mutable_dimensions(i)->set_padding_high(0);
75 }
76 if (conv_window->dimensions(i).base_dilation() != 1) {
77 padding_config.mutable_dimensions(dim)->set_interior_padding(
78 conv_window->dimensions(i).base_dilation() - 1);
79 conv_window->mutable_dimensions(i)->set_base_dilation(1);
80 }
81 }
82 PrimitiveType element_type = input->shape().element_type();
83 HloInstruction* padding = computation->AddInstruction(
84 HloInstruction::CreateConstant(LiteralUtil::Zero(element_type)));
85 input = MakePadHlo(input, padding, padding_config).ValueOrDie();
86 }
87
88 if (window_util::HasNegativePadding(*conv_window)) {
89 // If the window has negative padding, insert a kSlice that explicitly
90 // applies negative padding.
91 //
92 // For each dimension, initialize the start index to 0 and the limit index
93 // to the size of that dimension.
94 std::vector<int64> start_indices(input->shape().dimensions_size(), 0);
95 std::vector<int64> limit_indices(input->shape().dimensions().begin(),
96 input->shape().dimensions().end());
97 std::vector<int64> strides(input->shape().dimensions_size(), 1);
98 for (size_t i = 0; i < conv_dnums.input_spatial_dimensions().size(); ++i) {
99 int64 dim = conv_dnums.input_spatial_dimensions(i);
100 // If dimension "dim" has negative padding, increase the start index or
101 // decrement the limit index by the amount of negative padding.
102 if (conv_window->dimensions(i).padding_low() < 0) {
103 start_indices[dim] += -conv_window->dimensions(i).padding_low();
104 conv_window->mutable_dimensions(i)->set_padding_low(0);
105 }
106 if (conv_window->dimensions(i).padding_high() < 0) {
107 limit_indices[dim] -= -conv_window->dimensions(i).padding_high();
108 conv_window->mutable_dimensions(i)->set_padding_high(0);
109 }
110 }
111
112 input =
113 MakeSliceHlo(input, start_indices, limit_indices, strides).ValueOrDie();
114 }
115
116 return input;
117 }
118
119 // If the padding on the kernel operand of a convolution can't be folded into a
120 // cuDNN convolution libcall (e.g. dilation), returns a kPad instruction that
121 // explicitly applies the padding; otherwise returns the original kernel
122 // operand.
MaybePaddedKernel(const Window & conv_window,const ConvolutionDimensionNumbers & conv_dnums,HloInstruction * kernel)123 HloInstruction* MaybePaddedKernel(const Window& conv_window,
124 const ConvolutionDimensionNumbers& conv_dnums,
125 HloInstruction* kernel) {
126 if (!window_util::HasWindowDilation(conv_window)) {
127 return kernel;
128 }
129
130 // Compute the shape and padding config of the pad to be inserted.
131 PaddingConfig padding_config;
132 for (size_t i = 0; i < kernel->shape().dimensions_size(); ++i) {
133 padding_config.add_dimensions();
134 }
135 for (size_t i = 0; i < conv_dnums.kernel_spatial_dimensions().size(); ++i) {
136 int64 dim = conv_dnums.kernel_spatial_dimensions(i);
137 padding_config.mutable_dimensions(dim)->set_interior_padding(
138 conv_window.dimensions(i).window_dilation() - 1);
139 }
140
141 HloComputation* computation = kernel->parent();
142 PrimitiveType element_type = kernel->shape().element_type();
143 HloInstruction* padding = computation->AddInstruction(
144 HloInstruction::CreateConstant(LiteralUtil::Zero(element_type)));
145 return MakePadHlo(kernel, padding, padding_config).ValueOrDie();
146 }
147 } // namespace
148
CanonicalizeForwardConvolution(HloInstruction * conv)149 bool GpuConvPaddingLegalization::CanonicalizeForwardConvolution(
150 HloInstruction* conv) {
151 if (IsForwardConvolutionCanonical(*conv)) {
152 return false;
153 }
154
155 // Insert slices and/or pads between the convolution and its input and/or
156 // kernel operand.
157 Window new_conv_window = conv->window();
158 HloInstruction* new_input = MaybePaddedAndSlicedInput(
159 &new_conv_window, conv->convolution_dimension_numbers(),
160 conv->mutable_operand(0));
161 HloInstruction* new_kernel =
162 MaybePaddedKernel(new_conv_window, conv->convolution_dimension_numbers(),
163 conv->mutable_operand(1));
164
165 // Remove the window dilation from convolution's window field. These paddings
166 // are made explicit with the pads inserted by MaybePaddedKernel().
167 for (size_t i = 0; i < new_conv_window.dimensions_size(); ++i) {
168 WindowDimension* dim = new_conv_window.mutable_dimensions(i);
169
170 // The size of the kernel may have changed so update the Window to match.
171 dim->set_size(new_kernel->shape().dimensions(
172 conv->convolution_dimension_numbers().kernel_spatial_dimensions(i)));
173 dim->set_window_dilation(1);
174 }
175
176 // The conv CustomCall returns a tuple (conv_result, scratch_buffer). Extract
177 // out the shape of conv_result.
178 VLOG(1) << "Canonicalizing forward conv";
179 std::vector<HloInstruction*> operands(conv->operands().begin(),
180 conv->operands().end());
181 operands[0] = new_input;
182 operands[1] = new_kernel;
183 auto new_conv = conv->parent()->AddInstruction(
184 conv->CloneWithNewOperands(conv->shape(), operands));
185 new_conv->set_window(new_conv_window);
186 VLOG(1) << "Replacing:\n " << conv->ToString() << "\nwith:\n "
187 << new_conv->ToString();
188 TF_CHECK_OK(conv->parent()->ReplaceInstruction(conv, new_conv));
189 return true;
190 }
191
192 namespace {
IncreasePaddingLowBy(int64 delta,WindowDimension * window_dim)193 void IncreasePaddingLowBy(int64 delta, WindowDimension* window_dim) {
194 window_dim->set_padding_low(window_dim->padding_low() + delta);
195 }
196
IncreasePaddingHighBy(int64 delta,WindowDimension * window_dim)197 void IncreasePaddingHighBy(int64 delta, WindowDimension* window_dim) {
198 window_dim->set_padding_high(window_dim->padding_high() + delta);
199 }
200 } // namespace
201
CanonicalizeBackwardFilterConvolution(HloInstruction * backward_conv)202 bool GpuConvPaddingLegalization::CanonicalizeBackwardFilterConvolution(
203 HloInstruction* backward_conv) {
204 CHECK_EQ(backward_conv->custom_call_target(),
205 kCudnnConvBackwardFilterCallTarget);
206 if (window_util::HasSymmetricPadding(backward_conv->window())) {
207 return false;
208 }
209
210 // A backward filter convolution with uneven padding can be canonicalized to
211 // one with even padding by padding the activations (input) beforehand. For
212 // example,
213 // BackwardFilterConv(ABCD, xyz, padding_low=1, padding_high=2)
214 // is equivalent to
215 // ABCD0 = Pad(ABCD, padding_high=1)
216 // BackwardFilterConv(ABCD0, xyz, padding_low=padding_high=1)
217 // We choose the lesser of padding_low and padding_high as the new padding.
218 HloInstruction* input = backward_conv->mutable_operand(0);
219 Window new_backward_conv_window = backward_conv->window();
220 // input_padding_config is the config of the kPad to be inserted.
221 PaddingConfig input_padding_config =
222 MakeNoPaddingConfig(input->shape().rank());
223 ConvolutionDimensionNumbers backward_conv_dnums =
224 backward_conv->convolution_dimension_numbers();
225 for (size_t i = 0; i < backward_conv->window().dimensions_size(); ++i) {
226 int64 padding_low = backward_conv->window().dimensions(i).padding_low();
227 int64 padding_high = backward_conv->window().dimensions(i).padding_high();
228 if (padding_low < 0 || padding_high < 0) {
229 // TODO(b/32744257): The following canonicalization wouldn't remove
230 // negative padding in a backward convolution, and would therefore cause
231 // cuDNN convolution (which doesn't support negative padding) to fail.
232 return false;
233 }
234 // Compute the new, even padding for the backward conv operation.
235 int64 new_conv_padding = std::min(padding_low, padding_high);
236 int64 dim = backward_conv_dnums.input_spatial_dimensions(i);
237 input_padding_config.mutable_dimensions(dim)->set_edge_padding_low(
238 padding_low - new_conv_padding);
239 input_padding_config.mutable_dimensions(dim)->set_edge_padding_high(
240 padding_high - new_conv_padding);
241
242 // Since we move some padding from the backward convolution to the kPad, we
243 // need to accordingly reduce the padding amount of the backward convolution
244 // and its inner forward convolution.
245 auto* new_dim = new_backward_conv_window.mutable_dimensions(i);
246 new_dim->set_padding_low(new_conv_padding);
247 new_dim->set_padding_high(new_conv_padding);
248 }
249
250 // Create a new backward convolution replacing the old one.
251 HloComputation* computation = backward_conv->parent();
252 HloInstruction* output = backward_conv->mutable_operand(1);
253 HloInstruction* padding =
254 computation->AddInstruction(HloInstruction::CreateConstant(
255 LiteralUtil::Zero(input->shape().element_type())));
256 HloInstruction* padded_input =
257 MakePadHlo(input, padding, input_padding_config).ValueOrDie();
258
259 // The shape of the backward_conv CustomCall is a tuple (conv_result,
260 // scratch_buffer). Extract out the shape of conv_result.
261 HloInstruction* new_backward_conv =
262 computation->AddInstruction(backward_conv->CloneWithNewOperands(
263 backward_conv->shape(), {padded_input, output}));
264 new_backward_conv->set_window(new_backward_conv_window);
265
266 VLOG(1) << "Canonicalizing backward filter conv";
267 VLOG(1) << "Replacing:\n " << backward_conv->ToString() << "\nwith:\n "
268 << new_backward_conv->ToString();
269
270 TF_CHECK_OK(
271 computation->ReplaceInstruction(backward_conv, new_backward_conv));
272 return true;
273 }
274
CanonicalizeBackwardInputConvolution(HloInstruction * backward_conv)275 bool GpuConvPaddingLegalization::CanonicalizeBackwardInputConvolution(
276 HloInstruction* backward_conv) {
277 if (window_util::HasSymmetricPadding(backward_conv->window())) {
278 return false;
279 }
280
281 Window new_backward_conv_window = backward_conv->window();
282 ConvolutionDimensionNumbers backward_conv_dnums =
283 backward_conv->convolution_dimension_numbers();
284
285 // The backward_conv CustomCall returns a tuple (conv_result, scratch_memory).
286 // Get the shape of conv_result.
287 Shape backward_conv_shape = backward_conv->shape().tuple_shapes(0);
288
289 Shape new_backward_conv_shape = backward_conv_shape;
290 for (size_t i = 0; i < backward_conv->window().dimensions_size(); ++i) {
291 int64 padding_low = backward_conv->window().dimensions(i).padding_low();
292 int64 padding_high = backward_conv->window().dimensions(i).padding_high();
293 if (padding_low < 0 || padding_high < 0) {
294 // TODO(b/32744257): The following canonicalization wouldn't remove
295 // negative padding in a backward convolution, and would therefore cause
296 // cuDNN convolution (which doesn't support negative padding) to fail.
297 return false;
298 }
299 // If the backward convolution has uneven padding on the activations, we
300 // move some padding on the larger end to "internal" padding, so that the
301 // backward convolution produces larger activations which get sliced later.
302 //
303 // For example, suppose we have a non-canonical HLO
304 // [A] = BackwardInputConvolve([a b], [x y z], padding=(low=2,high=1))
305 // where the amount of padding low is larger, we can canonicalize it to
306 // [B A] = BackwardInputConvolve([a b], [x y z], padding=(low=1,high=1))
307 // [A] = Slice([B A])
308 if (padding_low > padding_high) {
309 IncreasePaddingLowBy(padding_high - padding_low,
310 new_backward_conv_window.mutable_dimensions(i));
311 } else if (padding_low < padding_high) {
312 IncreasePaddingHighBy(padding_low - padding_high,
313 new_backward_conv_window.mutable_dimensions(i));
314 }
315 // Decreasing the padding by X *increases* the size of our output by X.
316 // Note that we have swapped input spatial dimensions with output spatial
317 // dimensions to be compatible with the cuDNN API, so
318 // input_spatial_dimensions(i) gives the i-th spatial dimension of the
319 // output.
320 int64 dim = backward_conv_dnums.input_spatial_dimensions(i);
321 new_backward_conv_shape.set_dimensions(
322 dim, new_backward_conv_shape.dimensions(dim) +
323 std::abs(padding_low - padding_high));
324 }
325
326 // Create a new backward convolution replacing the old one.
327 HloComputation* computation = backward_conv->parent();
328 HloInstruction* output = backward_conv->mutable_operand(0);
329 HloInstruction* filter = backward_conv->mutable_operand(1);
330
331 HloInstruction* new_backward_conv_call =
332 computation->AddInstruction(backward_conv->CloneWithNewOperands(
333 ShapeUtil::MakeTupleShape(
334 {new_backward_conv_shape, ShapeUtil::MakeShape(U8, {0})}),
335 {output, filter}));
336 new_backward_conv_call->set_window(new_backward_conv_window);
337
338 // The CustomCall created above returns a tuple (conv_result, scratch_memory).
339 // Extract out the two elements.
340 HloInstruction* new_backward_conv =
341 computation->AddInstruction(HloInstruction::CreateGetTupleElement(
342 new_backward_conv_shape, new_backward_conv_call, 0));
343 HloInstruction* new_backward_conv_scratch =
344 computation->AddInstruction(HloInstruction::CreateGetTupleElement(
345 new_backward_conv_call->shape().tuple_shapes(1),
346 new_backward_conv_call, 1));
347
348 // Slice the new backward convolution.
349 //
350 // Initialize start_indices and limit_indices as no slicing.
351 std::vector<int64> start_indices(new_backward_conv->shape().dimensions_size(),
352 0LL);
353 std::vector<int64> limit_indices(
354 new_backward_conv->shape().dimensions().begin(),
355 new_backward_conv->shape().dimensions().end());
356 std::vector<int64> strides(new_backward_conv->shape().dimensions_size(), 1LL);
357 for (size_t i = 0; i < backward_conv->window().dimensions_size(); ++i) {
358 int64 padding_low = backward_conv->window().dimensions(i).padding_low();
359 int64 padding_high = backward_conv->window().dimensions(i).padding_high();
360 // Note that we have swapped input spatial dimensions with output spatial
361 // dimensions to be compatible with the cuDNN API, so
362 // input_spatial_dimensions(i) gives the i-th spatial dimension of the
363 // output.
364 int64 dim = backward_conv_dnums.input_spatial_dimensions(i);
365 if (padding_low > padding_high) {
366 // If the amount of low padding (of the old backward convolution) is
367 // larger, we internally pad the low end of the activations and slice
368 // internal padding out here.
369 start_indices[dim] += padding_low - padding_high;
370 } else if (padding_low < padding_high) {
371 // If the amount of high padding is larger, we slice out the internal
372 // padding on the high end.
373 limit_indices[dim] -= padding_high - padding_low;
374 }
375 }
376
377 // Replace the old backward convolution with the slice.
378 Shape slice_shape =
379 ShapeInference::InferSliceShape(new_backward_conv->shape(), start_indices,
380 limit_indices, strides)
381 .ConsumeValueOrDie();
382 CHECK(ShapeUtil::Compatible(slice_shape, backward_conv_shape))
383 << ShapeUtil::HumanString(slice_shape) << " vs "
384 << ShapeUtil::HumanString(backward_conv_shape);
385
386 HloInstruction* slice = computation->AddInstruction(
387 HloInstruction::CreateSlice(backward_conv_shape, new_backward_conv,
388 start_indices, limit_indices, strides));
389 HloInstruction* new_tuple = computation->AddInstruction(
390 HloInstruction::CreateTuple({slice, new_backward_conv_scratch}));
391
392 VLOG(1) << "Canonicalizing backward input conv";
393 VLOG(1) << "Replacing:\n " << backward_conv->ToString() << "\nwith:\n "
394 << new_tuple->ToString();
395
396 TF_CHECK_OK(computation->ReplaceInstruction(backward_conv, new_tuple));
397 return true;
398 }
399
RunOnComputation(HloComputation * computation)400 StatusOr<bool> GpuConvPaddingLegalization::RunOnComputation(
401 HloComputation* computation) {
402 bool changed = false;
403 std::vector<HloCustomCallInstruction*> convs;
404 for (auto* instr : computation->instructions()) {
405 if (IsCustomCallToDnnConvolution(*instr)) {
406 convs.push_back(Cast<HloCustomCallInstruction>(instr));
407 }
408 }
409 for (HloCustomCallInstruction* instruction : convs) {
410 TF_ASSIGN_OR_RETURN(auto kind, GetCudnnConvKind(instruction));
411 changed |= [&] {
412 switch (kind) {
413 case CudnnConvKind::kForward:
414 case CudnnConvKind::kForwardActivation:
415 return CanonicalizeForwardConvolution(instruction);
416 case CudnnConvKind::kBackwardInput:
417 return CanonicalizeBackwardInputConvolution(instruction);
418 case CudnnConvKind::kBackwardFilter:
419 return CanonicalizeBackwardFilterConvolution(instruction);
420 }
421 }();
422 }
423 return changed;
424 }
425
Run(HloModule * module)426 StatusOr<bool> GpuConvPaddingLegalization::Run(HloModule* module) {
427 bool changed = false;
428 for (HloComputation* computation : module->MakeNonfusionComputations()) {
429 TF_ASSIGN_OR_RETURN(bool result, RunOnComputation(computation));
430 changed |= result;
431 }
432 return changed;
433 }
434
435 } // namespace gpu
436 } // namespace xla
437