• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/gpu_conv_runner.h"
17 
18 #include "absl/strings/str_cat.h"
19 #include "absl/strings/string_view.h"
20 #include "tensorflow/compiler/xla/layout_util.h"
21 #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
22 #include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h"
23 #include "tensorflow/compiler/xla/shape_util.h"
24 #include "tensorflow/compiler/xla/status_macros.h"
25 #include "tensorflow/compiler/xla/util.h"
26 
27 namespace xla {
28 namespace gpu {
29 namespace {
30 
31 using se::DeviceMemory;
32 using se::DeviceMemoryBase;
33 using se::Stream;
34 using se::dnn::AlgorithmConfig;
35 using se::dnn::BatchDescriptor;
36 using se::dnn::ConvolutionDescriptor;
37 using se::dnn::DataLayout;
38 using se::dnn::DimIndex;
39 using se::dnn::FilterDescriptor;
40 using se::dnn::FilterLayout;
41 using se::dnn::ProfileResult;
42 
43 // A StreamExecutor ScratchAllocator that wraps a single XLA allocation,
44 // returning it (in its entirety) the first time Allocate() is called.
45 class ScratchBufAllocator : public se::ScratchAllocator {
46  public:
ScratchBufAllocator(se::DeviceMemoryBase scratch)47   explicit ScratchBufAllocator(se::DeviceMemoryBase scratch)
48       : scratch_(scratch) {}
49 
50   ~ScratchBufAllocator() override = default;
51 
GetMemoryLimitInBytes()52   int64 GetMemoryLimitInBytes() override { return scratch_.size(); }
53 
AllocateBytes(int64_t byte_size)54   se::port::StatusOr<DeviceMemory<uint8>> AllocateBytes(
55       int64_t byte_size) override {
56     if (allocated_) {
57       return se::port::InternalError(
58           "Can't allocate twice from a ScratchBufAllocator.");
59     }
60     if (byte_size > scratch_.size()) {
61       return se::port::InternalError(absl::StrCat(
62           "Can't allocate ", byte_size,
63           " bytes from a ScratchBufAllocator of size ", scratch_.size()));
64     }
65 
66     allocated_ = true;
67     return se::DeviceMemory<uint8>(scratch_);
68   }
69 
70  private:
71   se::DeviceMemoryBase scratch_;
72   bool allocated_ = false;
73 };
74 
75 template <typename ElementType, typename OutputType>
RunGpuConvForward(GpuConvParams params,se::ScratchAllocator * scratch_allocator,se::Stream * stream,RunConvOptions options,DeviceMemory<ElementType> input_buf,DeviceMemory<ElementType> filter_buf,DeviceMemory<OutputType> output_buf,AlgorithmConfig algorithm)76 Status RunGpuConvForward(GpuConvParams params,
77                          se::ScratchAllocator* scratch_allocator,
78                          se::Stream* stream, RunConvOptions options,
79                          DeviceMemory<ElementType> input_buf,
80                          DeviceMemory<ElementType> filter_buf,
81                          DeviceMemory<OutputType> output_buf,
82                          AlgorithmConfig algorithm) {
83   if (params.config.conv_result_scale != 1) {
84     return InternalError(
85         "StreamExecutor doesn't support scaled convolution: %lf.",
86         params.config.conv_result_scale);
87   }
88   return stream->ConvolveWithAlgorithm(
89       params.config.input_descriptor, input_buf,
90       params.config.filter_descriptor, filter_buf, params.config.conv_desc,
91       params.config.output_descriptor, &output_buf, scratch_allocator,
92       algorithm, options.profile_result);
93 }
94 
95 template <typename ElementType, typename BiasType, typename OutputType>
RunGpuConvForwardActivation(GpuConvParams params,se::ScratchAllocator * scratch_allocator,se::Stream * stream,RunConvOptions options,DeviceMemory<ElementType> input_buf,DeviceMemory<ElementType> filter_buf,DeviceMemory<OutputType> output_buf,AlgorithmConfig algorithm)96 Status RunGpuConvForwardActivation(GpuConvParams params,
97                                    se::ScratchAllocator* scratch_allocator,
98                                    se::Stream* stream, RunConvOptions options,
99                                    DeviceMemory<ElementType> input_buf,
100                                    DeviceMemory<ElementType> filter_buf,
101                                    DeviceMemory<OutputType> output_buf,
102                                    AlgorithmConfig algorithm) {
103   BatchDescriptor bias_desc;
104   bias_desc.set_count(1)
105       .set_height(1)
106       .set_width(1)
107       .set_feature_map_count(
108           params.config.output_descriptor.feature_map_count())
109       .set_layout([&] {
110         // Normalize NCHW_VECT_C to NCHW for layout of `bias`, even though it's
111         // actually the same (because `bias` only has one dimension):  cudnn
112         // does not accept NCHW_VECT_C for `bias`.
113         DataLayout layout = params.config.output_descriptor.layout();
114         switch (layout) {
115           case DataLayout::kBatchDepthYX4:
116           case DataLayout::kBatchDepthYX32:
117             return DataLayout::kBatchDepthYX;
118           default:
119             return layout;
120         }
121       }());
122 
123   se::DeviceMemory<OutputType> side_input(params.fusion->side_input_buf);
124   // If there is no side input, use output as the side input.
125   if (side_input.is_null()) {
126     if (params.config.fusion->side_input_scale != 0) {
127       return InternalError(
128           "Side input scale is not 0, yet no side input buffer is "
129           "provided");
130     }
131     // Since side-input scale is 0, the values in the side input don't
132     // matter.  The simplest thing to do would be to pass in a null buffer
133     // for the side input, but cudnn doesn't allow this.  cudnn does promise
134     // that if side-input-scale is 0 the side input won't be read, so we
135     // just pass in the output buffer, since it's handy and has the correct
136     // size.
137     side_input = output_buf;
138   }
139 
140   return stream->FusedConvolveWithAlgorithm(
141       params.config.input_descriptor, input_buf,
142       params.config.conv_result_scale, params.config.filter_descriptor,
143       filter_buf, params.config.conv_desc, side_input,
144       params.config.fusion->side_input_scale, bias_desc,
145       DeviceMemory<BiasType>(params.fusion->bias_buf),
146       params.config.fusion->mode, params.config.output_descriptor, &output_buf,
147       scratch_allocator, algorithm, options.profile_result);
148 }
149 
150 // StreamExecutor supports various data types via overloading, and the support
151 // is maintained on-demand. To avoid calling into non-exist overloads, we have
152 // to carefully not call into them by using enable_if.
153 // TODO(timshen): Ideally, to avoid such complication in the runner, we can turn
154 // StreamExecutor overloadings to template functions, and for unsupported data
155 // types return runtime errors.
156 // This is the specialization for double, float, and half types.  All kinds of
157 // convolutions are supported here.
158 template <typename ElementType, typename BiasType, typename OutputType,
159           typename std::enable_if<
160               !std::is_integral<ElementType>::value>::type* = nullptr>
RunGpuConvInternalImpl(GpuConvParams params,se::ScratchAllocator * scratch_allocator,se::Stream * stream,RunConvOptions options,DeviceMemory<ElementType> input_buf,DeviceMemory<ElementType> filter_buf,DeviceMemory<OutputType> output_buf,AlgorithmConfig algorithm)161 Status RunGpuConvInternalImpl(GpuConvParams params,
162                               se::ScratchAllocator* scratch_allocator,
163                               se::Stream* stream, RunConvOptions options,
164                               DeviceMemory<ElementType> input_buf,
165                               DeviceMemory<ElementType> filter_buf,
166                               DeviceMemory<OutputType> output_buf,
167                               AlgorithmConfig algorithm) {
168   switch (params.config.kind) {
169     case CudnnConvKind::kForward:
170       return RunGpuConvForward(params, scratch_allocator, stream, options,
171                                input_buf, filter_buf, output_buf, algorithm);
172     case CudnnConvKind::kBackwardInput:
173       if (params.config.conv_result_scale != 1) {
174         return InternalError(
175             "StreamExecutor doesn't support scaled convolution: %lf.",
176             params.config.conv_result_scale);
177       }
178       return stream->ConvolveBackwardDataWithAlgorithm(
179           params.config.filter_descriptor, filter_buf,
180           params.config.output_descriptor, output_buf, params.config.conv_desc,
181           params.config.input_descriptor, &input_buf, scratch_allocator,
182           algorithm, options.profile_result);
183       break;
184     case CudnnConvKind::kBackwardFilter:
185       if (params.config.conv_result_scale != 1) {
186         return InternalError(
187             "StreamExecutor doesn't support scaled convolution: %lf.",
188             params.config.conv_result_scale);
189       }
190       return stream->ConvolveBackwardFilterWithAlgorithm(
191           params.config.input_descriptor, input_buf,
192           params.config.output_descriptor, output_buf, params.config.conv_desc,
193           params.config.filter_descriptor, &filter_buf, scratch_allocator,
194           algorithm, options.profile_result);
195       break;
196     case CudnnConvKind::kForwardActivation: {
197       return RunGpuConvForwardActivation<ElementType, BiasType, OutputType>(
198           params, scratch_allocator, stream, options, input_buf, filter_buf,
199           output_buf, algorithm);
200     }
201   }
202   return Status::OK();
203 }
204 
205 // Specialization for integer types.  Only two forward convolutions are allowed.
206 template <typename ElementType, typename BiasType, typename OutputType,
207           typename std::enable_if<std::is_integral<ElementType>::value>::type* =
208               nullptr>
RunGpuConvInternalImpl(GpuConvParams params,se::ScratchAllocator * scratch_allocator,se::Stream * stream,RunConvOptions options,DeviceMemory<ElementType> input_buf,DeviceMemory<ElementType> filter_buf,DeviceMemory<OutputType> output_buf,AlgorithmConfig algorithm)209 Status RunGpuConvInternalImpl(GpuConvParams params,
210                               se::ScratchAllocator* scratch_allocator,
211                               se::Stream* stream, RunConvOptions options,
212                               DeviceMemory<ElementType> input_buf,
213                               DeviceMemory<ElementType> filter_buf,
214                               DeviceMemory<OutputType> output_buf,
215                               AlgorithmConfig algorithm) {
216   switch (params.config.kind) {
217     case CudnnConvKind::kForward:
218       return RunGpuConvForward(params, scratch_allocator, stream, options,
219                                input_buf, filter_buf, output_buf, algorithm);
220     case CudnnConvKind::kForwardActivation:
221       return RunGpuConvForwardActivation<ElementType, BiasType, OutputType>(
222           params, scratch_allocator, stream, options, input_buf, filter_buf,
223           output_buf, algorithm);
224     default:
225       return InternalError(
226           "Only convolution kinds kForward and kForwardActivation are "
227           "supported for integer types");
228   }
229   return Status::OK();
230 }
231 
232 template <typename ElementType, typename BiasType, typename OutputType>
RunGpuConvImpl(const GpuConvParams & params,se::ScratchAllocator * scratch_allocator,se::Stream * stream,RunConvOptions options)233 Status RunGpuConvImpl(const GpuConvParams& params,
234                       se::ScratchAllocator* scratch_allocator,
235                       se::Stream* stream, RunConvOptions options) {
236   auto input_buf = se::DeviceMemory<ElementType>(params.input_buf);
237   auto filter_buf = se::DeviceMemory<ElementType>(params.filter_buf);
238   auto output_buf = se::DeviceMemory<OutputType>(params.output_buf);
239   AlgorithmConfig algorithm = params.config.algorithm;
240 
241   if (options.algo_override.has_value()) {
242     algorithm = AlgorithmConfig(*options.algo_override);
243     if (options.scratch_size_override.has_value()) {
244       algorithm.set_scratch_size(*options.scratch_size_override);
245     }
246   }
247 
248   Status run_status = RunGpuConvInternalImpl<ElementType, BiasType, OutputType>(
249       params, scratch_allocator, stream, options, input_buf, filter_buf,
250       output_buf, algorithm);
251 
252   if (run_status != Status::OK()) {
253     return run_status;
254   }
255 
256   if (!stream->ok()) {
257     return InternalError(
258         "Unable to launch convolution with type %s and algorithm (%d, %s)",
259         CudnnConvKindToString(params.config.kind),
260         algorithm.algorithm()->algo_id(),
261         algorithm.algorithm_no_scratch().has_value()
262             ? absl::StrCat(algorithm.algorithm_no_scratch()->algo_id())
263             : "none");
264   }
265   return Status::OK();
266 }
267 
GetVectCSize(DataLayout layout)268 int64 GetVectCSize(DataLayout layout) {
269   switch (layout) {
270     case DataLayout::kBatchDepthYX4:
271       return 4;
272     case DataLayout::kBatchDepthYX32:
273       return 32;
274     default:
275       return 1;
276   }
277 }
278 
GetVectCSize(FilterLayout layout)279 int64 GetVectCSize(FilterLayout layout) {
280   switch (layout) {
281     case FilterLayout::kOutputInputYX4:
282       return 4;
283     case FilterLayout::kOutputInputYX32:
284       return 32;
285     default:
286       return 1;
287   }
288 }
289 
290 }  // anonymous namespace
291 
GetGpuConvConfig(const GpuConvDescriptor & desc,const absl::string_view inst_as_string)292 StatusOr<GpuConvConfig> GetGpuConvConfig(
293     const GpuConvDescriptor& desc, const absl::string_view inst_as_string) {
294   GpuConvConfig config;
295 
296   const Shape& operand0_shape = desc.operand0_shape;
297   const Shape& operand1_shape = desc.operand1_shape;
298   const Shape& result_shape = desc.result_shape;
299   const CudnnConvBackendConfig& backend_config = desc.backend_config;
300 
301   config.input_type = operand0_shape.element_type();
302   config.output_type = result_shape.element_type();
303   config.kind = desc.kind;
304 
305   // The third field is scratch size stored from conv_algorithm_picker
306   // The operand is added to the shape field of the conv instruction
307   // in GpuConvAlgorithmPicker::RunOnInstruction() call.
308   config.algorithm = se::dnn::AlgorithmConfig(
309       se::dnn::AlgorithmDesc(backend_config.algorithm(),
310                              backend_config.tensor_ops_enabled()),
311       desc.scratch_size);
312   config.conv_result_scale = backend_config.conv_result_scale();
313 
314   switch (config.kind) {
315     case CudnnConvKind::kForward:
316     case CudnnConvKind::kForwardActivation:
317       config.input_shape = operand0_shape;
318       config.filter_shape = operand1_shape;
319       config.output_shape = result_shape;
320       break;
321     case CudnnConvKind::kBackwardInput:
322       config.input_shape = result_shape;
323       config.filter_shape = operand1_shape;
324       config.output_shape = operand0_shape;
325       break;
326     case CudnnConvKind::kBackwardFilter:
327       config.input_shape = operand0_shape;
328       config.filter_shape = result_shape;
329       config.output_shape = operand1_shape;
330       break;
331     default:
332       return InternalError("Unknown convolution kind");
333   }
334 
335   if (config.kind == CudnnConvKind::kForwardActivation) {
336     config.fusion.emplace();
337     GpuConvConfig::FusionConfig& fusion = *config.fusion;
338     if (!se::dnn::ActivationMode_IsValid(backend_config.activation_mode())) {
339       return InternalError("Bad activation mode: %s",
340                            backend_config.ShortDebugString());
341     }
342     fusion.mode =
343         static_cast<se::dnn::ActivationMode>(backend_config.activation_mode());
344     fusion.side_input_scale = backend_config.side_input_scale();
345   }
346 
347   const Window& window = desc.window;
348   const ConvolutionDimensionNumbers& dnums = desc.dnums;
349 
350   VLOG(3) << "Convolution Algorithm: "
351           << config.algorithm.algorithm()->algo_id();
352   VLOG(3) << "tensor_ops_enabled: "
353           << config.algorithm.algorithm()->tensor_ops_enabled();
354   VLOG(3) << "Convolution kind: " << CudnnConvKindToString(config.kind);
355   VLOG(3) << "input shape: "
356           << ShapeUtil::HumanStringWithLayout(config.input_shape);
357   VLOG(3) << "filter shape: "
358           << ShapeUtil::HumanStringWithLayout(config.filter_shape);
359   VLOG(3) << "Output shape: "
360           << ShapeUtil::HumanStringWithLayout(config.output_shape);
361   VLOG(3) << "Window: { " << window.ShortDebugString() << " }";
362   VLOG(3) << "Dim nums: { " << dnums.ShortDebugString() << " }";
363 
364   const int num_dimensions = window.dimensions_size();
365   CHECK_LE(num_dimensions, 3) << inst_as_string;
366 
367   // cuDNN does not support 1D convolutions. We therefore express 1D
368   // convolutions as 2D convolutions where the first spatial dimension is 1.
369   // This matches the behavior of TF (see definition of conv1d in
370   // tensorflow/python/ops/nn_ops.py).
371   const int effective_num_dimensions = std::max(2, num_dimensions);
372 
373   // If one dimension is reversed, we need to have all dimensions reversed (so
374   // we're doing convolution not cross correlation).
375   const bool dims_reversed =
376       window.dimensions_size() > 0 && window.dimensions()[0].window_reversal();
377 
378   CHECK_EQ(num_dimensions, dnums.input_spatial_dimensions_size())
379       << inst_as_string;
380   CHECK_EQ(num_dimensions, dnums.kernel_spatial_dimensions_size())
381       << inst_as_string;
382   CHECK_EQ(num_dimensions, dnums.output_spatial_dimensions_size())
383       << inst_as_string;
384   for (const WindowDimension& dim : window.dimensions()) {
385     CHECK_EQ(dims_reversed, dim.window_reversal()) << inst_as_string;
386     CHECK_EQ(dim.padding_low(), dim.padding_high()) << inst_as_string;
387     CHECK_EQ(dim.base_dilation(), 1)
388         << "cudnn does not support base dilation; it "
389            "must be made explicit with a kPad: "
390         << inst_as_string;
391   }
392 
393   // cuDNN's convolution APIs support the BDYX layout for activations/output and
394   // the OIYX layout for weights.
395   DataLayout input_dl;
396   FilterLayout filter_dl;
397   DataLayout output_dl;
398 
399   const Shape& input_shape = config.input_shape;
400   const Shape& filter_shape = config.filter_shape;
401   const Shape& output_shape = config.output_shape;
402 
403   TF_ASSIGN_OR_RETURN(std::tie(input_dl, filter_dl, output_dl),
404                       XlaConvShapesToStreamExecutorLayouts(
405                           dnums, input_shape, filter_shape, output_shape));
406 
407   BatchDescriptor& input_descriptor = config.input_descriptor;
408   input_descriptor = BatchDescriptor(effective_num_dimensions);
409   input_descriptor.set_layout(input_dl)
410       .set_feature_map_count(
411           GetVectCSize(input_dl) *
412           input_shape.dimensions(dnums.input_feature_dimension()))
413       .set_count(input_shape.dimensions(dnums.input_batch_dimension()));
414   for (int dim = 0; dim < num_dimensions; ++dim) {
415     // Note that the dimensions are reversed. The same holds below.
416     input_descriptor.set_spatial_dim(
417         static_cast<DimIndex>(effective_num_dimensions - dim - 1),
418         input_shape.dimensions(dnums.input_spatial_dimensions(dim)));
419   }
420 
421   FilterDescriptor& filter_descriptor = config.filter_descriptor;
422   filter_descriptor = FilterDescriptor(effective_num_dimensions);
423   filter_descriptor.set_layout(filter_dl)
424       .set_input_feature_map_count(
425           GetVectCSize(filter_dl) *
426           filter_shape.dimensions(dnums.kernel_input_feature_dimension()))
427       .set_output_feature_map_count(
428           filter_shape.dimensions(dnums.kernel_output_feature_dimension()));
429   for (int dim = 0; dim < num_dimensions; ++dim) {
430     filter_descriptor.set_spatial_dim(
431         static_cast<DimIndex>(effective_num_dimensions - dim - 1),
432         filter_shape.dimensions(dnums.kernel_spatial_dimensions(dim)));
433   }
434 
435   config.conv_desc = ConvolutionDescriptor(effective_num_dimensions);
436   config.conv_desc.set_group_count(desc.feature_group_count);
437   config.conv_desc.set_convolution_not_crosscorr(dims_reversed);
438   for (int dim = 0; dim < num_dimensions; ++dim) {
439     config.conv_desc
440         .set_zero_padding(
441             static_cast<DimIndex>(effective_num_dimensions - dim - 1),
442             window.dimensions(dim).padding_low())
443         .set_filter_stride(
444             static_cast<DimIndex>(effective_num_dimensions - dim - 1),
445             window.dimensions(dim).stride())
446         .set_dilation_rate(
447             static_cast<DimIndex>(effective_num_dimensions - dim - 1),
448             window.dimensions(dim).window_dilation());
449   }
450 
451   BatchDescriptor& output_descriptor = config.output_descriptor;
452   output_descriptor = BatchDescriptor(effective_num_dimensions);
453   output_descriptor.set_layout(output_dl)
454       .set_feature_map_count(
455           GetVectCSize(output_dl) *
456           output_shape.dimensions(dnums.output_feature_dimension()))
457       .set_count(output_shape.dimensions(dnums.output_batch_dimension()));
458   for (int dim = 0; dim < num_dimensions; ++dim) {
459     output_descriptor.set_spatial_dim(
460         static_cast<DimIndex>(effective_num_dimensions - dim - 1),
461         output_shape.dimensions(dnums.output_spatial_dimensions(dim)));
462   }
463 
464   // Add a singleton dimension in the 1D convolution case.
465   for (int dim = 0; dim < effective_num_dimensions - num_dimensions; dim++) {
466     input_descriptor.set_spatial_dim(static_cast<DimIndex>(dim), 1);
467     output_descriptor.set_spatial_dim(static_cast<DimIndex>(dim), 1);
468     filter_descriptor.set_spatial_dim(static_cast<DimIndex>(dim), 1);
469     config.conv_desc.set_zero_padding(static_cast<DimIndex>(dim), 0)
470         .set_filter_stride(static_cast<DimIndex>(dim), 1);
471   }
472 
473   return config;
474 }
475 
GetGpuConvConfig(const HloCustomCallInstruction * cudnn_call)476 StatusOr<GpuConvConfig> GetGpuConvConfig(
477     const HloCustomCallInstruction* cudnn_call) {
478   GpuConvDescriptor descriptor;
479 
480   TF_ASSIGN_OR_RETURN(descriptor.kind, GetCudnnConvKind(cudnn_call));
481   TF_ASSIGN_OR_RETURN(descriptor.backend_config,
482                       cudnn_call->backend_config<CudnnConvBackendConfig>());
483   descriptor.operand0_shape = cudnn_call->operand(0)->shape();
484   descriptor.operand1_shape = cudnn_call->operand(1)->shape();
485   descriptor.result_shape = cudnn_call->shape().tuple_shapes(0);
486   descriptor.scratch_size = cudnn_call->shape().tuple_shapes(1).dimensions(0);
487   descriptor.window = cudnn_call->window();
488   descriptor.dnums = cudnn_call->convolution_dimension_numbers();
489   descriptor.feature_group_count = cudnn_call->feature_group_count();
490   return GetGpuConvConfig(descriptor, cudnn_call->ToString());
491 }
492 
GetGpuConvParams(const GpuConvConfig & config,absl::Span<se::DeviceMemoryBase> operand_buffers,se::DeviceMemoryBase result_buffer)493 StatusOr<GpuConvParams> GetGpuConvParams(
494     const GpuConvConfig& config,
495     absl::Span<se::DeviceMemoryBase> operand_buffers,
496     se::DeviceMemoryBase result_buffer) {
497   GpuConvParams params;
498   params.config = config;
499 
500   switch (config.kind) {
501     case CudnnConvKind::kForward:
502     case CudnnConvKind::kForwardActivation:
503       params.input_buf = operand_buffers[0];
504       params.filter_buf = operand_buffers[1];
505       params.output_buf = result_buffer;
506       break;
507     case CudnnConvKind::kBackwardInput:
508       params.input_buf = result_buffer;
509       params.filter_buf = operand_buffers[1];
510       params.output_buf = operand_buffers[0];
511       break;
512     case CudnnConvKind::kBackwardFilter:
513       params.input_buf = operand_buffers[0];
514       params.filter_buf = result_buffer;
515       params.output_buf = operand_buffers[1];
516       break;
517   }
518 
519   if (config.kind == CudnnConvKind::kForwardActivation) {
520     params.fusion.emplace();
521     GpuConvParams::FusionParams& fusion = *params.fusion;
522     fusion.bias_buf = operand_buffers[2];
523     if (operand_buffers.size() >= 4) {
524       fusion.side_input_buf = operand_buffers[3];
525     }
526   }
527 
528   return params;
529 }
530 
RunGpuConv(const gpu::GpuConvConfig & config,absl::Span<se::DeviceMemoryBase> operand_buffers,se::DeviceMemoryBase result_buffer,se::DeviceMemoryBase scratch_buf,se::Stream * stream,RunConvOptions options)531 Status RunGpuConv(const gpu::GpuConvConfig& config,
532                   absl::Span<se::DeviceMemoryBase> operand_buffers,
533                   se::DeviceMemoryBase result_buffer,
534                   se::DeviceMemoryBase scratch_buf, se::Stream* stream,
535                   RunConvOptions options) {
536   ScratchBufAllocator scratch_allocator(scratch_buf);
537   return RunGpuConv(config, operand_buffers, result_buffer, &scratch_allocator,
538                     stream, options);
539 }
540 
RunGpuConv(const gpu::GpuConvConfig & config,absl::Span<se::DeviceMemoryBase> operand_buffers,se::DeviceMemoryBase result_buffer,se::ScratchAllocator * scratch_allocator,se::Stream * stream,RunConvOptions options)541 Status RunGpuConv(const gpu::GpuConvConfig& config,
542                   absl::Span<se::DeviceMemoryBase> operand_buffers,
543                   se::DeviceMemoryBase result_buffer,
544                   se::ScratchAllocator* scratch_allocator, se::Stream* stream,
545                   RunConvOptions options) {
546   TF_ASSIGN_OR_RETURN(GpuConvParams params,
547                       GetGpuConvParams(config, operand_buffers, result_buffer));
548 
549   PrimitiveType input_primitive_type = config.input_type;
550   switch (input_primitive_type) {
551     case F16:
552       return RunGpuConvImpl<Eigen::half, Eigen::half, Eigen::half>(
553           params, scratch_allocator, stream, options);
554     case BF16:
555       return RunGpuConvImpl<Eigen::bfloat16, Eigen::bfloat16, Eigen::bfloat16>(
556           params, scratch_allocator, stream, options);
557     case F32:
558       return RunGpuConvImpl<float, float, float>(params, scratch_allocator,
559                                                  stream, options);
560     case F64:
561       return RunGpuConvImpl<double, double, double>(params, scratch_allocator,
562                                                     stream, options);
563     case S8: {
564       PrimitiveType output_primitive_type = config.output_type;
565       switch (output_primitive_type) {
566         case F32:
567           return RunGpuConvImpl<int8, float, float>(params, scratch_allocator,
568                                                     stream, options);
569         case S8:
570           return RunGpuConvImpl<int8, float, int8>(params, scratch_allocator,
571                                                    stream, options);
572         default:
573           return Unimplemented("Unimplemented convolution");
574       }
575     }
576     default:
577       return Unimplemented("Unimplemented convolution");
578   }
579 }
580 
581 }  // namespace gpu
582 }  // namespace xla
583