• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 // =============================================================================
15 #include "tensorflow/core/framework/common_shape_fns.h"
16 #include "tensorflow/core/framework/op.h"
17 #include "tensorflow/core/framework/resource_mgr.h"
18 #include "tensorflow/core/framework/shape_inference.h"
19 
20 namespace tensorflow {
21 namespace boosted_trees {
22 using shape_inference::DimensionHandle;
23 using shape_inference::InferenceContext;
24 using shape_inference::ShapeHandle;
25 
26 REGISTER_RESOURCE_HANDLE_OP(StatsAccumulatorScalarResource);
27 
28 REGISTER_OP("StatsAccumulatorScalarIsInitialized")
29     .Input("stats_accumulator_handle: resource")
30     .Output("is_initialized: bool")
31     .SetShapeFn(tensorflow::shape_inference::ScalarShape)
32     .Doc(R"doc(
33 Checks whether a stats accumulator has been initialized.
34 )doc");
35 
36 REGISTER_OP("CreateStatsAccumulatorScalar")
37     .Input("stats_accumulator_handle: resource")
38     .Input("stamp_token: int64")
__anone13797730102(InferenceContext* c) 39     .SetShapeFn([](InferenceContext* c) {
40       ShapeHandle unused_input;
41       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input));
42       // stamp_token is a scalar.
43       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused_input));
44       return Status::OK();
45     })
46     .Doc(R"doc(
47 Creates a scalar stats accumulator.
48 
49 stats_accumulator_handle: handle to the stats accumulator.
50 stamp_token: Token to use as the initial value of the resource stamp.
51 )doc");
52 
53 REGISTER_OP("StatsAccumulatorScalarAdd")
54     .Attr("num_resource_handles: int >= 1")
55     .Input("stats_accumulator_handles: num_resource_handles * resource")
56     .Input("stamp_token: int64")
57     .Input("partition_ids: num_resource_handles * int32")
58     .Input("feature_ids: num_resource_handles * int64")
59     .Input("gradients: num_resource_handles * float")
60     .Input("hessians: num_resource_handles * float")
__anone13797730202(InferenceContext* c) 61     .SetShapeFn([](InferenceContext* c) {
62       int num_resource_handles;
63       TF_RETURN_IF_ERROR(
64           c->GetAttr("num_resource_handles", &num_resource_handles));
65       for (int i = 0; i < num_resource_handles; ++i) {
66         ShapeHandle unused_input;
67         DimensionHandle unused_dim;
68         TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 0, &unused_input));
69         TF_RETURN_IF_ERROR(
70             c->WithRank(c->input(num_resource_handles), 0, &unused_input));
71         ShapeHandle partition_ids_shape;
72         TF_RETURN_IF_ERROR(c->WithRank(c->input(num_resource_handles + i + 1),
73                                        1, &partition_ids_shape));
74         ShapeHandle feature_ids_shape;
75         TF_RETURN_IF_ERROR(c->WithRank(
76             c->input(num_resource_handles * 2 + i + 1), 2, &feature_ids_shape));
77         ShapeHandle gradients_shape;
78         TF_RETURN_IF_ERROR(c->WithRank(
79             c->input(num_resource_handles * 3 + i + 1), 1, &gradients_shape));
80         TF_RETURN_IF_ERROR(c->Merge(c->Dim(partition_ids_shape, 0),
81                                     c->Dim(gradients_shape, 0), &unused_dim));
82         ShapeHandle hessians_shape;
83         TF_RETURN_IF_ERROR(c->WithRank(
84             c->input(num_resource_handles * 4 + i + 1), 1, &hessians_shape));
85         TF_RETURN_IF_ERROR(c->Merge(c->Dim(partition_ids_shape, 0),
86                                     c->Dim(hessians_shape, 0), &unused_dim));
87       }
88       return Status::OK();
89     })
90     .Doc(R"doc(
91 Updates the scalar stats accumulator.
92 
93 stamp_token: Stamp token for Read/Write operations.
94              Any operation with a mismatching token will be dropped.
95 stats_accumulator_handles: A list of handles to the stats accumulator.
96 partition_ids: A list of vectors of partition_ids.
97 feature_ids: Rank 2 tensor of feature id and feature dimension ids.
98 gradients: A list of vectors of gradients for each slot in
99     <partition_id, feature_id, feature_dimension_id>.
100 hessians: A list of vectors of hessians for each slot in
101     <partition_id, feature_id, feature_dimension_id>.
102 )doc");
103 
104 REGISTER_OP("StatsAccumulatorScalarFlush")
105     .Input("stats_accumulator_handle: resource")
106     .Input("stamp_token: int64")
107     .Input("next_stamp_token: int64")
108     .Output("num_updates: int64")
109     .Output("output_partition_ids: int32")
110     .Output("output_feature_ids: int64")
111     .Output("output_gradients: float")
112     .Output("output_hessians: float")
__anone13797730302(InferenceContext* c) 113     .SetShapeFn([](InferenceContext* c) {
114       ShapeHandle unused_input;
115       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input));
116       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused_input));
117       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused_input));
118       c->set_output(0, c->Scalar());
119       c->set_output(1, c->Vector(c->UnknownDim()));
120       c->set_output(2, c->UnknownShape());
121       c->set_output(3, c->Vector(c->UnknownDim()));
122       c->set_output(4, c->Vector(c->UnknownDim()));
123       return Status::OK();
124     })
125     .Doc(R"doc(
126 Flushes the scalar stats accumulator to output and resets the internal state.
127 
128 stats_accumulator_handle: handle to the stats accumulator.
129 stamp_token: Stamp token for Read/Write operations.
130              Any operation with a mismatching token will be dropped.
131 next_stamp_token: Stamp token for the next iteration.
132 num_updates: Number of times stats were added to this accumulator since last
133     flush.
134 output_partition_ids A vector of partition_ids for the slots.
135 output_feature_ids: Rank 2 tensor of feature id and feature dimension ids.
136 output_gradients: A vector of gradients, with a value for each slot
137                   in <output_partition_id, output_feature_id>.
138 output_hessians: A vector of hessians, with a value for each slot
139                  in <output_partition_id, output_feature_id>.
140 )doc");
141 
142 REGISTER_OP("StatsAccumulatorScalarDeserialize")
143     .Input("stats_accumulator_handle: resource")
144     .Input("stamp_token: int64")
145     .Input("num_updates: int64")
146     .Input("partition_ids: int32")
147     .Input("feature_ids: int64")
148     .Input("gradients: float")
149     .Input("hessians: float")
__anone13797730402(InferenceContext* c) 150     .SetShapeFn([](InferenceContext* c) {
151       ShapeHandle unused_input;
152       DimensionHandle unused_dim;
153       // stats_accumulator_handle
154       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input));
155       // stamp_token
156       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused_input));
157       // num_updates
158       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused_input));
159       ShapeHandle partition_ids_shape;
160       TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 1, &partition_ids_shape));
161       ShapeHandle feature_ids_shape;
162       TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 2, &feature_ids_shape));
163       ShapeHandle gradients_shape;
164       TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 1, &gradients_shape));
165       TF_RETURN_IF_ERROR(c->Merge(c->Dim(partition_ids_shape, 0),
166                                   c->Dim(gradients_shape, 0), &unused_dim));
167       ShapeHandle hessians_shape;
168       TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 1, &hessians_shape));
169       TF_RETURN_IF_ERROR(c->Merge(c->Dim(partition_ids_shape, 0),
170                                   c->Dim(hessians_shape, 0), &unused_dim));
171       return Status::OK();
172     })
173     .Doc(R"doc(
174 Resets the scalar stats accumulator with the serialized state.
175 
176 stats_accumulator_handle: handle to the stats accumulator.
177 stamp_token: Stamp token for Read/Write operations.
178              Any operation with a mismatching token will be dropped.
179 num_updates: Number of times stats were added to this accumulator since last
180     flush.
181 partition_ids: A vector of partition_ids.
182 feature_ids: Rank 2 tensor of feature id and feature dimension ids.
183 gradients: A vector of gradients for each slot in <partition_id, feature_id,
184 feature_dimension_id>.
185 hessians: A vector of hessians for each slot in <partition_id, feature_id,
186 feature_dimension_id>
187 )doc");
188 
189 REGISTER_OP("StatsAccumulatorScalarSerialize")
190     .Input("stats_accumulator_handle: resource")
191     .Output("stamp_token: int64")
192     .Output("num_updates: int64")
193     .Output("output_partition_ids: int32")
194     .Output("output_feature_ids: int64")
195     .Output("output_gradients: float")
196     .Output("output_hessians: float")
__anone13797730502(InferenceContext* c) 197     .SetShapeFn([](InferenceContext* c) {
198       ShapeHandle unused_input;
199       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input));
200       // stamp_token
201       c->set_output(0, c->Scalar());
202       // num_updates
203       c->set_output(1, c->Scalar());
204       c->set_output(2, c->Vector(c->UnknownDim()));
205       c->set_output(3, c->UnknownShape());
206       c->set_output(4, c->Vector(c->UnknownDim()));
207       c->set_output(5, c->Vector(c->UnknownDim()));
208       return Status::OK();
209     })
210     .Doc(R"doc(
211 Serializes the scalar stats accumulator state.
212 
213 stats_accumulator_handle: handle to the stats accumulator.
214 stamp_token: The current stamp token for the resource.
215 num_updates: Number of times stats were added to this accumulator since last
216     flush.
217 output_partition_ids A vector of partition_ids for the slots.
218 output_feature_ids: Rank 2 tensor of feature id and feature dimension ids.
219 output_gradients: A vector of gradients, with a value for each slot
220                   in <output_partition_id, output_feature_id>.
221 output_hessians: A vector of hessians, with a value for each slot
222                  in <output_partition_id, output_feature_id>.
223 )doc");
224 
225 REGISTER_OP("StatsAccumulatorScalarMakeSummary")
226     .Input("partition_ids: int32")
227     .Input("feature_ids: int64")
228     .Input("gradients: float")
229     .Input("hessians: float")
230     .Output("output_partition_ids: int32")
231     .Output("output_feature_ids: int64")
232     .Output("output_gradients: float")
233     .Output("output_hessians: float")
234     .Doc(R"doc(
235 )doc");
236 
237 // Tensor version of the stats accumulator ops.
238 REGISTER_RESOURCE_HANDLE_OP(StatsAccumulatorTensorResource);
239 
240 REGISTER_OP("StatsAccumulatorTensorIsInitialized")
241     .Input("stats_accumulator_handle: resource")
242     .Output("is_initialized: bool")
243     .SetShapeFn(tensorflow::shape_inference::ScalarShape)
244     .Doc(R"doc(
245 Checks whether a tensor stats accumulator has been initialized.
246 )doc");
247 
248 REGISTER_OP("CreateStatsAccumulatorTensor")
249     .Input("stats_accumulator_handle: resource")
250     .Input("stamp_token: int64")
251     .Input("per_slot_gradient_shape: int64")
252     .Input("per_slot_hessian_shape: int64")
__anone13797730602(InferenceContext* c) 253     .SetShapeFn([](InferenceContext* c) {
254       ShapeHandle unused_input;
255       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input));
256       // stamp_token is a scalar.
257       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused_input));
258       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused_input));
259       TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 1, &unused_input));
260       return Status::OK();
261     })
262     .Doc(R"doc(
263 Creates a tensor stats accumulator.
264 
265 stats_accumulator_handle: handle to the tree ensemble resource to be created.
266 stamp_token: Token to use as the initial value of the resource stamp.
267 per_slot_gradient_shape: a vector that defines the shape of gradients.
268 per_slot_hessian_shape:  a vector that defines the shape of hessians.
269 )doc");
270 
271 REGISTER_OP("StatsAccumulatorTensorAdd")
272     .Attr("num_resource_handles: int >= 1")
273     .Input("stats_accumulator_handles: num_resource_handles * resource")
274     .Input("stamp_token: int64")
275     .Input("partition_ids: num_resource_handles * int32")
276     .Input("feature_ids: num_resource_handles * int64")
277     .Input("gradients: num_resource_handles * float")
278     .Input("hessians: num_resource_handles * float")
__anone13797730702(InferenceContext* c) 279     .SetShapeFn([](InferenceContext* c) {
280       int num_resource_handles;
281       TF_RETURN_IF_ERROR(
282           c->GetAttr("num_resource_handles", &num_resource_handles));
283       for (int i = 0; i < num_resource_handles; ++i) {
284         ShapeHandle unused_input;
285         DimensionHandle unused_dim;
286         TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 0, &unused_input));
287         TF_RETURN_IF_ERROR(
288             c->WithRank(c->input(num_resource_handles), 0, &unused_input));
289         ShapeHandle partition_ids_shape;
290         TF_RETURN_IF_ERROR(c->WithRank(c->input(num_resource_handles + i + 1),
291                                        1, &partition_ids_shape));
292         ShapeHandle feature_ids_shape;
293         TF_RETURN_IF_ERROR(c->WithRank(
294             c->input(num_resource_handles * 2 + i + 1), 2, &feature_ids_shape));
295         ShapeHandle gradients_shape;
296         TF_RETURN_IF_ERROR(c->WithRankAtLeast(
297             c->input(num_resource_handles * 3 + i + 1), 2, &gradients_shape));
298         TF_RETURN_IF_ERROR(c->Merge(c->Dim(partition_ids_shape, 0),
299                                     c->Dim(gradients_shape, 0), &unused_dim));
300         ShapeHandle hessians_shape;
301         TF_RETURN_IF_ERROR(c->WithRankAtLeast(
302             c->input(num_resource_handles * 4 + i + 1), 2, &hessians_shape));
303         TF_RETURN_IF_ERROR(c->Merge(c->Dim(partition_ids_shape, 0),
304                                     c->Dim(hessians_shape, 0), &unused_dim));
305       }
306       return Status::OK();
307     })
308     .Doc(R"doc(
309 Updates the tensor stats accumulator.
310 
311 stats_accumulator_handles: A list of handles to the stats accumulator.
312 stamp_token: Stamp token for Read/Write operations.
313              Any operation with a mismatching token will be dropped.
314 partition_ids: A list of vectors of partition_ids.
315 feature_ids: Rank 2 tensor of feature id and feature dimension ids.
316 gradients: A list of vectors of gradients for each slot in
317     <partition_id, feature_id, feature_dimension_id>.
318 hessians: A list of vectors of hessians for each slot in
319     <partition_id, feature_id, feature_dimension_id>.
320 )doc");
321 
322 REGISTER_OP("StatsAccumulatorTensorFlush")
323     .Input("stats_accumulator_handle: resource")
324     .Input("stamp_token: int64")
325     .Input("next_stamp_token: int64")
326     .Output("num_updates: int64")
327     .Output("output_partition_ids: int32")
328     .Output("output_feature_ids: int64")
329     .Output("output_gradients: float")
330     .Output("output_hessians: float")
__anone13797730802(InferenceContext* c) 331     .SetShapeFn([](InferenceContext* c) {
332       ShapeHandle unused_input;
333       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input));
334       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused_input));
335       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused_input));
336       // num_updates
337       c->set_output(0, c->Scalar());
338       c->set_output(1, c->Vector(c->UnknownDim()));
339       c->set_output(2, c->UnknownShape());
340       c->set_output(3, c->UnknownShape());
341       c->set_output(4, c->UnknownShape());
342       return Status::OK();
343     })
344     .Doc(R"doc(
345 Flushes the stats accumulator to output and resets the internal state.
346 
347 stats_accumulator_handle: handle to the tree ensemble resource to be created.
348 stamp_token: Stamp token for Read/Write operations.
349              Any operation with a mismatching token will be dropped.
350 next_stamp_token: Stamp token to be used for the next iteration.
351 num_updates: Number of times stats were added to this accumulator since last
352     flush.
353 output_partition_ids: A vector of partition_ids for the slots.
354 output_feature_ids: Rank 2 tensor of feature id and feature dimension ids.
355 output_gradients: A tensor of gradients, first dimension matches slots
356                   in <partition_id, feature_id, feature_dimension_id>.
357 output_hessians: A tensor of hessians, first dimension matches slots
358                  in <partition_id, feature_id, feature_dimension_id>>.
359 )doc");
360 
361 REGISTER_OP("StatsAccumulatorTensorDeserialize")
362     .Input("stats_accumulator_handle: resource")
363     .Input("stamp_token: int64")
364     .Input("num_updates: int64")
365     .Input("partition_ids: int32")
366     .Input("feature_ids: int64")
367     .Input("gradients: float")
368     .Input("hessians: float")
__anone13797730902(InferenceContext* c) 369     .SetShapeFn([](InferenceContext* c) {
370       ShapeHandle unused_input;
371       DimensionHandle unused_dim;
372       // stats_accumulator_handle
373       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input));
374       // stamp_token
375       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused_input));
376       // num_updates
377       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused_input));
378       ShapeHandle partition_ids_shape;
379       TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 1, &partition_ids_shape));
380       ShapeHandle feature_ids_shape;
381       TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 2, &feature_ids_shape));
382       ShapeHandle gradients_shape;
383       TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(5), 2, &gradients_shape));
384       TF_RETURN_IF_ERROR(c->Merge(c->Dim(partition_ids_shape, 0),
385                                   c->Dim(gradients_shape, 0), &unused_dim));
386       ShapeHandle hessians_shape;
387       TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(6), 2, &hessians_shape));
388       TF_RETURN_IF_ERROR(c->Merge(c->Dim(partition_ids_shape, 0),
389                                   c->Dim(hessians_shape, 0), &unused_dim));
390 
391       return Status::OK();
392     })
393     .Doc(R"doc(
394 Resets the tensor stats accumulator with the serialized state.
395 
396 stats_accumulator_handle: handle to the tree ensemble resource to be created.
397 stamp_token: Stamp token for Read/Write operations.
398              Any operation with a mismatching token will be dropped.
399 num_updates: Number of times stats were added to this accumulator since last
400     flush.
401 partition_ids: A vector of partition_ids.
402 feature_ids: Rank 2 tensor of feature id and feature dimension ids.
403 gradients: A vector of gradients for each slot in <partition_id, feature_id,
404 feature_dimension_id>
405 hessians: A vector of hessians for each slot in <partition_id, feature_id,
406 feature_dimension_id>.
407 )doc");
408 
409 REGISTER_OP("StatsAccumulatorTensorSerialize")
410     .Input("stats_accumulator_handle: resource")
411     .Output("stamp_token: int64")
412     .Output("num_updates: int64")
413     .Output("output_partition_ids: int32")
414     .Output("output_feature_ids: int64")
415     .Output("output_gradients: float")
416     .Output("output_hessians: float")
__anone13797730a02(InferenceContext* c) 417     .SetShapeFn([](InferenceContext* c) {
418       ShapeHandle unused_input;
419       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input));
420       // stamp_token
421       c->set_output(0, c->Scalar());
422       // num_updates
423       c->set_output(1, c->Scalar());
424       c->set_output(2, c->Vector(c->UnknownDim()));
425       c->set_output(3, c->UnknownShape());
426       c->set_output(4, c->UnknownShape());
427       c->set_output(5, c->UnknownShape());
428       return Status::OK();
429     })
430     .Doc(R"doc(
431 Serializes the scalar stats accumulator state.
432 
433 stats_accumulator_handle: handle to the tree ensemble resource to be created.
434 stamp_token: Stamp token for Read/Write operations.
435              Any operation with a mismatching token will be dropped.
436 num_updates: Number of times stats were added to this accumulator since last
437     flush.
438 output_partition_ids: A vector of partition_ids for the slots.
439 output_feature_ids: Rank 2 tensor of feature id and feature dimension ids.
440 output_gradients: A tensor of gradients, first dimension matches slots
441                   in <partition_id, feature_id, feature_dimension_id>.
442 output_hessians: A tensor of hessians, first dimension matches slots
443                  in <partition_id, feature_id, feature_dimension_id>.
444 )doc");
445 
446 REGISTER_OP("StatsAccumulatorTensorMakeSummary")
447     .Input("partition_ids: int32")
448     .Input("feature_ids: int64")
449     .Input("gradients: float")
450     .Input("hessians: float")
451     .Output("output_partition_ids: int32")
452     .Output("output_feature_ids: int64")
453     .Output("output_gradients: float")
454     .Output("output_hessians: float")
455     .Doc(R"doc(
456 Summarizes the stats by summing the <gradients, hessians> that are for the same
457 <partition_id, feature_id, feature_dimension_id>.
458 
459 partition_ids: A vector of partition_ids.
460 feature_ids: Rank 2 tensor of feature id and feature dimension ids.
461 gradients: A vector of gradients for each slot in <partition_id, feature_id,
462 feature_dimension_id>.
463 hessians: A vector of hessians for each slot in <partition_id, feature_id,
464 feature_dimension_id>.
465 output_partition_ids: A vector of partition_ids for the slots.
466 output_feature_ids: A rank2 tensor of feature_ids and dimensions for the slots.
467 output_gradients: A tensor of gradients, first dimension matches slots
468                   in <partition_id, feature_id, feature_dimension_id>.
469 output_hessians: A tensor of hessians, first dimension matches slots
470                  in <partition_id, feature_id, feature_dimension_id>.
471 )doc");
472 }  // namespace boosted_trees
473 }  // namespace tensorflow
474