• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_SIMPLIFY_PADDING_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_SIMPLIFY_PADDING_H_
18 
19 #include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
20 #include "tensorflow/compiler/xla/statusor.h"
21 
22 namespace xla::gpu {
23 
24 // Simplifies or eliminates padding introduced by CudnnPadForConvolutions and
25 // CudnnVectorizeConvolutions.
26 //
27 // CudnnVectorizeConvolutions will generate code that does the following.
28 //  - pad input and output features to a multiple of 32 (or 4),
29 //  - reshape input from [N,C,H,W] to [N,C/32,H,W,32] and reshape kernel from
30 //    [I,O,H,W] to [I/32,32,O,H,W],
31 //  - run the conv,
32 //  - reshape output from [N,C/32,H,W,32] to [N,C,H,W], and finally
33 //  - slice off the padding on the C channel.
34 //
35 // But if this is followed by another convolution (very common), then the slice
36 // is immediately followed by another pad. This may be redundant; we know that
37 // the trailing channels sliced off from the first conv are 0.
38 //
39 // Ideally we can eliminate the whole reshape+slice+pad+reshape sequence between
40 // the two convolutions.
41 //
42 // Specifically, this pass tries to merge the slice at the end of the sequence
43 // above into the pad from the next convolution (when we can prove that the
44 // sliced-off elements are all 0). We then rely on algsimp to remove the pad if
45 // it's a nop and then to merge and eliminate the remaining reshapes.
46 class CudnnSimplifyPadding : public HloModulePass {
47  public:
48   CudnnSimplifyPadding() = default;
49 
name()50   absl::string_view name() const override { return "cudnn_simplify_padding"; }
51 
52   using HloPassInterface::Run;
53   StatusOr<bool> Run(
54       HloModule* module,
55       const absl::flat_hash_set<absl::string_view>& execution_threads) override;
56 };
57 
58 }  // namespace xla::gpu
59 
60 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_SIMPLIFY_PADDING_H_
61