1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/gpu_conv_rewriter.h"
17
18 #include <cstdlib>
19 #include <numeric>
20 #include <vector>
21
22 #include "tensorflow/compiler/xla/literal.h"
23 #include "tensorflow/compiler/xla/permutation_util.h"
24 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
25 #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
26 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
27 #include "tensorflow/compiler/xla/service/hlo_computation.h"
28 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
29 #include "tensorflow/compiler/xla/util.h"
30 #include "tensorflow/compiler/xla/window_util.h"
31 #include "tensorflow/compiler/xla/xla_data.pb.h"
32 #include "tensorflow/core/lib/core/status.h"
33 #include "tensorflow/core/platform/logging.h"
34
35 namespace xla {
36 namespace gpu {
37
38 namespace {
39
CreateGpuConv(const char * call_target,const Shape & shape,HloInstruction * lhs,HloInstruction * rhs,const Window & window,const ConvolutionDimensionNumbers & dnums,int64 feature_group_count,const OpMetadata & metadata)40 HloInstruction* CreateGpuConv(const char* call_target, const Shape& shape,
41 HloInstruction* lhs, HloInstruction* rhs,
42 const Window& window,
43 const ConvolutionDimensionNumbers& dnums,
44 int64 feature_group_count,
45 const OpMetadata& metadata) {
46 HloComputation* computation = lhs->parent();
47
48 // This call returns a tuple of (conv_result, scratch_memory), where
49 // conv_result is the actual result of the convolution, and scratch_memory is
50 // temporary memory used by cudnn.
51 //
52 // At the moment, we don't know how much scratch memory this conv is going to
53 // use, so we put u8[0] in this place. Later on another pass will choose
54 // which conv algorithm to use, and at that point we'll modify the shape of
55 // this second tuple element.
56 Shape call_shape =
57 ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U8, {0})});
58
59 HloInstruction* custom_call = computation->AddInstruction(
60 HloInstruction::CreateCustomCall(call_shape, {lhs, rhs}, call_target));
61 custom_call->set_window(window);
62 custom_call->set_convolution_dimension_numbers(dnums);
63 custom_call->set_feature_group_count(feature_group_count);
64 custom_call->set_metadata(metadata);
65 return custom_call;
66 }
67
ConvertBatchGroupedToFeatureGroupedConvolution(HloInstruction * conv)68 HloInstruction* ConvertBatchGroupedToFeatureGroupedConvolution(
69 HloInstruction* conv) {
70 CHECK_EQ(conv->feature_group_count(), 1);
71 int64 num_groups = conv->batch_group_count();
72 auto dim_numbers = conv->convolution_dimension_numbers();
73 auto lhs = conv->mutable_operand(0);
74 auto rhs = conv->mutable_operand(1);
75
76 int64 input_batch_dimension = dim_numbers.input_batch_dimension();
77
78 Shape output_shape = conv->shape();
79 int64 input_feature_dimension = dim_numbers.input_feature_dimension();
80 int64 input_feature = lhs->shape().dimensions(input_feature_dimension);
81
82 HloComputation* computation = lhs->parent();
83 auto add = [&](std::unique_ptr<HloInstruction> inst) {
84 return computation->AddInstruction(std::move(inst));
85 };
86 // Reshape batch_dim N -> [G, N/G]
87 std::vector<int64> reshape_dims = SpanToVector(lhs->shape().dimensions());
88 reshape_dims[input_batch_dimension] =
89 reshape_dims[input_batch_dimension] / num_groups;
90 reshape_dims.insert(reshape_dims.begin() + input_batch_dimension, num_groups);
91 lhs = add(HloInstruction::CreateReshape(
92 ShapeUtil::MakeShape(lhs->shape().element_type(), reshape_dims), lhs));
93
94 // Transpose G to the axis before C, For eg: [G, N/G, H, W, C ] -> [N/G, H,
95 // W, G, C]
96 std::vector<int64> transpose_dims(lhs->shape().dimensions_size());
97 std::iota(transpose_dims.begin(), transpose_dims.end(), 0);
98 transpose_dims.erase(transpose_dims.begin() + input_batch_dimension);
99 transpose_dims.insert(transpose_dims.begin() + input_feature_dimension,
100 input_batch_dimension);
101 std::vector<int64> transpose_reshape_dims =
102 ComposePermutations(lhs->shape().dimensions(), transpose_dims);
103 lhs = add(HloInstruction::CreateTranspose(
104 ShapeUtil::MakeShape(lhs->shape().element_type(), transpose_reshape_dims),
105 lhs, transpose_dims));
106
107 // Merge [G,C] -> [C*G]
108 Shape new_shape = lhs->shape();
109 new_shape.DeleteDimension(input_feature_dimension);
110 new_shape.set_dimensions(input_feature_dimension, input_feature * num_groups);
111 lhs = add(HloInstruction::CreateReshape(new_shape, lhs));
112
113 std::vector<HloInstruction*> new_operands = {lhs, rhs};
114 auto new_conv = conv->CloneWithNewOperands(output_shape, new_operands);
115 new_conv->set_feature_group_count(num_groups);
116 new_conv->set_batch_group_count(1);
117 new_conv->set_convolution_dimension_numbers(dim_numbers);
118 return computation->AddInstruction(std::move(new_conv));
119 }
120
CanImplementAsGpuForwardConv(HloInstruction * conv)121 bool CanImplementAsGpuForwardConv(HloInstruction* conv) {
122 const ConvolutionDimensionNumbers& dnums =
123 conv->convolution_dimension_numbers();
124 if (dnums.input_spatial_dimensions_size() > 3) {
125 return false;
126 }
127
128 // CuDNN does not accept zero-element arguments
129 if (ShapeUtil::IsZeroElementArray(conv->operand(0)->shape()) ||
130 ShapeUtil::IsZeroElementArray(conv->operand(1)->shape())) {
131 return false;
132 }
133
134 // CuDNN can perform either cross correlation (no reversal),
135 // or convolution (all dimensions reversed).
136 if (dnums.input_spatial_dimensions_size() == 2
137 ? !window_util::AllOrNoneReversed(conv->window())
138 : window_util::HasWindowReversal(conv->window())) {
139 return false;
140 }
141 return true;
142 }
143
144 // Try to match a backward filter pattern that contains "conv".
145 // Precondition: "conv" is a kConvolution.
146 std::tuple<bool, Window, ConvolutionDimensionNumbers, HloInstruction*>
MatchBackwardFilter(HloInstruction * conv)147 MatchBackwardFilter(HloInstruction* conv) {
148 VLOG(2) << "Trying to match convolution backward filter.";
149 const auto no_match_result =
150 std::make_tuple(false, Window(), ConvolutionDimensionNumbers(), nullptr);
151
152 if (conv->feature_group_count() > 1) {
153 VLOG(1) << conv->ToString()
154 << " is a forward convolution. All grouped backward filters are "
155 "mapped to batch grouped convolutions in tf2xla bridge. Hence "
156 "backward filter "
157 "convolutions cannot have feature groups greater than 1 at this "
158 "point. No need to fold to backward filter.";
159 return no_match_result;
160 }
161
162 // Step 1: match the instruction pattern without considering the paddings and
163 // dimension numbers just yet. We may need some generic pattern matcher
164 // similar to third_party/llvm/llvm/include/llvm/IR/PatternMatch.h
165 //
166 // Backward filter convolution is implemented in XLA as the forward
167 // convolution of padded activations and dilated gradients. Padding on
168 // activations and dilation on gradients are specified in the "window" field
169 // of the forward convolution.
170 //
171 // activations gradients
172 // \ /
173 // v v
174 // Convolution
175 // conv
176 CHECK_EQ(HloOpcode::kConvolution, conv->opcode());
177
178 // Step 2: match paddings and dimension numbers of the forward convolution.
179 const ConvolutionDimensionNumbers& conv_dnums =
180 conv->convolution_dimension_numbers();
181 auto input_batch_dim = conv_dnums.input_batch_dimension();
182 auto input_feature_dim = conv_dnums.input_feature_dimension();
183 auto input_spatial_dims = conv_dnums.input_spatial_dimensions();
184 auto kernel_input_feature_dim = conv_dnums.kernel_input_feature_dimension();
185 auto kernel_output_feature_dim = conv_dnums.kernel_output_feature_dimension();
186 auto kernel_spatial_dims = conv_dnums.kernel_spatial_dimensions();
187 auto output_batch_dim = conv_dnums.output_batch_dimension();
188 auto output_feature_dim = conv_dnums.output_feature_dimension();
189 auto output_spatial_dims = conv_dnums.output_spatial_dimensions();
190 for (const WindowDimension& window_dim : conv->window().dimensions()) {
191 if (window_dim.stride() != 1) {
192 VLOG(1) << "Forward convolution's window "
193 << conv->window().ShortDebugString()
194 << " should have stride of 1.";
195 return no_match_result;
196 }
197 if (window_dim.base_dilation() != 1) {
198 VLOG(1) << "Forward convolution's window "
199 << conv->window().ShortDebugString()
200 << " should have no base (LHS) dilation.";
201 return no_match_result;
202 }
203 if (window_dim.padding_low() < 0) {
204 VLOG(1) << "Padding low should be non-negative.";
205 return no_match_result;
206 }
207 if (window_dim.window_reversal()) {
208 VLOG(1) << "Window reversal field not supported";
209 return no_match_result;
210 }
211 // Padding high will be checked in Step 3.
212 }
213 // Mathematically, there is no difference between convolution forward vs
214 // backward filter. A backward filter:
215 // [N, O, H+h-1, W+w-1] x [N, C, H, W] -> [O, C, h, w]
216 // Can be treated as a forward convolution with `N` treated as the new
217 // contracting (feature) dimension, `O` treated as the new batch dimension,
218 // and `C` treated as the new output feature dimension. The only difference is
219 // layouts and performance.
220 //
221 // Since there is no way to precisely tell whether we want a foward conv or
222 // backward filter conv, we have to rely on heuristics. Empirically forward
223 // convolutions have very small kernel dimensions, while in the backward pass
224 // "kernel dimensions" are large. If kernel dimensions are smaller than the
225 // output dimensions, return foward conv; otherwise proceed with backward
226 // filter conv.
227 if ((kernel_spatial_dims.empty() ||
228 conv->operand(1)->shape().dimensions(kernel_spatial_dims[0]) <=
229 conv->shape().dimensions(output_spatial_dims[0])) &&
230 !window_util::HasWindowDilation(conv->window())) {
231 VLOG(1) << conv->ToString()
232 << " is a regular forward convolution. No need "
233 "to fold it to a backward filter convolution....";
234 return no_match_result;
235 }
236
237 // Step 3: fuse the matched HLOs into a backward convolution instruction.
238 //
239 // Compute the window of the backward convolution.
240 Window backward_conv_window;
241 for (int i = 0; i < input_spatial_dims.size(); ++i) {
242 WindowDimension* dim = backward_conv_window.add_dimensions();
243 // The window size of the backward convolution equals the output size of the
244 // forward convolution.
245 int64 filter_size = conv->shape().dimensions(output_spatial_dims[i]);
246 dim->set_size(filter_size);
247 // The window stride equals the window dilation of the forward convolution.
248 dim->set_stride(conv->window().dimensions(i).window_dilation());
249 // The window's low padding is the same as the low padding of the
250 // activations.
251 dim->set_padding_low(conv->window().dimensions(i).padding_low());
252 dim->set_base_dilation(1);
253 dim->set_window_dilation(1);
254
255 int64 input_size =
256 conv->operand(0)->shape().dimensions(input_spatial_dims[i]);
257 int64 output_size = conv->window().dimensions(i).size();
258 // Compute the range of the amount of valid high padding. We first compute
259 // min_padding_high, the amount of padding on the right/bottom to ensure the
260 // last patch ends at the border, i.e.,
261 //
262 // input_size + dim->padding_low() + min_padding_high
263 // = (output_size - 1) * stride + filter_size
264 //
265 // Because convolution ignores trailing incomplete windows, any amount of
266 // padding high from min_padding_high to min_padding_high+stride-1
267 // (max_padding_high) has the same effect.
268 int64 padded_input_size = filter_size + (output_size - 1) * dim->stride();
269 int64 min_padding_high =
270 padded_input_size - input_size - dim->padding_low();
271 int64 max_padding_high = min_padding_high + dim->stride() - 1;
272 CHECK_GE(dim->padding_low(), 0);
273 // In practice, since cuDNN convolution only supports even padding, we make
274 // the amount of high padding the same as the amount of low padding as long
275 // as it is between min_padding_high and max_padding_high. If it is not in
276 // that range, we pick the one that's closest to dim->padding_low() and let
277 // GpuConvPaddingLegalization canonicalize the resultant backward
278 // convolution later. Picking the closest one minimizes the cost of the kPad
279 // instruction to be inserted by GpuConvPaddingLegalization.
280 if (dim->padding_low() >= min_padding_high &&
281 dim->padding_low() <= max_padding_high) {
282 dim->set_padding_high(dim->padding_low());
283 } else {
284 if (dim->padding_low() < min_padding_high) {
285 dim->set_padding_high(min_padding_high);
286 } else {
287 dim->set_padding_high(max_padding_high);
288 }
289 }
290 if (dim->padding_high() < 0) {
291 LOG(WARNING)
292 << "Fusing this pattern to backward filter convolution would cause "
293 "negative padding ("
294 << dim->padding_high()
295 << ") on right/bottom of the weight gradients, which is not "
296 "supported by GpuConvPaddingLegalization (b/32744257). "
297 "Falling back to "
298 "unfused convolution for instruction: "
299 << conv->ToString();
300 return no_match_result;
301 }
302 }
303
304 // Restore the dimension numbers of the backward convolution from the forward
305 // convolution. The two activation dimensions are reversed (batch and
306 // feature).
307 ConvolutionDimensionNumbers backward_conv_dnums;
308 backward_conv_dnums.set_input_batch_dimension(input_feature_dim);
309 backward_conv_dnums.set_input_feature_dimension(input_batch_dim);
310 for (int i = 0; i < input_spatial_dims.size(); ++i) {
311 backward_conv_dnums.add_input_spatial_dimensions(input_spatial_dims[i]);
312 }
313 backward_conv_dnums.set_output_batch_dimension(kernel_input_feature_dim);
314 backward_conv_dnums.set_output_feature_dimension(kernel_output_feature_dim);
315 for (int i = 0; i < kernel_spatial_dims.size(); ++i) {
316 backward_conv_dnums.add_output_spatial_dimensions(kernel_spatial_dims[i]);
317 }
318 // The dimension numbering of the output of the forward convolution (before
319 // transposition) is the same as that of the activations (according to the
320 // semantics of kConvolution). The batch dimension of the activations should
321 // be treated as the input feature dimension, and the feature dimension should
322 // be treated as the output feature.
323 backward_conv_dnums.set_kernel_input_feature_dimension(output_batch_dim);
324 backward_conv_dnums.set_kernel_output_feature_dimension(output_feature_dim);
325 for (int i = 0; i < output_spatial_dims.size(); ++i) {
326 backward_conv_dnums.add_kernel_spatial_dimensions(output_spatial_dims[i]);
327 }
328
329 HloInstruction* lhs = conv->mutable_operand(0);
330 return std::make_tuple(true, backward_conv_window, backward_conv_dnums, lhs);
331 }
332
333 // Try to match a backward input pattern that contains "conv".
334 // Precondition: "conv" is a kConvolution.
335 std::tuple<bool, Window, ConvolutionDimensionNumbers, HloInstruction*>
MatchBackwardInput(HloInstruction * conv)336 MatchBackwardInput(HloInstruction* conv) {
337 VLOG(2) << "Trying to match convolution backward input.";
338 const auto no_match_result =
339 std::make_tuple(false, Window(), ConvolutionDimensionNumbers(), nullptr);
340
341 // Match instruction pattern.
342 CHECK_EQ(HloOpcode::kConvolution, conv->opcode());
343 HloInstruction* reverse_filter = conv->mutable_operand(1);
344 ConvolutionDimensionNumbers dnums = conv->convolution_dimension_numbers();
345
346 // We pattern-match to a backwards input conv if:
347 //
348 // - all spatial dims of the filter are reversed
349 //
350 // OR
351 //
352 // - filter is 1x1 or a constant AND
353 // - conv has base dilation (otherwise this is just a regular forward conv).
354 //
355 // The final criterion above is just for canonicalization; cudnn seems to run
356 // just as fast if we canonicalize 1x1/constant filters without base dilation
357 // to forward or backward convs. We canonicalize to forward conv because (a)
358 // it's more natural (constant filters usually show up when doing inference,
359 // and having backwards convolutions in inference graphs would be weird), and
360 // (b) cudnn has special fusions for forward conv plus bias and activation,
361 // and we want to pattern-match to that after running this pass.
362 bool is_reversed_filter =
363 reverse_filter->opcode() == HloOpcode::kReverse &&
364 absl::c_is_permutation(dnums.kernel_spatial_dimensions(),
365 reverse_filter->dimensions());
366 bool is_1x1_filter =
367 absl::c_all_of(conv->window().dimensions(),
368 [](const WindowDimension& d) { return d.size() == 1; });
369 if (!is_reversed_filter &&
370 !(window_util::HasBaseDilation(conv->window()) &&
371 (reverse_filter->IsConstant() || is_1x1_filter))) {
372 VLOG(1) << "Can't match to backwards convolution. Either filter is not "
373 "kReverse, or it's not a base-dilated conv with a 1x1 or "
374 "constant filter.";
375 return no_match_result;
376 }
377
378 // Match padding and dilation of the forward convolution.
379 for (const WindowDimension& window_dim : conv->window().dimensions()) {
380 if (window_dim.stride() != 1) {
381 VLOG(1) << "Forward convolution's window "
382 << conv->window().ShortDebugString()
383 << " should have stride of 1.";
384 return no_match_result;
385 }
386 if (window_dim.window_dilation() != 1) {
387 VLOG(1) << "Forward convolution's window "
388 << conv->window().ShortDebugString()
389 << " should have no window dilation.";
390 return no_match_result;
391 }
392 if (window_dim.window_reversal()) {
393 VLOG(1) << "Window reversal field not supported";
394 return no_match_result;
395 }
396 }
397
398 const auto& input_spatial_dims = dnums.input_spatial_dimensions();
399 const auto& output_spatial_dims = dnums.output_spatial_dimensions();
400 CHECK_EQ(conv->window().dimensions().size(), input_spatial_dims.size());
401 CHECK_EQ(output_spatial_dims.size(), input_spatial_dims.size());
402
403 const Window& old_window = conv->window();
404 Window new_window = old_window;
405 for (size_t i = 0; i < input_spatial_dims.size(); ++i) {
406 // Restore backward convolution's padding config from the matched pattern.
407 // See the comment in tensorflow/core/kernels/conv_grad_ops.h for how we
408 // convert backward input convolution to a variant of forward convolution.
409 //
410 // The stride of the backward convolution
411 // = the base dilation factor of the forward convolution
412 auto dim = new_window.mutable_dimensions(i);
413 dim->set_stride(old_window.dimensions(i).base_dilation());
414 dim->set_base_dilation(1);
415
416 // The low padding = kernel_size - 1 - low padding on the gradients
417 // Make sure the low padding is not negative.
418 auto kernel_size = old_window.dimensions(i).size();
419 auto backward_padding_low =
420 kernel_size - 1 - old_window.dimensions(i).padding_low();
421 if (backward_padding_low < 0) {
422 LOG(WARNING)
423 << "The low padding of the backward convolution would be negative ("
424 << backward_padding_low
425 << "), which isn't supported by GpuConvPaddingLegalization "
426 "for now (b/32744257).";
427 return no_match_result;
428 }
429 dim->set_padding_low(backward_padding_low);
430
431 // Compute the range of the amount of padding on the right/bottom of the
432 // activations. XLA's convolution requires all patches to be within the
433 // padded base. This gives us flexiblity to choose the amount of high
434 // padding from a set of values without changing the result of the backward
435 // convolution. The minimum amount (min_padding_high) makes the last patch
436 // end at the border. The maximum amount (max_padding_high) equals
437 // min_padding_high+stride-1 -- max_padding_high+1 would cause the output
438 // size to change.
439 auto unpadded_input_size = conv->shape().dimensions(output_spatial_dims[i]);
440 auto output_size =
441 conv->operand(0)->shape().dimensions(input_spatial_dims[i]);
442 auto padded_input_size = kernel_size + dim->stride() * (output_size - 1);
443 auto total_pad_size = padded_input_size - unpadded_input_size;
444 auto min_padding_high = total_pad_size - backward_padding_low;
445 auto max_padding_high = min_padding_high + dim->stride() - 1;
446
447 if (backward_padding_low >= min_padding_high &&
448 backward_padding_low <= max_padding_high) {
449 // In the best case (most likely), if backward_padding_low is in the range
450 // of the amounts of valid high padding, we choose backward_padding_low
451 // because cuDNN supports even padding only.
452 dim->set_padding_high(backward_padding_low);
453 } else {
454 // Otherwise, we choose the amount that's closest to backward_padding_low,
455 // and GpuConvPaddingLegalization will later insert kSlice
456 // instructions to enforce even padding.
457 //
458 // For example, consider the backward convolution pattern
459 //
460 // ab xy
461 // | pad | reverse
462 // .a.b yx
463 // \ /
464 // ABC
465 //
466 // The amount of low padding on activations (in backward convolution) is
467 // backward_padding_low = kernel_size - 1 - forward_padding_low
468 // = 2 - 1 - 1 = 0
469 //
470 // The amount of padding high must be between 1 and 2, in order to make
471 // Conv(ABC, xy, stride=2) produce exactly 2 elements (ab). 0 is not in
472 // the range of [1,2], so we pick the closest valid amount of padding
473 // high, which is 1 in this case. Therefore, we fuse the above pattern to
474 //
475 // ABC = BackwardInputConv(ab, xy, stride=2, padding_high=1)
476 if (backward_padding_low < min_padding_high) {
477 dim->set_padding_high(min_padding_high);
478 } else {
479 dim->set_padding_high(max_padding_high);
480 }
481 }
482 // GpuConvPaddingLegalization doesn't handle backward input
483 // convolution with negative padding for now. So fall back to unfused
484 // convolution in case of negative padding. For example,
485 // ABCD = Conv(abc, reverse(xy), padding_high=2)
486 // could be fused to
487 // ABCD = BackwardInputConv(abc, xy, padding_low=1, padding_high=-1)
488 // with positive padding low but negative padding high.
489 if (dim->padding_high() < 0) {
490 LOG(WARNING) << "Fusing this pattern to backward convolution would cause "
491 "negative padding ("
492 << dim->padding_high()
493 << ") on right/bottom of the activations, which is not "
494 "supported by GpuConvPaddingLegalization (b/32744257). "
495 "Falling back to unfused convolution for instruction: "
496 << conv->ToString();
497 return no_match_result;
498 }
499 }
500
501 // OK, it's a match! Switch the input feature dimension with the output
502 // feature dimension. Also switch the output with the input. This is the way
503 // cuDNN expects it to be.
504 auto conv_dnums = conv->convolution_dimension_numbers();
505 dnums.set_kernel_input_feature_dimension(
506 conv_dnums.kernel_output_feature_dimension());
507 dnums.set_kernel_output_feature_dimension(
508 conv_dnums.kernel_input_feature_dimension());
509 for (int i = 0; i < input_spatial_dims.size(); ++i) {
510 dnums.set_input_spatial_dimensions(i,
511 conv_dnums.output_spatial_dimensions(i));
512 dnums.set_output_spatial_dimensions(i,
513 conv_dnums.input_spatial_dimensions(i));
514 }
515 dnums.set_input_feature_dimension(conv_dnums.output_feature_dimension());
516 dnums.set_input_batch_dimension(conv_dnums.output_batch_dimension());
517 dnums.set_output_feature_dimension(conv_dnums.input_feature_dimension());
518 dnums.set_output_batch_dimension(conv_dnums.input_batch_dimension());
519
520 // If we matched against a constant, we need to add a reverse op that can be
521 // subsumed by the cuDNN call. algebraic-simplifier will later remove any
522 // unnecessary reverses.
523 if (reverse_filter->opcode() != HloOpcode::kReverse &&
524 reverse_filter->IsConstant()) {
525 // Create a double-reverse, which is a nop.
526 HloComputation* c = conv->parent();
527 reverse_filter = c->AddInstruction(HloInstruction::CreateReverse(
528 reverse_filter->shape(), reverse_filter,
529 AsInt64Slice(dnums.kernel_spatial_dimensions())));
530 reverse_filter = c->AddInstruction(HloInstruction::CreateReverse(
531 reverse_filter->shape(), reverse_filter,
532 AsInt64Slice(dnums.kernel_spatial_dimensions())));
533 TF_CHECK_OK(conv->ReplaceOperandWith(/*operand_num=*/1, reverse_filter));
534 }
535
536 // Calculate the 'rhs' that goes into the backward input convolution.
537 HloInstruction* rhs = reverse_filter;
538 // One reverse is subsumed by the cuDNN call.
539 if (rhs->opcode() == HloOpcode::kReverse) {
540 rhs = rhs->mutable_operand(0);
541 }
542 if (conv->feature_group_count() == 1) {
543 return std::make_tuple(true, new_window, dnums, rhs);
544 }
545
546 // Handle grouped convolutions. Because we swapped the input feature dimension
547 // with the output feature dimension, we need to also reshape the kernel so
548 // that the 'feature_group_count' parameter still makes sense. The
549 // 'feature_group_count' parameter essentially specifies how often the
550 // 'kernel_input_feature_dimension' is repeated. So when we swap these
551 // dimensions, we need to divide the new 'kernel_input_feature_dimension' by
552 // 'feature_group_count' and multiply the new
553 // 'kernel_output_feature_dimension' by 'feature_group_count'.
554 int64 input_feature_dimension = dnums.kernel_input_feature_dimension();
555 int64 output_feature_dimension = dnums.kernel_output_feature_dimension();
556 // The following code assumes that input_feature_dimension and
557 // output_feature_dimension are adjacent.
558 if (std::abs(input_feature_dimension - output_feature_dimension) != 1) {
559 return no_match_result;
560 }
561
562 int64 input_features = rhs->shape().dimensions(input_feature_dimension);
563 int64 output_features = rhs->shape().dimensions(output_feature_dimension);
564
565 // Reshape [H, W, ..., in_depth, out_depth / G] -> [H, W, ..., G, in_depth/G,
566 // out_depth / G]
567 std::vector<int64> reshape_dims = SpanToVector(rhs->shape().dimensions());
568 auto num_groups = conv->feature_group_count();
569 CHECK_EQ(input_features % num_groups, 0)
570 << "Input feature count should be an exact multiple of feature group "
571 "count";
572 reshape_dims[input_feature_dimension] =
573 reshape_dims[input_feature_dimension] / num_groups;
574 reshape_dims.insert(reshape_dims.begin() + input_feature_dimension,
575 num_groups);
576
577 HloComputation* c = conv->parent();
578 rhs = c->AddInstruction(HloInstruction::CreateReshape(
579 ShapeUtil::MakeShape(rhs->shape().element_type(), reshape_dims), rhs));
580
581 // Transpose [H, W, ..., G, in_depth/G, out_depth / G] -> [H, W, ...,
582 // in_depth/G, G, out_depth / G]
583 std::vector<int64> transpose_dims(rhs->shape().dimensions_size());
584 std::iota(transpose_dims.begin(), transpose_dims.end(), 0);
585 transpose_dims.erase(transpose_dims.begin() + input_feature_dimension);
586 transpose_dims.insert(transpose_dims.begin() + output_feature_dimension,
587 input_feature_dimension);
588 std::vector<int64> transpose_reshape_dims =
589 SpanToVector(rhs->shape().dimensions());
590 transpose_reshape_dims.erase(transpose_reshape_dims.begin() +
591 input_feature_dimension);
592 transpose_reshape_dims.insert(
593 transpose_reshape_dims.begin() + output_feature_dimension, num_groups);
594 rhs = c->AddInstruction(HloInstruction::CreateTranspose(
595 ShapeUtil::MakeShape(rhs->shape().element_type(), transpose_reshape_dims),
596 rhs, transpose_dims));
597
598 // Reshape [H, W, ..., in_depth/G, G, out_depth / G] -> [H, W, ...,
599 // in_depth/G, out_depth]
600 Shape new_shape = rhs->shape();
601 new_shape.DeleteDimension(output_feature_dimension);
602 new_shape.set_dimensions(output_feature_dimension,
603 output_features * num_groups);
604 rhs = c->AddInstruction(HloInstruction::CreateReshape(new_shape, rhs));
605 return std::make_tuple(true, new_window, dnums, rhs);
606 }
607
GetDefaultBackendConfig()608 CudnnConvBackendConfig GetDefaultBackendConfig() {
609 CudnnConvBackendConfig config;
610 config.set_conv_result_scale(1);
611 return config;
612 }
613
614 // Helper function to create a custom_call instruction to replace the given
615 // conv instruction
CreateCustomCallHelper(HloInstruction * conv)616 static StatusOr<HloInstruction*> CreateCustomCallHelper(HloInstruction* conv) {
617 bool match;
618 Window window;
619 ConvolutionDimensionNumbers dnums;
620 HloInstruction* rhs;
621 HloInstruction* lhs;
622
623 std::tie(match, window, dnums, rhs) = MatchBackwardInput(conv);
624 if (match) {
625 return CreateGpuConv(kCudnnConvBackwardInputCallTarget, conv->shape(),
626 conv->mutable_operand(0), rhs, window, dnums,
627 conv->feature_group_count(), conv->metadata());
628 }
629
630 std::tie(match, window, dnums, lhs) = MatchBackwardFilter(conv);
631 if (match) {
632 return CreateGpuConv(kCudnnConvBackwardFilterCallTarget, conv->shape(), lhs,
633 conv->mutable_operand(1), window, dnums,
634 conv->batch_group_count(), conv->metadata());
635 }
636
637 // If all else fails, try a forward convolution.
638 if (CanImplementAsGpuForwardConv(conv)) {
639 if (primitive_util::IsIntegralType(
640 conv->operand(0)->shape().element_type())) {
641 // In addition to replacing a convolution instruction with
642 // a custom call, integer convolutions must have this pattern to match
643 // CuDNN semantics:
644 // conv<InputT=int32, ResultT=int32>(
645 // convert<int32>(int8_x), convert<int32>(int8_y))
646 // We transform it to:
647 // custom_call<int32>(int8_x, int8_y, target=cudnnConvolutionForward)
648 //
649 // We will error out, if the pattern is not found for integer
650 // convolution.
651 const auto is_int8_to_int32_cast =
652 [](const HloInstruction* instr) -> bool {
653 return (instr->opcode() == HloOpcode::kConvert &&
654 instr->operand(0)->shape().element_type() == S8 &&
655 instr->shape().element_type() == S32);
656 };
657 HloInstruction* input_convert = conv->mutable_operand(0);
658 HloInstruction* kernel_convert = conv->mutable_operand(1);
659 if (conv->shape().element_type() != S32 ||
660 !is_int8_to_int32_cast(input_convert) ||
661 !is_int8_to_int32_cast(kernel_convert)) {
662 return Unimplemented(
663 "Integer convolutions for CuDNN must have this pattern: "
664 "conv<InputT=int32, ResultT=int32>(convert<int32>(int8_x), "
665 "convert<int32>(int8_y))");
666 }
667 // Bypass the convert<int32> for both inputs.
668 TF_RETURN_IF_ERROR(conv->ReplaceOperandWithDifferentShape(
669 0, input_convert->mutable_operand(0)));
670 TF_RETURN_IF_ERROR(
671 conv->parent()->RemoveInstructionAndUnusedOperands(input_convert));
672 TF_RETURN_IF_ERROR(conv->ReplaceOperandWithDifferentShape(
673 1, kernel_convert->mutable_operand(0)));
674 TF_RETURN_IF_ERROR(
675 conv->parent()->RemoveInstructionAndUnusedOperands(kernel_convert));
676 }
677
678 if (conv->batch_group_count() > 1) {
679 conv = ConvertBatchGroupedToFeatureGroupedConvolution(conv);
680 }
681
682 return CreateGpuConv(kCudnnConvForwardCallTarget, conv->shape(),
683 conv->mutable_operand(0), conv->mutable_operand(1),
684 conv->window(), conv->convolution_dimension_numbers(),
685 conv->feature_group_count(), conv->metadata());
686 }
687
688 return nullptr;
689 }
690
691 // Tries to rewrite a single convolution into a call to cudnn/miopen.
RunOnInstruction(HloInstruction * conv)692 StatusOr<bool> RunOnInstruction(HloInstruction* conv) {
693 CHECK_EQ(conv->opcode(), HloOpcode::kConvolution);
694
695 TF_ASSIGN_OR_RETURN(HloInstruction * custom_call,
696 CreateCustomCallHelper(conv));
697 if (custom_call == nullptr) {
698 return false;
699 }
700
701 TF_RETURN_IF_ERROR(
702 custom_call->set_backend_config(GetDefaultBackendConfig()));
703
704 VLOG(1) << "Replacing convolution " << conv->ToString() << " with "
705 << custom_call->ToString();
706
707 // The CustomCall returns a tuple (conv_result, scratch_memory). Extract
708 // out the conv result and replace `conv` with it.
709 TF_RETURN_IF_ERROR(conv->parent()->ReplaceWithNewInstruction(
710 conv,
711 HloInstruction::CreateGetTupleElement(conv->shape(), custom_call, 0)));
712 return true;
713 }
714
715 // Rewrites the convolutions in the given computation into calls to
716 // cudnn/miopen.
717 // Returns true if it made any changes.
RunOnComputation(HloComputation * computation)718 StatusOr<bool> RunOnComputation(HloComputation* computation) {
719 std::vector<HloInstruction*> convs;
720 for (auto* hlo : computation->instructions()) {
721 if (hlo->opcode() == HloOpcode::kConvolution) {
722 convs.push_back(hlo);
723 }
724 }
725
726 bool changed = false;
727 for (HloInstruction* conv : convs) {
728 TF_ASSIGN_OR_RETURN(bool result, RunOnInstruction(conv));
729 changed |= result;
730 }
731 return changed;
732 }
733 } // namespace
734
Run(HloModule * module)735 StatusOr<bool> GpuConvRewriter::Run(HloModule* module) {
736 XLA_VLOG_LINES(2, "GpuConvRewriter::Run(), before:\n" + module->ToString());
737 bool changed = false;
738 for (HloComputation* computation : module->MakeNonfusionComputations()) {
739 TF_ASSIGN_OR_RETURN(bool result, RunOnComputation(computation));
740 changed |= result;
741 }
742 XLA_VLOG_LINES(2, "GpuConvRewriter::Run(), after:\n" + module->ToString());
743 return changed;
744 }
745
746 } // namespace gpu
747 } // namespace xla
748