1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/gpu_conv_runner.h"
17
18 #include "absl/strings/str_cat.h"
19 #include "absl/strings/string_view.h"
20 #include "tensorflow/compiler/xla/layout_util.h"
21 #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
22 #include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h"
23 #include "tensorflow/compiler/xla/shape_util.h"
24 #include "tensorflow/compiler/xla/status_macros.h"
25 #include "tensorflow/compiler/xla/util.h"
26
27 namespace xla {
28 namespace gpu {
29 namespace {
30
31 using se::DeviceMemory;
32 using se::DeviceMemoryBase;
33 using se::Stream;
34 using se::dnn::AlgorithmConfig;
35 using se::dnn::BatchDescriptor;
36 using se::dnn::ConvolutionDescriptor;
37 using se::dnn::DataLayout;
38 using se::dnn::DimIndex;
39 using se::dnn::FilterDescriptor;
40 using se::dnn::FilterLayout;
41 using se::dnn::ProfileResult;
42
43 // A StreamExecutor ScratchAllocator that wraps a single XLA allocation,
44 // returning it (in its entirety) the first time Allocate() is called.
45 class ScratchBufAllocator : public se::ScratchAllocator {
46 public:
ScratchBufAllocator(se::DeviceMemoryBase scratch)47 explicit ScratchBufAllocator(se::DeviceMemoryBase scratch)
48 : scratch_(scratch) {}
49
50 ~ScratchBufAllocator() override = default;
51
GetMemoryLimitInBytes()52 int64 GetMemoryLimitInBytes() override { return scratch_.size(); }
53
AllocateBytes(int64_t byte_size)54 se::port::StatusOr<DeviceMemory<uint8>> AllocateBytes(
55 int64_t byte_size) override {
56 if (allocated_) {
57 return se::port::InternalError(
58 "Can't allocate twice from a ScratchBufAllocator.");
59 }
60 if (byte_size > scratch_.size()) {
61 return se::port::InternalError(absl::StrCat(
62 "Can't allocate ", byte_size,
63 " bytes from a ScratchBufAllocator of size ", scratch_.size()));
64 }
65
66 allocated_ = true;
67 return se::DeviceMemory<uint8>(scratch_);
68 }
69
70 private:
71 se::DeviceMemoryBase scratch_;
72 bool allocated_ = false;
73 };
74
75 template <typename ElementType, typename OutputType>
RunGpuConvForward(GpuConvParams params,se::ScratchAllocator * scratch_allocator,se::Stream * stream,RunConvOptions options,DeviceMemory<ElementType> input_buf,DeviceMemory<ElementType> filter_buf,DeviceMemory<OutputType> output_buf,AlgorithmConfig algorithm)76 Status RunGpuConvForward(GpuConvParams params,
77 se::ScratchAllocator* scratch_allocator,
78 se::Stream* stream, RunConvOptions options,
79 DeviceMemory<ElementType> input_buf,
80 DeviceMemory<ElementType> filter_buf,
81 DeviceMemory<OutputType> output_buf,
82 AlgorithmConfig algorithm) {
83 if (params.config.conv_result_scale != 1) {
84 return InternalError(
85 "StreamExecutor doesn't support scaled convolution: %lf.",
86 params.config.conv_result_scale);
87 }
88 return stream->ConvolveWithAlgorithm(
89 params.config.input_descriptor, input_buf,
90 params.config.filter_descriptor, filter_buf, params.config.conv_desc,
91 params.config.output_descriptor, &output_buf, scratch_allocator,
92 algorithm, options.profile_result);
93 }
94
95 template <typename ElementType, typename BiasType, typename OutputType>
RunGpuConvForwardActivation(GpuConvParams params,se::ScratchAllocator * scratch_allocator,se::Stream * stream,RunConvOptions options,DeviceMemory<ElementType> input_buf,DeviceMemory<ElementType> filter_buf,DeviceMemory<OutputType> output_buf,AlgorithmConfig algorithm)96 Status RunGpuConvForwardActivation(GpuConvParams params,
97 se::ScratchAllocator* scratch_allocator,
98 se::Stream* stream, RunConvOptions options,
99 DeviceMemory<ElementType> input_buf,
100 DeviceMemory<ElementType> filter_buf,
101 DeviceMemory<OutputType> output_buf,
102 AlgorithmConfig algorithm) {
103 BatchDescriptor bias_desc;
104 bias_desc.set_count(1)
105 .set_height(1)
106 .set_width(1)
107 .set_feature_map_count(
108 params.config.output_descriptor.feature_map_count())
109 .set_layout([&] {
110 // Normalize NCHW_VECT_C to NCHW for layout of `bias`, even though it's
111 // actually the same (because `bias` only has one dimension): cudnn
112 // does not accept NCHW_VECT_C for `bias`.
113 DataLayout layout = params.config.output_descriptor.layout();
114 switch (layout) {
115 case DataLayout::kBatchDepthYX4:
116 case DataLayout::kBatchDepthYX32:
117 return DataLayout::kBatchDepthYX;
118 default:
119 return layout;
120 }
121 }());
122
123 se::DeviceMemory<OutputType> side_input(params.fusion->side_input_buf);
124 // If there is no side input, use output as the side input.
125 if (side_input.is_null()) {
126 if (params.config.fusion->side_input_scale != 0) {
127 return InternalError(
128 "Side input scale is not 0, yet no side input buffer is "
129 "provided");
130 }
131 // Since side-input scale is 0, the values in the side input don't
132 // matter. The simplest thing to do would be to pass in a null buffer
133 // for the side input, but cudnn doesn't allow this. cudnn does promise
134 // that if side-input-scale is 0 the side input won't be read, so we
135 // just pass in the output buffer, since it's handy and has the correct
136 // size.
137 side_input = output_buf;
138 }
139
140 return stream->FusedConvolveWithAlgorithm(
141 params.config.input_descriptor, input_buf,
142 params.config.conv_result_scale, params.config.filter_descriptor,
143 filter_buf, params.config.conv_desc, side_input,
144 params.config.fusion->side_input_scale, bias_desc,
145 DeviceMemory<BiasType>(params.fusion->bias_buf),
146 params.config.fusion->mode, params.config.output_descriptor, &output_buf,
147 scratch_allocator, algorithm, options.profile_result);
148 }
149
150 // StreamExecutor supports various data types via overloading, and the support
151 // is maintained on-demand. To avoid calling into non-exist overloads, we have
152 // to carefully not call into them by using enable_if.
153 // TODO(timshen): Ideally, to avoid such complication in the runner, we can turn
154 // StreamExecutor overloadings to template functions, and for unsupported data
155 // types return runtime errors.
156 // This is the specialization for double, float, and half types. All kinds of
157 // convolutions are supported here.
158 template <typename ElementType, typename BiasType, typename OutputType,
159 typename std::enable_if<
160 !std::is_integral<ElementType>::value>::type* = nullptr>
RunGpuConvInternalImpl(GpuConvParams params,se::ScratchAllocator * scratch_allocator,se::Stream * stream,RunConvOptions options,DeviceMemory<ElementType> input_buf,DeviceMemory<ElementType> filter_buf,DeviceMemory<OutputType> output_buf,AlgorithmConfig algorithm)161 Status RunGpuConvInternalImpl(GpuConvParams params,
162 se::ScratchAllocator* scratch_allocator,
163 se::Stream* stream, RunConvOptions options,
164 DeviceMemory<ElementType> input_buf,
165 DeviceMemory<ElementType> filter_buf,
166 DeviceMemory<OutputType> output_buf,
167 AlgorithmConfig algorithm) {
168 switch (params.config.kind) {
169 case CudnnConvKind::kForward:
170 return RunGpuConvForward(params, scratch_allocator, stream, options,
171 input_buf, filter_buf, output_buf, algorithm);
172 case CudnnConvKind::kBackwardInput:
173 if (params.config.conv_result_scale != 1) {
174 return InternalError(
175 "StreamExecutor doesn't support scaled convolution: %lf.",
176 params.config.conv_result_scale);
177 }
178 return stream->ConvolveBackwardDataWithAlgorithm(
179 params.config.filter_descriptor, filter_buf,
180 params.config.output_descriptor, output_buf, params.config.conv_desc,
181 params.config.input_descriptor, &input_buf, scratch_allocator,
182 algorithm, options.profile_result);
183 break;
184 case CudnnConvKind::kBackwardFilter:
185 if (params.config.conv_result_scale != 1) {
186 return InternalError(
187 "StreamExecutor doesn't support scaled convolution: %lf.",
188 params.config.conv_result_scale);
189 }
190 return stream->ConvolveBackwardFilterWithAlgorithm(
191 params.config.input_descriptor, input_buf,
192 params.config.output_descriptor, output_buf, params.config.conv_desc,
193 params.config.filter_descriptor, &filter_buf, scratch_allocator,
194 algorithm, options.profile_result);
195 break;
196 case CudnnConvKind::kForwardActivation: {
197 return RunGpuConvForwardActivation<ElementType, BiasType, OutputType>(
198 params, scratch_allocator, stream, options, input_buf, filter_buf,
199 output_buf, algorithm);
200 }
201 }
202 return Status::OK();
203 }
204
205 // Specialization for integer types. Only two forward convolutions are allowed.
206 template <typename ElementType, typename BiasType, typename OutputType,
207 typename std::enable_if<std::is_integral<ElementType>::value>::type* =
208 nullptr>
RunGpuConvInternalImpl(GpuConvParams params,se::ScratchAllocator * scratch_allocator,se::Stream * stream,RunConvOptions options,DeviceMemory<ElementType> input_buf,DeviceMemory<ElementType> filter_buf,DeviceMemory<OutputType> output_buf,AlgorithmConfig algorithm)209 Status RunGpuConvInternalImpl(GpuConvParams params,
210 se::ScratchAllocator* scratch_allocator,
211 se::Stream* stream, RunConvOptions options,
212 DeviceMemory<ElementType> input_buf,
213 DeviceMemory<ElementType> filter_buf,
214 DeviceMemory<OutputType> output_buf,
215 AlgorithmConfig algorithm) {
216 switch (params.config.kind) {
217 case CudnnConvKind::kForward:
218 return RunGpuConvForward(params, scratch_allocator, stream, options,
219 input_buf, filter_buf, output_buf, algorithm);
220 case CudnnConvKind::kForwardActivation:
221 return RunGpuConvForwardActivation<ElementType, BiasType, OutputType>(
222 params, scratch_allocator, stream, options, input_buf, filter_buf,
223 output_buf, algorithm);
224 default:
225 return InternalError(
226 "Only convolution kinds kForward and kForwardActivation are "
227 "supported for integer types");
228 }
229 return Status::OK();
230 }
231
232 template <typename ElementType, typename BiasType, typename OutputType>
RunGpuConvImpl(const GpuConvParams & params,se::ScratchAllocator * scratch_allocator,se::Stream * stream,RunConvOptions options)233 Status RunGpuConvImpl(const GpuConvParams& params,
234 se::ScratchAllocator* scratch_allocator,
235 se::Stream* stream, RunConvOptions options) {
236 auto input_buf = se::DeviceMemory<ElementType>(params.input_buf);
237 auto filter_buf = se::DeviceMemory<ElementType>(params.filter_buf);
238 auto output_buf = se::DeviceMemory<OutputType>(params.output_buf);
239 AlgorithmConfig algorithm = params.config.algorithm;
240
241 if (options.algo_override.has_value()) {
242 algorithm = AlgorithmConfig(*options.algo_override);
243 if (options.scratch_size_override.has_value()) {
244 algorithm.set_scratch_size(*options.scratch_size_override);
245 }
246 }
247
248 Status run_status = RunGpuConvInternalImpl<ElementType, BiasType, OutputType>(
249 params, scratch_allocator, stream, options, input_buf, filter_buf,
250 output_buf, algorithm);
251
252 if (run_status != Status::OK()) {
253 return run_status;
254 }
255
256 if (!stream->ok()) {
257 return InternalError(
258 "Unable to launch convolution with type %s and algorithm (%d, %s)",
259 CudnnConvKindToString(params.config.kind),
260 algorithm.algorithm()->algo_id(),
261 algorithm.algorithm_no_scratch().has_value()
262 ? absl::StrCat(algorithm.algorithm_no_scratch()->algo_id())
263 : "none");
264 }
265 return Status::OK();
266 }
267
GetVectCSize(DataLayout layout)268 int64 GetVectCSize(DataLayout layout) {
269 switch (layout) {
270 case DataLayout::kBatchDepthYX4:
271 return 4;
272 case DataLayout::kBatchDepthYX32:
273 return 32;
274 default:
275 return 1;
276 }
277 }
278
GetVectCSize(FilterLayout layout)279 int64 GetVectCSize(FilterLayout layout) {
280 switch (layout) {
281 case FilterLayout::kOutputInputYX4:
282 return 4;
283 case FilterLayout::kOutputInputYX32:
284 return 32;
285 default:
286 return 1;
287 }
288 }
289
290 } // anonymous namespace
291
GetGpuConvConfig(const GpuConvDescriptor & desc,const absl::string_view inst_as_string)292 StatusOr<GpuConvConfig> GetGpuConvConfig(
293 const GpuConvDescriptor& desc, const absl::string_view inst_as_string) {
294 GpuConvConfig config;
295
296 const Shape& operand0_shape = desc.operand0_shape;
297 const Shape& operand1_shape = desc.operand1_shape;
298 const Shape& result_shape = desc.result_shape;
299 const CudnnConvBackendConfig& backend_config = desc.backend_config;
300
301 config.input_type = operand0_shape.element_type();
302 config.output_type = result_shape.element_type();
303 config.kind = desc.kind;
304
305 // The third field is scratch size stored from conv_algorithm_picker
306 // The operand is added to the shape field of the conv instruction
307 // in GpuConvAlgorithmPicker::RunOnInstruction() call.
308 config.algorithm = se::dnn::AlgorithmConfig(
309 se::dnn::AlgorithmDesc(backend_config.algorithm(),
310 backend_config.tensor_ops_enabled()),
311 desc.scratch_size);
312 config.conv_result_scale = backend_config.conv_result_scale();
313
314 switch (config.kind) {
315 case CudnnConvKind::kForward:
316 case CudnnConvKind::kForwardActivation:
317 config.input_shape = operand0_shape;
318 config.filter_shape = operand1_shape;
319 config.output_shape = result_shape;
320 break;
321 case CudnnConvKind::kBackwardInput:
322 config.input_shape = result_shape;
323 config.filter_shape = operand1_shape;
324 config.output_shape = operand0_shape;
325 break;
326 case CudnnConvKind::kBackwardFilter:
327 config.input_shape = operand0_shape;
328 config.filter_shape = result_shape;
329 config.output_shape = operand1_shape;
330 break;
331 default:
332 return InternalError("Unknown convolution kind");
333 }
334
335 if (config.kind == CudnnConvKind::kForwardActivation) {
336 config.fusion.emplace();
337 GpuConvConfig::FusionConfig& fusion = *config.fusion;
338 if (!se::dnn::ActivationMode_IsValid(backend_config.activation_mode())) {
339 return InternalError("Bad activation mode: %s",
340 backend_config.ShortDebugString());
341 }
342 fusion.mode =
343 static_cast<se::dnn::ActivationMode>(backend_config.activation_mode());
344 fusion.side_input_scale = backend_config.side_input_scale();
345 }
346
347 const Window& window = desc.window;
348 const ConvolutionDimensionNumbers& dnums = desc.dnums;
349
350 VLOG(3) << "Convolution Algorithm: "
351 << config.algorithm.algorithm()->algo_id();
352 VLOG(3) << "tensor_ops_enabled: "
353 << config.algorithm.algorithm()->tensor_ops_enabled();
354 VLOG(3) << "Convolution kind: " << CudnnConvKindToString(config.kind);
355 VLOG(3) << "input shape: "
356 << ShapeUtil::HumanStringWithLayout(config.input_shape);
357 VLOG(3) << "filter shape: "
358 << ShapeUtil::HumanStringWithLayout(config.filter_shape);
359 VLOG(3) << "Output shape: "
360 << ShapeUtil::HumanStringWithLayout(config.output_shape);
361 VLOG(3) << "Window: { " << window.ShortDebugString() << " }";
362 VLOG(3) << "Dim nums: { " << dnums.ShortDebugString() << " }";
363
364 const int num_dimensions = window.dimensions_size();
365 CHECK_LE(num_dimensions, 3) << inst_as_string;
366
367 // cuDNN does not support 1D convolutions. We therefore express 1D
368 // convolutions as 2D convolutions where the first spatial dimension is 1.
369 // This matches the behavior of TF (see definition of conv1d in
370 // tensorflow/python/ops/nn_ops.py).
371 const int effective_num_dimensions = std::max(2, num_dimensions);
372
373 // If one dimension is reversed, we need to have all dimensions reversed (so
374 // we're doing convolution not cross correlation).
375 const bool dims_reversed =
376 window.dimensions_size() > 0 && window.dimensions()[0].window_reversal();
377
378 CHECK_EQ(num_dimensions, dnums.input_spatial_dimensions_size())
379 << inst_as_string;
380 CHECK_EQ(num_dimensions, dnums.kernel_spatial_dimensions_size())
381 << inst_as_string;
382 CHECK_EQ(num_dimensions, dnums.output_spatial_dimensions_size())
383 << inst_as_string;
384 for (const WindowDimension& dim : window.dimensions()) {
385 CHECK_EQ(dims_reversed, dim.window_reversal()) << inst_as_string;
386 CHECK_EQ(dim.padding_low(), dim.padding_high()) << inst_as_string;
387 CHECK_EQ(dim.base_dilation(), 1)
388 << "cudnn does not support base dilation; it "
389 "must be made explicit with a kPad: "
390 << inst_as_string;
391 }
392
393 // cuDNN's convolution APIs support the BDYX layout for activations/output and
394 // the OIYX layout for weights.
395 DataLayout input_dl;
396 FilterLayout filter_dl;
397 DataLayout output_dl;
398
399 const Shape& input_shape = config.input_shape;
400 const Shape& filter_shape = config.filter_shape;
401 const Shape& output_shape = config.output_shape;
402
403 TF_ASSIGN_OR_RETURN(std::tie(input_dl, filter_dl, output_dl),
404 XlaConvShapesToStreamExecutorLayouts(
405 dnums, input_shape, filter_shape, output_shape));
406
407 BatchDescriptor& input_descriptor = config.input_descriptor;
408 input_descriptor = BatchDescriptor(effective_num_dimensions);
409 input_descriptor.set_layout(input_dl)
410 .set_feature_map_count(
411 GetVectCSize(input_dl) *
412 input_shape.dimensions(dnums.input_feature_dimension()))
413 .set_count(input_shape.dimensions(dnums.input_batch_dimension()));
414 for (int dim = 0; dim < num_dimensions; ++dim) {
415 // Note that the dimensions are reversed. The same holds below.
416 input_descriptor.set_spatial_dim(
417 static_cast<DimIndex>(effective_num_dimensions - dim - 1),
418 input_shape.dimensions(dnums.input_spatial_dimensions(dim)));
419 }
420
421 FilterDescriptor& filter_descriptor = config.filter_descriptor;
422 filter_descriptor = FilterDescriptor(effective_num_dimensions);
423 filter_descriptor.set_layout(filter_dl)
424 .set_input_feature_map_count(
425 GetVectCSize(filter_dl) *
426 filter_shape.dimensions(dnums.kernel_input_feature_dimension()))
427 .set_output_feature_map_count(
428 filter_shape.dimensions(dnums.kernel_output_feature_dimension()));
429 for (int dim = 0; dim < num_dimensions; ++dim) {
430 filter_descriptor.set_spatial_dim(
431 static_cast<DimIndex>(effective_num_dimensions - dim - 1),
432 filter_shape.dimensions(dnums.kernel_spatial_dimensions(dim)));
433 }
434
435 config.conv_desc = ConvolutionDescriptor(effective_num_dimensions);
436 config.conv_desc.set_group_count(desc.feature_group_count);
437 config.conv_desc.set_convolution_not_crosscorr(dims_reversed);
438 for (int dim = 0; dim < num_dimensions; ++dim) {
439 config.conv_desc
440 .set_zero_padding(
441 static_cast<DimIndex>(effective_num_dimensions - dim - 1),
442 window.dimensions(dim).padding_low())
443 .set_filter_stride(
444 static_cast<DimIndex>(effective_num_dimensions - dim - 1),
445 window.dimensions(dim).stride())
446 .set_dilation_rate(
447 static_cast<DimIndex>(effective_num_dimensions - dim - 1),
448 window.dimensions(dim).window_dilation());
449 }
450
451 BatchDescriptor& output_descriptor = config.output_descriptor;
452 output_descriptor = BatchDescriptor(effective_num_dimensions);
453 output_descriptor.set_layout(output_dl)
454 .set_feature_map_count(
455 GetVectCSize(output_dl) *
456 output_shape.dimensions(dnums.output_feature_dimension()))
457 .set_count(output_shape.dimensions(dnums.output_batch_dimension()));
458 for (int dim = 0; dim < num_dimensions; ++dim) {
459 output_descriptor.set_spatial_dim(
460 static_cast<DimIndex>(effective_num_dimensions - dim - 1),
461 output_shape.dimensions(dnums.output_spatial_dimensions(dim)));
462 }
463
464 // Add a singleton dimension in the 1D convolution case.
465 for (int dim = 0; dim < effective_num_dimensions - num_dimensions; dim++) {
466 input_descriptor.set_spatial_dim(static_cast<DimIndex>(dim), 1);
467 output_descriptor.set_spatial_dim(static_cast<DimIndex>(dim), 1);
468 filter_descriptor.set_spatial_dim(static_cast<DimIndex>(dim), 1);
469 config.conv_desc.set_zero_padding(static_cast<DimIndex>(dim), 0)
470 .set_filter_stride(static_cast<DimIndex>(dim), 1);
471 }
472
473 return config;
474 }
475
GetGpuConvConfig(const HloCustomCallInstruction * cudnn_call)476 StatusOr<GpuConvConfig> GetGpuConvConfig(
477 const HloCustomCallInstruction* cudnn_call) {
478 GpuConvDescriptor descriptor;
479
480 TF_ASSIGN_OR_RETURN(descriptor.kind, GetCudnnConvKind(cudnn_call));
481 TF_ASSIGN_OR_RETURN(descriptor.backend_config,
482 cudnn_call->backend_config<CudnnConvBackendConfig>());
483 descriptor.operand0_shape = cudnn_call->operand(0)->shape();
484 descriptor.operand1_shape = cudnn_call->operand(1)->shape();
485 descriptor.result_shape = cudnn_call->shape().tuple_shapes(0);
486 descriptor.scratch_size = cudnn_call->shape().tuple_shapes(1).dimensions(0);
487 descriptor.window = cudnn_call->window();
488 descriptor.dnums = cudnn_call->convolution_dimension_numbers();
489 descriptor.feature_group_count = cudnn_call->feature_group_count();
490 return GetGpuConvConfig(descriptor, cudnn_call->ToString());
491 }
492
GetGpuConvParams(const GpuConvConfig & config,absl::Span<se::DeviceMemoryBase> operand_buffers,se::DeviceMemoryBase result_buffer)493 StatusOr<GpuConvParams> GetGpuConvParams(
494 const GpuConvConfig& config,
495 absl::Span<se::DeviceMemoryBase> operand_buffers,
496 se::DeviceMemoryBase result_buffer) {
497 GpuConvParams params;
498 params.config = config;
499
500 switch (config.kind) {
501 case CudnnConvKind::kForward:
502 case CudnnConvKind::kForwardActivation:
503 params.input_buf = operand_buffers[0];
504 params.filter_buf = operand_buffers[1];
505 params.output_buf = result_buffer;
506 break;
507 case CudnnConvKind::kBackwardInput:
508 params.input_buf = result_buffer;
509 params.filter_buf = operand_buffers[1];
510 params.output_buf = operand_buffers[0];
511 break;
512 case CudnnConvKind::kBackwardFilter:
513 params.input_buf = operand_buffers[0];
514 params.filter_buf = result_buffer;
515 params.output_buf = operand_buffers[1];
516 break;
517 }
518
519 if (config.kind == CudnnConvKind::kForwardActivation) {
520 params.fusion.emplace();
521 GpuConvParams::FusionParams& fusion = *params.fusion;
522 fusion.bias_buf = operand_buffers[2];
523 if (operand_buffers.size() >= 4) {
524 fusion.side_input_buf = operand_buffers[3];
525 }
526 }
527
528 return params;
529 }
530
RunGpuConv(const gpu::GpuConvConfig & config,absl::Span<se::DeviceMemoryBase> operand_buffers,se::DeviceMemoryBase result_buffer,se::DeviceMemoryBase scratch_buf,se::Stream * stream,RunConvOptions options)531 Status RunGpuConv(const gpu::GpuConvConfig& config,
532 absl::Span<se::DeviceMemoryBase> operand_buffers,
533 se::DeviceMemoryBase result_buffer,
534 se::DeviceMemoryBase scratch_buf, se::Stream* stream,
535 RunConvOptions options) {
536 ScratchBufAllocator scratch_allocator(scratch_buf);
537 return RunGpuConv(config, operand_buffers, result_buffer, &scratch_allocator,
538 stream, options);
539 }
540
RunGpuConv(const gpu::GpuConvConfig & config,absl::Span<se::DeviceMemoryBase> operand_buffers,se::DeviceMemoryBase result_buffer,se::ScratchAllocator * scratch_allocator,se::Stream * stream,RunConvOptions options)541 Status RunGpuConv(const gpu::GpuConvConfig& config,
542 absl::Span<se::DeviceMemoryBase> operand_buffers,
543 se::DeviceMemoryBase result_buffer,
544 se::ScratchAllocator* scratch_allocator, se::Stream* stream,
545 RunConvOptions options) {
546 TF_ASSIGN_OR_RETURN(GpuConvParams params,
547 GetGpuConvParams(config, operand_buffers, result_buffer));
548
549 PrimitiveType input_primitive_type = config.input_type;
550 switch (input_primitive_type) {
551 case F16:
552 return RunGpuConvImpl<Eigen::half, Eigen::half, Eigen::half>(
553 params, scratch_allocator, stream, options);
554 case BF16:
555 return RunGpuConvImpl<Eigen::bfloat16, Eigen::bfloat16, Eigen::bfloat16>(
556 params, scratch_allocator, stream, options);
557 case F32:
558 return RunGpuConvImpl<float, float, float>(params, scratch_allocator,
559 stream, options);
560 case F64:
561 return RunGpuConvImpl<double, double, double>(params, scratch_allocator,
562 stream, options);
563 case S8: {
564 PrimitiveType output_primitive_type = config.output_type;
565 switch (output_primitive_type) {
566 case F32:
567 return RunGpuConvImpl<int8, float, float>(params, scratch_allocator,
568 stream, options);
569 case S8:
570 return RunGpuConvImpl<int8, float, int8>(params, scratch_allocator,
571 stream, options);
572 default:
573 return Unimplemented("Unimplemented convolution");
574 }
575 }
576 default:
577 return Unimplemented("Unimplemented convolution");
578 }
579 }
580
581 } // namespace gpu
582 } // namespace xla
583