1// LINT: LEGACY_NAMES 2syntax = "proto3"; 3 4package stream_executor.dnn; 5 6// Specifies the data type used by an operation. 7enum DataType { 8 kFloat = 0; 9 kDouble = 1; 10 kHalf = 2; 11 kInt8 = 3; 12 kInt32 = 4; 13} 14 15// Describes how a convolution input or output layer's data is formatted. 16enum DataLayout { 17 // Naming convention: 18 // Y <-> row or height 19 // X <-> column or width 20 // Batch <-> batch, or N 21 // Depth <-> feature, or channel 22 // TODO(timshen): turn them into cuDNN names, e.g. kNCHW. 23 kYXDepthBatch = 0; 24 kYXBatchDepth = 1; 25 kBatchYXDepth = 2; // cuDNN's NHWC layout 26 kBatchDepthYX = 3; // cuDNN's NCHW layout 27 kBatchDepthYX4 = 4; // cuDNN's NCHW_VECT_C layout 28} 29 30// Describes how a convolution filter is laid out in the memory. 31enum FilterLayout { 32 // Naming convention: 33 // Y <-> row or height 34 // X <-> column or width 35 // Output <-> output feature, or N 36 // Input <-> input feature, or N 37 // TODO(timshen): turn them into cuDNN names, e.g. kNCHW. 38 kOutputInputYX = 0; // cuDNN's NCHW layout 39 kOutputYXInput = 1; // cuDNN's NHWC layout 40 kOutputInputYX4 = 2; // cuDNN's NCHW_VECT_C layout 41 kInputYXOutput = 3; 42 kYXInputOutput = 4; 43} 44 45// Describes a kind of non-linearity (threshold-like mathematical function). 46enum ActivationMode { 47 kNone = 0; 48 kSigmoid = 1; 49 // Rectified linear activation: f(x) = x < 0 ? 0 : x 50 kRelu = 2; 51 // Rectified linear activation; where upper maximum is 6.0. 52 kRelu6 = 3; 53 // Rectified linear activation; where upper maximum specified by 54 // BatchDescriptor::value_max(). 55 kReluX = 4; 56 kTanh = 5; 57 // Like ReluX; but passes all values in the range [-X,X]. 58 kBandPass = 6; 59} 60 61// Describe the math definition for the conv op. The popular behavior is 62// actually called cross-correlation in math, despite the operation is often 63// referred as convolution. See cuDNN cudnnConvolutionMode_t. 64enum ConvolutionMode { 65 CROSS_CORRELATION = 0; 66 CONVOLUTION = 1; 67} 68 69enum ConvolutionKind { 70 INVALID = 0; 71 FORWARD = 1; 72 BACKWARD_FILTER = 2; 73 BACKWARD_DATA = 3; 74} 75 76// Generic tensor representation. 77message TensorDescriptorProto { 78 repeated int64 dimensions = 1; 79 DataType data_type = 2; 80 oneof layout_oneof { 81 DataLayout data_layout = 3; 82 FilterLayout filter_layout = 4; 83 } 84} 85 86// Generic algorithm representation. 87message AlgorithmProto { 88 enum MathType { 89 DEFAULT_MATH = 0; 90 // The GPU may operate 4x4 matrix FMA. 91 // See cuDNN's documentation for CUDNN_TENSOR_OP_MATH. 92 TENSOR_OP_MATH = 1; 93 } 94 int64 algo_id = 1; 95 MathType math_type = 2; 96} 97 98// Convolution-specific parameters. 99message ConvolutionDescriptorProto { 100 repeated int64 paddings = 1; 101 repeated int64 strides = 2; 102 repeated int64 dilations = 3; 103 // The "accumulator" type. For example, use F32 as an accumulator for F16 104 // convolutions. 105 // See cuDNN's cudnnConvolutionMode_t. 106 DataType compute_mode = 4; 107 // See cuDNN's group count. 108 int32 group_count = 5; 109 ConvolutionMode convolution_mode = 6; 110} 111