1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/gpu_conv_padding_legalization.h"
17
18 #include <memory>
19
20 #include "tensorflow/compiler/xla/literal.h"
21 #include "tensorflow/compiler/xla/literal_util.h"
22 #include "tensorflow/compiler/xla/service/gpu/cublas_cudnn.h"
23 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
24 #include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
25 #include "tensorflow/compiler/xla/service/shape_inference.h"
26 #include "tensorflow/compiler/xla/util.h"
27 #include "tensorflow/compiler/xla/window_util.h"
28 #include "tensorflow/compiler/xla/xla_data.pb.h"
29
30 namespace xla {
31 namespace gpu {
32
33 namespace {
IsForwardConvolutionCanonical(const HloInstruction & conv)34 bool IsForwardConvolutionCanonical(const HloInstruction& conv) {
35 CHECK(conv.custom_call_target() == kCudnnConvForwardCallTarget ||
36 conv.custom_call_target() == kCudnnConvBiasActivationForwardCallTarget);
37 return window_util::HasSymmetricPadding(conv.window()) &&
38 !window_util::HasNegativePadding(conv.window()) &&
39 !window_util::HasDilation(conv.window());
40 }
41
42 // If the (positive and negative) padding on the input operand of a convolution
43 // can't be folded into a cuDNN convolution libcall (e.g. uneven padding and
44 // dilation), returns kPad and/or kSlice instructions that explicitly apply the
45 // padding; otherwise returns the original input operand. When there is both
46 // positive padding (including dilation) and negative padding, we insert both
47 // kPad and kSlice. Modifies 'conv_window' accordingly if any padding was moved
48 // into a kPad or kSlice op.
MaybePaddedAndSlicedInput(Window * conv_window,const ConvolutionDimensionNumbers & conv_dnums,HloInstruction * input)49 HloInstruction* MaybePaddedAndSlicedInput(
50 Window* conv_window, const ConvolutionDimensionNumbers& conv_dnums,
51 HloInstruction* input) {
52 HloComputation* computation = input->parent();
53 if (!window_util::HasSymmetricPadding(*conv_window) ||
54 window_util::HasBaseDilation(*conv_window)) {
55 // If padding is uneven or has dilation, we insert a kPad instruction that
56 // applies positive padding and dilation.
57 //
58 // TODO(phawkins): If conv_window has asymmetric padding, perhaps instead of
59 // moving all the padding into an explicit pad op, we should keep as much
60 // padding inside of cudnn as possible, on the assumption that padding
61 // within cudnn is basically free, whereas a kPad's cost increases as the
62 // amount of padding increases.
63 PaddingConfig padding_config =
64 MakeNoPaddingConfig(input->shape().dimensions_size());
65 for (size_t i = 0; i < conv_dnums.input_spatial_dimensions().size(); ++i) {
66 int64_t dim = conv_dnums.input_spatial_dimensions(i);
67 if (conv_window->dimensions(i).padding_low() > 0) {
68 padding_config.mutable_dimensions(dim)->set_edge_padding_low(
69 conv_window->dimensions(i).padding_low());
70 conv_window->mutable_dimensions(i)->set_padding_low(0);
71 }
72 if (conv_window->dimensions(i).padding_high() > 0) {
73 padding_config.mutable_dimensions(dim)->set_edge_padding_high(
74 conv_window->dimensions(i).padding_high());
75 conv_window->mutable_dimensions(i)->set_padding_high(0);
76 }
77 if (conv_window->dimensions(i).base_dilation() != 1) {
78 padding_config.mutable_dimensions(dim)->set_interior_padding(
79 conv_window->dimensions(i).base_dilation() - 1);
80 conv_window->mutable_dimensions(i)->set_base_dilation(1);
81 }
82 }
83 PrimitiveType element_type = input->shape().element_type();
84 HloInstruction* padding = computation->AddInstruction(
85 HloInstruction::CreateConstant(LiteralUtil::Zero(element_type)));
86 input = MakePadHlo(input, padding, padding_config, &input->metadata())
87 .ValueOrDie();
88 }
89
90 if (window_util::HasNegativePadding(*conv_window)) {
91 // If the window has negative padding, insert a kSlice that explicitly
92 // applies negative padding.
93 //
94 // For each dimension, initialize the start index to 0 and the limit index
95 // to the size of that dimension.
96 std::vector<int64_t> start_indices(input->shape().dimensions_size(), 0);
97 std::vector<int64_t> limit_indices(input->shape().dimensions().begin(),
98 input->shape().dimensions().end());
99 std::vector<int64_t> strides(input->shape().dimensions_size(), 1);
100 for (size_t i = 0; i < conv_dnums.input_spatial_dimensions().size(); ++i) {
101 int64_t dim = conv_dnums.input_spatial_dimensions(i);
102 // If dimension "dim" has negative padding, increase the start index or
103 // decrement the limit index by the amount of negative padding.
104 if (conv_window->dimensions(i).padding_low() < 0) {
105 start_indices[dim] += -conv_window->dimensions(i).padding_low();
106 conv_window->mutable_dimensions(i)->set_padding_low(0);
107 }
108 if (conv_window->dimensions(i).padding_high() < 0) {
109 limit_indices[dim] -= -conv_window->dimensions(i).padding_high();
110 conv_window->mutable_dimensions(i)->set_padding_high(0);
111 }
112 }
113
114 input =
115 MakeSliceHlo(input, start_indices, limit_indices, strides).ValueOrDie();
116 }
117
118 return input;
119 }
120
121 // If the padding on the kernel operand of a convolution can't be folded into a
122 // cuDNN convolution libcall (e.g. dilation), returns a kPad instruction that
123 // explicitly applies the padding; otherwise returns the original kernel
124 // operand.
MaybePaddedKernel(const Window & conv_window,const ConvolutionDimensionNumbers & conv_dnums,HloInstruction * kernel)125 HloInstruction* MaybePaddedKernel(const Window& conv_window,
126 const ConvolutionDimensionNumbers& conv_dnums,
127 HloInstruction* kernel) {
128 if (!window_util::HasWindowDilation(conv_window)) {
129 return kernel;
130 }
131
132 // Compute the shape and padding config of the pad to be inserted.
133 PaddingConfig padding_config;
134 for (size_t i = 0; i < kernel->shape().dimensions_size(); ++i) {
135 padding_config.add_dimensions();
136 }
137 for (size_t i = 0; i < conv_dnums.kernel_spatial_dimensions().size(); ++i) {
138 int64_t dim = conv_dnums.kernel_spatial_dimensions(i);
139 padding_config.mutable_dimensions(dim)->set_interior_padding(
140 conv_window.dimensions(i).window_dilation() - 1);
141 }
142
143 HloComputation* computation = kernel->parent();
144 PrimitiveType element_type = kernel->shape().element_type();
145 HloInstruction* padding = computation->AddInstruction(
146 HloInstruction::CreateConstant(LiteralUtil::Zero(element_type)));
147 return MakePadHlo(kernel, padding, padding_config, &kernel->metadata())
148 .ValueOrDie();
149 }
150 } // namespace
151
CanonicalizeForwardConvolution(HloInstruction * conv)152 bool GpuConvPaddingLegalization::CanonicalizeForwardConvolution(
153 HloInstruction* conv) {
154 if (IsForwardConvolutionCanonical(*conv)) {
155 return false;
156 }
157
158 // Insert slices and/or pads between the convolution and its input and/or
159 // kernel operand.
160 Window new_conv_window = conv->window();
161 HloInstruction* new_input = MaybePaddedAndSlicedInput(
162 &new_conv_window, conv->convolution_dimension_numbers(),
163 conv->mutable_operand(0));
164 HloInstruction* new_kernel =
165 MaybePaddedKernel(new_conv_window, conv->convolution_dimension_numbers(),
166 conv->mutable_operand(1));
167
168 // Remove the window dilation from convolution's window field. These paddings
169 // are made explicit with the pads inserted by MaybePaddedKernel().
170 for (size_t i = 0; i < new_conv_window.dimensions_size(); ++i) {
171 WindowDimension* dim = new_conv_window.mutable_dimensions(i);
172
173 // The size of the kernel may have changed so update the Window to match.
174 dim->set_size(new_kernel->shape().dimensions(
175 conv->convolution_dimension_numbers().kernel_spatial_dimensions(i)));
176 dim->set_window_dilation(1);
177 }
178
179 // The conv CustomCall returns a tuple (conv_result, scratch_buffer). Extract
180 // out the shape of conv_result.
181 VLOG(1) << "Canonicalizing forward conv";
182 std::vector<HloInstruction*> operands(conv->operands().begin(),
183 conv->operands().end());
184 operands[0] = new_input;
185 operands[1] = new_kernel;
186 auto new_conv = conv->parent()->AddInstruction(
187 conv->CloneWithNewOperands(conv->shape(), operands));
188 new_conv->set_window(new_conv_window);
189 VLOG(1) << "Replacing:\n " << conv->ToString() << "\nwith:\n "
190 << new_conv->ToString();
191 TF_CHECK_OK(conv->parent()->ReplaceInstruction(conv, new_conv));
192 return true;
193 }
194
195 namespace {
IncreasePaddingLowBy(int64_t delta,WindowDimension * window_dim)196 void IncreasePaddingLowBy(int64_t delta, WindowDimension* window_dim) {
197 window_dim->set_padding_low(window_dim->padding_low() + delta);
198 }
199
IncreasePaddingHighBy(int64_t delta,WindowDimension * window_dim)200 void IncreasePaddingHighBy(int64_t delta, WindowDimension* window_dim) {
201 window_dim->set_padding_high(window_dim->padding_high() + delta);
202 }
203 } // namespace
204
CanonicalizeBackwardFilterConvolution(HloInstruction * backward_conv)205 bool GpuConvPaddingLegalization::CanonicalizeBackwardFilterConvolution(
206 HloInstruction* backward_conv) {
207 CHECK_EQ(backward_conv->custom_call_target(),
208 kCudnnConvBackwardFilterCallTarget);
209 if (window_util::HasSymmetricPadding(backward_conv->window())) {
210 return false;
211 }
212
213 // A backward filter convolution with uneven padding can be canonicalized to
214 // one with even padding by padding the activations (input) beforehand. For
215 // example,
216 // BackwardFilterConv(ABCD, xyz, padding_low=1, padding_high=2)
217 // is equivalent to
218 // ABCD0 = Pad(ABCD, padding_high=1)
219 // BackwardFilterConv(ABCD0, xyz, padding_low=padding_high=1)
220 // We choose the lesser of padding_low and padding_high as the new padding.
221 HloInstruction* input = backward_conv->mutable_operand(0);
222 Window new_backward_conv_window = backward_conv->window();
223 // input_padding_config is the config of the kPad to be inserted.
224 PaddingConfig input_padding_config =
225 MakeNoPaddingConfig(input->shape().rank());
226 ConvolutionDimensionNumbers backward_conv_dnums =
227 backward_conv->convolution_dimension_numbers();
228 for (size_t i = 0; i < backward_conv->window().dimensions_size(); ++i) {
229 int64_t padding_low = backward_conv->window().dimensions(i).padding_low();
230 int64_t padding_high = backward_conv->window().dimensions(i).padding_high();
231 if (padding_low < 0 || padding_high < 0) {
232 // TODO(b/32744257): The following canonicalization wouldn't remove
233 // negative padding in a backward convolution, and would therefore cause
234 // cuDNN convolution (which doesn't support negative padding) to fail.
235 return false;
236 }
237 // Compute the new, even padding for the backward conv operation.
238 int64_t new_conv_padding = std::min(padding_low, padding_high);
239 int64_t dim = backward_conv_dnums.input_spatial_dimensions(i);
240 input_padding_config.mutable_dimensions(dim)->set_edge_padding_low(
241 padding_low - new_conv_padding);
242 input_padding_config.mutable_dimensions(dim)->set_edge_padding_high(
243 padding_high - new_conv_padding);
244
245 // Since we move some padding from the backward convolution to the kPad, we
246 // need to accordingly reduce the padding amount of the backward convolution
247 // and its inner forward convolution.
248 auto* new_dim = new_backward_conv_window.mutable_dimensions(i);
249 new_dim->set_padding_low(new_conv_padding);
250 new_dim->set_padding_high(new_conv_padding);
251 }
252
253 // Create a new backward convolution replacing the old one.
254 HloComputation* computation = backward_conv->parent();
255 HloInstruction* output = backward_conv->mutable_operand(1);
256 HloInstruction* padding =
257 computation->AddInstruction(HloInstruction::CreateConstant(
258 LiteralUtil::Zero(input->shape().element_type())));
259 HloInstruction* padded_input =
260 MakePadHlo(input, padding, input_padding_config).ValueOrDie();
261
262 // The shape of the backward_conv CustomCall is a tuple (conv_result,
263 // scratch_buffer). Extract out the shape of conv_result.
264 HloInstruction* new_backward_conv =
265 computation->AddInstruction(backward_conv->CloneWithNewOperands(
266 backward_conv->shape(), {padded_input, output}));
267 new_backward_conv->set_window(new_backward_conv_window);
268
269 VLOG(1) << "Canonicalizing backward filter conv";
270 VLOG(1) << "Replacing:\n " << backward_conv->ToString() << "\nwith:\n "
271 << new_backward_conv->ToString();
272
273 TF_CHECK_OK(
274 computation->ReplaceInstruction(backward_conv, new_backward_conv));
275 return true;
276 }
277
CanonicalizeBackwardInputConvolution(HloInstruction * backward_conv)278 bool GpuConvPaddingLegalization::CanonicalizeBackwardInputConvolution(
279 HloInstruction* backward_conv) {
280 if (window_util::HasSymmetricPadding(backward_conv->window())) {
281 return false;
282 }
283
284 Window new_backward_conv_window = backward_conv->window();
285 ConvolutionDimensionNumbers backward_conv_dnums =
286 backward_conv->convolution_dimension_numbers();
287
288 // The backward_conv CustomCall returns a tuple (conv_result, scratch_memory).
289 // Get the shape of conv_result.
290 Shape backward_conv_shape = backward_conv->shape().tuple_shapes(0);
291
292 Shape new_backward_conv_shape = backward_conv_shape;
293 for (size_t i = 0; i < backward_conv->window().dimensions_size(); ++i) {
294 int64_t padding_low = backward_conv->window().dimensions(i).padding_low();
295 int64_t padding_high = backward_conv->window().dimensions(i).padding_high();
296 if (padding_low < 0 || padding_high < 0) {
297 // TODO(b/32744257): The following canonicalization wouldn't remove
298 // negative padding in a backward convolution, and would therefore cause
299 // cuDNN convolution (which doesn't support negative padding) to fail.
300 return false;
301 }
302 // If the backward convolution has uneven padding on the activations, we
303 // move some padding on the larger end to "internal" padding, so that the
304 // backward convolution produces larger activations which get sliced later.
305 //
306 // For example, suppose we have a non-canonical HLO
307 // [A] = BackwardInputConvolve([a b], [x y z], padding=(low=2,high=1))
308 // where the amount of padding low is larger, we can canonicalize it to
309 // [B A] = BackwardInputConvolve([a b], [x y z], padding=(low=1,high=1))
310 // [A] = Slice([B A])
311 if (padding_low > padding_high) {
312 IncreasePaddingLowBy(padding_high - padding_low,
313 new_backward_conv_window.mutable_dimensions(i));
314 } else if (padding_low < padding_high) {
315 IncreasePaddingHighBy(padding_low - padding_high,
316 new_backward_conv_window.mutable_dimensions(i));
317 }
318 // Decreasing the padding by X *increases* the size of our output by X.
319 // Note that we have swapped input spatial dimensions with output spatial
320 // dimensions to be compatible with the cuDNN API, so
321 // input_spatial_dimensions(i) gives the i-th spatial dimension of the
322 // output.
323 int64_t dim = backward_conv_dnums.input_spatial_dimensions(i);
324 new_backward_conv_shape.set_dimensions(
325 dim, new_backward_conv_shape.dimensions(dim) +
326 std::abs(padding_low - padding_high));
327 }
328
329 // Create a new backward convolution replacing the old one.
330 HloComputation* computation = backward_conv->parent();
331 HloInstruction* output = backward_conv->mutable_operand(0);
332 HloInstruction* filter = backward_conv->mutable_operand(1);
333
334 HloInstruction* new_backward_conv_call =
335 computation->AddInstruction(backward_conv->CloneWithNewOperands(
336 ShapeUtil::MakeTupleShape(
337 {new_backward_conv_shape, ShapeUtil::MakeShape(U8, {0})}),
338 {output, filter}));
339 new_backward_conv_call->set_window(new_backward_conv_window);
340
341 // The CustomCall created above returns a tuple (conv_result, scratch_memory).
342 // Extract out the two elements.
343 HloInstruction* new_backward_conv =
344 computation->AddInstruction(HloInstruction::CreateGetTupleElement(
345 new_backward_conv_shape, new_backward_conv_call, 0));
346 HloInstruction* new_backward_conv_scratch =
347 computation->AddInstruction(HloInstruction::CreateGetTupleElement(
348 new_backward_conv_call->shape().tuple_shapes(1),
349 new_backward_conv_call, 1));
350
351 // Slice the new backward convolution.
352 //
353 // Initialize start_indices and limit_indices as no slicing.
354 std::vector<int64_t> start_indices(
355 new_backward_conv->shape().dimensions_size(), 0LL);
356 std::vector<int64_t> limit_indices(
357 new_backward_conv->shape().dimensions().begin(),
358 new_backward_conv->shape().dimensions().end());
359 std::vector<int64_t> strides(new_backward_conv->shape().dimensions_size(),
360 1LL);
361 for (size_t i = 0; i < backward_conv->window().dimensions_size(); ++i) {
362 int64_t padding_low = backward_conv->window().dimensions(i).padding_low();
363 int64_t padding_high = backward_conv->window().dimensions(i).padding_high();
364 // Note that we have swapped input spatial dimensions with output spatial
365 // dimensions to be compatible with the cuDNN API, so
366 // input_spatial_dimensions(i) gives the i-th spatial dimension of the
367 // output.
368 int64_t dim = backward_conv_dnums.input_spatial_dimensions(i);
369 if (padding_low > padding_high) {
370 // If the amount of low padding (of the old backward convolution) is
371 // larger, we internally pad the low end of the activations and slice
372 // internal padding out here.
373 start_indices[dim] += padding_low - padding_high;
374 } else if (padding_low < padding_high) {
375 // If the amount of high padding is larger, we slice out the internal
376 // padding on the high end.
377 limit_indices[dim] -= padding_high - padding_low;
378 }
379 }
380
381 // Replace the old backward convolution with the slice.
382 Shape slice_shape =
383 ShapeInference::InferSliceShape(new_backward_conv->shape(), start_indices,
384 limit_indices, strides)
385 .value();
386 CHECK(ShapeUtil::Compatible(slice_shape, backward_conv_shape))
387 << ShapeUtil::HumanString(slice_shape) << " vs "
388 << ShapeUtil::HumanString(backward_conv_shape);
389
390 HloInstruction* slice = computation->AddInstruction(
391 HloInstruction::CreateSlice(backward_conv_shape, new_backward_conv,
392 start_indices, limit_indices, strides));
393 HloInstruction* new_tuple = computation->AddInstruction(
394 HloInstruction::CreateTuple({slice, new_backward_conv_scratch}));
395
396 VLOG(1) << "Canonicalizing backward input conv";
397 VLOG(1) << "Replacing:\n " << backward_conv->ToString() << "\nwith:\n "
398 << new_tuple->ToString();
399
400 TF_CHECK_OK(computation->ReplaceInstruction(backward_conv, new_tuple));
401 return true;
402 }
403
RunOnComputation(HloComputation * computation)404 StatusOr<bool> GpuConvPaddingLegalization::RunOnComputation(
405 HloComputation* computation) {
406 bool changed = false;
407 std::vector<HloCustomCallInstruction*> convs;
408 for (auto* instr : computation->instructions()) {
409 if (IsCustomCallToDnnConvolution(*instr)) {
410 convs.push_back(Cast<HloCustomCallInstruction>(instr));
411 }
412 }
413 for (HloCustomCallInstruction* instruction : convs) {
414 TF_ASSIGN_OR_RETURN(auto kind, GetCudnnConvKind(instruction));
415 changed |= [&] {
416 switch (kind) {
417 case CudnnConvKind::kForward:
418 case CudnnConvKind::kForwardActivation:
419 return CanonicalizeForwardConvolution(instruction);
420 case CudnnConvKind::kBackwardInput:
421 return CanonicalizeBackwardInputConvolution(instruction);
422 case CudnnConvKind::kBackwardFilter:
423 return CanonicalizeBackwardFilterConvolution(instruction);
424 }
425 }();
426 }
427 return changed;
428 }
429
Run(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)430 StatusOr<bool> GpuConvPaddingLegalization::Run(
431 HloModule* module,
432 const absl::flat_hash_set<absl::string_view>& execution_threads) {
433 bool changed = false;
434 for (HloComputation* computation :
435 module->MakeNonfusionComputations(execution_threads)) {
436 TF_ASSIGN_OR_RETURN(bool result, RunOnComputation(computation));
437 changed |= result;
438 }
439 return changed;
440 }
441
442 } // namespace gpu
443 } // namespace xla
444