1 // Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 // =============================================================================
15 #include <algorithm>
16 #include <iterator>
17 #include <map>
18 #include <string>
19 #include <vector>
20
21 #include "tensorflow/contrib/boosted_trees/lib/utils/parallel_for.h"
22 #include "tensorflow/contrib/boosted_trees/lib/utils/tensor_utils.h"
23 #include "tensorflow/contrib/boosted_trees/resources/stamped_resource.h"
24 #include "tensorflow/core/framework/op_kernel.h"
25 #include "tensorflow/core/framework/resource_mgr.h"
26 #include "tensorflow/core/framework/tensor.h"
27 #include "tensorflow/core/framework/tensor_shape.h"
28 #include "tensorflow/core/framework/types.h"
29 #include "tensorflow/core/lib/core/errors.h"
30 #include "tensorflow/core/lib/core/status.h"
31 #include "tensorflow/core/platform/types.h"
32 #include "tensorflow/core/util/work_sharder.h"
33
34 namespace tensorflow {
35 namespace boosted_trees {
36
37 namespace {
38 const char* const kStampTokenName = "stamp_token";
39 const char* const kNextStampTokenName = "next_stamp_token";
40
41 struct PartitionKey {
PartitionKeytensorflow::boosted_trees::__anon2081c9750111::PartitionKey42 PartitionKey() : partition_id(-1), feature_id(-1), dimension(-1) {}
43
PartitionKeytensorflow::boosted_trees::__anon2081c9750111::PartitionKey44 PartitionKey(int32 p, int64 f, int32 d)
45 : partition_id(p), feature_id(f), dimension(d) {}
46
operator ==tensorflow::boosted_trees::__anon2081c9750111::PartitionKey47 bool operator==(const PartitionKey& other) const {
48 return (partition_id == other.partition_id) &&
49 (dimension == other.dimension) && (feature_id == other.feature_id);
50 }
51
52 // Compare for PartitionKey.
53 struct Less {
operator ()tensorflow::boosted_trees::__anon2081c9750111::PartitionKey::Less54 bool operator()(const PartitionKey& a, const PartitionKey& b) const {
55 if (a.partition_id < b.partition_id) {
56 return true;
57 }
58 if ((a.partition_id == b.partition_id) && (a.dimension < b.dimension)) {
59 return true;
60 }
61 if ((a.partition_id == b.partition_id) && (a.dimension == b.dimension) &&
62 (a.feature_id < b.feature_id)) {
63 return true;
64 }
65 return false;
66 }
67 };
68
69 // Tree partition defined by traversing the tree to the leaf.
70 int32 partition_id;
71
72 // Feature column id.
73 int64 feature_id;
74
75 // Dimension within feature column.
76 int32 dimension;
77 };
78
79 template <typename GradientType, typename HessianType>
80 class StatsAccumulatorResource : public boosted_trees::StampedResource {
81 using StatsByPartition =
82 std::map<PartitionKey, std::pair<GradientType, HessianType>,
83 PartitionKey::Less>;
84
85 public:
StatsAccumulatorResource(const TensorShape & gradient_shape,const TensorShape & hessian_shape)86 StatsAccumulatorResource(const TensorShape& gradient_shape,
87 const TensorShape& hessian_shape)
88 : gradient_shape_(gradient_shape),
89 hessian_shape_(hessian_shape),
90 num_updates_(0) {
91 // If GradientType/HessianType is scalar float then the shapes should be
92 // scalar and vice versa.
93 CHECK_EQ((std::is_same<GradientType, float>::value),
94 TensorShapeUtils::IsScalar(gradient_shape));
95 CHECK_EQ((std::is_same<HessianType, float>::value),
96 TensorShapeUtils::IsScalar(hessian_shape));
97 }
98
DebugString() const99 string DebugString() const override {
100 return strings::StrCat("StatsAccumulatorResource[size=", values_.size(),
101 "]");
102 }
103
Clear()104 void Clear() {
105 values_.clear();
106 num_updates_ = 0;
107 }
108
mutex()109 tensorflow::mutex* mutex() { return &mu_; }
mutable_values()110 StatsByPartition* mutable_values() { return &values_; }
values() const111 const StatsByPartition& values() const { return values_; }
num_updates() const112 const int64& num_updates() const { return num_updates_; }
set_num_updates(int64 val)113 void set_num_updates(int64 val) { num_updates_ = val; }
gradient_shape() const114 const TensorShape& gradient_shape() const { return gradient_shape_; }
hessian_shape() const115 const TensorShape& hessian_shape() const { return hessian_shape_; }
116
117 private:
118 // Key into a specific partition to accumulate stats for the specified feature
119 // id.
120 StatsByPartition values_;
121 const TensorShape gradient_shape_;
122 const TensorShape hessian_shape_;
123 int64 num_updates_;
124 tensorflow::mutex mu_;
125 TF_DISALLOW_COPY_AND_ASSIGN(StatsAccumulatorResource);
126 };
127
128 using StatsAccumulatorScalarResource = StatsAccumulatorResource<float, float>;
129 using StatsAccumulatorTensorResource =
130 StatsAccumulatorResource<std::vector<float>, std::vector<float>>;
131
SerializeScalarAccumulatorToOutput(const StatsAccumulatorScalarResource & accumulator_resource,OpKernelContext * context)132 void SerializeScalarAccumulatorToOutput(
133 const StatsAccumulatorScalarResource& accumulator_resource,
134 OpKernelContext* context) {
135 int64 num_slots = accumulator_resource.values().size();
136 Tensor* partition_ids_t = nullptr;
137 OP_REQUIRES_OK(context, context->allocate_output("output_partition_ids",
138 TensorShape({num_slots}),
139 &partition_ids_t));
140 auto partition_ids = partition_ids_t->vec<int32>();
141
142 // Feature ids tensor has ids of feature columns and their dimensions.
143 Tensor* feature_ids_t = nullptr;
144 OP_REQUIRES_OK(context, context->allocate_output("output_feature_ids",
145 TensorShape({num_slots, 2}),
146 &feature_ids_t));
147 auto feature_ids = feature_ids_t->matrix<int64>();
148
149 Tensor* gradients_t = nullptr;
150 OP_REQUIRES_OK(
151 context, context->allocate_output(
152 "output_gradients", TensorShape({num_slots}), &gradients_t));
153 auto gradients = gradients_t->vec<float>();
154
155 Tensor* hessians_t = nullptr;
156 OP_REQUIRES_OK(
157 context, context->allocate_output("output_hessians",
158 TensorShape({num_slots}), &hessians_t));
159 auto hessians = hessians_t->vec<float>();
160
161 int i = 0;
162 for (const auto& iter : accumulator_resource.values()) {
163 partition_ids(i) = iter.first.partition_id;
164 feature_ids(i, 0) = iter.first.feature_id;
165 feature_ids(i, 1) = iter.first.dimension;
166
167 gradients(i) = iter.second.first;
168 hessians(i) = iter.second.second;
169 ++i;
170 }
171 }
172
SerializeTensorAccumulatorToOutput(const StatsAccumulatorTensorResource & accumulator_resource,OpKernelContext * context)173 void SerializeTensorAccumulatorToOutput(
174 const StatsAccumulatorTensorResource& accumulator_resource,
175 OpKernelContext* context) {
176 int64 num_slots = accumulator_resource.values().size();
177 Tensor* partition_ids_t = nullptr;
178 OP_REQUIRES_OK(context, context->allocate_output("output_partition_ids",
179 TensorShape({num_slots}),
180 &partition_ids_t));
181 auto partition_ids = partition_ids_t->vec<int32>();
182
183 Tensor* feature_ids_t = nullptr;
184 OP_REQUIRES_OK(context, context->allocate_output("output_feature_ids",
185 TensorShape({num_slots, 2}),
186 &feature_ids_t));
187 auto feature_ids = feature_ids_t->matrix<int64>();
188
189 TensorShape gradient_shape = accumulator_resource.gradient_shape();
190 int64 num_gradient_elements = gradient_shape.num_elements();
191 gradient_shape.InsertDim(0, num_slots);
192 Tensor* gradients_t = nullptr;
193 OP_REQUIRES_OK(context,
194 context->allocate_output("output_gradients", gradient_shape,
195 &gradients_t));
196 auto gradients = gradients_t->flat_outer_dims<float>();
197
198 TensorShape hessian_shape = accumulator_resource.hessian_shape();
199 int64 num_hessian_elements = hessian_shape.num_elements();
200 hessian_shape.InsertDim(0, num_slots);
201 Tensor* hessians_t = nullptr;
202 OP_REQUIRES_OK(context, context->allocate_output("output_hessians",
203 hessian_shape, &hessians_t));
204 auto hessians = hessians_t->flat_outer_dims<float>();
205
206 int i = 0;
207 for (const auto& iter : accumulator_resource.values()) {
208 partition_ids(i) = iter.first.partition_id;
209 feature_ids(i, 0) = iter.first.feature_id;
210 feature_ids(i, 1) = iter.first.dimension;
211
212 for (int j = 0; j < num_gradient_elements; ++j) {
213 gradients(i, j) = iter.second.first[j];
214 }
215 for (int j = 0; j < num_hessian_elements; ++j) {
216 hessians(i, j) = iter.second.second[j];
217 }
218 ++i;
219 }
220 }
221
AddToScalarAccumulator(StatsAccumulatorScalarResource * accumulator_resource,const Tensor & partition_ids_t,const Tensor & feature_ids_t,const Tensor & gradients_t,const Tensor & hessians_t)222 void AddToScalarAccumulator(
223 StatsAccumulatorScalarResource* accumulator_resource,
224 const Tensor& partition_ids_t, const Tensor& feature_ids_t,
225 const Tensor& gradients_t, const Tensor& hessians_t) {
226 accumulator_resource->set_num_updates(accumulator_resource->num_updates() +
227 1);
228 const TensorShape& partition_ids_shape = partition_ids_t.shape();
229 const auto& partition_ids = partition_ids_t.vec<int32>();
230 const auto& feature_ids_and_dimensions = feature_ids_t.matrix<int64>();
231 const auto& gradients = gradients_t.vec<float>();
232 const auto& hessians = hessians_t.vec<float>();
233
234 int64 num_updates = partition_ids_shape.dim_size(0);
235 auto stats_map = accumulator_resource->mutable_values();
236 for (int64 i = 0; i < num_updates; ++i) {
237 const auto key =
238 PartitionKey(partition_ids(i), feature_ids_and_dimensions(i, 0),
239 feature_ids_and_dimensions(i, 1));
240 auto itr = stats_map->find(key);
241 if (itr != stats_map->end()) {
242 itr->second.first += gradients(i);
243 itr->second.second += hessians(i);
244 } else {
245 (*stats_map)[key] = {gradients(i), hessians(i)};
246 }
247 }
248 }
249
AddToScalarAccumulator(StatsAccumulatorScalarResource * accumulator_resource,OpKernelContext * context)250 void AddToScalarAccumulator(
251 StatsAccumulatorScalarResource* accumulator_resource,
252 OpKernelContext* context) {
253 const Tensor* partition_ids_t;
254 OP_REQUIRES_OK(context, context->input("partition_ids", &partition_ids_t));
255 const Tensor* feature_ids_t;
256 OP_REQUIRES_OK(context, context->input("feature_ids", &feature_ids_t));
257 const Tensor* gradients_t;
258 OP_REQUIRES_OK(context, context->input("gradients", &gradients_t));
259 const Tensor* hessians_t;
260 OP_REQUIRES_OK(context, context->input("hessians", &hessians_t));
261 AddToScalarAccumulator(accumulator_resource, *partition_ids_t, *feature_ids_t,
262 *gradients_t, *hessians_t);
263 }
264
AddToTensorAccumulator(StatsAccumulatorTensorResource * accumulator_resource,const Tensor & partition_ids_t,const Tensor & feature_ids_t,const Tensor & gradients_t,const Tensor & hessians_t,OpKernelContext * context)265 void AddToTensorAccumulator(
266 StatsAccumulatorTensorResource* accumulator_resource,
267 const Tensor& partition_ids_t, const Tensor& feature_ids_t,
268 const Tensor& gradients_t, const Tensor& hessians_t,
269 OpKernelContext* context) {
270 accumulator_resource->set_num_updates(accumulator_resource->num_updates() +
271 1);
272
273 const TensorShape& partition_ids_shape = partition_ids_t.shape();
274 const auto& partition_ids = partition_ids_t.vec<int32>();
275 const auto& feature_ids_and_dimensions = feature_ids_t.matrix<int64>();
276 TensorShape gradients_shape = gradients_t.shape();
277 const auto& gradients = gradients_t.flat_outer_dims<float>();
278 TensorShape hessians_shape = hessians_t.shape();
279 const auto& hessians = hessians_t.flat_outer_dims<float>();
280
281 gradients_shape.RemoveDim(0);
282 hessians_shape.RemoveDim(0);
283
284 // TODO(soroush): Move gradient and hessian shape check to ShapeFn.
285 OP_REQUIRES(
286 context, gradients_shape == accumulator_resource->gradient_shape(),
287 errors::InvalidArgument(strings::StrCat(
288 "Gradients dimensions must match: ", gradients_shape.DebugString(),
289 ", ", accumulator_resource->gradient_shape().DebugString())));
290
291 OP_REQUIRES(
292 context, hessians_shape == accumulator_resource->hessian_shape(),
293 errors::InvalidArgument(strings::StrCat(
294 "Hessian dimensions must match: ", hessians_shape.DebugString(), ", ",
295 accumulator_resource->hessian_shape().DebugString())));
296
297 int64 num_updates = partition_ids_shape.dim_size(0);
298 auto stats_map = accumulator_resource->mutable_values();
299 for (int64 i = 0; i < num_updates; ++i) {
300 const auto key =
301 PartitionKey(partition_ids(i), feature_ids_and_dimensions(i, 0),
302 feature_ids_and_dimensions(i, 1));
303 auto itr = stats_map->find(key);
304 if (itr == stats_map->end()) {
305 std::vector<float> new_gradients(gradients_shape.num_elements());
306 for (int j = 0; j < gradients_shape.num_elements(); ++j) {
307 new_gradients[j] = gradients(i, j);
308 }
309 std::vector<float> new_hessians(hessians_shape.num_elements());
310 for (int j = 0; j < hessians_shape.num_elements(); ++j) {
311 new_hessians[j] = hessians(i, j);
312 }
313 (*stats_map)[key] = {new_gradients, new_hessians};
314 } else {
315 auto& stored_gradients = itr->second.first;
316 for (int j = 0; j < gradients_shape.num_elements(); ++j) {
317 stored_gradients[j] += gradients(i, j);
318 }
319 auto& stored_hessians = itr->second.second;
320 for (int j = 0; j < hessians_shape.num_elements(); ++j) {
321 stored_hessians[j] += hessians(i, j);
322 }
323 }
324 }
325 }
326
AddToTensorAccumulator(StatsAccumulatorTensorResource * accumulator_resource,OpKernelContext * context)327 void AddToTensorAccumulator(
328 StatsAccumulatorTensorResource* accumulator_resource,
329 OpKernelContext* context) {
330 const Tensor* partition_ids_t;
331 OP_REQUIRES_OK(context, context->input("partition_ids", &partition_ids_t));
332 const Tensor* feature_ids_t;
333 OP_REQUIRES_OK(context, context->input("feature_ids", &feature_ids_t));
334 const Tensor* gradients_t;
335 OP_REQUIRES_OK(context, context->input("gradients", &gradients_t));
336 const Tensor* hessians_t;
337 OP_REQUIRES_OK(context, context->input("hessians", &hessians_t));
338 AddToTensorAccumulator(accumulator_resource, *partition_ids_t, *feature_ids_t,
339 *gradients_t, *hessians_t, context);
340 }
341
342 } // namespace
343
344 REGISTER_RESOURCE_HANDLE_KERNEL(StatsAccumulatorScalarResource);
345 REGISTER_RESOURCE_HANDLE_KERNEL(StatsAccumulatorTensorResource);
346
347 REGISTER_KERNEL_BUILDER(
348 Name("StatsAccumulatorScalarIsInitialized").Device(DEVICE_CPU),
349 IsResourceInitialized<StatsAccumulatorScalarResource>);
350
351 REGISTER_KERNEL_BUILDER(
352 Name("StatsAccumulatorTensorIsInitialized").Device(DEVICE_CPU),
353 IsResourceInitialized<StatsAccumulatorTensorResource>);
354
355 class CreateStatsAccumulatorScalarOp : public OpKernel {
356 public:
CreateStatsAccumulatorScalarOp(OpKernelConstruction * context)357 explicit CreateStatsAccumulatorScalarOp(OpKernelConstruction* context)
358 : OpKernel(context) {}
359
Compute(OpKernelContext * context)360 void Compute(OpKernelContext* context) override {
361 const Tensor* stamp_token_t;
362 OP_REQUIRES_OK(context, context->input("stamp_token", &stamp_token_t));
363
364 TensorShape gradient_shape = TensorShape({});
365 TensorShape hessian_shape = TensorShape({});
366
367 auto* result =
368 new StatsAccumulatorScalarResource(gradient_shape, hessian_shape);
369 result->set_stamp(stamp_token_t->scalar<int64>()());
370 // Only create one, if one does not exist already. Report status for all
371 // other exceptions. If one already exists, it unrefs the new one.
372 auto status = CreateResource(context, HandleFromInput(context, 0), result);
373 if (!status.ok() && status.code() != tensorflow::error::ALREADY_EXISTS) {
374 OP_REQUIRES(context, false, status);
375 }
376 }
377 };
378
379 REGISTER_KERNEL_BUILDER(Name("CreateStatsAccumulatorScalar").Device(DEVICE_CPU),
380 CreateStatsAccumulatorScalarOp);
381
382 class CreateStatsAccumulatorTensorOp : public OpKernel {
383 public:
CreateStatsAccumulatorTensorOp(OpKernelConstruction * context)384 explicit CreateStatsAccumulatorTensorOp(OpKernelConstruction* context)
385 : OpKernel(context) {}
386
Compute(OpKernelContext * context)387 void Compute(OpKernelContext* context) override {
388 const Tensor* stamp_token_t;
389 OP_REQUIRES_OK(context, context->input("stamp_token", &stamp_token_t));
390
391 const Tensor* gradient_shape_t;
392 OP_REQUIRES_OK(
393 context, context->input("per_slot_gradient_shape", &gradient_shape_t));
394
395 const Tensor* hessian_shape_t;
396 OP_REQUIRES_OK(context,
397 context->input("per_slot_hessian_shape", &hessian_shape_t));
398 TensorShape gradient_shape = TensorShape(gradient_shape_t->vec<int64>());
399 TensorShape hessian_shape = TensorShape(hessian_shape_t->vec<int64>());
400 auto* result =
401 new StatsAccumulatorTensorResource(gradient_shape, hessian_shape);
402 result->set_stamp(stamp_token_t->scalar<int64>()());
403
404 // Only create one, if one does not exist already. Report status for all
405 // other exceptions. If one already exists, it unrefs the new one.
406 auto status = CreateResource(context, HandleFromInput(context, 0), result);
407 if (!status.ok() && status.code() != tensorflow::error::ALREADY_EXISTS) {
408 OP_REQUIRES(context, false, status);
409 }
410 }
411 };
412
413 REGISTER_KERNEL_BUILDER(Name("CreateStatsAccumulatorTensor").Device(DEVICE_CPU),
414 CreateStatsAccumulatorTensorOp);
415
416 class StatsAccumulatorScalarAddOp : public OpKernel {
417 public:
StatsAccumulatorScalarAddOp(OpKernelConstruction * context)418 explicit StatsAccumulatorScalarAddOp(OpKernelConstruction* context)
419 : OpKernel(context) {}
420
Compute(OpKernelContext * context)421 void Compute(OpKernelContext* context) override {
422 OpInputList resource_handle_list;
423 OP_REQUIRES_OK(context, context->input_list("stats_accumulator_handles",
424 &resource_handle_list));
425 OpInputList partition_ids_list;
426 OP_REQUIRES_OK(context,
427 context->input_list("partition_ids", &partition_ids_list));
428
429 OpInputList feature_ids_list;
430 OP_REQUIRES_OK(context,
431 context->input_list("feature_ids", &feature_ids_list));
432 OpInputList gradients_list;
433 OP_REQUIRES_OK(context, context->input_list("gradients", &gradients_list));
434 OpInputList hessians_list;
435 OP_REQUIRES_OK(context, context->input_list("hessians", &hessians_list));
436
437 const Tensor* stamp_token_t;
438 OP_REQUIRES_OK(context, context->input(kStampTokenName, &stamp_token_t));
439 int64 stamp_token = stamp_token_t->scalar<int64>()();
440
441 thread::ThreadPool* const worker_threads =
442 context->device()->tensorflow_cpu_worker_threads()->workers;
443 boosted_trees::utils::ParallelFor(
444 resource_handle_list.size(), worker_threads->NumThreads(),
445 worker_threads,
446 [&context, &resource_handle_list, &partition_ids_list,
447 &feature_ids_list, &gradients_list, &hessians_list,
448 stamp_token](int64 start, int64 end) {
449 for (int resource_handle_idx = start; resource_handle_idx < end;
450 ++resource_handle_idx) {
451 const ResourceHandle& handle =
452 resource_handle_list[resource_handle_idx]
453 .flat<ResourceHandle>()(0);
454
455 StatsAccumulatorScalarResource* accumulator_resource;
456 OP_REQUIRES_OK(context, LookupResource(context, handle,
457 &accumulator_resource));
458 mutex_lock l(*accumulator_resource->mutex());
459 core::ScopedUnref unref_me(accumulator_resource);
460
461 // If the stamp is invalid we drop the update.
462 if (!accumulator_resource->is_stamp_valid(stamp_token)) {
463 VLOG(1) << "Invalid stamp token in StatsAccumulatorScalarAddOp. "
464 << "Passed stamp token: " << stamp_token << " "
465 << "Current token: " << accumulator_resource->stamp();
466 return;
467 }
468 AddToScalarAccumulator(accumulator_resource,
469 partition_ids_list[resource_handle_idx],
470 feature_ids_list[resource_handle_idx],
471 gradients_list[resource_handle_idx],
472 hessians_list[resource_handle_idx]);
473 }
474 });
475 }
476 };
477
478 REGISTER_KERNEL_BUILDER(Name("StatsAccumulatorScalarAdd").Device(DEVICE_CPU),
479 StatsAccumulatorScalarAddOp);
480
481 class StatsAccumulatorTensorAddOp : public OpKernel {
482 public:
StatsAccumulatorTensorAddOp(OpKernelConstruction * context)483 explicit StatsAccumulatorTensorAddOp(OpKernelConstruction* context)
484 : OpKernel(context) {}
485
Compute(OpKernelContext * context)486 void Compute(OpKernelContext* context) override {
487 OpInputList resource_handle_list;
488 OP_REQUIRES_OK(context, context->input_list("stats_accumulator_handles",
489 &resource_handle_list));
490 OpInputList partition_ids_list;
491 OP_REQUIRES_OK(context,
492 context->input_list("partition_ids", &partition_ids_list));
493
494 OpInputList feature_ids_list;
495 OP_REQUIRES_OK(context,
496 context->input_list("feature_ids", &feature_ids_list));
497 OpInputList gradients_list;
498 OP_REQUIRES_OK(context, context->input_list("gradients", &gradients_list));
499 OpInputList hessians_list;
500 OP_REQUIRES_OK(context, context->input_list("hessians", &hessians_list));
501
502 const Tensor* stamp_token_t;
503 OP_REQUIRES_OK(context, context->input(kStampTokenName, &stamp_token_t));
504 int64 stamp_token = stamp_token_t->scalar<int64>()();
505
506 thread::ThreadPool* const worker_threads =
507 context->device()->tensorflow_cpu_worker_threads()->workers;
508 boosted_trees::utils::ParallelFor(
509 resource_handle_list.size(), worker_threads->NumThreads(),
510 worker_threads,
511 [&context, &resource_handle_list, &partition_ids_list,
512 &feature_ids_list, &gradients_list, &hessians_list,
513 stamp_token](int64 start, int64 end) {
514 for (int resource_handle_idx = start; resource_handle_idx < end;
515 ++resource_handle_idx) {
516 const ResourceHandle& handle =
517 resource_handle_list[resource_handle_idx]
518 .flat<ResourceHandle>()(0);
519
520 StatsAccumulatorTensorResource* accumulator_resource;
521 OP_REQUIRES_OK(context, LookupResource(context, handle,
522 &accumulator_resource));
523 mutex_lock l(*accumulator_resource->mutex());
524 core::ScopedUnref unref_me(accumulator_resource);
525
526 // If the stamp is invalid we drop the update.
527 if (!accumulator_resource->is_stamp_valid(stamp_token)) {
528 VLOG(1) << "Invalid stamp token in StatsAccumulatorScalarAddOp. "
529 << "Passed stamp token: " << stamp_token << " "
530 << "Current token: " << accumulator_resource->stamp();
531 return;
532 }
533 AddToTensorAccumulator(accumulator_resource,
534 partition_ids_list[resource_handle_idx],
535 feature_ids_list[resource_handle_idx],
536 gradients_list[resource_handle_idx],
537 hessians_list[resource_handle_idx], context);
538 }
539 });
540 }
541 };
542
543 REGISTER_KERNEL_BUILDER(Name("StatsAccumulatorTensorAdd").Device(DEVICE_CPU),
544 StatsAccumulatorTensorAddOp);
545
546 class StatsAccumulatorScalarFlushOp : public OpKernel {
547 public:
StatsAccumulatorScalarFlushOp(OpKernelConstruction * context)548 explicit StatsAccumulatorScalarFlushOp(OpKernelConstruction* context)
549 : OpKernel(context) {}
550
Compute(OpKernelContext * context)551 void Compute(OpKernelContext* context) override {
552 StatsAccumulatorScalarResource* accumulator_resource;
553 OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
554 &accumulator_resource));
555 mutex_lock l(*accumulator_resource->mutex());
556 core::ScopedUnref unref_me(accumulator_resource);
557
558 const Tensor* stamp_token_t;
559 OP_REQUIRES_OK(context, context->input(kStampTokenName, &stamp_token_t));
560 int64 stamp_token = stamp_token_t->scalar<int64>()();
561
562 // If the stamp is invalid we restart the PS. It shouldn't happen since
563 // only Chief should call this function and chief is guaranteed to be in
564 // a consistent state.
565 CHECK(accumulator_resource->is_stamp_valid(stamp_token));
566
567 const Tensor* next_stamp_token_t;
568 OP_REQUIRES_OK(context,
569 context->input(kNextStampTokenName, &next_stamp_token_t));
570 int64 next_stamp_token = next_stamp_token_t->scalar<int64>()();
571 CHECK(stamp_token != next_stamp_token);
572
573 SerializeScalarAccumulatorToOutput(*accumulator_resource, context);
574 Tensor* num_updates_t = nullptr;
575 OP_REQUIRES_OK(context,
576 context->allocate_output("num_updates", TensorShape({}),
577 &num_updates_t));
578 num_updates_t->scalar<int64>()() = accumulator_resource->num_updates();
579
580 accumulator_resource->Clear();
581 accumulator_resource->set_stamp(next_stamp_token);
582 }
583 };
584
585 REGISTER_KERNEL_BUILDER(Name("StatsAccumulatorScalarFlush").Device(DEVICE_CPU),
586 StatsAccumulatorScalarFlushOp);
587
588 class StatsAccumulatorTensorFlushOp : public OpKernel {
589 public:
StatsAccumulatorTensorFlushOp(OpKernelConstruction * context)590 explicit StatsAccumulatorTensorFlushOp(OpKernelConstruction* context)
591 : OpKernel(context) {}
592
Compute(OpKernelContext * context)593 void Compute(OpKernelContext* context) override {
594 StatsAccumulatorTensorResource* accumulator_resource;
595 OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
596 &accumulator_resource));
597 mutex_lock l(*accumulator_resource->mutex());
598 core::ScopedUnref unref_me(accumulator_resource);
599
600 const Tensor* stamp_token_t;
601 OP_REQUIRES_OK(context, context->input(kStampTokenName, &stamp_token_t));
602 int64 stamp_token = stamp_token_t->scalar<int64>()();
603
604 const Tensor* next_stamp_token_t;
605 OP_REQUIRES_OK(context,
606 context->input(kNextStampTokenName, &next_stamp_token_t));
607 int64 next_stamp_token = next_stamp_token_t->scalar<int64>()();
608
609 // If the stamp is invalid we restart the PS. It shouldn't happen since
610 // only Chief should call this function and chief is guaranteed to be in
611 // a consistent state.
612 CHECK(accumulator_resource->is_stamp_valid(stamp_token));
613 CHECK(stamp_token != next_stamp_token);
614 SerializeTensorAccumulatorToOutput(*accumulator_resource, context);
615 Tensor* num_updates_t = nullptr;
616 OP_REQUIRES_OK(context,
617 context->allocate_output("num_updates", TensorShape({}),
618 &num_updates_t));
619 num_updates_t->scalar<int64>()() = accumulator_resource->num_updates();
620 accumulator_resource->Clear();
621 accumulator_resource->set_stamp(next_stamp_token);
622 }
623 };
624
625 REGISTER_KERNEL_BUILDER(Name("StatsAccumulatorTensorFlush").Device(DEVICE_CPU),
626 StatsAccumulatorTensorFlushOp);
627
628 class StatsAccumulatorScalarDeserializeOp : public OpKernel {
629 public:
StatsAccumulatorScalarDeserializeOp(OpKernelConstruction * context)630 explicit StatsAccumulatorScalarDeserializeOp(OpKernelConstruction* context)
631 : OpKernel(context) {}
632
Compute(OpKernelContext * context)633 void Compute(OpKernelContext* context) override {
634 StatsAccumulatorScalarResource* accumulator_resource;
635 OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
636 &accumulator_resource));
637 mutex_lock l(*accumulator_resource->mutex());
638 core::ScopedUnref unref_me(accumulator_resource);
639
640 // Check the stamp token.
641 const Tensor* stamp_token_t;
642 OP_REQUIRES_OK(context, context->input(kStampTokenName, &stamp_token_t));
643 int64 stamp_token = stamp_token_t->scalar<int64>()();
644 accumulator_resource->Clear();
645 accumulator_resource->set_stamp(stamp_token);
646 AddToScalarAccumulator(accumulator_resource, context);
647 const Tensor* num_updates_t;
648 OP_REQUIRES_OK(context, context->input("num_updates", &num_updates_t));
649 accumulator_resource->set_num_updates(num_updates_t->scalar<int64>()());
650 }
651 };
652
653 REGISTER_KERNEL_BUILDER(
654 Name("StatsAccumulatorScalarDeserialize").Device(DEVICE_CPU),
655 StatsAccumulatorScalarDeserializeOp);
656
657 class StatsAccumulatorTensorDeserializeOp : public OpKernel {
658 public:
StatsAccumulatorTensorDeserializeOp(OpKernelConstruction * context)659 explicit StatsAccumulatorTensorDeserializeOp(OpKernelConstruction* context)
660 : OpKernel(context) {}
661
Compute(OpKernelContext * context)662 void Compute(OpKernelContext* context) override {
663 StatsAccumulatorTensorResource* accumulator_resource;
664 OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
665 &accumulator_resource));
666 mutex_lock l(*accumulator_resource->mutex());
667 core::ScopedUnref unref_me(accumulator_resource);
668
669 // Check the stamp token.
670 const Tensor* stamp_token_t;
671 OP_REQUIRES_OK(context, context->input(kStampTokenName, &stamp_token_t));
672 int64 stamp_token = stamp_token_t->scalar<int64>()();
673 accumulator_resource->Clear();
674 accumulator_resource->set_stamp(stamp_token);
675 AddToTensorAccumulator(accumulator_resource, context);
676 const Tensor* num_updates_t;
677 OP_REQUIRES_OK(context, context->input("num_updates", &num_updates_t));
678 accumulator_resource->set_num_updates(num_updates_t->scalar<int64>()());
679 }
680 };
681
682 REGISTER_KERNEL_BUILDER(
683 Name("StatsAccumulatorTensorDeserialize").Device(DEVICE_CPU),
684 StatsAccumulatorTensorDeserializeOp);
685
686 class StatsAccumulatorScalarSerializeOp : public OpKernel {
687 public:
StatsAccumulatorScalarSerializeOp(OpKernelConstruction * context)688 explicit StatsAccumulatorScalarSerializeOp(OpKernelConstruction* context)
689 : OpKernel(context) {}
690
Compute(OpKernelContext * context)691 void Compute(OpKernelContext* context) override {
692 StatsAccumulatorScalarResource* accumulator_resource;
693 OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
694 &accumulator_resource));
695 mutex_lock l(*accumulator_resource->mutex());
696 core::ScopedUnref unref_me(accumulator_resource);
697 SerializeScalarAccumulatorToOutput(*accumulator_resource, context);
698 Tensor* stamp_token_t = nullptr;
699 OP_REQUIRES_OK(context,
700 context->allocate_output("stamp_token", TensorShape({}),
701 &stamp_token_t));
702 stamp_token_t->scalar<int64>()() = accumulator_resource->stamp();
703
704 Tensor* num_updates_t = nullptr;
705 OP_REQUIRES_OK(context,
706 context->allocate_output("num_updates", TensorShape({}),
707 &num_updates_t));
708 num_updates_t->scalar<int64>()() = accumulator_resource->num_updates();
709 }
710 };
711
712 REGISTER_KERNEL_BUILDER(
713 Name("StatsAccumulatorScalarSerialize").Device(DEVICE_CPU),
714 StatsAccumulatorScalarSerializeOp);
715
716 class StatsAccumulatorTensorSerializeOp : public OpKernel {
717 public:
StatsAccumulatorTensorSerializeOp(OpKernelConstruction * context)718 explicit StatsAccumulatorTensorSerializeOp(OpKernelConstruction* context)
719 : OpKernel(context) {}
720
Compute(OpKernelContext * context)721 void Compute(OpKernelContext* context) override {
722 StatsAccumulatorTensorResource* accumulator_resource;
723 OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
724 &accumulator_resource));
725 mutex_lock l(*accumulator_resource->mutex());
726 core::ScopedUnref unref_me(accumulator_resource);
727 SerializeTensorAccumulatorToOutput(*accumulator_resource, context);
728 Tensor* stamp_token_t = nullptr;
729 OP_REQUIRES_OK(context,
730 context->allocate_output("stamp_token", TensorShape({}),
731 &stamp_token_t));
732 stamp_token_t->scalar<int64>()() = accumulator_resource->stamp();
733
734 Tensor* num_updates_t = nullptr;
735 OP_REQUIRES_OK(context,
736 context->allocate_output("num_updates", TensorShape({}),
737 &num_updates_t));
738 num_updates_t->scalar<int64>()() = accumulator_resource->num_updates();
739 }
740 };
741
742 REGISTER_KERNEL_BUILDER(
743 Name("StatsAccumulatorTensorSerialize").Device(DEVICE_CPU),
744 StatsAccumulatorTensorSerializeOp);
745
746 class StatsAccumulatorScalarMakeSummaryOp : public OpKernel {
747 public:
StatsAccumulatorScalarMakeSummaryOp(OpKernelConstruction * context)748 explicit StatsAccumulatorScalarMakeSummaryOp(OpKernelConstruction* context)
749 : OpKernel(context) {}
750
Compute(OpKernelContext * context)751 void Compute(OpKernelContext* context) override {
752 TensorShape gradient_shape = TensorShape({});
753 TensorShape hessian_shape = TensorShape({});
754 StatsAccumulatorScalarResource* accumulator_resource =
755 new StatsAccumulatorScalarResource(gradient_shape, hessian_shape);
756 core::ScopedUnref unref_me(accumulator_resource);
757 // Check the stamp token.
758 AddToScalarAccumulator(accumulator_resource, context);
759 SerializeScalarAccumulatorToOutput(*accumulator_resource, context);
760 }
761 };
762
763 REGISTER_KERNEL_BUILDER(
764 Name("StatsAccumulatorScalarMakeSummary").Device(DEVICE_CPU),
765 StatsAccumulatorScalarMakeSummaryOp);
766
767 class StatsAccumulatorTensorMakeSummaryOp : public OpKernel {
768 public:
StatsAccumulatorTensorMakeSummaryOp(OpKernelConstruction * context)769 explicit StatsAccumulatorTensorMakeSummaryOp(OpKernelConstruction* context)
770 : OpKernel(context) {}
771
Compute(OpKernelContext * context)772 void Compute(OpKernelContext* context) override {
773 const Tensor* gradients_t;
774 OP_REQUIRES_OK(context, context->input("gradients", &gradients_t));
775 TensorShape gradients_shape = gradients_t->shape();
776 gradients_shape.RemoveDim(0);
777
778 const Tensor* hessians_t;
779 OP_REQUIRES_OK(context, context->input("hessians", &hessians_t));
780 TensorShape hessians_shape = hessians_t->shape();
781 hessians_shape.RemoveDim(0);
782
783 StatsAccumulatorTensorResource* accumulator_resource =
784 new StatsAccumulatorTensorResource(gradients_shape, hessians_shape);
785 core::ScopedUnref unref_me(accumulator_resource);
786 // Check the stamp token.
787 AddToTensorAccumulator(accumulator_resource, context);
788 SerializeTensorAccumulatorToOutput(*accumulator_resource, context);
789 }
790 };
791
792 REGISTER_KERNEL_BUILDER(
793 Name("StatsAccumulatorTensorMakeSummary").Device(DEVICE_CPU),
794 StatsAccumulatorTensorMakeSummaryOp);
795
796 } // namespace boosted_trees
797 } // namespace tensorflow
798