• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1// LINT: LEGACY_NAMES
2syntax = "proto3";
3
4package stream_executor.dnn;
5
6option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/stream_executor";
7
8// Specifies the data type used by an operation.
9enum DataType {
10  kFloat = 0;
11  kDouble = 1;
12  kHalf = 2;
13  kInt8 = 3;
14  kInt32 = 4;
15  kComplexFloat = 5;
16  kComplexDouble = 6;
17}
18
19// Describes how a convolution input or output layer's data is formatted.
20enum DataLayout {
21  // Naming convention:
22  // Y <-> row or height
23  // X <-> column or width
24  // Batch <-> batch, or N
25  // Depth <-> feature, or channel
26  // TODO(timshen): turn them into cuDNN names, e.g. kNCHW.
27  kYXDepthBatch = 0;
28  kYXBatchDepth = 1;
29  kBatchYXDepth = 2;   // cuDNN's NHWC layout
30  kBatchDepthYX = 3;   // cuDNN's NCHW layout
31  kBatchDepthYX4 = 4;  // cuDNN's NCHW_VECT_C layout
32}
33
34// Describes how a convolution filter is laid out in the memory.
35enum FilterLayout {
36  // Naming convention:
37  // Y <-> row or height
38  // X <-> column or width
39  // Output <-> output feature, or N
40  // Input <-> input feature, or N
41  // TODO(timshen): turn them into cuDNN names, e.g. kNCHW.
42  kOutputInputYX = 0;   // cuDNN's NCHW layout
43  kOutputYXInput = 1;   // cuDNN's NHWC layout
44  kOutputInputYX4 = 2;  // cuDNN's NCHW_VECT_C layout
45  kInputYXOutput = 3;
46  kYXInputOutput = 4;
47}
48
49// Describes a kind of non-linearity (threshold-like mathematical function).
50enum ActivationMode {
51  kNone = 0;
52  kSigmoid = 1;
53  // Rectified linear activation: f(x) = x < 0 ? 0 : x
54  kRelu = 2;
55  // Rectified linear activation; where upper maximum is 6.0.
56  kRelu6 = 3;
57  // Rectified linear activation; where upper maximum specified by
58  // BatchDescriptor::value_max().
59  kReluX = 4;
60  kTanh = 5;
61  // Like ReluX; but passes all values in the range [-X,X].
62  kBandPass = 6;
63}
64
65// Describe the math definition for the conv op. The popular behavior is
66// actually called cross-correlation in math, despite the operation is often
67// referred as convolution. See cuDNN cudnnConvolutionMode_t.
68enum ConvolutionMode {
69  CROSS_CORRELATION = 0;
70  CONVOLUTION = 1;
71}
72
73enum ConvolutionKind {
74  INVALID = 0;
75  FORWARD = 1;
76  BACKWARD_FILTER = 2;
77  BACKWARD_DATA = 3;
78  FORWARD_BIAS_ACTIVATION = 4;
79}
80
81// Generic tensor representation.
82message TensorDescriptorProto {
83  repeated int64 dimensions = 1;
84  DataType data_type = 2;
85  oneof layout_oneof {
86    DataLayout data_layout = 3;
87    FilterLayout filter_layout = 4;
88  }
89}
90
91// Generic algorithm representation.
92message AlgorithmProto {
93  enum MathType {
94    DEFAULT_MATH = 0;
95    // The GPU may operate 4x4 matrix FMA.
96    // See cuDNN's documentation for CUDNN_TENSOR_OP_MATH.
97    TENSOR_OP_MATH = 1;
98  }
99  int64 algo_id = 1;
100  MathType math_type = 2;
101}
102
103// Convolution-specific parameters.
104message ConvolutionDescriptorProto {
105  repeated int64 paddings = 1;
106  repeated int64 strides = 2;
107  repeated int64 dilations = 3;
108  // The "accumulator" type. For example, use F32 as an accumulator for F16
109  // convolutions.
110  // See cuDNN's cudnnConvolutionMode_t.
111  DataType compute_mode = 4;
112  // See cuDNN's group count.
113  int32 group_count = 5;
114  ConvolutionMode convolution_mode = 6;
115  // Tensorflow node name, same as in NodeDef, for debugging purposes.
116  string name = 7;
117}
118