• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1syntax = "proto3";
2
3package tensorflow.boosted_trees;
4
5option cc_enable_arenas = true;
6option java_outer_classname = "BoostedTreesProtos";
7option java_multiple_files = true;
8option java_package = "org.tensorflow.framework";
9
10// Node describes a node in a tree.
11message Node {
12  oneof node {
13    Leaf leaf = 1;
14    BucketizedSplit bucketized_split = 2;
15    CategoricalSplit categorical_split = 3;
16    DenseSplit dense_split = 4;
17  }
18  NodeMetadata metadata = 777;
19}
20
21// NodeMetadata encodes metadata associated with each node in a tree.
22message NodeMetadata {
23  // The gain associated with this node.
24  float gain = 1;
25
26  // The original leaf node before this node was split.
27  Leaf original_leaf = 2;
28}
29
30// Leaves can either hold dense or sparse information.
31message Leaf {
32  oneof leaf {
33    Vector vector = 1;
34    SparseVector sparse_vector = 2;
35  }
36  float scalar = 3;
37}
38
39message Vector {
40  repeated float value = 1;
41}
42
43message SparseVector {
44  repeated int32 index = 1;
45  repeated float value = 2;
46}
47
48enum SplitTypeWithDefault {
49  INEQUALITY_DEFAULT_LEFT = 0;
50  INEQUALITY_DEFAULT_RIGHT = 1;
51  EQUALITY_DEFAULT_RIGHT = 3;
52}
53
54enum DefaultDirection {
55  // Left is the default direction.
56  DEFAULT_LEFT = 0;
57  DEFAULT_RIGHT = 1;
58}
59
60message BucketizedSplit {
61  // Float feature column and split threshold describing
62  // the rule feature <= threshold.
63  int32 feature_id = 1;
64  int32 threshold = 2;
65  // If feature column is multivalent, this holds the index of the dimension
66  // for the split. Defaults to 0.
67  int32 dimension_id = 5;
68  // default direction for missing values.
69  DefaultDirection default_direction = 6;
70
71  // Node children indexing into a contiguous
72  // vector of nodes starting from the root.
73  int32 left_id = 3;
74  int32 right_id = 4;
75}
76
77message CategoricalSplit {
78  // Categorical feature column and split describing the rule feature value ==
79  // value.
80  int32 feature_id = 1;
81  int32 value = 2;
82  // If feature column is multivalent, this holds the index of the dimension
83  // for the split. Defaults to 0.
84  int32 dimension_id = 5;
85
86  // Node children indexing into a contiguous
87  // vector of nodes starting from the root.
88  int32 left_id = 3;
89  int32 right_id = 4;
90}
91
92// TODO(nponomareva): move out of boosted_trees and rename to trees.proto
93message DenseSplit {
94  // Float feature column and split threshold describing
95  // the rule feature <= threshold.
96  int32 feature_id = 1;
97  float threshold = 2;
98
99  // Node children indexing into a contiguous
100  // vector of nodes starting from the root.
101  int32 left_id = 3;
102  int32 right_id = 4;
103}
104
105// Tree describes a list of connected nodes.
106// Node 0 must be the root and can carry any payload including a leaf
107// in the case of representing the bias.
108// Note that each node id is implicitly its index in the list of nodes.
109message Tree {
110  repeated Node nodes = 1;
111}
112
113message TreeMetadata {
114  // Number of layers grown for this tree.
115  int32 num_layers_grown = 2;
116
117  // Whether the tree is finalized in that no more layers can be grown.
118  bool is_finalized = 3;
119
120  // If tree was finalized and post pruning happened, it is possible that cache
121  // still refers to some nodes that were deleted or that the node ids changed
122  // (e.g. node id 5 became node id 2 due to pruning of the other branch).
123  // The mapping below allows us to understand where the old ids now map to and
124  // how the values should be adjusted due to post-pruning.
125  // The size of the list should be equal to the number of nodes in the tree
126  // before post-pruning happened.
127  // If the node was pruned, it will have new_node_id equal to the id of a node
128  // that this node was collapsed into. For a node that didn't get pruned, it is
129  // possible that its id still changed, so new_node_id will have the
130  // corresponding id in the pruned tree.
131  // If post-pruning didn't happen, or it did and it had no effect (e.g. no
132  // nodes got pruned), this list will be empty.
133  repeated PostPruneNodeUpdate post_pruned_nodes_meta = 4;
134
135  message PostPruneNodeUpdate {
136    int32 new_node_id = 1;
137    repeated float logit_change = 2;
138  }
139}
140
141message GrowingMetadata {
142  // Number of trees that we have attempted to build. After pruning, these
143  // trees might have been removed.
144  int64 num_trees_attempted = 1;
145  // Number of layers that we have attempted to build. After pruning, these
146  // layers might have been removed.
147  int64 num_layers_attempted = 2;
148  // The start (inclusive) and end (exclusive) ids of the nodes in the latest
149  // layer of the latest tree.
150  int32 last_layer_node_start = 3;
151  int32 last_layer_node_end = 4;
152}
153
154// TreeEnsemble describes an ensemble of decision trees.
155message TreeEnsemble {
156  repeated Tree trees = 1;
157  repeated float tree_weights = 2;
158
159  repeated TreeMetadata tree_metadata = 3;
160  // Metadata that is used during the training.
161  GrowingMetadata growing_metadata = 4;
162}
163
164// DebugOutput contains outputs useful for debugging/model interpretation, at
165// the individual example-level. Debug outputs that are available to the user
166// are: 1) Directional feature contributions (DFCs) 2) Node IDs for ensemble
167// prediction paths 3) Leaf node IDs.
168message DebugOutput {
169  // Return the logits and associated feature splits across prediction paths for
170  // each tree, for every example, at predict time. We will use these values to
171  // compute DFCs in Python, by subtracting each child prediction from its
172  // parent prediction and associating this change with its respective feature
173  // id.
174  repeated int32 feature_ids = 1;
175  repeated float logits_path = 2;
176  // Return the node_id for each leaf node we reach in our prediction path.
177  repeated int32 leaf_node_ids = 3;
178  // TODO(crawles): return 4) Node IDs for ensemble prediction path
179}
180