1syntax = "proto3"; 2 3package tensorflow.boosted_trees; 4option cc_enable_arenas = true; 5option java_outer_classname = "BoostedTreesProtos"; 6option java_multiple_files = true; 7option java_package = "org.tensorflow.framework"; 8 9// Node describes a node in a tree. 10message Node { 11 oneof node { 12 Leaf leaf = 1; 13 BucketizedSplit bucketized_split = 2; 14 CategoricalSplit categorical_split = 3; 15 DenseSplit dense_split = 4; 16 } 17 NodeMetadata metadata = 777; 18} 19 20// NodeMetadata encodes metadata associated with each node in a tree. 21message NodeMetadata { 22 // The gain associated with this node. 23 float gain = 1; 24 25 // The original leaf node before this node was split. 26 Leaf original_leaf = 2; 27} 28 29// Leaves can either hold dense or sparse information. 30message Leaf { 31 oneof leaf { 32 // See third_party/tensorflow/contrib/decision_trees/ 33 // proto/generic_tree_model.proto 34 // for a description of how vector and sparse_vector might be used. 35 Vector vector = 1; 36 SparseVector sparse_vector = 2; 37 } 38 float scalar = 3; 39} 40 41message Vector { 42 repeated float value = 1; 43} 44 45message SparseVector { 46 repeated int32 index = 1; 47 repeated float value = 2; 48} 49 50message BucketizedSplit { 51 // Float feature column and split threshold describing 52 // the rule feature <= threshold. 53 int32 feature_id = 1; 54 int32 threshold = 2; 55 // If feature column is multivalent, this holds the index of the dimension 56 // for the split. Defaults to 0. 57 int32 dimension_id = 5; 58 enum DefaultDirection { 59 // Left is the default direction. 60 DEFAULT_LEFT = 0; 61 DEFAULT_RIGHT = 1; 62 } 63 // default direction for missing values. 64 DefaultDirection default_direction = 6; 65 66 // Node children indexing into a contiguous 67 // vector of nodes starting from the root. 68 int32 left_id = 3; 69 int32 right_id = 4; 70} 71 72message CategoricalSplit { 73 // Categorical feature column and split describing the rule feature value == 74 // value. 75 int32 feature_id = 1; 76 int32 value = 2; 77 78 // Node children indexing into a contiguous 79 // vector of nodes starting from the root. 80 int32 left_id = 3; 81 int32 right_id = 4; 82} 83 84// TODO(nponomareva): move out of boosted_trees and rename to trees.proto 85message DenseSplit { 86 // Float feature column and split threshold describing 87 // the rule feature <= threshold. 88 int32 feature_id = 1; 89 float threshold = 2; 90 91 // Node children indexing into a contiguous 92 // vector of nodes starting from the root. 93 int32 left_id = 3; 94 int32 right_id = 4; 95} 96 97// Tree describes a list of connected nodes. 98// Node 0 must be the root and can carry any payload including a leaf 99// in the case of representing the bias. 100// Note that each node id is implicitly its index in the list of nodes. 101message Tree { 102 repeated Node nodes = 1; 103} 104 105message TreeMetadata { 106 // Number of layers grown for this tree. 107 int32 num_layers_grown = 2; 108 109 // Whether the tree is finalized in that no more layers can be grown. 110 bool is_finalized = 3; 111 112 // If tree was finalized and post pruning happened, it is possible that cache 113 // still refers to some nodes that were deleted or that the node ids changed 114 // (e.g. node id 5 became node id 2 due to pruning of the other branch). 115 // The mapping below allows us to understand where the old ids now map to and 116 // how the values should be adjusted due to post-pruning. 117 // The size of the list should be equal to the number of nodes in the tree 118 // before post-pruning happened. 119 // If the node was pruned, it will have new_node_id equal to the id of a node 120 // that this node was collapsed into. For a node that didn't get pruned, it is 121 // possible that its id still changed, so new_node_id will have the 122 // corresponding id in the pruned tree. 123 // If post-pruning didn't happen, or it did and it had no effect (e.g. no 124 // nodes got pruned), this list will be empty. 125 repeated PostPruneNodeUpdate post_pruned_nodes_meta = 4; 126 127 message PostPruneNodeUpdate { 128 int32 new_node_id = 1; 129 float logit_change = 2; 130 } 131} 132 133message GrowingMetadata { 134 // Number of trees that we have attempted to build. After pruning, these 135 // trees might have been removed. 136 int64 num_trees_attempted = 1; 137 // Number of layers that we have attempted to build. After pruning, these 138 // layers might have been removed. 139 int64 num_layers_attempted = 2; 140 // The start (inclusive) and end (exclusive) ids of the nodes in the latest 141 // layer of the latest tree. 142 int32 last_layer_node_start = 3; 143 int32 last_layer_node_end = 4; 144} 145 146// TreeEnsemble describes an ensemble of decision trees. 147message TreeEnsemble { 148 repeated Tree trees = 1; 149 repeated float tree_weights = 2; 150 151 repeated TreeMetadata tree_metadata = 3; 152 // Metadata that is used during the training. 153 GrowingMetadata growing_metadata = 4; 154} 155 156// DebugOutput contains outputs useful for debugging/model interpretation, at 157// the individual example-level. Debug outputs that are available to the user 158// are: 1) Directional feature contributions (DFCs) 2) Node IDs for ensemble 159// prediction path 3) Leaf node IDs. 160message DebugOutput { 161 // Return the logits and associated feature splits across prediction paths for 162 // each tree, for every example, at predict time. We will use these values to 163 // compute DFCs in Python, by subtracting each child prediction from its 164 // parent prediction and associating this change with its respective feature 165 // id. 166 repeated int32 feature_ids = 1; 167 repeated float logits_path = 2; 168 169 // TODO(crawles): return 2) Node IDs for ensemble prediction path 3) Leaf node 170 // IDs. 171} 172