• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <vector>
17 
18 #include "tensorflow/core/framework/common_shape_fns.h"
19 #include "tensorflow/core/framework/op.h"
20 #include "tensorflow/core/framework/resource_mgr.h"
21 #include "tensorflow/core/framework/shape_inference.h"
22 #include "tensorflow/core/framework/tensor_shape.h"
23 
24 namespace tensorflow {
25 
26 using shape_inference::DimensionHandle;
27 using shape_inference::InferenceContext;
28 using shape_inference::ShapeHandle;
29 
30 REGISTER_RESOURCE_HANDLE_OP(BoostedTreesEnsembleResource);
31 
32 REGISTER_OP("IsBoostedTreesEnsembleInitialized")
33     .Input("tree_ensemble_handle: resource")
34     .Output("is_initialized: bool")
__anone3ce78ae0102(shape_inference::InferenceContext* c) 35     .SetShapeFn([](shape_inference::InferenceContext* c) {
36       shape_inference::ShapeHandle unused_input;
37       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input));
38       c->set_output(0, c->Scalar());
39       return Status::OK();
40     });
41 
42 REGISTER_OP("BoostedTreesCalculateBestGainsPerFeature")
43     .Input("node_id_range: int32")
44     .Input("stats_summary_list: num_features * float32")
45     .Input("l1: float")
46     .Input("l2: float")
47     .Input("tree_complexity: float")
48     .Input("min_node_weight: float")
49     .Attr("max_splits: int >= 1")
50     .Attr("num_features: int >= 1")  // not passed but populated automatically.
51     .Output("node_ids_list: num_features * int32")
52     .Output("gains_list: num_features * float32")
53     .Output("thresholds_list: num_features * int32")
54     .Output("left_node_contribs_list: num_features * float32")
55     .Output("right_node_contribs_list: num_features * float32")
__anone3ce78ae0202(shape_inference::InferenceContext* c) 56     .SetShapeFn([](shape_inference::InferenceContext* c) {
57       // Confirms the rank of the inputs and sets the shape of the outputs.
58       int max_splits;
59       int num_features;
60       TF_RETURN_IF_ERROR(c->GetAttr("max_splits", &max_splits));
61       TF_RETURN_IF_ERROR(c->GetAttr("num_features", &num_features));
62       shape_inference::ShapeHandle node_id_range_shape;
63       shape_inference::ShapeHandle unused_shape;
64       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &node_id_range_shape));
65       TF_RETURN_IF_ERROR(
66           c->Merge(node_id_range_shape, c->MakeShape({2}), &unused_shape));
67       // Checks that all stats summary entries are of the same shape.
68       shape_inference::ShapeHandle summary_shape_base;
69       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 3, &summary_shape_base));
70       TF_RETURN_IF_ERROR(c->Merge(summary_shape_base,
71                                   c->MakeShape({max_splits, -1, 2}),
72                                   &unused_shape));
73       for (int i = 1; i < num_features; ++i) {
74         shape_inference::ShapeHandle summary_shape;
75         TF_RETURN_IF_ERROR(c->WithRank(c->input(1 + i), 3, &summary_shape));
76         TF_RETURN_IF_ERROR(
77             c->Merge(summary_shape_base, summary_shape, &unused_shape));
78       }
79       TF_RETURN_IF_ERROR(
80           c->WithRank(c->input(num_features + 1), 0, &unused_shape));
81       TF_RETURN_IF_ERROR(
82           c->WithRank(c->input(num_features + 2), 0, &unused_shape));
83       TF_RETURN_IF_ERROR(
84           c->WithRank(c->input(num_features + 3), 0, &unused_shape));
85       // Sets the output lists.
86       std::vector<shape_inference::ShapeHandle> output_shapes_vec(
87           num_features, c->MakeShape({-1}));
88       TF_RETURN_IF_ERROR(c->set_output("node_ids_list", output_shapes_vec));
89       TF_RETURN_IF_ERROR(c->set_output("gains_list", output_shapes_vec));
90       TF_RETURN_IF_ERROR(c->set_output("thresholds_list", output_shapes_vec));
91       std::vector<shape_inference::ShapeHandle> output_shapes_contribs(
92           num_features, c->MakeShape({-1, 1}));
93       TF_RETURN_IF_ERROR(
94           c->set_output("left_node_contribs_list", output_shapes_contribs));
95       TF_RETURN_IF_ERROR(
96           c->set_output("right_node_contribs_list", output_shapes_contribs));
97       return Status::OK();
98     });
99 
100 REGISTER_OP("BoostedTreesCreateEnsemble")
101     .Input("tree_ensemble_handle: resource")
102     .Input("stamp_token: int64")
103     .Input("tree_ensemble_serialized: string")
__anone3ce78ae0302(shape_inference::InferenceContext* c) 104     .SetShapeFn([](shape_inference::InferenceContext* c) {
105       shape_inference::ShapeHandle unused_input;
106       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input));
107       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused_input));
108       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused_input));
109       return Status::OK();
110     });
111 
112 REGISTER_OP("BoostedTreesDeserializeEnsemble")
113     .Input("tree_ensemble_handle: resource")
114     .Input("stamp_token: int64")
115     .Input("tree_ensemble_serialized: string")
__anone3ce78ae0402(shape_inference::InferenceContext* c) 116     .SetShapeFn([](shape_inference::InferenceContext* c) {
117       shape_inference::ShapeHandle unused_input;
118       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input));
119       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused_input));
120       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused_input));
121       return Status::OK();
122     });
123 
124 REGISTER_OP("BoostedTreesGetEnsembleStates")
125     .Input("tree_ensemble_handle: resource")
126     .Output("stamp_token: int64")
127     .Output("num_trees: int32")
128     .Output("num_finalized_trees: int32")
129     .Output("num_attempted_layers: int32")
130     .Output("last_layer_nodes_range: int32")
__anone3ce78ae0502(shape_inference::InferenceContext* c) 131     .SetShapeFn([](shape_inference::InferenceContext* c) {
132       shape_inference::ShapeHandle unused_input;
133       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input));
134       c->set_output(0, c->Scalar());
135       c->set_output(1, c->Scalar());
136       c->set_output(2, c->Scalar());
137       c->set_output(3, c->Scalar());
138       c->set_output(4, c->Vector(2));
139       return Status::OK();
140     });
141 
142 REGISTER_OP("BoostedTreesMakeStatsSummary")
143     .Input("node_ids: int32")
144     .Input("gradients: float")
145     .Input("hessians: float")
146     .Input("bucketized_features_list: num_features * int32")
147     .Attr("max_splits: int >= 1")
148     .Attr("num_buckets: int >= 1")
149     .Attr("num_features: int >= 1")
150     .Output("stats_summary: float")
__anone3ce78ae0602(shape_inference::InferenceContext* c) 151     .SetShapeFn([](shape_inference::InferenceContext* c) {
152       // Sets the shape of the output as a Rank 4 Tensor.
153       int max_splits;
154       int num_buckets;
155       int num_features;
156       TF_RETURN_IF_ERROR(c->GetAttr("max_splits", &max_splits));
157       TF_RETURN_IF_ERROR(c->GetAttr("num_buckets", &num_buckets));
158       TF_RETURN_IF_ERROR(c->GetAttr("num_features", &num_features));
159       shape_inference::ShapeHandle node_ids_shape;
160       shape_inference::ShapeHandle gradients_shape;
161       shape_inference::ShapeHandle hessians_shape;
162       shape_inference::ShapeHandle bucketized_feature_shape;
163       shape_inference::ShapeHandle unused_shape;
164       shape_inference::DimensionHandle unused_dim;
165       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &node_ids_shape));
166       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &gradients_shape));
167       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 2, &hessians_shape));
168       TF_RETURN_IF_ERROR(c->Merge(c->Dim(node_ids_shape, 0),
169                                   c->Dim(gradients_shape, 0), &unused_dim));
170       TF_RETURN_IF_ERROR(
171           c->Merge(gradients_shape, hessians_shape, &unused_shape));
172       for (int f = 0; f < num_features; ++f) {
173         TF_RETURN_IF_ERROR(
174             c->WithRank(c->input(3 + f), 1, &bucketized_feature_shape));
175         TF_RETURN_IF_ERROR(c->Merge(c->Dim(node_ids_shape, 0),
176                                     c->Dim(bucketized_feature_shape, 0),
177                                     &unused_dim));
178       }
179       c->set_output(0,
180                     c->MakeShape({num_features, max_splits, num_buckets, 2}));
181       return Status::OK();
182     });
183 
184 // TODO(nponomareva): when/if creating the new op for unbucketized data, rename
185 // bucketized_features to features.
186 REGISTER_OP("BoostedTreesPredict")
187     .Input("tree_ensemble_handle: resource")
188     .Input("bucketized_features: num_bucketized_features * int32")
189     .Attr("num_bucketized_features: int >= 1")  // Inferred.
190     .Attr("logits_dimension: int")
191     .Output("logits: float")
__anone3ce78ae0702(shape_inference::InferenceContext* c) 192     .SetShapeFn([](shape_inference::InferenceContext* c) {
193       shape_inference::ShapeHandle feature_shape;
194       int num_bucketized_features;
195       TF_RETURN_IF_ERROR(
196           c->GetAttr("num_bucketized_features", &num_bucketized_features));
197       shape_inference::ShapeHandle unused_input;
198       for (int i = 0; i < num_bucketized_features; ++i) {
199         TF_RETURN_IF_ERROR(c->WithRank(c->input(i + 1), 1, &feature_shape));
200         // Check that the shapes of all bucketized features are the same.
201         TF_RETURN_IF_ERROR(c->Merge(c->input(1), feature_shape, &unused_input));
202       }
203 
204       int logits_dimension;
205       TF_RETURN_IF_ERROR(c->GetAttr("logits_dimension", &logits_dimension));
206       auto logits_shape =
207           c->MakeShape({c->Dim(feature_shape, 0), logits_dimension});
208       // Logits.
209       c->set_output(0, logits_shape);
210       return Status::OK();
211     });
212 
213 REGISTER_OP("BoostedTreesExampleDebugOutputs")
214     .Input("tree_ensemble_handle: resource")
215     .Input("bucketized_features: num_bucketized_features * int32")
216     .Attr("num_bucketized_features: int >= 1")  // Inferred.
217     .Attr("logits_dimension: int")
218     .Output("examples_debug_outputs_serialized: string")
__anone3ce78ae0802(shape_inference::InferenceContext* c) 219     .SetShapeFn([](shape_inference::InferenceContext* c) {
220       shape_inference::ShapeHandle feature_shape;
221       int num_bucketized_features;
222       TF_RETURN_IF_ERROR(
223           c->GetAttr("num_bucketized_features", &num_bucketized_features));
224       shape_inference::ShapeHandle unused_input;
225       for (int i = 0; i < num_bucketized_features; ++i) {
226         TF_RETURN_IF_ERROR(c->WithRank(c->input(i + 1), 1, &feature_shape));
227         // Check that the shapes of all bucketized features are the same.
228         TF_RETURN_IF_ERROR(c->Merge(c->input(1), feature_shape, &unused_input));
229       }
230 
231       // Multi-class will be supported by modifying the proto.
232       auto batch_size = c->MakeShape({c->Dim(feature_shape, 0)});
233       c->set_output(0, batch_size);
234       return Status::OK();
235     });
236 
237 REGISTER_OP("BoostedTreesSerializeEnsemble")
238     .Input("tree_ensemble_handle: resource")
239     .Output("stamp_token: int64")
240     .Output("tree_ensemble_serialized: string")
__anone3ce78ae0902(shape_inference::InferenceContext* c) 241     .SetShapeFn([](shape_inference::InferenceContext* c) {
242       shape_inference::ShapeHandle unused_input;
243       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input));
244       c->set_output(0, c->Scalar());
245       c->set_output(1, c->Scalar());
246       return Status::OK();
247     });
248 
249 REGISTER_OP("BoostedTreesTrainingPredict")
250     .Input("tree_ensemble_handle: resource")
251     .Input("cached_tree_ids: int32")
252     .Input("cached_node_ids: int32")
253     .Input("bucketized_features: num_bucketized_features * int32")
254     .Attr("num_bucketized_features: int >= 1")
255     .Attr("logits_dimension: int")
256     .Output("partial_logits: float")
257     .Output("tree_ids: int32")
258     .Output("node_ids: int32")
__anone3ce78ae0a02(shape_inference::InferenceContext* c) 259     .SetShapeFn([](shape_inference::InferenceContext* c) {
260       shape_inference::ShapeHandle feature_shape;
261       int num_bucketized_features;
262       TF_RETURN_IF_ERROR(
263           c->GetAttr("num_bucketized_features", &num_bucketized_features));
264 
265       shape_inference::ShapeHandle unused_input;
266       for (int i = 0; i < num_bucketized_features; ++i) {
267         TF_RETURN_IF_ERROR(c->WithRank(c->input(i + 3), 1, &feature_shape));
268         TF_RETURN_IF_ERROR(
269             c->Merge(c->input(i + 3), feature_shape, &unused_input));
270       }
271       // all inputs/outputs except logits should have same shape.
272       TF_RETURN_IF_ERROR(c->Merge(c->input(1), feature_shape, &unused_input));
273       TF_RETURN_IF_ERROR(c->Merge(c->input(2), feature_shape, &unused_input));
274 
275       int logits_dimension;
276       TF_RETURN_IF_ERROR(c->GetAttr("logits_dimension", &logits_dimension));
277       auto logits_shape =
278           c->MakeShape({c->Dim(feature_shape, 0), logits_dimension});
279       // Partial logits.
280       c->set_output(0, logits_shape);
281       // Tree ids.
282       c->set_output(1, c->MakeShape({c->Dim(feature_shape, 0)}));
283       // Node ids.
284       c->set_output(2, c->MakeShape({c->Dim(feature_shape, 0)}));
285       return Status::OK();
286     });
287 
288 REGISTER_OP("BoostedTreesUpdateEnsemble")
289     .Input("tree_ensemble_handle: resource")
290     .Input("feature_ids: int32")
291     .Input("node_ids: num_features * int32")
292     .Input("gains: num_features * float")
293     .Input("thresholds: num_features * int32")
294     .Input("left_node_contribs: num_features * float")
295     .Input("right_node_contribs: num_features * float")
296     .Input("max_depth: int32")
297     .Input("learning_rate: float")
298     .Attr("pruning_mode: int >=0")
299     .Attr("num_features: int >= 0")  // Inferred.
__anone3ce78ae0b02(shape_inference::InferenceContext* c) 300     .SetShapeFn([](shape_inference::InferenceContext* c) {
301       shape_inference::ShapeHandle shape_handle;
302       int num_features;
303       TF_RETURN_IF_ERROR(c->GetAttr("num_features", &num_features));
304 
305       // Feature_ids, should be one for each feature.
306       shape_inference::ShapeHandle feature_ids_shape;
307       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &feature_ids_shape));
308       TF_RETURN_IF_ERROR(
309           c->Merge(c->input(1), c->Vector(num_features), &shape_handle));
310 
311       for (int i = 0; i < num_features; ++i) {
312         // Node ids.
313         TF_RETURN_IF_ERROR(c->WithRank(c->input(i + 2), 1, &shape_handle));
314         auto shape_rank_1 = c->MakeShape({c->Dim(shape_handle, 0)});
315         auto shape_rank_2 = c->MakeShape({c->Dim(shape_handle, 0), 1});
316 
317         // Gains.
318         TF_RETURN_IF_ERROR(
319             c->WithRank(c->input(i + num_features + 2), 1, &shape_handle));
320         // TODO(nponomareva): replace this with input("name",vector of shapes).
321         TF_RETURN_IF_ERROR(c->Merge(c->input(i + num_features + 2),
322                                     shape_rank_1, &shape_handle));
323         // Thresholds.
324         TF_RETURN_IF_ERROR(
325             c->WithRank(c->input(i + num_features * 2 + 2), 1, &shape_handle));
326         TF_RETURN_IF_ERROR(c->Merge(c->input(i + num_features * 2 + 2),
327                                     shape_rank_1, &shape_handle));
328         // Left and right node contribs.
329         TF_RETURN_IF_ERROR(
330             c->WithRank(c->input(i + num_features * 3 + 2), 2, &shape_handle));
331         TF_RETURN_IF_ERROR(c->Merge(c->input(i + num_features * 3 + 2),
332                                     shape_rank_2, &shape_handle));
333         TF_RETURN_IF_ERROR(
334             c->WithRank(c->input(i + num_features * 4 + 2), 2, &shape_handle));
335         TF_RETURN_IF_ERROR(c->Merge(c->input(i + num_features * 4 + 2),
336                                     shape_rank_2, &shape_handle));
337       }
338       return Status::OK();
339     });
340 
341 REGISTER_OP("BoostedTreesCenterBias")
342     .Input("tree_ensemble_handle: resource")
343     .Input("mean_gradients: float")
344     .Input("mean_hessians: float")
345     // Regularization-related.
346     .Input("l1: float")
347     .Input("l2: float")
348     .Output("continue_centering: bool")
__anone3ce78ae0c02(shape_inference::InferenceContext* c) 349     .SetShapeFn([](shape_inference::InferenceContext* c) {
350       shape_inference::ShapeHandle gradients_shape;
351       shape_inference::ShapeHandle hessians_shape;
352       shape_inference::ShapeHandle unused_shape;
353       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &gradients_shape));
354       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 2, &hessians_shape));
355       TF_RETURN_IF_ERROR(
356           c->Merge(gradients_shape, hessians_shape, &unused_shape));
357       TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused_shape));
358       TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused_shape));
359 
360       c->set_output(0, c->Scalar());
361       return Status::OK();
362     });
363 
364 REGISTER_RESOURCE_HANDLE_OP(BoostedTreesQuantileStreamResource);
365 
366 REGISTER_OP("IsBoostedTreesQuantileStreamResourceInitialized")
367     .Input("quantile_stream_resource_handle: resource")
368     .Output("is_initialized: bool")
__anone3ce78ae0d02(shape_inference::InferenceContext* c) 369     .SetShapeFn([](shape_inference::InferenceContext* c) {
370       shape_inference::ShapeHandle unused_input;
371       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input));
372       c->set_output(0, c->Scalar());
373       return Status::OK();
374     });
375 
376 REGISTER_OP("BoostedTreesCreateQuantileStreamResource")
377     .Attr("max_elements: int = 1099511627776")  // 1 << 40
378     .Input("quantile_stream_resource_handle: resource")
379     .Input("epsilon: float")
380     .Input("num_streams: int64")
__anone3ce78ae0e02(shape_inference::InferenceContext* c) 381     .SetShapeFn([](shape_inference::InferenceContext* c) {
382       shape_inference::ShapeHandle unused_input;
383       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input));
384       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused_input));
385       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused_input));
386       return Status::OK();
387     });
388 
389 REGISTER_OP("BoostedTreesMakeQuantileSummaries")
390     .Attr("num_features: int >= 0")
391     .Input("float_values: num_features * float")
392     .Input("example_weights: float")
393     .Input("epsilon: float")
394     .Output("summaries: num_features * float")
__anone3ce78ae0f02(InferenceContext* c) 395     .SetShapeFn([](InferenceContext* c) {
396       int num_features;
397       TF_RETURN_IF_ERROR(c->GetAttr("num_features", &num_features));
398       ShapeHandle example_weights_shape;
399       TF_RETURN_IF_ERROR(
400           c->WithRank(c->input(num_features), 1, &example_weights_shape));
401       for (int i = 0; i < num_features; ++i) {
402         ShapeHandle feature_shape;
403         TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &feature_shape));
404         // the columns are value, weight, min_rank, max_rank.
405         c->set_output(i, c->MakeShape({c->UnknownDim(), 4}));
406       }
407       // epsilon must be a scalar.
408       ShapeHandle unused_input;
409       TF_RETURN_IF_ERROR(
410           c->WithRank(c->input(num_features + 1), 0, &unused_input));
411       return Status::OK();
412     });
413 
414 REGISTER_OP("BoostedTreesQuantileStreamResourceAddSummaries")
415     .Attr("num_features: int >= 0")
416     .Input("quantile_stream_resource_handle: resource")
417     .Input("summaries: num_features * float")
__anone3ce78ae1002(InferenceContext* c) 418     .SetShapeFn([](InferenceContext* c) {
419       int num_features;
420       TF_RETURN_IF_ERROR(c->GetAttr("num_features", &num_features));
421       // resource handle must be a scalar.
422       shape_inference::ShapeHandle unused_input;
423       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input));
424       // each summary must be rank 2.
425       for (int i = 1; i < num_features + 1; i++) {
426         TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 2, &unused_input));
427       }
428       return Status::OK();
429     });
430 
431 REGISTER_OP("BoostedTreesQuantileStreamResourceDeserialize")
432     .Attr("num_streams: int")
433     .Input("quantile_stream_resource_handle: resource")
434     .Input("bucket_boundaries: num_streams * float")
__anone3ce78ae1102(shape_inference::InferenceContext* c) 435     .SetShapeFn([](shape_inference::InferenceContext* c) {
436       shape_inference::ShapeHandle unused_input;
437       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input));
438       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused_input));
439       return Status::OK();
440     });
441 
442 REGISTER_OP("BoostedTreesQuantileStreamResourceFlush")
443     .Attr("generate_quantiles: bool = False")
444     .Input("quantile_stream_resource_handle: resource")
445     .Input("num_buckets: int64")
__anone3ce78ae1202(InferenceContext* c) 446     .SetShapeFn([](InferenceContext* c) {
447       // All the inputs are scalars.
448       shape_inference::ShapeHandle unused_input;
449       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input));
450       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused_input));
451       return Status::OK();
452     });
453 
454 REGISTER_OP("BoostedTreesQuantileStreamResourceGetBucketBoundaries")
455     .Attr("num_features: int >= 0")
456     .Input("quantile_stream_resource_handle: resource")
457     .Output("bucket_boundaries: num_features * float")
__anone3ce78ae1302(InferenceContext* c) 458     .SetShapeFn([](InferenceContext* c) {
459       int num_features;
460       TF_RETURN_IF_ERROR(c->GetAttr("num_features", &num_features));
461       shape_inference::ShapeHandle unused_input;
462       // resource handle must be a scalar.
463       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input));
464       for (int i = 0; i < num_features; i++) {
465         c->set_output(i, c->Vector(c->UnknownDim()));
466       }
467       return Status::OK();
468     });
469 
470 REGISTER_OP("BoostedTreesBucketize")
471     .Attr("num_features: int >= 0")
472     .Input("float_values: num_features * float")
473     .Input("bucket_boundaries: num_features * float")
474     .Output("buckets: num_features * int32")
__anone3ce78ae1402(InferenceContext* c) 475     .SetShapeFn([](InferenceContext* c) {
476       int num_features;
477       TF_RETURN_IF_ERROR(c->GetAttr("num_features", &num_features));
478       ShapeHandle feature_shape;
479       DimensionHandle unused_dim;
480       for (int i = 0; i < num_features; i++) {
481         TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &feature_shape));
482         TF_RETURN_IF_ERROR(c->Merge(c->Dim(feature_shape, 0),
483                                     c->Dim(c->input(0), 0), &unused_dim));
484       }
485       // Bucketized result should have same dimension as input.
486       for (int i = 0; i < num_features; i++) {
487         c->set_output(i, c->MakeShape({c->Dim(c->input(i), 0)}));
488       }
489       return Status::OK();
490     });
491 
492 }  // namespace tensorflow
493