1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/gpu_conv_runner.h"
17
18 #include "absl/strings/str_cat.h"
19 #include "absl/strings/string_view.h"
20 #include "tensorflow/compiler/xla/layout_util.h"
21 #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
22 #include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h"
23 #include "tensorflow/compiler/xla/shape_util.h"
24 #include "tensorflow/compiler/xla/status_macros.h"
25 #include "tensorflow/compiler/xla/util.h"
26
27 namespace xla {
28 namespace gpu {
29 namespace {
30
31 using se::DeviceMemory;
32 using se::DeviceMemoryBase;
33 using se::Stream;
34 using se::dnn::AlgorithmConfig;
35 using se::dnn::BatchDescriptor;
36 using se::dnn::ConvolutionDescriptor;
37 using se::dnn::DataLayout;
38 using se::dnn::DimIndex;
39 using se::dnn::FilterDescriptor;
40 using se::dnn::FilterLayout;
41 using se::dnn::ProfileResult;
42
43 // A StreamExecutor ScratchAllocator that wraps a single XLA allocation,
44 // returning it (in its entirety) the first time Allocate() is called.
45 class ScratchBufAllocator : public se::ScratchAllocator {
46 public:
ScratchBufAllocator(se::DeviceMemoryBase scratch)47 explicit ScratchBufAllocator(se::DeviceMemoryBase scratch)
48 : scratch_(scratch) {}
49
50 ~ScratchBufAllocator() override = default;
51
GetMemoryLimitInBytes()52 int64 GetMemoryLimitInBytes() override { return scratch_.size(); }
53
AllocateBytes(int64 byte_size)54 se::port::StatusOr<DeviceMemory<uint8>> AllocateBytes(
55 int64 byte_size) override {
56 if (allocated_) {
57 return se::port::InternalError(
58 "Can't allocate twice from a ScratchBufAllocator.");
59 }
60 if (byte_size > scratch_.size()) {
61 return se::port::InternalError(absl::StrCat(
62 "Can't allocate ", byte_size,
63 " bytes from a ScratchBufAllocator of size ", scratch_.size()));
64 }
65
66 allocated_ = true;
67 return se::DeviceMemory<uint8>(scratch_);
68 }
69
70 private:
71 se::DeviceMemoryBase scratch_;
72 bool allocated_ = false;
73 };
74
75 template <typename ElementType, typename OutputType>
RunGpuConvForward(GpuConvParams params,se::ScratchAllocator * scratch_allocator,se::Stream * stream,RunConvOptions options,DeviceMemory<ElementType> input_buf,DeviceMemory<ElementType> filter_buf,DeviceMemory<OutputType> output_buf,AlgorithmConfig algorithm)76 Status RunGpuConvForward(GpuConvParams params,
77 se::ScratchAllocator* scratch_allocator,
78 se::Stream* stream, RunConvOptions options,
79 DeviceMemory<ElementType> input_buf,
80 DeviceMemory<ElementType> filter_buf,
81 DeviceMemory<OutputType> output_buf,
82 AlgorithmConfig algorithm) {
83 if (params.config.conv_result_scale != 1) {
84 return InternalError(
85 "StreamExecutor doesn't support scaled convolution: %lf.",
86 params.config.conv_result_scale);
87 }
88 return stream->ConvolveWithAlgorithm(
89 params.config.input_descriptor, input_buf,
90 params.config.filter_descriptor, filter_buf, params.config.conv_desc,
91 params.config.output_descriptor, &output_buf, scratch_allocator,
92 algorithm, options.profile_result);
93 }
94
95 template <typename ElementType, typename BiasType, typename OutputType>
RunGpuConvForwardActivation(GpuConvParams params,se::ScratchAllocator * scratch_allocator,se::Stream * stream,RunConvOptions options,DeviceMemory<ElementType> input_buf,DeviceMemory<ElementType> filter_buf,DeviceMemory<OutputType> output_buf,AlgorithmConfig algorithm)96 Status RunGpuConvForwardActivation(GpuConvParams params,
97 se::ScratchAllocator* scratch_allocator,
98 se::Stream* stream, RunConvOptions options,
99 DeviceMemory<ElementType> input_buf,
100 DeviceMemory<ElementType> filter_buf,
101 DeviceMemory<OutputType> output_buf,
102 AlgorithmConfig algorithm) {
103 BatchDescriptor bias_desc;
104 bias_desc.set_count(1)
105 .set_height(1)
106 .set_width(1)
107 .set_feature_map_count(
108 params.config.output_descriptor.feature_map_count())
109 .set_layout(params.config.output_descriptor.layout());
110
111 se::DeviceMemory<OutputType> side_input(params.fusion->side_input_buf);
112 // If there is no side input, use output as the side input.
113 if (side_input.is_null()) {
114 if (params.config.fusion->side_input_scale != 0) {
115 return InternalError(
116 "Side input scale is not 0, yet no side input buffer is "
117 "provided");
118 }
119 // Since side-input scale is 0, the values in the side input don't
120 // matter. The simplest thing to do would be to pass in a null buffer
121 // for the side input, but cudnn doesn't allow this. cudnn does promise
122 // that if side-input-scale is 0 the side input won't be read, so we
123 // just pass in the output buffer, since it's handy and has the correct
124 // size.
125 side_input = output_buf;
126 }
127
128 return stream->FusedConvolveWithAlgorithm(
129 params.config.input_descriptor, input_buf,
130 params.config.conv_result_scale, params.config.filter_descriptor,
131 filter_buf, params.config.conv_desc, side_input,
132 params.config.fusion->side_input_scale, bias_desc,
133 DeviceMemory<BiasType>(params.fusion->bias_buf),
134 params.config.fusion->mode, params.config.output_descriptor, &output_buf,
135 scratch_allocator, algorithm, options.profile_result);
136 }
137
138 // StreamExecutor supports various data types via overloading, and the support
139 // is maintained on-demand. To avoid calling into non-exist overloads, we have
140 // to carefully not call into them by using enable_if.
141 // TODO(timshen): Ideally, to avoid such complication in the runner, we can turn
142 // StreamExecutor overloadings to template functions, and for unsupported data
143 // types return runtime errors.
144 // This is the specialization for double, float, and half types. All kinds of
145 // convolutions are supported here.
146 template <typename ElementType, typename BiasType, typename OutputType,
147 typename std::enable_if<
148 !std::is_integral<ElementType>::value>::type* = nullptr>
RunGpuConvInternalImpl(GpuConvParams params,se::ScratchAllocator * scratch_allocator,se::Stream * stream,RunConvOptions options,DeviceMemory<ElementType> input_buf,DeviceMemory<ElementType> filter_buf,DeviceMemory<OutputType> output_buf,AlgorithmConfig algorithm)149 Status RunGpuConvInternalImpl(GpuConvParams params,
150 se::ScratchAllocator* scratch_allocator,
151 se::Stream* stream, RunConvOptions options,
152 DeviceMemory<ElementType> input_buf,
153 DeviceMemory<ElementType> filter_buf,
154 DeviceMemory<OutputType> output_buf,
155 AlgorithmConfig algorithm) {
156 switch (params.config.kind) {
157 case CudnnConvKind::kForward:
158 return RunGpuConvForward(params, scratch_allocator, stream, options,
159 input_buf, filter_buf, output_buf, algorithm);
160 case CudnnConvKind::kBackwardInput:
161 if (params.config.conv_result_scale != 1) {
162 return InternalError(
163 "StreamExecutor doesn't support scaled convolution: %lf.",
164 params.config.conv_result_scale);
165 }
166 return stream->ConvolveBackwardDataWithAlgorithm(
167 params.config.filter_descriptor, filter_buf,
168 params.config.output_descriptor, output_buf, params.config.conv_desc,
169 params.config.input_descriptor, &input_buf, scratch_allocator,
170 algorithm, options.profile_result);
171 break;
172 case CudnnConvKind::kBackwardFilter:
173 if (params.config.conv_result_scale != 1) {
174 return InternalError(
175 "StreamExecutor doesn't support scaled convolution: %lf.",
176 params.config.conv_result_scale);
177 }
178 return stream->ConvolveBackwardFilterWithAlgorithm(
179 params.config.input_descriptor, input_buf,
180 params.config.output_descriptor, output_buf, params.config.conv_desc,
181 params.config.filter_descriptor, &filter_buf, scratch_allocator,
182 algorithm, options.profile_result);
183 break;
184 case CudnnConvKind::kForwardActivation: {
185 return RunGpuConvForwardActivation<ElementType, BiasType, OutputType>(
186 params, scratch_allocator, stream, options, input_buf, filter_buf,
187 output_buf, algorithm);
188 }
189 }
190 return Status::OK();
191 }
192
193 // Specialization for integer types. Only two forward convolutions are allowed.
194 template <typename ElementType, typename BiasType, typename OutputType,
195 typename std::enable_if<std::is_integral<ElementType>::value>::type* =
196 nullptr>
RunGpuConvInternalImpl(GpuConvParams params,se::ScratchAllocator * scratch_allocator,se::Stream * stream,RunConvOptions options,DeviceMemory<ElementType> input_buf,DeviceMemory<ElementType> filter_buf,DeviceMemory<OutputType> output_buf,AlgorithmConfig algorithm)197 Status RunGpuConvInternalImpl(GpuConvParams params,
198 se::ScratchAllocator* scratch_allocator,
199 se::Stream* stream, RunConvOptions options,
200 DeviceMemory<ElementType> input_buf,
201 DeviceMemory<ElementType> filter_buf,
202 DeviceMemory<OutputType> output_buf,
203 AlgorithmConfig algorithm) {
204 switch (params.config.kind) {
205 case CudnnConvKind::kForward:
206 return RunGpuConvForward(params, scratch_allocator, stream, options,
207 input_buf, filter_buf, output_buf, algorithm);
208 case CudnnConvKind::kForwardActivation:
209 return RunGpuConvForwardActivation<ElementType, BiasType, OutputType>(
210 params, scratch_allocator, stream, options, input_buf, filter_buf,
211 output_buf, algorithm);
212 default:
213 return InternalError(
214 "Only convolution kinds kForward and kForwardActivation are "
215 "supported for integer types");
216 }
217 return Status::OK();
218 }
219
220 template <typename ElementType, typename BiasType, typename OutputType>
RunGpuConvImpl(const GpuConvParams & params,se::ScratchAllocator * scratch_allocator,se::Stream * stream,RunConvOptions options)221 Status RunGpuConvImpl(const GpuConvParams& params,
222 se::ScratchAllocator* scratch_allocator,
223 se::Stream* stream, RunConvOptions options) {
224 auto input_buf = se::DeviceMemory<ElementType>(params.input_buf);
225 auto filter_buf = se::DeviceMemory<ElementType>(params.filter_buf);
226 auto output_buf = se::DeviceMemory<OutputType>(params.output_buf);
227 AlgorithmConfig algorithm = params.config.algorithm;
228
229 if (options.algo_override.has_value()) {
230 algorithm = AlgorithmConfig(*options.algo_override);
231 if (options.scratch_size_override.has_value()) {
232 algorithm.set_scratch_size(*options.scratch_size_override);
233 }
234 }
235
236 Status run_status = RunGpuConvInternalImpl<ElementType, BiasType, OutputType>(
237 params, scratch_allocator, stream, options, input_buf, filter_buf,
238 output_buf, algorithm);
239
240 if (run_status != Status::OK()) {
241 return run_status;
242 }
243
244 if (!stream->ok()) {
245 return InternalError(
246 "Unable to launch convolution with type %s and algorithm (%d, %s)",
247 CudnnConvKindToString(params.config.kind),
248 algorithm.algorithm()->algo_id(),
249 algorithm.algorithm_no_scratch().has_value()
250 ? absl::StrCat(algorithm.algorithm_no_scratch()->algo_id())
251 : "none");
252 }
253 return Status::OK();
254 }
255
256 } // anonymous namespace
257
GetGpuConvConfig(const GpuConvDescriptor & desc,const absl::string_view inst_as_string)258 StatusOr<GpuConvConfig> GetGpuConvConfig(
259 const GpuConvDescriptor& desc, const absl::string_view inst_as_string) {
260 GpuConvConfig config;
261
262 const Shape& operand0_shape = desc.operand0_shape;
263 const Shape& operand1_shape = desc.operand1_shape;
264 const Shape& result_shape = desc.result_shape;
265 const CudnnConvBackendConfig& backend_config = desc.backend_config;
266
267 config.input_type = operand0_shape.element_type();
268 config.output_type = result_shape.element_type();
269 config.kind = desc.kind;
270
271 // The third field is scratch size stored from conv_algorithm_picker
272 // The operand is added to the shape field of the conv instruction
273 // in GpuConvAlgorithmPicker::RunOnInstruction() call.
274 config.algorithm = se::dnn::AlgorithmConfig(
275 se::dnn::AlgorithmDesc(backend_config.algorithm(),
276 backend_config.tensor_ops_enabled()),
277 desc.scratch_size);
278 config.conv_result_scale = backend_config.conv_result_scale();
279
280 switch (config.kind) {
281 case CudnnConvKind::kForward:
282 case CudnnConvKind::kForwardActivation:
283 config.input_shape = operand0_shape;
284 config.filter_shape = operand1_shape;
285 config.output_shape = result_shape;
286 break;
287 case CudnnConvKind::kBackwardInput:
288 config.input_shape = result_shape;
289 config.filter_shape = operand1_shape;
290 config.output_shape = operand0_shape;
291 break;
292 case CudnnConvKind::kBackwardFilter:
293 config.input_shape = operand0_shape;
294 config.filter_shape = result_shape;
295 config.output_shape = operand1_shape;
296 break;
297 default:
298 return InternalError("Unknown convolution kind");
299 }
300
301 if (config.kind == CudnnConvKind::kForwardActivation) {
302 config.fusion.emplace();
303 GpuConvConfig::FusionConfig& fusion = *config.fusion;
304 if (!se::dnn::ActivationMode_IsValid(backend_config.activation_mode())) {
305 return InternalError("Bad activation mode: %s",
306 backend_config.ShortDebugString());
307 }
308 fusion.mode =
309 static_cast<se::dnn::ActivationMode>(backend_config.activation_mode());
310 fusion.side_input_scale = backend_config.side_input_scale();
311 }
312
313 const Window& window = desc.window;
314 const ConvolutionDimensionNumbers& dnums = desc.dnums;
315
316 VLOG(3) << "Convolution Algorithm: "
317 << config.algorithm.algorithm()->algo_id();
318 VLOG(3) << "tensor_ops_enabled: "
319 << config.algorithm.algorithm()->tensor_ops_enabled();
320 VLOG(3) << "Convolution kind: " << CudnnConvKindToString(config.kind);
321 VLOG(3) << "input shape: "
322 << ShapeUtil::HumanStringWithLayout(config.input_shape);
323 VLOG(3) << "filter shape: "
324 << ShapeUtil::HumanStringWithLayout(config.filter_shape);
325 VLOG(3) << "Output shape: "
326 << ShapeUtil::HumanStringWithLayout(config.output_shape);
327 VLOG(3) << "Window: { " << window.ShortDebugString() << " }";
328 VLOG(3) << "Dim nums: { " << dnums.ShortDebugString() << " }";
329
330 const int num_dimensions = window.dimensions_size();
331 CHECK_LE(num_dimensions, 3) << inst_as_string;
332
333 // cuDNN does not support 1D convolutions. We therefore express 1D
334 // convolutions as 2D convolutions where the first spatial dimension is 1.
335 // This matches the behavior of TF (see definition of conv1d in
336 // tensorflow/python/ops/nn_ops.py).
337 const int effective_num_dimensions = std::max(2, num_dimensions);
338
339 // If one dimension is reversed, we need to have all dimensions reversed (so
340 // we're doing convolution not cross correlation).
341 const bool dims_reversed =
342 window.dimensions_size() > 0 && window.dimensions()[0].window_reversal();
343
344 CHECK_EQ(num_dimensions, dnums.input_spatial_dimensions_size())
345 << inst_as_string;
346 CHECK_EQ(num_dimensions, dnums.kernel_spatial_dimensions_size())
347 << inst_as_string;
348 CHECK_EQ(num_dimensions, dnums.output_spatial_dimensions_size())
349 << inst_as_string;
350 for (const WindowDimension& dim : window.dimensions()) {
351 CHECK_EQ(dims_reversed, dim.window_reversal()) << inst_as_string;
352 CHECK_EQ(dim.padding_low(), dim.padding_high()) << inst_as_string;
353 CHECK_EQ(dim.base_dilation(), 1)
354 << "cudnn does not support base dilation; it "
355 "must be made explicit with a kPad: "
356 << inst_as_string;
357 }
358
359 // cuDNN's convolution APIs support the BDYX layout for activations/output and
360 // the OIYX layout for weights.
361 DataLayout input_dl;
362 FilterLayout filter_dl;
363 DataLayout output_dl;
364
365 const Shape& input_shape = config.input_shape;
366 const Shape& filter_shape = config.filter_shape;
367 const Shape& output_shape = config.output_shape;
368
369 TF_ASSIGN_OR_RETURN(std::tie(input_dl, filter_dl, output_dl),
370 XlaConvLayoutsToStreamExecutorLayouts(
371 dnums, input_shape.layout(), filter_shape.layout(),
372 output_shape.layout()));
373
374 BatchDescriptor& input_descriptor = config.input_descriptor;
375 input_descriptor = BatchDescriptor(effective_num_dimensions);
376 input_descriptor.set_layout(input_dl)
377 .set_feature_map_count(
378 input_shape.dimensions(dnums.input_feature_dimension()))
379 .set_count(input_shape.dimensions(dnums.input_batch_dimension()));
380 for (int dim = 0; dim < num_dimensions; ++dim) {
381 // Note that the dimensions are reversed. The same holds below.
382 input_descriptor.set_spatial_dim(
383 static_cast<DimIndex>(effective_num_dimensions - dim - 1),
384 input_shape.dimensions(dnums.input_spatial_dimensions(dim)));
385 }
386
387 FilterDescriptor& filter_descriptor = config.filter_descriptor;
388 filter_descriptor = FilterDescriptor(effective_num_dimensions);
389 filter_descriptor.set_layout(filter_dl)
390 .set_input_feature_map_count(
391 filter_shape.dimensions(dnums.kernel_input_feature_dimension()))
392 .set_output_feature_map_count(
393 filter_shape.dimensions(dnums.kernel_output_feature_dimension()));
394 for (int dim = 0; dim < num_dimensions; ++dim) {
395 filter_descriptor.set_spatial_dim(
396 static_cast<DimIndex>(effective_num_dimensions - dim - 1),
397 filter_shape.dimensions(dnums.kernel_spatial_dimensions(dim)));
398 }
399
400 config.conv_desc = ConvolutionDescriptor(effective_num_dimensions);
401 config.conv_desc.set_group_count(desc.feature_group_count);
402 config.conv_desc.set_convolution_not_crosscorr(dims_reversed);
403 for (int dim = 0; dim < num_dimensions; ++dim) {
404 config.conv_desc
405 .set_zero_padding(
406 static_cast<DimIndex>(effective_num_dimensions - dim - 1),
407 window.dimensions(dim).padding_low())
408 .set_filter_stride(
409 static_cast<DimIndex>(effective_num_dimensions - dim - 1),
410 window.dimensions(dim).stride())
411 .set_dilation_rate(
412 static_cast<DimIndex>(effective_num_dimensions - dim - 1),
413 window.dimensions(dim).window_dilation());
414 }
415
416 BatchDescriptor& output_descriptor = config.output_descriptor;
417 output_descriptor = BatchDescriptor(effective_num_dimensions);
418 output_descriptor.set_layout(output_dl)
419 .set_feature_map_count(
420 output_shape.dimensions(dnums.output_feature_dimension()))
421 .set_count(output_shape.dimensions(dnums.output_batch_dimension()));
422 for (int dim = 0; dim < num_dimensions; ++dim) {
423 output_descriptor.set_spatial_dim(
424 static_cast<DimIndex>(effective_num_dimensions - dim - 1),
425 output_shape.dimensions(dnums.output_spatial_dimensions(dim)));
426 }
427
428 // Add a singleton dimension in the 1D convolution case.
429 for (int dim = 0; dim < effective_num_dimensions - num_dimensions; dim++) {
430 input_descriptor.set_spatial_dim(static_cast<DimIndex>(dim), 1);
431 output_descriptor.set_spatial_dim(static_cast<DimIndex>(dim), 1);
432 filter_descriptor.set_spatial_dim(static_cast<DimIndex>(dim), 1);
433 config.conv_desc.set_zero_padding(static_cast<DimIndex>(dim), 0)
434 .set_filter_stride(static_cast<DimIndex>(dim), 1);
435 }
436
437 return config;
438 }
439
GetGpuConvConfig(const HloCustomCallInstruction * cudnn_call)440 StatusOr<GpuConvConfig> GetGpuConvConfig(
441 const HloCustomCallInstruction* cudnn_call) {
442 GpuConvDescriptor descriptor;
443
444 TF_ASSIGN_OR_RETURN(descriptor.kind, GetCudnnConvKind(cudnn_call));
445 TF_ASSIGN_OR_RETURN(descriptor.backend_config,
446 cudnn_call->backend_config<CudnnConvBackendConfig>());
447 descriptor.operand0_shape = cudnn_call->operand(0)->shape();
448 descriptor.operand1_shape = cudnn_call->operand(1)->shape();
449 descriptor.result_shape = cudnn_call->shape().tuple_shapes(0);
450 descriptor.scratch_size = cudnn_call->shape().tuple_shapes(1).dimensions(0);
451 descriptor.window = cudnn_call->window();
452 descriptor.dnums = cudnn_call->convolution_dimension_numbers();
453 descriptor.feature_group_count = cudnn_call->feature_group_count();
454 return GetGpuConvConfig(descriptor, cudnn_call->ToString());
455 }
456
GetGpuConvParams(const GpuConvConfig & config,absl::Span<se::DeviceMemoryBase> operand_buffers,se::DeviceMemoryBase result_buffer)457 StatusOr<GpuConvParams> GetGpuConvParams(
458 const GpuConvConfig& config,
459 absl::Span<se::DeviceMemoryBase> operand_buffers,
460 se::DeviceMemoryBase result_buffer) {
461 GpuConvParams params;
462 params.config = config;
463
464 switch (config.kind) {
465 case CudnnConvKind::kForward:
466 case CudnnConvKind::kForwardActivation:
467 params.input_buf = operand_buffers[0];
468 params.filter_buf = operand_buffers[1];
469 params.output_buf = result_buffer;
470 break;
471 case CudnnConvKind::kBackwardInput:
472 params.input_buf = result_buffer;
473 params.filter_buf = operand_buffers[1];
474 params.output_buf = operand_buffers[0];
475 break;
476 case CudnnConvKind::kBackwardFilter:
477 params.input_buf = operand_buffers[0];
478 params.filter_buf = result_buffer;
479 params.output_buf = operand_buffers[1];
480 break;
481 }
482
483 if (config.kind == CudnnConvKind::kForwardActivation) {
484 params.fusion.emplace();
485 GpuConvParams::FusionParams& fusion = *params.fusion;
486 fusion.bias_buf = operand_buffers[2];
487 if (operand_buffers.size() >= 4) {
488 fusion.side_input_buf = operand_buffers[3];
489 }
490 }
491
492 return params;
493 }
494
RunGpuConv(const gpu::GpuConvConfig & config,absl::Span<se::DeviceMemoryBase> operand_buffers,se::DeviceMemoryBase result_buffer,se::DeviceMemoryBase scratch_buf,se::Stream * stream,RunConvOptions options)495 Status RunGpuConv(const gpu::GpuConvConfig& config,
496 absl::Span<se::DeviceMemoryBase> operand_buffers,
497 se::DeviceMemoryBase result_buffer,
498 se::DeviceMemoryBase scratch_buf, se::Stream* stream,
499 RunConvOptions options) {
500 ScratchBufAllocator scratch_allocator(scratch_buf);
501 return RunGpuConv(config, operand_buffers, result_buffer, &scratch_allocator,
502 stream, options);
503 }
504
RunGpuConv(const gpu::GpuConvConfig & config,absl::Span<se::DeviceMemoryBase> operand_buffers,se::DeviceMemoryBase result_buffer,se::ScratchAllocator * scratch_allocator,se::Stream * stream,RunConvOptions options)505 Status RunGpuConv(const gpu::GpuConvConfig& config,
506 absl::Span<se::DeviceMemoryBase> operand_buffers,
507 se::DeviceMemoryBase result_buffer,
508 se::ScratchAllocator* scratch_allocator, se::Stream* stream,
509 RunConvOptions options) {
510 TF_ASSIGN_OR_RETURN(GpuConvParams params,
511 GetGpuConvParams(config, operand_buffers, result_buffer));
512
513 PrimitiveType input_primitive_type = config.input_type;
514 switch (input_primitive_type) {
515 case F16:
516 return RunGpuConvImpl<Eigen::half, Eigen::half, Eigen::half>(
517 params, scratch_allocator, stream, options);
518 case F32:
519 return RunGpuConvImpl<float, float, float>(params, scratch_allocator,
520 stream, options);
521 case F64:
522 return RunGpuConvImpl<double, double, double>(params, scratch_allocator,
523 stream, options);
524 case S8: {
525 PrimitiveType output_primitive_type = config.output_type;
526 switch (output_primitive_type) {
527 case F32:
528 return RunGpuConvImpl<int8, float, float>(params, scratch_allocator,
529 stream, options);
530 case S8:
531 return RunGpuConvImpl<int8, float, int8>(params, scratch_allocator,
532 stream, options);
533 default:
534 return Unimplemented("Unimplemented convolution");
535 }
536 }
537 default:
538 return Unimplemented("Unimplemented convolution");
539 }
540 }
541
542 } // namespace gpu
543 } // namespace xla
544