1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include <vector> 17 18 #include "tensorflow/core/framework/common_shape_fns.h" 19 #include "tensorflow/core/framework/op.h" 20 #include "tensorflow/core/framework/resource_mgr.h" 21 #include "tensorflow/core/framework/shape_inference.h" 22 #include "tensorflow/core/framework/tensor_shape.h" 23 24 namespace tensorflow { 25 26 using shape_inference::DimensionHandle; 27 using shape_inference::InferenceContext; 28 using shape_inference::ShapeHandle; 29 30 REGISTER_RESOURCE_HANDLE_OP(BoostedTreesEnsembleResource); 31 32 REGISTER_OP("IsBoostedTreesEnsembleInitialized") 33 .Input("tree_ensemble_handle: resource") 34 .Output("is_initialized: bool") __anone3ce78ae0102(shape_inference::InferenceContext* c) 35 .SetShapeFn([](shape_inference::InferenceContext* c) { 36 shape_inference::ShapeHandle unused_input; 37 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input)); 38 c->set_output(0, c->Scalar()); 39 return Status::OK(); 40 }); 41 42 REGISTER_OP("BoostedTreesCalculateBestGainsPerFeature") 43 .Input("node_id_range: int32") 44 .Input("stats_summary_list: num_features * float32") 45 .Input("l1: float") 46 .Input("l2: float") 47 .Input("tree_complexity: float") 48 .Input("min_node_weight: float") 49 .Attr("max_splits: int >= 1") 50 .Attr("num_features: int >= 1") // not passed but populated automatically. 51 .Output("node_ids_list: num_features * int32") 52 .Output("gains_list: num_features * float32") 53 .Output("thresholds_list: num_features * int32") 54 .Output("left_node_contribs_list: num_features * float32") 55 .Output("right_node_contribs_list: num_features * float32") __anone3ce78ae0202(shape_inference::InferenceContext* c) 56 .SetShapeFn([](shape_inference::InferenceContext* c) { 57 // Confirms the rank of the inputs and sets the shape of the outputs. 58 int max_splits; 59 int num_features; 60 TF_RETURN_IF_ERROR(c->GetAttr("max_splits", &max_splits)); 61 TF_RETURN_IF_ERROR(c->GetAttr("num_features", &num_features)); 62 shape_inference::ShapeHandle node_id_range_shape; 63 shape_inference::ShapeHandle unused_shape; 64 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &node_id_range_shape)); 65 TF_RETURN_IF_ERROR( 66 c->Merge(node_id_range_shape, c->MakeShape({2}), &unused_shape)); 67 // Checks that all stats summary entries are of the same shape. 68 shape_inference::ShapeHandle summary_shape_base; 69 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 3, &summary_shape_base)); 70 TF_RETURN_IF_ERROR(c->Merge(summary_shape_base, 71 c->MakeShape({max_splits, -1, 2}), 72 &unused_shape)); 73 for (int i = 1; i < num_features; ++i) { 74 shape_inference::ShapeHandle summary_shape; 75 TF_RETURN_IF_ERROR(c->WithRank(c->input(1 + i), 3, &summary_shape)); 76 TF_RETURN_IF_ERROR( 77 c->Merge(summary_shape_base, summary_shape, &unused_shape)); 78 } 79 TF_RETURN_IF_ERROR( 80 c->WithRank(c->input(num_features + 1), 0, &unused_shape)); 81 TF_RETURN_IF_ERROR( 82 c->WithRank(c->input(num_features + 2), 0, &unused_shape)); 83 TF_RETURN_IF_ERROR( 84 c->WithRank(c->input(num_features + 3), 0, &unused_shape)); 85 // Sets the output lists. 86 std::vector<shape_inference::ShapeHandle> output_shapes_vec( 87 num_features, c->MakeShape({-1})); 88 TF_RETURN_IF_ERROR(c->set_output("node_ids_list", output_shapes_vec)); 89 TF_RETURN_IF_ERROR(c->set_output("gains_list", output_shapes_vec)); 90 TF_RETURN_IF_ERROR(c->set_output("thresholds_list", output_shapes_vec)); 91 std::vector<shape_inference::ShapeHandle> output_shapes_contribs( 92 num_features, c->MakeShape({-1, 1})); 93 TF_RETURN_IF_ERROR( 94 c->set_output("left_node_contribs_list", output_shapes_contribs)); 95 TF_RETURN_IF_ERROR( 96 c->set_output("right_node_contribs_list", output_shapes_contribs)); 97 return Status::OK(); 98 }); 99 100 REGISTER_OP("BoostedTreesCreateEnsemble") 101 .Input("tree_ensemble_handle: resource") 102 .Input("stamp_token: int64") 103 .Input("tree_ensemble_serialized: string") __anone3ce78ae0302(shape_inference::InferenceContext* c) 104 .SetShapeFn([](shape_inference::InferenceContext* c) { 105 shape_inference::ShapeHandle unused_input; 106 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input)); 107 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused_input)); 108 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused_input)); 109 return Status::OK(); 110 }); 111 112 REGISTER_OP("BoostedTreesDeserializeEnsemble") 113 .Input("tree_ensemble_handle: resource") 114 .Input("stamp_token: int64") 115 .Input("tree_ensemble_serialized: string") __anone3ce78ae0402(shape_inference::InferenceContext* c) 116 .SetShapeFn([](shape_inference::InferenceContext* c) { 117 shape_inference::ShapeHandle unused_input; 118 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input)); 119 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused_input)); 120 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused_input)); 121 return Status::OK(); 122 }); 123 124 REGISTER_OP("BoostedTreesGetEnsembleStates") 125 .Input("tree_ensemble_handle: resource") 126 .Output("stamp_token: int64") 127 .Output("num_trees: int32") 128 .Output("num_finalized_trees: int32") 129 .Output("num_attempted_layers: int32") 130 .Output("last_layer_nodes_range: int32") __anone3ce78ae0502(shape_inference::InferenceContext* c) 131 .SetShapeFn([](shape_inference::InferenceContext* c) { 132 shape_inference::ShapeHandle unused_input; 133 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input)); 134 c->set_output(0, c->Scalar()); 135 c->set_output(1, c->Scalar()); 136 c->set_output(2, c->Scalar()); 137 c->set_output(3, c->Scalar()); 138 c->set_output(4, c->Vector(2)); 139 return Status::OK(); 140 }); 141 142 REGISTER_OP("BoostedTreesMakeStatsSummary") 143 .Input("node_ids: int32") 144 .Input("gradients: float") 145 .Input("hessians: float") 146 .Input("bucketized_features_list: num_features * int32") 147 .Attr("max_splits: int >= 1") 148 .Attr("num_buckets: int >= 1") 149 .Attr("num_features: int >= 1") 150 .Output("stats_summary: float") __anone3ce78ae0602(shape_inference::InferenceContext* c) 151 .SetShapeFn([](shape_inference::InferenceContext* c) { 152 // Sets the shape of the output as a Rank 4 Tensor. 153 int max_splits; 154 int num_buckets; 155 int num_features; 156 TF_RETURN_IF_ERROR(c->GetAttr("max_splits", &max_splits)); 157 TF_RETURN_IF_ERROR(c->GetAttr("num_buckets", &num_buckets)); 158 TF_RETURN_IF_ERROR(c->GetAttr("num_features", &num_features)); 159 shape_inference::ShapeHandle node_ids_shape; 160 shape_inference::ShapeHandle gradients_shape; 161 shape_inference::ShapeHandle hessians_shape; 162 shape_inference::ShapeHandle bucketized_feature_shape; 163 shape_inference::ShapeHandle unused_shape; 164 shape_inference::DimensionHandle unused_dim; 165 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &node_ids_shape)); 166 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &gradients_shape)); 167 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 2, &hessians_shape)); 168 TF_RETURN_IF_ERROR(c->Merge(c->Dim(node_ids_shape, 0), 169 c->Dim(gradients_shape, 0), &unused_dim)); 170 TF_RETURN_IF_ERROR( 171 c->Merge(gradients_shape, hessians_shape, &unused_shape)); 172 for (int f = 0; f < num_features; ++f) { 173 TF_RETURN_IF_ERROR( 174 c->WithRank(c->input(3 + f), 1, &bucketized_feature_shape)); 175 TF_RETURN_IF_ERROR(c->Merge(c->Dim(node_ids_shape, 0), 176 c->Dim(bucketized_feature_shape, 0), 177 &unused_dim)); 178 } 179 c->set_output(0, 180 c->MakeShape({num_features, max_splits, num_buckets, 2})); 181 return Status::OK(); 182 }); 183 184 // TODO(nponomareva): when/if creating the new op for unbucketized data, rename 185 // bucketized_features to features. 186 REGISTER_OP("BoostedTreesPredict") 187 .Input("tree_ensemble_handle: resource") 188 .Input("bucketized_features: num_bucketized_features * int32") 189 .Attr("num_bucketized_features: int >= 1") // Inferred. 190 .Attr("logits_dimension: int") 191 .Output("logits: float") __anone3ce78ae0702(shape_inference::InferenceContext* c) 192 .SetShapeFn([](shape_inference::InferenceContext* c) { 193 shape_inference::ShapeHandle feature_shape; 194 int num_bucketized_features; 195 TF_RETURN_IF_ERROR( 196 c->GetAttr("num_bucketized_features", &num_bucketized_features)); 197 shape_inference::ShapeHandle unused_input; 198 for (int i = 0; i < num_bucketized_features; ++i) { 199 TF_RETURN_IF_ERROR(c->WithRank(c->input(i + 1), 1, &feature_shape)); 200 // Check that the shapes of all bucketized features are the same. 201 TF_RETURN_IF_ERROR(c->Merge(c->input(1), feature_shape, &unused_input)); 202 } 203 204 int logits_dimension; 205 TF_RETURN_IF_ERROR(c->GetAttr("logits_dimension", &logits_dimension)); 206 auto logits_shape = 207 c->MakeShape({c->Dim(feature_shape, 0), logits_dimension}); 208 // Logits. 209 c->set_output(0, logits_shape); 210 return Status::OK(); 211 }); 212 213 REGISTER_OP("BoostedTreesExampleDebugOutputs") 214 .Input("tree_ensemble_handle: resource") 215 .Input("bucketized_features: num_bucketized_features * int32") 216 .Attr("num_bucketized_features: int >= 1") // Inferred. 217 .Attr("logits_dimension: int") 218 .Output("examples_debug_outputs_serialized: string") __anone3ce78ae0802(shape_inference::InferenceContext* c) 219 .SetShapeFn([](shape_inference::InferenceContext* c) { 220 shape_inference::ShapeHandle feature_shape; 221 int num_bucketized_features; 222 TF_RETURN_IF_ERROR( 223 c->GetAttr("num_bucketized_features", &num_bucketized_features)); 224 shape_inference::ShapeHandle unused_input; 225 for (int i = 0; i < num_bucketized_features; ++i) { 226 TF_RETURN_IF_ERROR(c->WithRank(c->input(i + 1), 1, &feature_shape)); 227 // Check that the shapes of all bucketized features are the same. 228 TF_RETURN_IF_ERROR(c->Merge(c->input(1), feature_shape, &unused_input)); 229 } 230 231 // Multi-class will be supported by modifying the proto. 232 auto batch_size = c->MakeShape({c->Dim(feature_shape, 0)}); 233 c->set_output(0, batch_size); 234 return Status::OK(); 235 }); 236 237 REGISTER_OP("BoostedTreesSerializeEnsemble") 238 .Input("tree_ensemble_handle: resource") 239 .Output("stamp_token: int64") 240 .Output("tree_ensemble_serialized: string") __anone3ce78ae0902(shape_inference::InferenceContext* c) 241 .SetShapeFn([](shape_inference::InferenceContext* c) { 242 shape_inference::ShapeHandle unused_input; 243 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input)); 244 c->set_output(0, c->Scalar()); 245 c->set_output(1, c->Scalar()); 246 return Status::OK(); 247 }); 248 249 REGISTER_OP("BoostedTreesTrainingPredict") 250 .Input("tree_ensemble_handle: resource") 251 .Input("cached_tree_ids: int32") 252 .Input("cached_node_ids: int32") 253 .Input("bucketized_features: num_bucketized_features * int32") 254 .Attr("num_bucketized_features: int >= 1") 255 .Attr("logits_dimension: int") 256 .Output("partial_logits: float") 257 .Output("tree_ids: int32") 258 .Output("node_ids: int32") __anone3ce78ae0a02(shape_inference::InferenceContext* c) 259 .SetShapeFn([](shape_inference::InferenceContext* c) { 260 shape_inference::ShapeHandle feature_shape; 261 int num_bucketized_features; 262 TF_RETURN_IF_ERROR( 263 c->GetAttr("num_bucketized_features", &num_bucketized_features)); 264 265 shape_inference::ShapeHandle unused_input; 266 for (int i = 0; i < num_bucketized_features; ++i) { 267 TF_RETURN_IF_ERROR(c->WithRank(c->input(i + 3), 1, &feature_shape)); 268 TF_RETURN_IF_ERROR( 269 c->Merge(c->input(i + 3), feature_shape, &unused_input)); 270 } 271 // all inputs/outputs except logits should have same shape. 272 TF_RETURN_IF_ERROR(c->Merge(c->input(1), feature_shape, &unused_input)); 273 TF_RETURN_IF_ERROR(c->Merge(c->input(2), feature_shape, &unused_input)); 274 275 int logits_dimension; 276 TF_RETURN_IF_ERROR(c->GetAttr("logits_dimension", &logits_dimension)); 277 auto logits_shape = 278 c->MakeShape({c->Dim(feature_shape, 0), logits_dimension}); 279 // Partial logits. 280 c->set_output(0, logits_shape); 281 // Tree ids. 282 c->set_output(1, c->MakeShape({c->Dim(feature_shape, 0)})); 283 // Node ids. 284 c->set_output(2, c->MakeShape({c->Dim(feature_shape, 0)})); 285 return Status::OK(); 286 }); 287 288 REGISTER_OP("BoostedTreesUpdateEnsemble") 289 .Input("tree_ensemble_handle: resource") 290 .Input("feature_ids: int32") 291 .Input("node_ids: num_features * int32") 292 .Input("gains: num_features * float") 293 .Input("thresholds: num_features * int32") 294 .Input("left_node_contribs: num_features * float") 295 .Input("right_node_contribs: num_features * float") 296 .Input("max_depth: int32") 297 .Input("learning_rate: float") 298 .Attr("pruning_mode: int >=0") 299 .Attr("num_features: int >= 0") // Inferred. __anone3ce78ae0b02(shape_inference::InferenceContext* c) 300 .SetShapeFn([](shape_inference::InferenceContext* c) { 301 shape_inference::ShapeHandle shape_handle; 302 int num_features; 303 TF_RETURN_IF_ERROR(c->GetAttr("num_features", &num_features)); 304 305 // Feature_ids, should be one for each feature. 306 shape_inference::ShapeHandle feature_ids_shape; 307 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &feature_ids_shape)); 308 TF_RETURN_IF_ERROR( 309 c->Merge(c->input(1), c->Vector(num_features), &shape_handle)); 310 311 for (int i = 0; i < num_features; ++i) { 312 // Node ids. 313 TF_RETURN_IF_ERROR(c->WithRank(c->input(i + 2), 1, &shape_handle)); 314 auto shape_rank_1 = c->MakeShape({c->Dim(shape_handle, 0)}); 315 auto shape_rank_2 = c->MakeShape({c->Dim(shape_handle, 0), 1}); 316 317 // Gains. 318 TF_RETURN_IF_ERROR( 319 c->WithRank(c->input(i + num_features + 2), 1, &shape_handle)); 320 // TODO(nponomareva): replace this with input("name",vector of shapes). 321 TF_RETURN_IF_ERROR(c->Merge(c->input(i + num_features + 2), 322 shape_rank_1, &shape_handle)); 323 // Thresholds. 324 TF_RETURN_IF_ERROR( 325 c->WithRank(c->input(i + num_features * 2 + 2), 1, &shape_handle)); 326 TF_RETURN_IF_ERROR(c->Merge(c->input(i + num_features * 2 + 2), 327 shape_rank_1, &shape_handle)); 328 // Left and right node contribs. 329 TF_RETURN_IF_ERROR( 330 c->WithRank(c->input(i + num_features * 3 + 2), 2, &shape_handle)); 331 TF_RETURN_IF_ERROR(c->Merge(c->input(i + num_features * 3 + 2), 332 shape_rank_2, &shape_handle)); 333 TF_RETURN_IF_ERROR( 334 c->WithRank(c->input(i + num_features * 4 + 2), 2, &shape_handle)); 335 TF_RETURN_IF_ERROR(c->Merge(c->input(i + num_features * 4 + 2), 336 shape_rank_2, &shape_handle)); 337 } 338 return Status::OK(); 339 }); 340 341 REGISTER_OP("BoostedTreesCenterBias") 342 .Input("tree_ensemble_handle: resource") 343 .Input("mean_gradients: float") 344 .Input("mean_hessians: float") 345 // Regularization-related. 346 .Input("l1: float") 347 .Input("l2: float") 348 .Output("continue_centering: bool") __anone3ce78ae0c02(shape_inference::InferenceContext* c) 349 .SetShapeFn([](shape_inference::InferenceContext* c) { 350 shape_inference::ShapeHandle gradients_shape; 351 shape_inference::ShapeHandle hessians_shape; 352 shape_inference::ShapeHandle unused_shape; 353 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &gradients_shape)); 354 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 2, &hessians_shape)); 355 TF_RETURN_IF_ERROR( 356 c->Merge(gradients_shape, hessians_shape, &unused_shape)); 357 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused_shape)); 358 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused_shape)); 359 360 c->set_output(0, c->Scalar()); 361 return Status::OK(); 362 }); 363 364 REGISTER_RESOURCE_HANDLE_OP(BoostedTreesQuantileStreamResource); 365 366 REGISTER_OP("IsBoostedTreesQuantileStreamResourceInitialized") 367 .Input("quantile_stream_resource_handle: resource") 368 .Output("is_initialized: bool") __anone3ce78ae0d02(shape_inference::InferenceContext* c) 369 .SetShapeFn([](shape_inference::InferenceContext* c) { 370 shape_inference::ShapeHandle unused_input; 371 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input)); 372 c->set_output(0, c->Scalar()); 373 return Status::OK(); 374 }); 375 376 REGISTER_OP("BoostedTreesCreateQuantileStreamResource") 377 .Attr("max_elements: int = 1099511627776") // 1 << 40 378 .Input("quantile_stream_resource_handle: resource") 379 .Input("epsilon: float") 380 .Input("num_streams: int64") __anone3ce78ae0e02(shape_inference::InferenceContext* c) 381 .SetShapeFn([](shape_inference::InferenceContext* c) { 382 shape_inference::ShapeHandle unused_input; 383 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input)); 384 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused_input)); 385 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused_input)); 386 return Status::OK(); 387 }); 388 389 REGISTER_OP("BoostedTreesMakeQuantileSummaries") 390 .Attr("num_features: int >= 0") 391 .Input("float_values: num_features * float") 392 .Input("example_weights: float") 393 .Input("epsilon: float") 394 .Output("summaries: num_features * float") __anone3ce78ae0f02(InferenceContext* c) 395 .SetShapeFn([](InferenceContext* c) { 396 int num_features; 397 TF_RETURN_IF_ERROR(c->GetAttr("num_features", &num_features)); 398 ShapeHandle example_weights_shape; 399 TF_RETURN_IF_ERROR( 400 c->WithRank(c->input(num_features), 1, &example_weights_shape)); 401 for (int i = 0; i < num_features; ++i) { 402 ShapeHandle feature_shape; 403 TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &feature_shape)); 404 // the columns are value, weight, min_rank, max_rank. 405 c->set_output(i, c->MakeShape({c->UnknownDim(), 4})); 406 } 407 // epsilon must be a scalar. 408 ShapeHandle unused_input; 409 TF_RETURN_IF_ERROR( 410 c->WithRank(c->input(num_features + 1), 0, &unused_input)); 411 return Status::OK(); 412 }); 413 414 REGISTER_OP("BoostedTreesQuantileStreamResourceAddSummaries") 415 .Attr("num_features: int >= 0") 416 .Input("quantile_stream_resource_handle: resource") 417 .Input("summaries: num_features * float") __anone3ce78ae1002(InferenceContext* c) 418 .SetShapeFn([](InferenceContext* c) { 419 int num_features; 420 TF_RETURN_IF_ERROR(c->GetAttr("num_features", &num_features)); 421 // resource handle must be a scalar. 422 shape_inference::ShapeHandle unused_input; 423 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input)); 424 // each summary must be rank 2. 425 for (int i = 1; i < num_features + 1; i++) { 426 TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 2, &unused_input)); 427 } 428 return Status::OK(); 429 }); 430 431 REGISTER_OP("BoostedTreesQuantileStreamResourceDeserialize") 432 .Attr("num_streams: int") 433 .Input("quantile_stream_resource_handle: resource") 434 .Input("bucket_boundaries: num_streams * float") __anone3ce78ae1102(shape_inference::InferenceContext* c) 435 .SetShapeFn([](shape_inference::InferenceContext* c) { 436 shape_inference::ShapeHandle unused_input; 437 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input)); 438 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused_input)); 439 return Status::OK(); 440 }); 441 442 REGISTER_OP("BoostedTreesQuantileStreamResourceFlush") 443 .Attr("generate_quantiles: bool = False") 444 .Input("quantile_stream_resource_handle: resource") 445 .Input("num_buckets: int64") __anone3ce78ae1202(InferenceContext* c) 446 .SetShapeFn([](InferenceContext* c) { 447 // All the inputs are scalars. 448 shape_inference::ShapeHandle unused_input; 449 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input)); 450 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused_input)); 451 return Status::OK(); 452 }); 453 454 REGISTER_OP("BoostedTreesQuantileStreamResourceGetBucketBoundaries") 455 .Attr("num_features: int >= 0") 456 .Input("quantile_stream_resource_handle: resource") 457 .Output("bucket_boundaries: num_features * float") __anone3ce78ae1302(InferenceContext* c) 458 .SetShapeFn([](InferenceContext* c) { 459 int num_features; 460 TF_RETURN_IF_ERROR(c->GetAttr("num_features", &num_features)); 461 shape_inference::ShapeHandle unused_input; 462 // resource handle must be a scalar. 463 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input)); 464 for (int i = 0; i < num_features; i++) { 465 c->set_output(i, c->Vector(c->UnknownDim())); 466 } 467 return Status::OK(); 468 }); 469 470 REGISTER_OP("BoostedTreesBucketize") 471 .Attr("num_features: int >= 0") 472 .Input("float_values: num_features * float") 473 .Input("bucket_boundaries: num_features * float") 474 .Output("buckets: num_features * int32") __anone3ce78ae1402(InferenceContext* c) 475 .SetShapeFn([](InferenceContext* c) { 476 int num_features; 477 TF_RETURN_IF_ERROR(c->GetAttr("num_features", &num_features)); 478 ShapeHandle feature_shape; 479 DimensionHandle unused_dim; 480 for (int i = 0; i < num_features; i++) { 481 TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &feature_shape)); 482 TF_RETURN_IF_ERROR(c->Merge(c->Dim(feature_shape, 0), 483 c->Dim(c->input(0), 0), &unused_dim)); 484 } 485 // Bucketized result should have same dimension as input. 486 for (int i = 0; i < num_features; i++) { 487 c->set_output(i, c->MakeShape({c->Dim(c->input(i), 0)})); 488 } 489 return Status::OK(); 490 }); 491 492 } // namespace tensorflow 493