• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1syntax = "proto3";
2
3package tensorflow.boosted_trees;
4option cc_enable_arenas = true;
5option java_outer_classname = "BoostedTreesProtos";
6option java_multiple_files = true;
7option java_package = "org.tensorflow.framework";
8
9// Node describes a node in a tree.
10message Node {
11  oneof node {
12    Leaf leaf = 1;
13    BucketizedSplit bucketized_split = 2;
14    CategoricalSplit categorical_split = 3;
15    DenseSplit dense_split = 4;
16  }
17  NodeMetadata metadata = 777;
18}
19
20// NodeMetadata encodes metadata associated with each node in a tree.
21message NodeMetadata {
22  // The gain associated with this node.
23  float gain = 1;
24
25  // The original leaf node before this node was split.
26  Leaf original_leaf = 2;
27}
28
29// Leaves can either hold dense or sparse information.
30message Leaf {
31  oneof leaf {
32    // See third_party/tensorflow/contrib/decision_trees/
33    // proto/generic_tree_model.proto
34    // for a description of how vector and sparse_vector might be used.
35    Vector vector = 1;
36    SparseVector sparse_vector = 2;
37  }
38  float scalar = 3;
39}
40
41message Vector {
42  repeated float value = 1;
43}
44
45message SparseVector {
46  repeated int32 index = 1;
47  repeated float value = 2;
48}
49
50message BucketizedSplit {
51  // Float feature column and split threshold describing
52  // the rule feature <= threshold.
53  int32 feature_id = 1;
54  int32 threshold = 2;
55  // If feature column is multivalent, this holds the index of the dimension
56  // for the split. Defaults to 0.
57  int32 dimension_id = 5;
58  enum DefaultDirection {
59    // Left is the default direction.
60    DEFAULT_LEFT = 0;
61    DEFAULT_RIGHT = 1;
62  }
63  // default direction for missing values.
64  DefaultDirection default_direction = 6;
65
66  // Node children indexing into a contiguous
67  // vector of nodes starting from the root.
68  int32 left_id = 3;
69  int32 right_id = 4;
70}
71
72message CategoricalSplit {
73  // Categorical feature column and split describing the rule feature value ==
74  // value.
75  int32 feature_id = 1;
76  int32 value = 2;
77
78  // Node children indexing into a contiguous
79  // vector of nodes starting from the root.
80  int32 left_id = 3;
81  int32 right_id = 4;
82}
83
84// TODO(nponomareva): move out of boosted_trees and rename to trees.proto
85message DenseSplit {
86  // Float feature column and split threshold describing
87  // the rule feature <= threshold.
88  int32 feature_id = 1;
89  float threshold = 2;
90
91  // Node children indexing into a contiguous
92  // vector of nodes starting from the root.
93  int32 left_id = 3;
94  int32 right_id = 4;
95}
96
97// Tree describes a list of connected nodes.
98// Node 0 must be the root and can carry any payload including a leaf
99// in the case of representing the bias.
100// Note that each node id is implicitly its index in the list of nodes.
101message Tree {
102  repeated Node nodes = 1;
103}
104
105message TreeMetadata {
106  // Number of layers grown for this tree.
107  int32 num_layers_grown = 2;
108
109  // Whether the tree is finalized in that no more layers can be grown.
110  bool is_finalized = 3;
111
112  // If tree was finalized and post pruning happened, it is possible that cache
113  // still refers to some nodes that were deleted or that the node ids changed
114  // (e.g. node id 5 became node id 2 due to pruning of the other branch).
115  // The mapping below allows us to understand where the old ids now map to and
116  // how the values should be adjusted due to post-pruning.
117  // The size of the list should be equal to the number of nodes in the tree
118  // before post-pruning happened.
119  // If the node was pruned, it will have new_node_id equal to the id of a node
120  // that this node was collapsed into. For a node that didn't get pruned, it is
121  // possible that its id still changed, so new_node_id will have the
122  // corresponding id in the pruned tree.
123  // If post-pruning didn't happen, or it did and it had no effect (e.g. no
124  // nodes got pruned), this list will be empty.
125  repeated PostPruneNodeUpdate post_pruned_nodes_meta = 4;
126
127  message PostPruneNodeUpdate {
128    int32 new_node_id = 1;
129    float logit_change = 2;
130  }
131}
132
133message GrowingMetadata {
134  // Number of trees that we have attempted to build. After pruning, these
135  // trees might have been removed.
136  int64 num_trees_attempted = 1;
137  // Number of layers that we have attempted to build. After pruning, these
138  // layers might have been removed.
139  int64 num_layers_attempted = 2;
140  // The start (inclusive) and end (exclusive) ids of the nodes in the latest
141  // layer of the latest tree.
142  int32 last_layer_node_start = 3;
143  int32 last_layer_node_end = 4;
144}
145
146// TreeEnsemble describes an ensemble of decision trees.
147message TreeEnsemble {
148  repeated Tree trees = 1;
149  repeated float tree_weights = 2;
150
151  repeated TreeMetadata tree_metadata = 3;
152  // Metadata that is used during the training.
153  GrowingMetadata growing_metadata = 4;
154}
155
156// DebugOutput contains outputs useful for debugging/model interpretation, at
157// the individual example-level. Debug outputs that are available to the user
158// are: 1) Directional feature contributions (DFCs) 2) Node IDs for ensemble
159// prediction path 3) Leaf node IDs.
160message DebugOutput {
161  // Return the logits and associated feature splits across prediction paths for
162  // each tree, for every example, at predict time. We will use these values to
163  // compute DFCs in Python, by subtracting each child prediction from its
164  // parent prediction and associating this change with its respective feature
165  // id.
166  repeated int32 feature_ids = 1;
167  repeated float logits_path = 2;
168
169  // TODO(crawles): return 2) Node IDs for ensemble prediction path 3) Leaf node
170  // IDs.
171}
172