1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_CONV_RUNNER_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_CONV_RUNNER_H_ 18 19 #include "absl/types/optional.h" 20 #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h" 21 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" 22 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 23 #include "tensorflow/compiler/xla/service/hlo_instructions.h" 24 #include "tensorflow/compiler/xla/status.h" 25 #include "tensorflow/compiler/xla/statusor.h" 26 #include "tensorflow/compiler/xla/types.h" 27 #include "tensorflow/compiler/xla/xla_data.pb.h" 28 #include "tensorflow/core/platform/stream_executor_no_cuda.h" 29 #include "tensorflow/stream_executor/dnn.h" 30 31 namespace xla { 32 namespace gpu { 33 34 struct RunConvOptions { 35 // Nullable output-parameter pointer for profiling results. 36 se::dnn::ProfileResult* profile_result = nullptr; 37 38 // Use this algorithm, instead of the one from the instruction. 39 absl::optional<se::dnn::AlgorithmDesc> algo_override; 40 41 // Use this scratch_bytes size, instead of the one from the instruction. 42 absl::optional<size_t> scratch_size_override; 43 }; 44 45 // Structure to describe static properties of a GPU convolution. 46 struct GpuConvConfig { 47 // Field related to cuDNN's fused convolution are in FusionConfig & 48 // FusionParams structures. The result thus is defined as: 49 // activation(conv_result_scale * conv(x, w) + 50 // side_input_scale * side_input + broadcast(bias)) 51 // 52 // The most common fused conv is conv forward + relu/identity, for example. 53 // 54 // bias_buf is a single-dimensional array, with the length equal to the number 55 // of output features. It'll be broadcasted to the output shape in order to be 56 // added to the final results. 57 // 58 // side_input_buf, if valid, must have the same shape as the output buffer. 59 struct FusionConfig { 60 se::dnn::ActivationMode mode; 61 double side_input_scale; 62 }; 63 64 PrimitiveType input_type; 65 PrimitiveType output_type; 66 CudnnConvKind kind; 67 se::dnn::AlgorithmConfig algorithm; 68 double conv_result_scale; 69 70 se::dnn::BatchDescriptor input_descriptor; 71 se::dnn::FilterDescriptor filter_descriptor; 72 se::dnn::BatchDescriptor output_descriptor; 73 se::dnn::ConvolutionDescriptor conv_desc; 74 75 Shape input_shape; 76 Shape filter_shape; 77 Shape output_shape; 78 absl::optional<FusionConfig> fusion; 79 }; 80 81 // Implementation struct exposed for debugging and log analysis. 82 struct GpuConvParams { 83 GpuConvConfig config; 84 struct FusionParams { 85 se::DeviceMemoryBase bias_buf; 86 se::DeviceMemoryBase side_input_buf; // nullable 87 }; 88 89 se::DeviceMemoryBase input_buf; 90 se::DeviceMemoryBase filter_buf; 91 se::DeviceMemoryBase output_buf; 92 93 absl::optional<FusionParams> fusion; 94 }; 95 96 // This file contains low-level routines for running cudnn convolutions. 97 98 // Calls into cudnn to run the specified convolution. 99 // 100 // We provide one overload which takes a scratch buffer, and another which takes 101 // an allocator which is responsible for allocating the scratch space. In 102 // theory the second one shouldn't be necessary -- users of this function could 103 // just ask cudnn how much scratch space it needs for a particular convolution. 104 // But in practice, StreamExecutor does not expose such an API, and in the name 105 // of parsimony, perhaps it's better not to add it. Instead, the first time you 106 // call a convolution, you should call the version that takes a scratch 107 // allocator and take note of how much memory is used. The next time you call 108 // the same conv, you can provide an explicitly preallocated scratch buffer of 109 // that size, if you like. 110 Status RunGpuConv(const GpuConvConfig& conv_config, 111 absl::Span<se::DeviceMemoryBase> operand_buffers, 112 se::DeviceMemoryBase result_buffer, 113 se::DeviceMemoryBase scratch_buf, se::Stream* stream, 114 RunConvOptions = {}); 115 116 Status RunGpuConv(const GpuConvConfig& conv_config, 117 absl::Span<se::DeviceMemoryBase> operand_buffers, 118 se::DeviceMemoryBase result_buffer, 119 se::ScratchAllocator* scratch_allocator, se::Stream* stream, 120 RunConvOptions = {}); 121 122 // Struct to describe properties of a convolution without being tied to specific 123 // IR. Will be used to help build Convolution thunks from either XLA HLO or 124 // LHLO GPU dialect in MLIR. 125 struct GpuConvDescriptor { 126 CudnnConvKind kind; 127 CudnnConvBackendConfig backend_config; 128 Shape operand0_shape; 129 Shape operand1_shape; 130 Shape result_shape; 131 size_t scratch_size; 132 Window window; 133 ConvolutionDimensionNumbers dnums; 134 int64 feature_group_count; 135 }; 136 137 // Returns the convolution configuration given a XLA HLO instruction. 138 StatusOr<GpuConvConfig> GetGpuConvConfig( 139 const HloCustomCallInstruction* cudnn_call); 140 141 // Returns the convolution configuration given a convolution descriptor `desc` 142 // and a string representation of the convolution instruction `inst_as_string` 143 // (for error reporting). 144 StatusOr<GpuConvConfig> GetGpuConvConfig(const GpuConvDescriptor& desc, 145 absl::string_view inst_as_string); 146 147 // Implementation details exposed for debugging and log analysis. 148 StatusOr<GpuConvParams> GetGpuConvParams( 149 const GpuConvConfig& conv_config, 150 absl::Span<se::DeviceMemoryBase> operand_buffers, 151 se::DeviceMemoryBase result_buffer); 152 153 } // namespace gpu 154 } // namespace xla 155 156 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_CONV_RUNNER_H_ 157