1// LINT: LEGACY_NAMES 2syntax = "proto3"; 3 4package stream_executor.dnn; 5 6option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/stream_executor"; 7 8// Specifies the data type used by an operation. 9enum DataType { 10 kFloat = 0; 11 kDouble = 1; 12 kHalf = 2; 13 kInt8 = 3; 14 kInt32 = 4; 15} 16 17// Describes how a convolution input or output layer's data is formatted. 18enum DataLayout { 19 // Naming convention: 20 // Y <-> row or height 21 // X <-> column or width 22 // Batch <-> batch, or N 23 // Depth <-> feature, or channel 24 // TODO(timshen): turn them into cuDNN names, e.g. kNCHW. 25 kYXDepthBatch = 0; 26 kYXBatchDepth = 1; 27 kBatchYXDepth = 2; // cuDNN's NHWC layout 28 kBatchDepthYX = 3; // cuDNN's NCHW layout 29 kBatchDepthYX4 = 4; // cuDNN's NCHW_VECT_C layout 30} 31 32// Describes how a convolution filter is laid out in the memory. 33enum FilterLayout { 34 // Naming convention: 35 // Y <-> row or height 36 // X <-> column or width 37 // Output <-> output feature, or N 38 // Input <-> input feature, or N 39 // TODO(timshen): turn them into cuDNN names, e.g. kNCHW. 40 kOutputInputYX = 0; // cuDNN's NCHW layout 41 kOutputYXInput = 1; // cuDNN's NHWC layout 42 kOutputInputYX4 = 2; // cuDNN's NCHW_VECT_C layout 43 kInputYXOutput = 3; 44 kYXInputOutput = 4; 45} 46 47// Describes a kind of non-linearity (threshold-like mathematical function). 48enum ActivationMode { 49 kNone = 0; 50 kSigmoid = 1; 51 // Rectified linear activation: f(x) = x < 0 ? 0 : x 52 kRelu = 2; 53 // Rectified linear activation; where upper maximum is 6.0. 54 kRelu6 = 3; 55 // Rectified linear activation; where upper maximum specified by 56 // BatchDescriptor::value_max(). 57 kReluX = 4; 58 kTanh = 5; 59 // Like ReluX; but passes all values in the range [-X,X]. 60 kBandPass = 6; 61} 62 63// Describe the math definition for the conv op. The popular behavior is 64// actually called cross-correlation in math, despite the operation is often 65// referred as convolution. See cuDNN cudnnConvolutionMode_t. 66enum ConvolutionMode { 67 CROSS_CORRELATION = 0; 68 CONVOLUTION = 1; 69} 70 71enum ConvolutionKind { 72 INVALID = 0; 73 FORWARD = 1; 74 BACKWARD_FILTER = 2; 75 BACKWARD_DATA = 3; 76 FORWARD_BIAS_ACTIVATION = 4; 77} 78 79// Generic tensor representation. 80message TensorDescriptorProto { 81 repeated int64 dimensions = 1; 82 DataType data_type = 2; 83 oneof layout_oneof { 84 DataLayout data_layout = 3; 85 FilterLayout filter_layout = 4; 86 } 87} 88 89// Generic algorithm representation. 90message AlgorithmProto { 91 enum MathType { 92 DEFAULT_MATH = 0; 93 // The GPU may operate 4x4 matrix FMA. 94 // See cuDNN's documentation for CUDNN_TENSOR_OP_MATH. 95 TENSOR_OP_MATH = 1; 96 } 97 int64 algo_id = 1; 98 MathType math_type = 2; 99} 100 101// Convolution-specific parameters. 102message ConvolutionDescriptorProto { 103 repeated int64 paddings = 1; 104 repeated int64 strides = 2; 105 repeated int64 dilations = 3; 106 // The "accumulator" type. For example, use F32 as an accumulator for F16 107 // convolutions. 108 // See cuDNN's cudnnConvolutionMode_t. 109 DataType compute_mode = 4; 110 // See cuDNN's group count. 111 int32 group_count = 5; 112 ConvolutionMode convolution_mode = 6; 113 // Tensorflow node name, same as in NodeDef, for debugging purposes. 114 string name = 7; 115} 116