• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1// LINT: LEGACY_NAMES
2syntax = "proto3";
3
4package stream_executor.dnn;
5
6option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/stream_executor";
7
8// Specifies the data type used by an operation.
9enum DataType {
10  kFloat = 0;
11  kDouble = 1;
12  kHalf = 2;
13  kInt8 = 3;
14  kInt32 = 4;
15  kComplexFloat = 5;
16  kComplexDouble = 6;
17  kBF16 = 7;
18}
19
20// Describes how a convolution input or output layer's data is formatted.
21enum DataLayout {
22  // Naming convention:
23  // Y <-> row or height
24  // X <-> column or width
25  // Batch <-> batch, or N
26  // Depth <-> feature, or channel
27  // TODO(timshen): turn them into cuDNN names, e.g. kNCHW.
28  //
29  // Note: In cudnn, kBatchDepthYX4 and kBatchDepthYX32 are the same layout
30  // (namely, NCHW_VECT_C).  It differentiates between these two by using a
31  // different data type (int8x4 vs int8x32).  In StreamExecutor we use
32  // different layouts for these, because we don't usually pass an explicit data
33  // type to StreamExecutor functions.
34  kYXDepthBatch = 0;
35  kYXBatchDepth = 1;
36  kBatchYXDepth = 2;    // cuDNN's NHWC layout
37  kBatchDepthYX = 3;    // cuDNN's NCHW layout
38  kBatchDepthYX4 = 4;   // cuDNN's NCHW_VECT_C with 4-elem vectors (e.g. int8x4)
39  kBatchDepthYX32 = 5;  // cuDNN's NCHW_VECT_C with 32-elem vects (e.g. int8x32)
40}
41
42// Describes how a convolution filter is laid out in the memory.
43enum FilterLayout {
44  // Naming convention:
45  // Y <-> row or height
46  // X <-> column or width
47  // Output <-> output feature, or N
48  // Input <-> input feature, or N
49  // TODO(timshen): turn them into cuDNN names, e.g. kNCHW.
50  kOutputInputYX = 0;    // cuDNN's NCHW layout
51  kOutputYXInput = 1;    // cuDNN's NHWC layout
52  kOutputInputYX4 = 2;   // cuDNN's NCHW_VECT_C layout with 4-elem vectors
53  kOutputInputYX32 = 5;  // cuDNN's NCHW_VECT_C layout with 32-elem vectors
54  kInputYXOutput = 3;
55  kYXInputOutput = 4;
56}
57
58// Describes a kind of non-linearity (threshold-like mathematical function).
59enum ActivationMode {
60  kNone = 0;
61  kSigmoid = 1;
62  // Rectified linear activation: f(x) = x < 0 ? 0 : x
63  kRelu = 2;
64  // Rectified linear activation; where upper maximum is 6.0.
65  kRelu6 = 3;
66  // Rectified linear activation; where upper maximum specified by
67  // BatchDescriptor::value_max().
68  kReluX = 4;
69  kTanh = 5;
70  // Like ReluX; but passes all values in the range [-X,X].
71  kBandPass = 6;
72}
73
74// Describe the math definition for the conv op. The popular behavior is
75// actually called cross-correlation in math, despite the operation is often
76// referred as convolution. See cuDNN cudnnConvolutionMode_t.
77enum ConvolutionMode {
78  CROSS_CORRELATION = 0;
79  CONVOLUTION = 1;
80}
81
82enum ConvolutionKind {
83  INVALID = 0;
84  FORWARD = 1;
85  BACKWARD_FILTER = 2;
86  BACKWARD_DATA = 3;
87  FORWARD_BIAS_ACTIVATION = 4;
88}
89
90// Generic tensor representation.
91message TensorDescriptorProto {
92  repeated int64 dimensions = 1;
93  DataType data_type = 2;
94  oneof layout_oneof {
95    DataLayout data_layout = 3;
96    FilterLayout filter_layout = 4;
97  }
98}
99
100// Generic algorithm representation.
101message AlgorithmProto {
102  enum MathType {
103    DEFAULT_MATH = 0;
104    // The GPU may operate 4x4 matrix FMA.
105    // See cuDNN's documentation for CUDNN_TENSOR_OP_MATH.
106    TENSOR_OP_MATH = 1;
107  }
108  int64 algo_id = 1;
109  MathType math_type = 2;
110  // cuDNN v8 uses a string to uniquely represent the backend plan.
111  string exec_plan_id = 3;
112}
113
114// Proto definition of AlgorithmConfig in "dnn.h".
115// TODO(ruochengw): After cl/380702564 is submitted, add support for algorithm
116// configs with cuDNN Frontend APIs.
117message AlgorithmConfigProto {
118  // Use oneof to emulate optional semantics in proto2 since older
119  // version of proto3 cannot distinguish "unset field" and "default field".
120  oneof optional_algorithm {
121    AlgorithmProto algorithm = 1;
122  }
123  oneof optional_algorithm_no_scratch {
124    AlgorithmProto algorithm_no_scratch = 2;
125  }
126  oneof optional_scratch_size {
127    int64 scratch_size = 3;
128  }
129}
130
131// Convolution-specific parameters.
132message ConvolutionDescriptorProto {
133  repeated int64 paddings = 1;
134  repeated int64 strides = 2;
135  repeated int64 dilations = 3;
136  // The "accumulator" type. For example, use F32 as an accumulator for F16
137  // convolutions.
138  // See cuDNN's cudnnConvolutionMode_t.
139  DataType compute_mode = 4;
140  // See cuDNN's group count.
141  int32 group_count = 5;
142  ConvolutionMode convolution_mode = 6;
143  // Tensorflow node name, same as in NodeDef, for debugging purposes.
144  string name = 7;
145}
146