• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/gpu_conv_rewriter.h"
17 
18 #include <cstdlib>
19 #include <numeric>
20 #include <vector>
21 
22 #include "tensorflow/compiler/xla/literal.h"
23 #include "tensorflow/compiler/xla/permutation_util.h"
24 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
25 #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
26 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
27 #include "tensorflow/compiler/xla/service/hlo_computation.h"
28 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
29 #include "tensorflow/compiler/xla/util.h"
30 #include "tensorflow/compiler/xla/window_util.h"
31 #include "tensorflow/compiler/xla/xla_data.pb.h"
32 #include "tensorflow/core/lib/core/status.h"
33 #include "tensorflow/core/platform/logging.h"
34 
35 namespace xla {
36 namespace gpu {
37 
38 namespace {
39 
CreateGpuConv(const char * call_target,const Shape & shape,HloInstruction * lhs,HloInstruction * rhs,const Window & window,const ConvolutionDimensionNumbers & dnums,int64 feature_group_count,const OpMetadata & metadata)40 HloInstruction* CreateGpuConv(const char* call_target, const Shape& shape,
41                               HloInstruction* lhs, HloInstruction* rhs,
42                               const Window& window,
43                               const ConvolutionDimensionNumbers& dnums,
44                               int64 feature_group_count,
45                               const OpMetadata& metadata) {
46   HloComputation* computation = lhs->parent();
47 
48   // This call returns a tuple of (conv_result, scratch_memory), where
49   // conv_result is the actual result of the convolution, and scratch_memory is
50   // temporary memory used by cudnn.
51   //
52   // At the moment, we don't know how much scratch memory this conv is going to
53   // use, so we put u8[0] in this place.  Later on another pass will choose
54   // which conv algorithm to use, and at that point we'll modify the shape of
55   // this second tuple element.
56   Shape call_shape =
57       ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U8, {0})});
58 
59   HloInstruction* custom_call = computation->AddInstruction(
60       HloInstruction::CreateCustomCall(call_shape, {lhs, rhs}, call_target));
61   custom_call->set_window(window);
62   custom_call->set_convolution_dimension_numbers(dnums);
63   custom_call->set_feature_group_count(feature_group_count);
64   custom_call->set_metadata(metadata);
65   return custom_call;
66 }
67 
ConvertBatchGroupedToFeatureGroupedConvolution(HloInstruction * conv)68 HloInstruction* ConvertBatchGroupedToFeatureGroupedConvolution(
69     HloInstruction* conv) {
70   CHECK_EQ(conv->feature_group_count(), 1);
71   int64 num_groups = conv->batch_group_count();
72   auto dim_numbers = conv->convolution_dimension_numbers();
73   auto lhs = conv->mutable_operand(0);
74   auto rhs = conv->mutable_operand(1);
75 
76   int64 input_batch_dimension = dim_numbers.input_batch_dimension();
77 
78   Shape output_shape = conv->shape();
79   int64 input_feature_dimension = dim_numbers.input_feature_dimension();
80   int64 input_feature = lhs->shape().dimensions(input_feature_dimension);
81 
82   HloComputation* computation = lhs->parent();
83   auto add = [&](std::unique_ptr<HloInstruction> inst) {
84     return computation->AddInstruction(std::move(inst));
85   };
86   // Reshape batch_dim N -> [G, N/G]
87   std::vector<int64> reshape_dims = SpanToVector(lhs->shape().dimensions());
88   reshape_dims[input_batch_dimension] =
89       reshape_dims[input_batch_dimension] / num_groups;
90   reshape_dims.insert(reshape_dims.begin() + input_batch_dimension, num_groups);
91   lhs = add(HloInstruction::CreateReshape(
92       ShapeUtil::MakeShape(lhs->shape().element_type(), reshape_dims), lhs));
93 
94   // Transpose G to the axis before C, For eg: [G, N/G, H, W, C ] -> [N/G, H,
95   // W, G, C]
96   std::vector<int64> transpose_dims(lhs->shape().dimensions_size());
97   std::iota(transpose_dims.begin(), transpose_dims.end(), 0);
98   transpose_dims.erase(transpose_dims.begin() + input_batch_dimension);
99   transpose_dims.insert(transpose_dims.begin() + input_feature_dimension,
100                         input_batch_dimension);
101   std::vector<int64> transpose_reshape_dims =
102       ComposePermutations(lhs->shape().dimensions(), transpose_dims);
103   lhs = add(HloInstruction::CreateTranspose(
104       ShapeUtil::MakeShape(lhs->shape().element_type(), transpose_reshape_dims),
105       lhs, transpose_dims));
106 
107   // Merge [G,C] -> [C*G]
108   Shape new_shape = lhs->shape();
109   new_shape.DeleteDimension(input_feature_dimension);
110   new_shape.set_dimensions(input_feature_dimension, input_feature * num_groups);
111   lhs = add(HloInstruction::CreateReshape(new_shape, lhs));
112 
113   std::vector<HloInstruction*> new_operands = {lhs, rhs};
114   auto new_conv = conv->CloneWithNewOperands(output_shape, new_operands);
115   new_conv->set_feature_group_count(num_groups);
116   new_conv->set_batch_group_count(1);
117   new_conv->set_convolution_dimension_numbers(dim_numbers);
118   return computation->AddInstruction(std::move(new_conv));
119 }
120 
CanImplementAsGpuForwardConv(HloInstruction * conv)121 bool CanImplementAsGpuForwardConv(HloInstruction* conv) {
122   const ConvolutionDimensionNumbers& dnums =
123       conv->convolution_dimension_numbers();
124   if (dnums.input_spatial_dimensions_size() > 3) {
125     return false;
126   }
127 
128   // CuDNN does not accept zero-element arguments
129   if (ShapeUtil::IsZeroElementArray(conv->operand(0)->shape()) ||
130       ShapeUtil::IsZeroElementArray(conv->operand(1)->shape())) {
131     return false;
132   }
133 
134   // CuDNN can perform either cross correlation (no reversal),
135   // or convolution (all dimensions reversed).
136   if (dnums.input_spatial_dimensions_size() == 2
137           ? !window_util::AllOrNoneReversed(conv->window())
138           : window_util::HasWindowReversal(conv->window())) {
139     return false;
140   }
141   return true;
142 }
143 
144 // Try to match a backward filter pattern that contains "conv".
145 // Precondition: "conv" is a kConvolution.
146 std::tuple<bool, Window, ConvolutionDimensionNumbers, HloInstruction*>
MatchBackwardFilter(HloInstruction * conv)147 MatchBackwardFilter(HloInstruction* conv) {
148   VLOG(2) << "Trying to match convolution backward filter.";
149   const auto no_match_result =
150       std::make_tuple(false, Window(), ConvolutionDimensionNumbers(), nullptr);
151 
152   if (conv->feature_group_count() > 1) {
153     VLOG(1) << conv->ToString()
154             << " is a forward convolution. All grouped backward filters are "
155                "mapped to batch grouped convolutions in tf2xla bridge. Hence "
156                "backward filter "
157                "convolutions cannot have feature groups greater than 1 at this "
158                "point. No need to fold to backward filter.";
159     return no_match_result;
160   }
161 
162   // Step 1: match the instruction pattern without considering the paddings and
163   // dimension numbers just yet. We may need some generic pattern matcher
164   // similar to third_party/llvm/llvm/include/llvm/IR/PatternMatch.h
165   //
166   // Backward filter convolution is implemented in XLA as the forward
167   // convolution of padded activations and dilated gradients. Padding on
168   // activations and dilation on gradients are specified in the "window" field
169   // of the forward convolution.
170   //
171   //        activations  gradients
172   //              \         /
173   //               v       v
174   //              Convolution
175   //                 conv
176   CHECK_EQ(HloOpcode::kConvolution, conv->opcode());
177 
178   // Step 2: match paddings and dimension numbers of the forward convolution.
179   const ConvolutionDimensionNumbers& conv_dnums =
180       conv->convolution_dimension_numbers();
181   auto input_batch_dim = conv_dnums.input_batch_dimension();
182   auto input_feature_dim = conv_dnums.input_feature_dimension();
183   auto input_spatial_dims = conv_dnums.input_spatial_dimensions();
184   auto kernel_input_feature_dim = conv_dnums.kernel_input_feature_dimension();
185   auto kernel_output_feature_dim = conv_dnums.kernel_output_feature_dimension();
186   auto kernel_spatial_dims = conv_dnums.kernel_spatial_dimensions();
187   auto output_batch_dim = conv_dnums.output_batch_dimension();
188   auto output_feature_dim = conv_dnums.output_feature_dimension();
189   auto output_spatial_dims = conv_dnums.output_spatial_dimensions();
190   for (const WindowDimension& window_dim : conv->window().dimensions()) {
191     if (window_dim.stride() != 1) {
192       VLOG(1) << "Forward convolution's window "
193               << conv->window().ShortDebugString()
194               << " should have stride of 1.";
195       return no_match_result;
196     }
197     if (window_dim.base_dilation() != 1) {
198       VLOG(1) << "Forward convolution's window "
199               << conv->window().ShortDebugString()
200               << " should have no base (LHS) dilation.";
201       return no_match_result;
202     }
203     if (window_dim.padding_low() < 0) {
204       VLOG(1) << "Padding low should be non-negative.";
205       return no_match_result;
206     }
207     if (window_dim.window_reversal()) {
208       VLOG(1) << "Window reversal field not supported";
209       return no_match_result;
210     }
211     // Padding high will be checked in Step 3.
212   }
213   // Mathematically, there is no difference between convolution forward vs
214   // backward filter. A backward filter:
215   //   [N, O, H+h-1, W+w-1] x [N, C, H, W] -> [O, C, h, w]
216   // Can be treated as a forward convolution with `N` treated as the new
217   // contracting (feature) dimension, `O` treated as the new batch dimension,
218   // and `C` treated as the new output feature dimension. The only difference is
219   // layouts and performance.
220   //
221   // Since there is no way to precisely tell whether we want a foward conv or
222   // backward filter conv, we have to rely on heuristics. Empirically forward
223   // convolutions have very small kernel dimensions, while in the backward pass
224   // "kernel dimensions" are large. If kernel dimensions are smaller than the
225   // output dimensions, return foward conv; otherwise proceed with backward
226   // filter conv.
227   if ((kernel_spatial_dims.empty() ||
228        conv->operand(1)->shape().dimensions(kernel_spatial_dims[0]) <=
229            conv->shape().dimensions(output_spatial_dims[0])) &&
230       !window_util::HasWindowDilation(conv->window())) {
231     VLOG(1) << conv->ToString()
232             << " is a regular forward convolution. No need "
233                "to fold it to a backward filter convolution....";
234     return no_match_result;
235   }
236 
237   // Step 3: fuse the matched HLOs into a backward convolution instruction.
238   //
239   // Compute the window of the backward convolution.
240   Window backward_conv_window;
241   for (int i = 0; i < input_spatial_dims.size(); ++i) {
242     WindowDimension* dim = backward_conv_window.add_dimensions();
243     // The window size of the backward convolution equals the output size of the
244     // forward convolution.
245     int64 filter_size = conv->shape().dimensions(output_spatial_dims[i]);
246     dim->set_size(filter_size);
247     // The window stride equals the window dilation of the forward convolution.
248     dim->set_stride(conv->window().dimensions(i).window_dilation());
249     // The window's low padding is the same as the low padding of the
250     // activations.
251     dim->set_padding_low(conv->window().dimensions(i).padding_low());
252     dim->set_base_dilation(1);
253     dim->set_window_dilation(1);
254 
255     int64 input_size =
256         conv->operand(0)->shape().dimensions(input_spatial_dims[i]);
257     int64 output_size = conv->window().dimensions(i).size();
258     // Compute the range of the amount of valid high padding. We first compute
259     // min_padding_high, the amount of padding on the right/bottom to ensure the
260     // last patch ends at the border, i.e.,
261     //
262     //   input_size + dim->padding_low() + min_padding_high
263     //     = (output_size - 1) * stride + filter_size
264     //
265     // Because convolution ignores trailing incomplete windows, any amount of
266     // padding high from min_padding_high to min_padding_high+stride-1
267     // (max_padding_high) has the same effect.
268     int64 padded_input_size = filter_size + (output_size - 1) * dim->stride();
269     int64 min_padding_high =
270         padded_input_size - input_size - dim->padding_low();
271     int64 max_padding_high = min_padding_high + dim->stride() - 1;
272     CHECK_GE(dim->padding_low(), 0);
273     // In practice, since cuDNN convolution only supports even padding, we make
274     // the amount of high padding the same as the amount of low padding as long
275     // as it is between min_padding_high and max_padding_high. If it is not in
276     // that range, we pick the one that's closest to dim->padding_low() and let
277     // GpuConvPaddingLegalization canonicalize the resultant backward
278     // convolution later. Picking the closest one minimizes the cost of the kPad
279     // instruction to be inserted by GpuConvPaddingLegalization.
280     if (dim->padding_low() >= min_padding_high &&
281         dim->padding_low() <= max_padding_high) {
282       dim->set_padding_high(dim->padding_low());
283     } else {
284       if (dim->padding_low() < min_padding_high) {
285         dim->set_padding_high(min_padding_high);
286       } else {
287         dim->set_padding_high(max_padding_high);
288       }
289     }
290     if (dim->padding_high() < 0) {
291       LOG(WARNING)
292           << "Fusing this pattern to backward filter convolution would cause "
293              "negative padding ("
294           << dim->padding_high()
295           << ") on right/bottom of the weight gradients, which is not "
296              "supported by GpuConvPaddingLegalization (b/32744257). "
297              "Falling back to "
298              "unfused convolution for instruction: "
299           << conv->ToString();
300       return no_match_result;
301     }
302   }
303 
304   // Restore the dimension numbers of the backward convolution from the forward
305   // convolution. The two activation dimensions are reversed (batch and
306   // feature).
307   ConvolutionDimensionNumbers backward_conv_dnums;
308   backward_conv_dnums.set_input_batch_dimension(input_feature_dim);
309   backward_conv_dnums.set_input_feature_dimension(input_batch_dim);
310   for (int i = 0; i < input_spatial_dims.size(); ++i) {
311     backward_conv_dnums.add_input_spatial_dimensions(input_spatial_dims[i]);
312   }
313   backward_conv_dnums.set_output_batch_dimension(kernel_input_feature_dim);
314   backward_conv_dnums.set_output_feature_dimension(kernel_output_feature_dim);
315   for (int i = 0; i < kernel_spatial_dims.size(); ++i) {
316     backward_conv_dnums.add_output_spatial_dimensions(kernel_spatial_dims[i]);
317   }
318   // The dimension numbering of the output of the forward convolution (before
319   // transposition) is the same as that of the activations (according to the
320   // semantics of kConvolution). The batch dimension of the activations should
321   // be treated as the input feature dimension, and the feature dimension should
322   // be treated as the output feature.
323   backward_conv_dnums.set_kernel_input_feature_dimension(output_batch_dim);
324   backward_conv_dnums.set_kernel_output_feature_dimension(output_feature_dim);
325   for (int i = 0; i < output_spatial_dims.size(); ++i) {
326     backward_conv_dnums.add_kernel_spatial_dimensions(output_spatial_dims[i]);
327   }
328 
329   HloInstruction* lhs = conv->mutable_operand(0);
330   return std::make_tuple(true, backward_conv_window, backward_conv_dnums, lhs);
331 }
332 
333 // Try to match a backward input pattern that contains "conv".
334 // Precondition: "conv" is a kConvolution.
335 std::tuple<bool, Window, ConvolutionDimensionNumbers, HloInstruction*>
MatchBackwardInput(HloInstruction * conv)336 MatchBackwardInput(HloInstruction* conv) {
337   VLOG(2) << "Trying to match convolution backward input.";
338   const auto no_match_result =
339       std::make_tuple(false, Window(), ConvolutionDimensionNumbers(), nullptr);
340 
341   // Match instruction pattern.
342   CHECK_EQ(HloOpcode::kConvolution, conv->opcode());
343   HloInstruction* reverse_filter = conv->mutable_operand(1);
344   ConvolutionDimensionNumbers dnums = conv->convolution_dimension_numbers();
345 
346   // We pattern-match to a backwards input conv if:
347   //
348   //  - all spatial dims of the filter are reversed
349   //
350   // OR
351   //
352   //  - filter is 1x1 or a constant AND
353   //  - conv has base dilation (otherwise this is just a regular forward conv).
354   //
355   // The final criterion above is just for canonicalization; cudnn seems to run
356   // just as fast if we canonicalize 1x1/constant filters without base dilation
357   // to forward or backward convs.  We canonicalize to forward conv because (a)
358   // it's more natural (constant filters usually show up when doing inference,
359   // and having backwards convolutions in inference graphs would be weird), and
360   // (b) cudnn has special fusions for forward conv plus bias and activation,
361   // and we want to pattern-match to that after running this pass.
362   bool is_reversed_filter =
363       reverse_filter->opcode() == HloOpcode::kReverse &&
364       absl::c_is_permutation(dnums.kernel_spatial_dimensions(),
365                              reverse_filter->dimensions());
366   bool is_1x1_filter =
367       absl::c_all_of(conv->window().dimensions(),
368                      [](const WindowDimension& d) { return d.size() == 1; });
369   if (!is_reversed_filter &&
370       !(window_util::HasBaseDilation(conv->window()) &&
371         (reverse_filter->IsConstant() || is_1x1_filter))) {
372     VLOG(1) << "Can't match to backwards convolution. Either filter is not "
373                "kReverse, or it's not a base-dilated conv with a 1x1 or "
374                "constant filter.";
375     return no_match_result;
376   }
377 
378   // Match padding and dilation of the forward convolution.
379   for (const WindowDimension& window_dim : conv->window().dimensions()) {
380     if (window_dim.stride() != 1) {
381       VLOG(1) << "Forward convolution's window "
382               << conv->window().ShortDebugString()
383               << " should have stride of 1.";
384       return no_match_result;
385     }
386     if (window_dim.window_dilation() != 1) {
387       VLOG(1) << "Forward convolution's window "
388               << conv->window().ShortDebugString()
389               << " should have no window dilation.";
390       return no_match_result;
391     }
392     if (window_dim.window_reversal()) {
393       VLOG(1) << "Window reversal field not supported";
394       return no_match_result;
395     }
396   }
397 
398   const auto& input_spatial_dims = dnums.input_spatial_dimensions();
399   const auto& output_spatial_dims = dnums.output_spatial_dimensions();
400   CHECK_EQ(conv->window().dimensions().size(), input_spatial_dims.size());
401   CHECK_EQ(output_spatial_dims.size(), input_spatial_dims.size());
402 
403   const Window& old_window = conv->window();
404   Window new_window = old_window;
405   for (size_t i = 0; i < input_spatial_dims.size(); ++i) {
406     // Restore backward convolution's padding config from the matched pattern.
407     // See the comment in tensorflow/core/kernels/conv_grad_ops.h for how we
408     // convert backward input convolution to a variant of forward convolution.
409     //
410     // The stride of the backward convolution
411     // = the base dilation factor of the forward convolution
412     auto dim = new_window.mutable_dimensions(i);
413     dim->set_stride(old_window.dimensions(i).base_dilation());
414     dim->set_base_dilation(1);
415 
416     // The low padding = kernel_size - 1 - low padding on the gradients
417     // Make sure the low padding is not negative.
418     auto kernel_size = old_window.dimensions(i).size();
419     auto backward_padding_low =
420         kernel_size - 1 - old_window.dimensions(i).padding_low();
421     if (backward_padding_low < 0) {
422       LOG(WARNING)
423           << "The low padding of the backward convolution would be negative ("
424           << backward_padding_low
425           << "), which isn't supported by GpuConvPaddingLegalization "
426              "for now (b/32744257).";
427       return no_match_result;
428     }
429     dim->set_padding_low(backward_padding_low);
430 
431     // Compute the range of the amount of padding on the right/bottom of the
432     // activations. XLA's convolution requires all patches to be within the
433     // padded base. This gives us flexiblity to choose the amount of high
434     // padding from a set of values without changing the result of the backward
435     // convolution. The minimum amount (min_padding_high) makes the last patch
436     // end at the border. The maximum amount (max_padding_high) equals
437     // min_padding_high+stride-1 -- max_padding_high+1 would cause the output
438     // size to change.
439     auto unpadded_input_size = conv->shape().dimensions(output_spatial_dims[i]);
440     auto output_size =
441         conv->operand(0)->shape().dimensions(input_spatial_dims[i]);
442     auto padded_input_size = kernel_size + dim->stride() * (output_size - 1);
443     auto total_pad_size = padded_input_size - unpadded_input_size;
444     auto min_padding_high = total_pad_size - backward_padding_low;
445     auto max_padding_high = min_padding_high + dim->stride() - 1;
446 
447     if (backward_padding_low >= min_padding_high &&
448         backward_padding_low <= max_padding_high) {
449       // In the best case (most likely), if backward_padding_low is in the range
450       // of the amounts of valid high padding, we choose backward_padding_low
451       // because cuDNN supports even padding only.
452       dim->set_padding_high(backward_padding_low);
453     } else {
454       // Otherwise, we choose the amount that's closest to backward_padding_low,
455       // and GpuConvPaddingLegalization will later insert kSlice
456       // instructions to enforce even padding.
457       //
458       // For example, consider the backward convolution pattern
459       //
460       //   ab     xy
461       //   | pad  | reverse
462       //  .a.b    yx
463       //     \   /
464       //      ABC
465       //
466       // The amount of low padding on activations (in backward convolution) is
467       //   backward_padding_low = kernel_size - 1 - forward_padding_low
468       //                        = 2 - 1 - 1 = 0
469       //
470       // The amount of padding high must be between 1 and 2, in order to make
471       // Conv(ABC, xy, stride=2) produce exactly 2 elements (ab). 0 is not in
472       // the range of [1,2], so we pick the closest valid amount of padding
473       // high, which is 1 in this case. Therefore, we fuse the above pattern to
474       //
475       //   ABC = BackwardInputConv(ab, xy, stride=2, padding_high=1)
476       if (backward_padding_low < min_padding_high) {
477         dim->set_padding_high(min_padding_high);
478       } else {
479         dim->set_padding_high(max_padding_high);
480       }
481     }
482     // GpuConvPaddingLegalization doesn't handle backward input
483     // convolution with negative padding for now. So fall back to unfused
484     // convolution in case of negative padding. For example,
485     //   ABCD = Conv(abc, reverse(xy), padding_high=2)
486     // could be fused to
487     //   ABCD = BackwardInputConv(abc, xy, padding_low=1, padding_high=-1)
488     // with positive padding low but negative padding high.
489     if (dim->padding_high() < 0) {
490       LOG(WARNING) << "Fusing this pattern to backward convolution would cause "
491                       "negative padding ("
492                    << dim->padding_high()
493                    << ") on right/bottom of the activations, which is not "
494                       "supported by GpuConvPaddingLegalization (b/32744257). "
495                       "Falling back to unfused convolution for instruction: "
496                    << conv->ToString();
497       return no_match_result;
498     }
499   }
500 
501   // OK, it's a match! Switch the input feature dimension with the output
502   // feature dimension. Also switch the output with the input. This is the way
503   // cuDNN expects it to be.
504   auto conv_dnums = conv->convolution_dimension_numbers();
505   dnums.set_kernel_input_feature_dimension(
506       conv_dnums.kernel_output_feature_dimension());
507   dnums.set_kernel_output_feature_dimension(
508       conv_dnums.kernel_input_feature_dimension());
509   for (int i = 0; i < input_spatial_dims.size(); ++i) {
510     dnums.set_input_spatial_dimensions(i,
511                                        conv_dnums.output_spatial_dimensions(i));
512     dnums.set_output_spatial_dimensions(i,
513                                         conv_dnums.input_spatial_dimensions(i));
514   }
515   dnums.set_input_feature_dimension(conv_dnums.output_feature_dimension());
516   dnums.set_input_batch_dimension(conv_dnums.output_batch_dimension());
517   dnums.set_output_feature_dimension(conv_dnums.input_feature_dimension());
518   dnums.set_output_batch_dimension(conv_dnums.input_batch_dimension());
519 
520   // If we matched against a constant, we need to add a reverse op that can be
521   // subsumed by the cuDNN call. algebraic-simplifier will later remove any
522   // unnecessary reverses.
523   if (reverse_filter->opcode() != HloOpcode::kReverse &&
524       reverse_filter->IsConstant()) {
525     // Create a double-reverse, which is a nop.
526     HloComputation* c = conv->parent();
527     reverse_filter = c->AddInstruction(HloInstruction::CreateReverse(
528         reverse_filter->shape(), reverse_filter,
529         AsInt64Slice(dnums.kernel_spatial_dimensions())));
530     reverse_filter = c->AddInstruction(HloInstruction::CreateReverse(
531         reverse_filter->shape(), reverse_filter,
532         AsInt64Slice(dnums.kernel_spatial_dimensions())));
533     TF_CHECK_OK(conv->ReplaceOperandWith(/*operand_num=*/1, reverse_filter));
534   }
535 
536   // Calculate the 'rhs' that goes into the backward input convolution.
537   HloInstruction* rhs = reverse_filter;
538   // One reverse is subsumed by the cuDNN call.
539   if (rhs->opcode() == HloOpcode::kReverse) {
540     rhs = rhs->mutable_operand(0);
541   }
542   if (conv->feature_group_count() == 1) {
543     return std::make_tuple(true, new_window, dnums, rhs);
544   }
545 
546   // Handle grouped convolutions. Because we swapped the input feature dimension
547   // with the output feature dimension, we need to also reshape the kernel so
548   // that the 'feature_group_count' parameter still makes sense. The
549   // 'feature_group_count' parameter essentially specifies how often the
550   // 'kernel_input_feature_dimension' is repeated. So when we swap these
551   // dimensions, we need to divide the new 'kernel_input_feature_dimension' by
552   // 'feature_group_count' and multiply the new
553   // 'kernel_output_feature_dimension' by 'feature_group_count'.
554   int64 input_feature_dimension = dnums.kernel_input_feature_dimension();
555   int64 output_feature_dimension = dnums.kernel_output_feature_dimension();
556   // The following code assumes that input_feature_dimension and
557   // output_feature_dimension are adjacent.
558   if (std::abs(input_feature_dimension - output_feature_dimension) != 1) {
559     return no_match_result;
560   }
561 
562   int64 input_features = rhs->shape().dimensions(input_feature_dimension);
563   int64 output_features = rhs->shape().dimensions(output_feature_dimension);
564 
565   // Reshape [H, W, ..., in_depth, out_depth / G] -> [H, W, ..., G, in_depth/G,
566   // out_depth / G]
567   std::vector<int64> reshape_dims = SpanToVector(rhs->shape().dimensions());
568   auto num_groups = conv->feature_group_count();
569   CHECK_EQ(input_features % num_groups, 0)
570       << "Input feature count should be an exact multiple of feature group "
571          "count";
572   reshape_dims[input_feature_dimension] =
573       reshape_dims[input_feature_dimension] / num_groups;
574   reshape_dims.insert(reshape_dims.begin() + input_feature_dimension,
575                       num_groups);
576 
577   HloComputation* c = conv->parent();
578   rhs = c->AddInstruction(HloInstruction::CreateReshape(
579       ShapeUtil::MakeShape(rhs->shape().element_type(), reshape_dims), rhs));
580 
581   // Transpose [H, W, ..., G, in_depth/G, out_depth / G] -> [H, W, ...,
582   // in_depth/G, G, out_depth / G]
583   std::vector<int64> transpose_dims(rhs->shape().dimensions_size());
584   std::iota(transpose_dims.begin(), transpose_dims.end(), 0);
585   transpose_dims.erase(transpose_dims.begin() + input_feature_dimension);
586   transpose_dims.insert(transpose_dims.begin() + output_feature_dimension,
587                         input_feature_dimension);
588   std::vector<int64> transpose_reshape_dims =
589       SpanToVector(rhs->shape().dimensions());
590   transpose_reshape_dims.erase(transpose_reshape_dims.begin() +
591                                input_feature_dimension);
592   transpose_reshape_dims.insert(
593       transpose_reshape_dims.begin() + output_feature_dimension, num_groups);
594   rhs = c->AddInstruction(HloInstruction::CreateTranspose(
595       ShapeUtil::MakeShape(rhs->shape().element_type(), transpose_reshape_dims),
596       rhs, transpose_dims));
597 
598   // Reshape [H, W, ..., in_depth/G, G, out_depth / G] -> [H, W, ...,
599   // in_depth/G, out_depth]
600   Shape new_shape = rhs->shape();
601   new_shape.DeleteDimension(output_feature_dimension);
602   new_shape.set_dimensions(output_feature_dimension,
603                            output_features * num_groups);
604   rhs = c->AddInstruction(HloInstruction::CreateReshape(new_shape, rhs));
605   return std::make_tuple(true, new_window, dnums, rhs);
606 }
607 
GetDefaultBackendConfig()608 CudnnConvBackendConfig GetDefaultBackendConfig() {
609   CudnnConvBackendConfig config;
610   config.set_conv_result_scale(1);
611   return config;
612 }
613 
614 // Helper function to create a custom_call instruction to replace the given
615 // conv instruction
CreateCustomCallHelper(HloInstruction * conv)616 static StatusOr<HloInstruction*> CreateCustomCallHelper(HloInstruction* conv) {
617   bool match;
618   Window window;
619   ConvolutionDimensionNumbers dnums;
620   HloInstruction* rhs;
621   HloInstruction* lhs;
622 
623   std::tie(match, window, dnums, rhs) = MatchBackwardInput(conv);
624   if (match) {
625     return CreateGpuConv(kCudnnConvBackwardInputCallTarget, conv->shape(),
626                          conv->mutable_operand(0), rhs, window, dnums,
627                          conv->feature_group_count(), conv->metadata());
628   }
629 
630   std::tie(match, window, dnums, lhs) = MatchBackwardFilter(conv);
631   if (match) {
632     return CreateGpuConv(kCudnnConvBackwardFilterCallTarget, conv->shape(), lhs,
633                          conv->mutable_operand(1), window, dnums,
634                          conv->batch_group_count(), conv->metadata());
635   }
636 
637   // If all else fails, try a forward convolution.
638   if (CanImplementAsGpuForwardConv(conv)) {
639     if (primitive_util::IsIntegralType(
640             conv->operand(0)->shape().element_type())) {
641       // In addition to replacing a convolution instruction with
642       // a custom call, integer convolutions must have this pattern to match
643       // CuDNN semantics:
644       // conv<InputT=int32, ResultT=int32>(
645       //   convert<int32>(int8_x), convert<int32>(int8_y))
646       // We transform it to:
647       // custom_call<int32>(int8_x, int8_y, target=cudnnConvolutionForward)
648       //
649       // We will error out, if the pattern is not found for integer
650       // convolution.
651       const auto is_int8_to_int32_cast =
652           [](const HloInstruction* instr) -> bool {
653         return (instr->opcode() == HloOpcode::kConvert &&
654                 instr->operand(0)->shape().element_type() == S8 &&
655                 instr->shape().element_type() == S32);
656       };
657       HloInstruction* input_convert = conv->mutable_operand(0);
658       HloInstruction* kernel_convert = conv->mutable_operand(1);
659       if (conv->shape().element_type() != S32 ||
660           !is_int8_to_int32_cast(input_convert) ||
661           !is_int8_to_int32_cast(kernel_convert)) {
662         return Unimplemented(
663             "Integer convolutions for CuDNN must have this pattern: "
664             "conv<InputT=int32, ResultT=int32>(convert<int32>(int8_x), "
665             "convert<int32>(int8_y))");
666       }
667       // Bypass the convert<int32> for both inputs.
668       TF_RETURN_IF_ERROR(conv->ReplaceOperandWithDifferentShape(
669           0, input_convert->mutable_operand(0)));
670       TF_RETURN_IF_ERROR(
671           conv->parent()->RemoveInstructionAndUnusedOperands(input_convert));
672       TF_RETURN_IF_ERROR(conv->ReplaceOperandWithDifferentShape(
673           1, kernel_convert->mutable_operand(0)));
674       TF_RETURN_IF_ERROR(
675           conv->parent()->RemoveInstructionAndUnusedOperands(kernel_convert));
676     }
677 
678     if (conv->batch_group_count() > 1) {
679       conv = ConvertBatchGroupedToFeatureGroupedConvolution(conv);
680     }
681 
682     return CreateGpuConv(kCudnnConvForwardCallTarget, conv->shape(),
683                          conv->mutable_operand(0), conv->mutable_operand(1),
684                          conv->window(), conv->convolution_dimension_numbers(),
685                          conv->feature_group_count(), conv->metadata());
686   }
687 
688   return nullptr;
689 }
690 
691 // Tries to rewrite a single convolution into a call to cudnn/miopen.
RunOnInstruction(HloInstruction * conv)692 StatusOr<bool> RunOnInstruction(HloInstruction* conv) {
693   CHECK_EQ(conv->opcode(), HloOpcode::kConvolution);
694 
695   TF_ASSIGN_OR_RETURN(HloInstruction * custom_call,
696                       CreateCustomCallHelper(conv));
697   if (custom_call == nullptr) {
698     return false;
699   }
700 
701   TF_RETURN_IF_ERROR(
702       custom_call->set_backend_config(GetDefaultBackendConfig()));
703 
704   VLOG(1) << "Replacing convolution " << conv->ToString() << " with "
705           << custom_call->ToString();
706 
707   // The CustomCall returns a tuple (conv_result, scratch_memory).  Extract
708   // out the conv result and replace `conv` with it.
709   TF_RETURN_IF_ERROR(conv->parent()->ReplaceWithNewInstruction(
710       conv,
711       HloInstruction::CreateGetTupleElement(conv->shape(), custom_call, 0)));
712   return true;
713 }
714 
715 // Rewrites the convolutions in the given computation into calls to
716 // cudnn/miopen.
717 // Returns true if it made any changes.
RunOnComputation(HloComputation * computation)718 StatusOr<bool> RunOnComputation(HloComputation* computation) {
719   std::vector<HloInstruction*> convs;
720   for (auto* hlo : computation->instructions()) {
721     if (hlo->opcode() == HloOpcode::kConvolution) {
722       convs.push_back(hlo);
723     }
724   }
725 
726   bool changed = false;
727   for (HloInstruction* conv : convs) {
728     TF_ASSIGN_OR_RETURN(bool result, RunOnInstruction(conv));
729     changed |= result;
730   }
731   return changed;
732 }
733 }  // namespace
734 
Run(HloModule * module)735 StatusOr<bool> GpuConvRewriter::Run(HloModule* module) {
736   XLA_VLOG_LINES(2, "GpuConvRewriter::Run(), before:\n" + module->ToString());
737   bool changed = false;
738   for (HloComputation* computation : module->MakeNonfusionComputations()) {
739     TF_ASSIGN_OR_RETURN(bool result, RunOnComputation(computation));
740     changed |= result;
741   }
742   XLA_VLOG_LINES(2, "GpuConvRewriter::Run(), after:\n" + module->ToString());
743   return changed;
744 }
745 
746 }  // namespace gpu
747 }  // namespace xla
748