1 // Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 // ============================================================================= 15 #include "tensorflow/core/framework/common_shape_fns.h" 16 #include "tensorflow/core/framework/op.h" 17 #include "tensorflow/core/framework/shape_inference.h" 18 19 namespace tensorflow { 20 21 using shape_inference::DimensionHandle; 22 using shape_inference::InferenceContext; 23 using shape_inference::ShapeHandle; 24 25 REGISTER_OP("BuildDenseInequalitySplits") 26 .Input("num_minibatches: int64") 27 .Input("partition_ids: int32") 28 .Input("bucket_ids: int64") 29 .Input("gradients: float32") 30 .Input("hessians: float32") 31 .Input("bucket_boundaries: float32") 32 .Input("class_id: int32") 33 .Input("feature_column_group_id: int32") 34 .Input("l1_regularization: float") 35 .Input("l2_regularization: float") 36 .Input("tree_complexity_regularization: float") 37 .Input("min_node_weight: float") 38 .Input("multiclass_strategy: int32") 39 .Input("weak_learner_type: int32") 40 .Output("output_partition_ids: int32") 41 .Output("gains: float32") 42 .Output("split_infos: string") __anon295dc3ce0102(InferenceContext* c) 43 .SetShapeFn([](InferenceContext* c) { 44 DimensionHandle unused_dim; 45 ShapeHandle unused_shape; 46 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_shape)); 47 48 ShapeHandle partition_ids_shape; 49 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &partition_ids_shape)); 50 ShapeHandle bucket_ids_shape; 51 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 2, &bucket_ids_shape)); 52 ShapeHandle gradients_shape; 53 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(3), 1, &gradients_shape)); 54 TF_RETURN_IF_ERROR(c->Merge(c->Dim(partition_ids_shape, 0), 55 c->Dim(gradients_shape, 0), &unused_dim)); 56 ShapeHandle hessians_shape; 57 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(4), 1, &hessians_shape)); 58 TF_RETURN_IF_ERROR(c->Merge(c->Dim(partition_ids_shape, 0), 59 c->Dim(hessians_shape, 0), &unused_dim)); 60 ShapeHandle bucket_boundaries_shape; 61 TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 1, &bucket_boundaries_shape)); 62 c->set_output(0, c->Vector(c->UnknownDim())); 63 c->set_output(1, c->Vector(c->UnknownDim())); 64 c->set_output(2, c->Vector(c->UnknownDim())); 65 return Status::OK(); 66 }) 67 .Doc(R"doc( 68 Find the split that has the best gain for the accumulated stats. 69 70 num_minibatches: A scalar, the number of times per example gradients & hessians 71 were accumulated. The stats are divided by this to get per example stats. 72 partition_ids: A rank 1 tensor of partition IDs. 73 bucket_ids: A rank 2 tensor of buckets IDs and dimensions. 74 gradients: A rank 1 tensor of gradients. 75 hessians: A rank 1 tensor of hessians. 76 bucket_boundaries: A rank 1 tensor, thresholds that were used for bucketization. 77 class_id: A scalar, the class id for which we're building the splits. 78 feature_column_group_id: A scalar, the index of the feature we are spiltting on. 79 l1_regularization: A scalar, which specifies the l1 regularization term. 80 l2_regularization: A scalar, which specifies the l2 regularization term. 81 tree_complexity_regularization: A scalar, which specifies the tree complexity 82 regularization term. 83 min_node_weight: A scalar, minimum sum of example hessian needed in a child. 84 If a split results in a leaf node with a smaller value, the split will not 85 be considered. 86 multiclass_strategy: A scalar, specifying the multiclass handling strategy. 87 See LearnerConfig.MultiClassStrategy for valid values. 88 weak_learner_type: A scalar, specifying the weak learner type to use. 89 See LearnerConfig.WeakLearnerType for valid values. 90 output_partition_ids: A rank 1 tensor, the partition IDs that we created splits 91 for. 92 gains: A rank 1 tensor, for the computed gain for the created splits. 93 split_infos: A rank 1 tensor of serialized protos which contains the 94 `SplitInfo`s. 95 )doc"); 96 97 REGISTER_OP("BuildSparseInequalitySplits") 98 .Input("num_minibatches: int64") 99 .Input("partition_ids: int32") 100 .Input("bucket_ids: int64") 101 .Input("gradients: float32") 102 .Input("hessians: float32") 103 .Input("bucket_boundaries: float32") 104 .Input("class_id: int32") 105 .Input("feature_column_group_id: int32") 106 .Input("bias_feature_id: int64") 107 .Input("l1_regularization: float") 108 .Input("l2_regularization: float") 109 .Input("tree_complexity_regularization: float") 110 .Input("min_node_weight: float") 111 .Input("multiclass_strategy: int32") 112 .Output("output_partition_ids: int32") 113 .Output("gains: float32") 114 .Output("split_infos: string") __anon295dc3ce0202(InferenceContext* c) 115 .SetShapeFn([](InferenceContext* c) { 116 DimensionHandle unused_dim; 117 ShapeHandle unused_shape; 118 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_shape)); 119 120 ShapeHandle partition_ids_shape; 121 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &partition_ids_shape)); 122 ShapeHandle bucket_ids_shape; 123 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 2, &bucket_ids_shape)); 124 ShapeHandle gradients_shape; 125 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(3), 1, &gradients_shape)); 126 TF_RETURN_IF_ERROR(c->Merge(c->Dim(partition_ids_shape, 0), 127 c->Dim(gradients_shape, 0), &unused_dim)); 128 ShapeHandle hessians_shape; 129 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(4), 1, &hessians_shape)); 130 TF_RETURN_IF_ERROR(c->Merge(c->Dim(partition_ids_shape, 0), 131 c->Dim(hessians_shape, 0), &unused_dim)); 132 ShapeHandle bucket_boundaries_shape; 133 TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 1, &bucket_boundaries_shape)); 134 c->set_output(0, c->Vector(c->UnknownDim())); 135 c->set_output(1, c->Vector(c->UnknownDim())); 136 c->set_output(2, c->Vector(c->UnknownDim())); 137 return Status::OK(); 138 }) 139 .Doc(R"doc( 140 Find the split that has the best gain for the accumulated stats for a particular 141 feature column. 142 143 num_minibatches: A scalar, the number of times per example gradients & hessians 144 were accumulated. The stats are divided by this to get per example stats. 145 partition_ids: A rank 2 tensor of partition IDs for each dimension of feature column. 146 bucket_ids: A rank 2 tensor of buckets IDs and dimensions. 147 gradients: A rank 1 tensor of gradients. 148 hessians: A rank 1 tensor of hessians. 149 bucket_boundaries: A rank 1 tensor, thresholds that were used for bucketization. 150 class_id: A scalar, the class id for which we're building the splits. 151 feature_column_group_id: A scalar, the index of the feature we are spiltting on. 152 l1_regularization: A scalar, which specifies the l1 regularization term. 153 l2_regularization: A scalar, which specifies the l2 regularization term. 154 tree_complexity_regularization: A scalar, which specifies the tree complexity 155 regularization term. 156 min_node_weight: A scalar, minimum sum of example hessian needed in a child. 157 If a split results in a leaf node with a smaller value, the split will not 158 be considered. 159 multiclass_strategy: A scalar, specifying the multiclass handling strategy. 160 See LearnerConfig.MultiClassStrategy for valid values. 161 output_partition_ids: A rank 1 tensor, the partition IDs that we created splits 162 for. 163 gains: A rank 1 tensor, for the computed gain for the created splits. 164 split_infos: A rank 1 tensor of serialized protos which contains the 165 `SplitInfo`s. 166 )doc"); 167 168 REGISTER_OP("BuildCategoricalEqualitySplits") 169 .Input("num_minibatches: int64") 170 .Input("partition_ids: int32") 171 .Input("feature_ids: int64") 172 .Input("gradients: float32") 173 .Input("hessians: float32") 174 .Input("class_id: int32") 175 .Input("feature_column_group_id: int32") 176 .Input("bias_feature_id: int64") 177 .Input("l1_regularization: float") 178 .Input("l2_regularization: float") 179 .Input("tree_complexity_regularization: float") 180 .Input("min_node_weight: float") 181 .Input("multiclass_strategy: int32") 182 .Input("weak_learner_type: int32") 183 .Output("output_partition_ids: int32") 184 .Output("gains: float32") 185 .Output("split_infos: string") __anon295dc3ce0302(InferenceContext* c) 186 .SetShapeFn([](InferenceContext* c) { 187 DimensionHandle unused_dim; 188 ShapeHandle unused_shape; 189 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_shape)); 190 191 ShapeHandle partition_ids_shape; 192 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &partition_ids_shape)); 193 ShapeHandle bucket_ids_shape; 194 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 2, &bucket_ids_shape)); 195 ShapeHandle gradients_shape; 196 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(3), 1, &gradients_shape)); 197 TF_RETURN_IF_ERROR(c->Merge(c->Dim(partition_ids_shape, 0), 198 c->Dim(gradients_shape, 0), &unused_dim)); 199 ShapeHandle hessians_shape; 200 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(4), 1, &hessians_shape)); 201 TF_RETURN_IF_ERROR(c->Merge(c->Dim(partition_ids_shape, 0), 202 c->Dim(hessians_shape, 0), &unused_dim)); 203 c->set_output(0, c->Vector(c->UnknownDim())); 204 c->set_output(1, c->Vector(c->UnknownDim())); 205 c->set_output(2, c->Vector(c->UnknownDim())); 206 return Status::OK(); 207 }) 208 .Doc(R"doc( 209 Find the split that has the best gain for the accumulated stats. 210 211 num_minibatches: A scalar, the number of times per example gradients & hessians 212 were accumulated. The stats are divided by this to get per example stats. 213 partition_ids: A rank 1 tensor of partition IDs. 214 feature_ids: A rank 2 tensor of feature IDs and dimensions. 215 gradients: A rank 1 tensor of gradients. 216 hessians: A rank 1 tensor of hessians. 217 class_id: A scalar, the class id for which we're building the splits. 218 feature_column_group_id: A scalar, the index of the feature we are spiltting on. 219 l1_regularization: A scalar, which specifies the l1 regularization term. 220 l2_regularization: A scalar, which specifies the l2 regularization term. 221 tree_complexity_regularization: A scalar, which specifies the tree complexity 222 regularization term. 223 min_node_weight: A scalar, minimum sum of example hessian needed in a child. 224 If a split results in a leaf node with a smaller value, the split will not 225 be considered. 226 multiclass_strategy: A scalar, specifying the multiclass handling strategy. 227 See LearnerConfig.MultiClassStrategy for valid values. 228 weak_learner_type: A scalar, specifying the weak learner type to use. 229 See LearnerConfig.WeakLearnerType for valid values. 230 output_partition_ids: A rank 1 tensor, the partition IDs that we created splits 231 for. 232 gains: A rank 1 tensor, for the computed gain for the created splits. 233 split_infos: A rank 1 tensor of serialized protos which contains the 234 `SplitInfo`s. 235 )doc"); 236 237 } // namespace tensorflow 238