• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1// LINT: LEGACY_NAMES
2syntax = "proto3";
3
4package stream_executor.dnn;
5
6// Specifies the data type used by an operation.
7enum DataType {
8  kFloat = 0;
9  kDouble = 1;
10  kHalf = 2;
11  kInt8 = 3;
12  kInt32 = 4;
13}
14
15// Describes how a convolution input or output layer's data is formatted.
16enum DataLayout {
17  // Naming convention:
18  // Y <-> row or height
19  // X <-> column or width
20  // Batch <-> batch, or N
21  // Depth <-> feature, or channel
22  // TODO(timshen): turn them into cuDNN names, e.g. kNCHW.
23  kYXDepthBatch = 0;
24  kYXBatchDepth = 1;
25  kBatchYXDepth = 2;   // cuDNN's NHWC layout
26  kBatchDepthYX = 3;   // cuDNN's NCHW layout
27  kBatchDepthYX4 = 4;  // cuDNN's NCHW_VECT_C layout
28}
29
30// Describes how a convolution filter is laid out in the memory.
31enum FilterLayout {
32  // Naming convention:
33  // Y <-> row or height
34  // X <-> column or width
35  // Output <-> output feature, or N
36  // Input <-> input feature, or N
37  // TODO(timshen): turn them into cuDNN names, e.g. kNCHW.
38  kOutputInputYX = 0;   // cuDNN's NCHW layout
39  kOutputYXInput = 1;   // cuDNN's NHWC layout
40  kOutputInputYX4 = 2;  // cuDNN's NCHW_VECT_C layout
41  kInputYXOutput = 3;
42  kYXInputOutput = 4;
43}
44
45// Describes a kind of non-linearity (threshold-like mathematical function).
46enum ActivationMode {
47  kNone = 0;
48  kSigmoid = 1;
49  // Rectified linear activation: f(x) = x < 0 ? 0 : x
50  kRelu = 2;
51  // Rectified linear activation; where upper maximum is 6.0.
52  kRelu6 = 3;
53  // Rectified linear activation; where upper maximum specified by
54  // BatchDescriptor::value_max().
55  kReluX = 4;
56  kTanh = 5;
57  // Like ReluX; but passes all values in the range [-X,X].
58  kBandPass = 6;
59}
60
61// Describe the math definition for the conv op. The popular behavior is
62// actually called cross-correlation in math, despite the operation is often
63// referred as convolution. See cuDNN cudnnConvolutionMode_t.
64enum ConvolutionMode {
65  CROSS_CORRELATION = 0;
66  CONVOLUTION = 1;
67}
68
69enum ConvolutionKind {
70  INVALID = 0;
71  FORWARD = 1;
72  BACKWARD_FILTER = 2;
73  BACKWARD_DATA = 3;
74}
75
76// Generic tensor representation.
77message TensorDescriptorProto {
78  repeated int64 dimensions = 1;
79  DataType data_type = 2;
80  oneof layout_oneof {
81    DataLayout data_layout = 3;
82    FilterLayout filter_layout = 4;
83  }
84}
85
86// Generic algorithm representation.
87message AlgorithmProto {
88  enum MathType {
89    DEFAULT_MATH = 0;
90    // The GPU may operate 4x4 matrix FMA.
91    // See cuDNN's documentation for CUDNN_TENSOR_OP_MATH.
92    TENSOR_OP_MATH = 1;
93  }
94  int64 algo_id = 1;
95  MathType math_type = 2;
96}
97
98// Convolution-specific parameters.
99message ConvolutionDescriptorProto {
100  repeated int64 paddings = 1;
101  repeated int64 strides = 2;
102  repeated int64 dilations = 3;
103  // The "accumulator" type. For example, use F32 as an accumulator for F16
104  // convolutions.
105  // See cuDNN's cudnnConvolutionMode_t.
106  DataType compute_mode = 4;
107  // See cuDNN's group count.
108  int32 group_count = 5;
109  ConvolutionMode convolution_mode = 6;
110}
111