• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_CONV_RUNNER_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_CONV_RUNNER_H_
18 
19 #include "absl/types/optional.h"
20 #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
21 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
22 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
23 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
24 #include "tensorflow/compiler/xla/status.h"
25 #include "tensorflow/compiler/xla/statusor.h"
26 #include "tensorflow/compiler/xla/types.h"
27 #include "tensorflow/compiler/xla/xla_data.pb.h"
28 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
29 #include "tensorflow/stream_executor/dnn.h"
30 
31 namespace xla {
32 namespace gpu {
33 
34 struct RunConvOptions {
35   // Nullable output-parameter pointer for profiling results.
36   se::dnn::ProfileResult* profile_result = nullptr;
37 
38   // Use this algorithm, instead of the one from the instruction.
39   absl::optional<se::dnn::AlgorithmDesc> algo_override;
40 
41   // Use this scratch_bytes size, instead of the one from the instruction.
42   absl::optional<size_t> scratch_size_override;
43 };
44 
45 // Structure to describe static properties of a GPU convolution.
46 struct GpuConvConfig {
47   // Field related to cuDNN's fused convolution are in FusionConfig &
48   // FusionParams structures. The result thus is defined as:
49   //   activation(conv_result_scale * conv(x, w) +
50   //       side_input_scale * side_input + broadcast(bias))
51   //
52   // The most common fused conv is conv forward + relu/identity, for example.
53   //
54   // bias_buf is a single-dimensional array, with the length equal to the number
55   // of output features. It'll be broadcasted to the output shape in order to be
56   // added to the final results.
57   //
58   // side_input_buf, if valid, must have the same shape as the output buffer.
59   struct FusionConfig {
60     se::dnn::ActivationMode mode;
61     double side_input_scale;
62   };
63 
64   PrimitiveType input_type;
65   PrimitiveType output_type;
66   CudnnConvKind kind;
67   se::dnn::AlgorithmConfig algorithm;
68   double conv_result_scale;
69 
70   se::dnn::BatchDescriptor input_descriptor;
71   se::dnn::FilterDescriptor filter_descriptor;
72   se::dnn::BatchDescriptor output_descriptor;
73   se::dnn::ConvolutionDescriptor conv_desc;
74 
75   Shape input_shape;
76   Shape filter_shape;
77   Shape output_shape;
78   absl::optional<FusionConfig> fusion;
79 };
80 
81 // Implementation struct exposed for debugging and log analysis.
82 struct GpuConvParams {
83   GpuConvConfig config;
84   struct FusionParams {
85     se::DeviceMemoryBase bias_buf;
86     se::DeviceMemoryBase side_input_buf;  // nullable
87   };
88 
89   se::DeviceMemoryBase input_buf;
90   se::DeviceMemoryBase filter_buf;
91   se::DeviceMemoryBase output_buf;
92 
93   absl::optional<FusionParams> fusion;
94 };
95 
96 // This file contains low-level routines for running cudnn convolutions.
97 
98 // Calls into cudnn to run the specified convolution.
99 //
100 // We provide one overload which takes a scratch buffer, and another which takes
101 // an allocator which is responsible for allocating the scratch space.  In
102 // theory the second one shouldn't be necessary -- users of this function could
103 // just ask cudnn how much scratch space it needs for a particular convolution.
104 // But in practice, StreamExecutor does not expose such an API, and in the name
105 // of parsimony, perhaps it's better not to add it.  Instead, the first time you
106 // call a convolution, you should call the version that takes a scratch
107 // allocator and take note of how much memory is used.  The next time you call
108 // the same conv, you can provide an explicitly preallocated scratch buffer of
109 // that size, if you like.
110 Status RunGpuConv(const GpuConvConfig& conv_config,
111                   absl::Span<se::DeviceMemoryBase> operand_buffers,
112                   se::DeviceMemoryBase result_buffer,
113                   se::DeviceMemoryBase scratch_buf, se::Stream* stream,
114                   RunConvOptions = {});
115 
116 Status RunGpuConv(const GpuConvConfig& conv_config,
117                   absl::Span<se::DeviceMemoryBase> operand_buffers,
118                   se::DeviceMemoryBase result_buffer,
119                   se::ScratchAllocator* scratch_allocator, se::Stream* stream,
120                   RunConvOptions = {});
121 
122 // Struct to describe properties of a convolution without being tied to specific
123 // IR. Will be used to help build Convolution thunks from either XLA HLO or
124 // LHLO GPU dialect in MLIR.
125 struct GpuConvDescriptor {
126   CudnnConvKind kind;
127   CudnnConvBackendConfig backend_config;
128   Shape operand0_shape;
129   Shape operand1_shape;
130   Shape result_shape;
131   size_t scratch_size;
132   Window window;
133   ConvolutionDimensionNumbers dnums;
134   int64 feature_group_count;
135 };
136 
137 // Returns the convolution configuration given a XLA HLO instruction.
138 StatusOr<GpuConvConfig> GetGpuConvConfig(
139     const HloCustomCallInstruction* cudnn_call);
140 
141 // Returns the convolution configuration given a convolution descriptor `desc`
142 // and a string representation of the convolution instruction `inst_as_string`
143 // (for error reporting).
144 StatusOr<GpuConvConfig> GetGpuConvConfig(const GpuConvDescriptor& desc,
145                                          absl::string_view inst_as_string);
146 
147 // Implementation details exposed for debugging and log analysis.
148 StatusOr<GpuConvParams> GetGpuConvParams(
149     const GpuConvConfig& conv_config,
150     absl::Span<se::DeviceMemoryBase> operand_buffers,
151     se::DeviceMemoryBase result_buffer);
152 
153 }  // namespace gpu
154 }  // namespace xla
155 
156 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_CONV_RUNNER_H_
157