• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1syntax = "proto3";
2
3package xla.gpu;
4
5import "tensorflow/compiler/xla/xla_data.proto";
6
7// Backend configs for XLA:GPU.
8//
9// These are metadata that the GPU backend attaches to HloInstructions and later
10// uses during e.g. codegen.
11//
12// Remember that proto3 doesn't give clients a way to tell the difference
13// between a field not being present and a field having the default value.
14// Choose your defaults carefully.
15//
16// No guarantee is made about the stability of these protos.
17//
18// See HloInstruction::backend_config() for more info.
19
20// Backend config for a convolution that runs through cudnn.
21message CudnnConvBackendConfig {
22  // Opaque algorithm number of cudnn algorithm chosen for this conv.
23  int64 algorithm = 1;
24
25  // Whether we may use tensor cores when running this conv.  Even if this is
26  // true, cudnn may choose not to use tensor cores, e.g. because the GPU or
27  // selected algorithm doesn't support it.
28  bool tensor_ops_enabled = 2;
29
30  // The scaling factor multiplied with the convolution result.
31  double conv_result_scale = 4;
32
33  // Below are the fields related to cuDNN's fused convolution. Refer to
34  // GpuConvParams for their meanings.
35
36  // The requested activation (e.g. relu) after the convolution. It is with type
37  // stream_executor::dnn::ActivationMode.
38  int64 activation_mode = 3;
39
40  // The scaling factor multiplied with the side input. If no side input buffer
41  // is provided, this field must be 0.
42  double side_input_scale = 5;
43}
44
45// Backend config for the GEMM operation running through cuBLAS.
46message GemmBackendConfig {
47  // Opaque optional algorithm number. No chosen number indicates that a
48  // different cuBLAS API will be used, which does not allow for choosing an
49  // algorithm.
50  oneof algorithm {
51    int64 selected_algorithm = 1;
52  }
53
54  double alpha_real = 2;
55  double alpha_imag = 9;
56
57  double beta = 3;
58
59  xla.DotDimensionNumbers dot_dimension_numbers = 7;
60
61  int64 batch_size = 8;
62}
63