1syntax = "proto3"; 2 3package tensorflow.boosted_trees; 4 5option cc_enable_arenas = true; 6option java_outer_classname = "BoostedTreesProtos"; 7option java_multiple_files = true; 8option java_package = "org.tensorflow.framework"; 9 10// Node describes a node in a tree. 11message Node { 12 oneof node { 13 Leaf leaf = 1; 14 BucketizedSplit bucketized_split = 2; 15 CategoricalSplit categorical_split = 3; 16 DenseSplit dense_split = 4; 17 } 18 NodeMetadata metadata = 777; 19} 20 21// NodeMetadata encodes metadata associated with each node in a tree. 22message NodeMetadata { 23 // The gain associated with this node. 24 float gain = 1; 25 26 // The original leaf node before this node was split. 27 Leaf original_leaf = 2; 28} 29 30// Leaves can either hold dense or sparse information. 31message Leaf { 32 oneof leaf { 33 Vector vector = 1; 34 SparseVector sparse_vector = 2; 35 } 36 float scalar = 3; 37} 38 39message Vector { 40 repeated float value = 1; 41} 42 43message SparseVector { 44 repeated int32 index = 1; 45 repeated float value = 2; 46} 47 48enum SplitTypeWithDefault { 49 INEQUALITY_DEFAULT_LEFT = 0; 50 INEQUALITY_DEFAULT_RIGHT = 1; 51 EQUALITY_DEFAULT_RIGHT = 3; 52} 53 54enum DefaultDirection { 55 // Left is the default direction. 56 DEFAULT_LEFT = 0; 57 DEFAULT_RIGHT = 1; 58} 59 60message BucketizedSplit { 61 // Float feature column and split threshold describing 62 // the rule feature <= threshold. 63 int32 feature_id = 1; 64 int32 threshold = 2; 65 // If feature column is multivalent, this holds the index of the dimension 66 // for the split. Defaults to 0. 67 int32 dimension_id = 5; 68 // default direction for missing values. 69 DefaultDirection default_direction = 6; 70 71 // Node children indexing into a contiguous 72 // vector of nodes starting from the root. 73 int32 left_id = 3; 74 int32 right_id = 4; 75} 76 77message CategoricalSplit { 78 // Categorical feature column and split describing the rule feature value == 79 // value. 80 int32 feature_id = 1; 81 int32 value = 2; 82 // If feature column is multivalent, this holds the index of the dimension 83 // for the split. Defaults to 0. 84 int32 dimension_id = 5; 85 86 // Node children indexing into a contiguous 87 // vector of nodes starting from the root. 88 int32 left_id = 3; 89 int32 right_id = 4; 90} 91 92// TODO(nponomareva): move out of boosted_trees and rename to trees.proto 93message DenseSplit { 94 // Float feature column and split threshold describing 95 // the rule feature <= threshold. 96 int32 feature_id = 1; 97 float threshold = 2; 98 99 // Node children indexing into a contiguous 100 // vector of nodes starting from the root. 101 int32 left_id = 3; 102 int32 right_id = 4; 103} 104 105// Tree describes a list of connected nodes. 106// Node 0 must be the root and can carry any payload including a leaf 107// in the case of representing the bias. 108// Note that each node id is implicitly its index in the list of nodes. 109message Tree { 110 repeated Node nodes = 1; 111} 112 113message TreeMetadata { 114 // Number of layers grown for this tree. 115 int32 num_layers_grown = 2; 116 117 // Whether the tree is finalized in that no more layers can be grown. 118 bool is_finalized = 3; 119 120 // If tree was finalized and post pruning happened, it is possible that cache 121 // still refers to some nodes that were deleted or that the node ids changed 122 // (e.g. node id 5 became node id 2 due to pruning of the other branch). 123 // The mapping below allows us to understand where the old ids now map to and 124 // how the values should be adjusted due to post-pruning. 125 // The size of the list should be equal to the number of nodes in the tree 126 // before post-pruning happened. 127 // If the node was pruned, it will have new_node_id equal to the id of a node 128 // that this node was collapsed into. For a node that didn't get pruned, it is 129 // possible that its id still changed, so new_node_id will have the 130 // corresponding id in the pruned tree. 131 // If post-pruning didn't happen, or it did and it had no effect (e.g. no 132 // nodes got pruned), this list will be empty. 133 repeated PostPruneNodeUpdate post_pruned_nodes_meta = 4; 134 135 message PostPruneNodeUpdate { 136 int32 new_node_id = 1; 137 repeated float logit_change = 2; 138 } 139} 140 141message GrowingMetadata { 142 // Number of trees that we have attempted to build. After pruning, these 143 // trees might have been removed. 144 int64 num_trees_attempted = 1; 145 // Number of layers that we have attempted to build. After pruning, these 146 // layers might have been removed. 147 int64 num_layers_attempted = 2; 148 // The start (inclusive) and end (exclusive) ids of the nodes in the latest 149 // layer of the latest tree. 150 int32 last_layer_node_start = 3; 151 int32 last_layer_node_end = 4; 152} 153 154// TreeEnsemble describes an ensemble of decision trees. 155message TreeEnsemble { 156 repeated Tree trees = 1; 157 repeated float tree_weights = 2; 158 159 repeated TreeMetadata tree_metadata = 3; 160 // Metadata that is used during the training. 161 GrowingMetadata growing_metadata = 4; 162} 163 164// DebugOutput contains outputs useful for debugging/model interpretation, at 165// the individual example-level. Debug outputs that are available to the user 166// are: 1) Directional feature contributions (DFCs) 2) Node IDs for ensemble 167// prediction paths 3) Leaf node IDs. 168message DebugOutput { 169 // Return the logits and associated feature splits across prediction paths for 170 // each tree, for every example, at predict time. We will use these values to 171 // compute DFCs in Python, by subtracting each child prediction from its 172 // parent prediction and associating this change with its respective feature 173 // id. 174 repeated int32 feature_ids = 1; 175 repeated float logits_path = 2; 176 // Return the node_id for each leaf node we reach in our prediction path. 177 repeated int32 leaf_node_ids = 3; 178 // TODO(crawles): return 4) Node IDs for ensemble prediction path 179} 180