• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1// LINT: LEGACY_NAMES
2syntax = "proto3";
3
4package stream_executor.dnn;
5
6option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/stream_executor";
7
8// Specifies the data type used by an operation.
9enum DataType {
10  kFloat = 0;
11  kDouble = 1;
12  kHalf = 2;
13  kInt8 = 3;
14  kInt32 = 4;
15}
16
17// Describes how a convolution input or output layer's data is formatted.
18enum DataLayout {
19  // Naming convention:
20  // Y <-> row or height
21  // X <-> column or width
22  // Batch <-> batch, or N
23  // Depth <-> feature, or channel
24  // TODO(timshen): turn them into cuDNN names, e.g. kNCHW.
25  kYXDepthBatch = 0;
26  kYXBatchDepth = 1;
27  kBatchYXDepth = 2;   // cuDNN's NHWC layout
28  kBatchDepthYX = 3;   // cuDNN's NCHW layout
29  kBatchDepthYX4 = 4;  // cuDNN's NCHW_VECT_C layout
30}
31
32// Describes how a convolution filter is laid out in the memory.
33enum FilterLayout {
34  // Naming convention:
35  // Y <-> row or height
36  // X <-> column or width
37  // Output <-> output feature, or N
38  // Input <-> input feature, or N
39  // TODO(timshen): turn them into cuDNN names, e.g. kNCHW.
40  kOutputInputYX = 0;   // cuDNN's NCHW layout
41  kOutputYXInput = 1;   // cuDNN's NHWC layout
42  kOutputInputYX4 = 2;  // cuDNN's NCHW_VECT_C layout
43  kInputYXOutput = 3;
44  kYXInputOutput = 4;
45}
46
47// Describes a kind of non-linearity (threshold-like mathematical function).
48enum ActivationMode {
49  kNone = 0;
50  kSigmoid = 1;
51  // Rectified linear activation: f(x) = x < 0 ? 0 : x
52  kRelu = 2;
53  // Rectified linear activation; where upper maximum is 6.0.
54  kRelu6 = 3;
55  // Rectified linear activation; where upper maximum specified by
56  // BatchDescriptor::value_max().
57  kReluX = 4;
58  kTanh = 5;
59  // Like ReluX; but passes all values in the range [-X,X].
60  kBandPass = 6;
61}
62
63// Describe the math definition for the conv op. The popular behavior is
64// actually called cross-correlation in math, despite the operation is often
65// referred as convolution. See cuDNN cudnnConvolutionMode_t.
66enum ConvolutionMode {
67  CROSS_CORRELATION = 0;
68  CONVOLUTION = 1;
69}
70
71enum ConvolutionKind {
72  INVALID = 0;
73  FORWARD = 1;
74  BACKWARD_FILTER = 2;
75  BACKWARD_DATA = 3;
76  FORWARD_BIAS_ACTIVATION = 4;
77}
78
79// Generic tensor representation.
80message TensorDescriptorProto {
81  repeated int64 dimensions = 1;
82  DataType data_type = 2;
83  oneof layout_oneof {
84    DataLayout data_layout = 3;
85    FilterLayout filter_layout = 4;
86  }
87}
88
89// Generic algorithm representation.
90message AlgorithmProto {
91  enum MathType {
92    DEFAULT_MATH = 0;
93    // The GPU may operate 4x4 matrix FMA.
94    // See cuDNN's documentation for CUDNN_TENSOR_OP_MATH.
95    TENSOR_OP_MATH = 1;
96  }
97  int64 algo_id = 1;
98  MathType math_type = 2;
99}
100
101// Convolution-specific parameters.
102message ConvolutionDescriptorProto {
103  repeated int64 paddings = 1;
104  repeated int64 strides = 2;
105  repeated int64 dilations = 3;
106  // The "accumulator" type. For example, use F32 as an accumulator for F16
107  // convolutions.
108  // See cuDNN's cudnnConvolutionMode_t.
109  DataType compute_mode = 4;
110  // See cuDNN's group count.
111  int32 group_count = 5;
112  ConvolutionMode convolution_mode = 6;
113  // Tensorflow node name, same as in NodeDef, for debugging purposes.
114  string name = 7;
115}
116