• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 // =============================================================================
15 #include <algorithm>
16 #include <iterator>
17 #include <map>
18 #include <string>
19 #include <vector>
20 
21 #include "tensorflow/contrib/boosted_trees/lib/utils/parallel_for.h"
22 #include "tensorflow/contrib/boosted_trees/lib/utils/tensor_utils.h"
23 #include "tensorflow/contrib/boosted_trees/resources/stamped_resource.h"
24 #include "tensorflow/core/framework/op_kernel.h"
25 #include "tensorflow/core/framework/resource_mgr.h"
26 #include "tensorflow/core/framework/tensor.h"
27 #include "tensorflow/core/framework/tensor_shape.h"
28 #include "tensorflow/core/framework/types.h"
29 #include "tensorflow/core/lib/core/errors.h"
30 #include "tensorflow/core/lib/core/status.h"
31 #include "tensorflow/core/platform/types.h"
32 #include "tensorflow/core/util/work_sharder.h"
33 
34 namespace tensorflow {
35 namespace boosted_trees {
36 
37 namespace {
38 const char* const kStampTokenName = "stamp_token";
39 const char* const kNextStampTokenName = "next_stamp_token";
40 
41 struct PartitionKey {
PartitionKeytensorflow::boosted_trees::__anon2081c9750111::PartitionKey42   PartitionKey() : partition_id(-1), feature_id(-1), dimension(-1) {}
43 
PartitionKeytensorflow::boosted_trees::__anon2081c9750111::PartitionKey44   PartitionKey(int32 p, int64 f, int32 d)
45       : partition_id(p), feature_id(f), dimension(d) {}
46 
operator ==tensorflow::boosted_trees::__anon2081c9750111::PartitionKey47   bool operator==(const PartitionKey& other) const {
48     return (partition_id == other.partition_id) &&
49            (dimension == other.dimension) && (feature_id == other.feature_id);
50   }
51 
52   // Compare for PartitionKey.
53   struct Less {
operator ()tensorflow::boosted_trees::__anon2081c9750111::PartitionKey::Less54     bool operator()(const PartitionKey& a, const PartitionKey& b) const {
55       if (a.partition_id < b.partition_id) {
56         return true;
57       }
58       if ((a.partition_id == b.partition_id) && (a.dimension < b.dimension)) {
59         return true;
60       }
61       if ((a.partition_id == b.partition_id) && (a.dimension == b.dimension) &&
62           (a.feature_id < b.feature_id)) {
63         return true;
64       }
65       return false;
66     }
67   };
68 
69   // Tree partition defined by traversing the tree to the leaf.
70   int32 partition_id;
71 
72   // Feature column id.
73   int64 feature_id;
74 
75   // Dimension within feature column.
76   int32 dimension;
77 };
78 
79 template <typename GradientType, typename HessianType>
80 class StatsAccumulatorResource : public boosted_trees::StampedResource {
81   using StatsByPartition =
82       std::map<PartitionKey, std::pair<GradientType, HessianType>,
83                PartitionKey::Less>;
84 
85  public:
StatsAccumulatorResource(const TensorShape & gradient_shape,const TensorShape & hessian_shape)86   StatsAccumulatorResource(const TensorShape& gradient_shape,
87                            const TensorShape& hessian_shape)
88       : gradient_shape_(gradient_shape),
89         hessian_shape_(hessian_shape),
90         num_updates_(0) {
91     // If GradientType/HessianType is scalar float then the shapes should be
92     // scalar and vice versa.
93     CHECK_EQ((std::is_same<GradientType, float>::value),
94              TensorShapeUtils::IsScalar(gradient_shape));
95     CHECK_EQ((std::is_same<HessianType, float>::value),
96              TensorShapeUtils::IsScalar(hessian_shape));
97   }
98 
DebugString() const99   string DebugString() const override {
100     return strings::StrCat("StatsAccumulatorResource[size=", values_.size(),
101                            "]");
102   }
103 
Clear()104   void Clear() {
105     values_.clear();
106     num_updates_ = 0;
107   }
108 
mutex()109   tensorflow::mutex* mutex() { return &mu_; }
mutable_values()110   StatsByPartition* mutable_values() { return &values_; }
values() const111   const StatsByPartition& values() const { return values_; }
num_updates() const112   const int64& num_updates() const { return num_updates_; }
set_num_updates(int64 val)113   void set_num_updates(int64 val) { num_updates_ = val; }
gradient_shape() const114   const TensorShape& gradient_shape() const { return gradient_shape_; }
hessian_shape() const115   const TensorShape& hessian_shape() const { return hessian_shape_; }
116 
117  private:
118   // Key into a specific partition to accumulate stats for the specified feature
119   // id.
120   StatsByPartition values_;
121   const TensorShape gradient_shape_;
122   const TensorShape hessian_shape_;
123   int64 num_updates_;
124   tensorflow::mutex mu_;
125   TF_DISALLOW_COPY_AND_ASSIGN(StatsAccumulatorResource);
126 };
127 
128 using StatsAccumulatorScalarResource = StatsAccumulatorResource<float, float>;
129 using StatsAccumulatorTensorResource =
130     StatsAccumulatorResource<std::vector<float>, std::vector<float>>;
131 
SerializeScalarAccumulatorToOutput(const StatsAccumulatorScalarResource & accumulator_resource,OpKernelContext * context)132 void SerializeScalarAccumulatorToOutput(
133     const StatsAccumulatorScalarResource& accumulator_resource,
134     OpKernelContext* context) {
135   int64 num_slots = accumulator_resource.values().size();
136   Tensor* partition_ids_t = nullptr;
137   OP_REQUIRES_OK(context, context->allocate_output("output_partition_ids",
138                                                    TensorShape({num_slots}),
139                                                    &partition_ids_t));
140   auto partition_ids = partition_ids_t->vec<int32>();
141 
142   // Feature ids tensor has ids of feature columns and their dimensions.
143   Tensor* feature_ids_t = nullptr;
144   OP_REQUIRES_OK(context, context->allocate_output("output_feature_ids",
145                                                    TensorShape({num_slots, 2}),
146                                                    &feature_ids_t));
147   auto feature_ids = feature_ids_t->matrix<int64>();
148 
149   Tensor* gradients_t = nullptr;
150   OP_REQUIRES_OK(
151       context, context->allocate_output(
152                    "output_gradients", TensorShape({num_slots}), &gradients_t));
153   auto gradients = gradients_t->vec<float>();
154 
155   Tensor* hessians_t = nullptr;
156   OP_REQUIRES_OK(
157       context, context->allocate_output("output_hessians",
158                                         TensorShape({num_slots}), &hessians_t));
159   auto hessians = hessians_t->vec<float>();
160 
161   int i = 0;
162   for (const auto& iter : accumulator_resource.values()) {
163     partition_ids(i) = iter.first.partition_id;
164     feature_ids(i, 0) = iter.first.feature_id;
165     feature_ids(i, 1) = iter.first.dimension;
166 
167     gradients(i) = iter.second.first;
168     hessians(i) = iter.second.second;
169     ++i;
170   }
171 }
172 
SerializeTensorAccumulatorToOutput(const StatsAccumulatorTensorResource & accumulator_resource,OpKernelContext * context)173 void SerializeTensorAccumulatorToOutput(
174     const StatsAccumulatorTensorResource& accumulator_resource,
175     OpKernelContext* context) {
176   int64 num_slots = accumulator_resource.values().size();
177   Tensor* partition_ids_t = nullptr;
178   OP_REQUIRES_OK(context, context->allocate_output("output_partition_ids",
179                                                    TensorShape({num_slots}),
180                                                    &partition_ids_t));
181   auto partition_ids = partition_ids_t->vec<int32>();
182 
183   Tensor* feature_ids_t = nullptr;
184   OP_REQUIRES_OK(context, context->allocate_output("output_feature_ids",
185                                                    TensorShape({num_slots, 2}),
186                                                    &feature_ids_t));
187   auto feature_ids = feature_ids_t->matrix<int64>();
188 
189   TensorShape gradient_shape = accumulator_resource.gradient_shape();
190   int64 num_gradient_elements = gradient_shape.num_elements();
191   gradient_shape.InsertDim(0, num_slots);
192   Tensor* gradients_t = nullptr;
193   OP_REQUIRES_OK(context,
194                  context->allocate_output("output_gradients", gradient_shape,
195                                           &gradients_t));
196   auto gradients = gradients_t->flat_outer_dims<float>();
197 
198   TensorShape hessian_shape = accumulator_resource.hessian_shape();
199   int64 num_hessian_elements = hessian_shape.num_elements();
200   hessian_shape.InsertDim(0, num_slots);
201   Tensor* hessians_t = nullptr;
202   OP_REQUIRES_OK(context, context->allocate_output("output_hessians",
203                                                    hessian_shape, &hessians_t));
204   auto hessians = hessians_t->flat_outer_dims<float>();
205 
206   int i = 0;
207   for (const auto& iter : accumulator_resource.values()) {
208     partition_ids(i) = iter.first.partition_id;
209     feature_ids(i, 0) = iter.first.feature_id;
210     feature_ids(i, 1) = iter.first.dimension;
211 
212     for (int j = 0; j < num_gradient_elements; ++j) {
213       gradients(i, j) = iter.second.first[j];
214     }
215     for (int j = 0; j < num_hessian_elements; ++j) {
216       hessians(i, j) = iter.second.second[j];
217     }
218     ++i;
219   }
220 }
221 
AddToScalarAccumulator(StatsAccumulatorScalarResource * accumulator_resource,const Tensor & partition_ids_t,const Tensor & feature_ids_t,const Tensor & gradients_t,const Tensor & hessians_t)222 void AddToScalarAccumulator(
223     StatsAccumulatorScalarResource* accumulator_resource,
224     const Tensor& partition_ids_t, const Tensor& feature_ids_t,
225     const Tensor& gradients_t, const Tensor& hessians_t) {
226   accumulator_resource->set_num_updates(accumulator_resource->num_updates() +
227                                         1);
228   const TensorShape& partition_ids_shape = partition_ids_t.shape();
229   const auto& partition_ids = partition_ids_t.vec<int32>();
230   const auto& feature_ids_and_dimensions = feature_ids_t.matrix<int64>();
231   const auto& gradients = gradients_t.vec<float>();
232   const auto& hessians = hessians_t.vec<float>();
233 
234   int64 num_updates = partition_ids_shape.dim_size(0);
235   auto stats_map = accumulator_resource->mutable_values();
236   for (int64 i = 0; i < num_updates; ++i) {
237     const auto key =
238         PartitionKey(partition_ids(i), feature_ids_and_dimensions(i, 0),
239                      feature_ids_and_dimensions(i, 1));
240     auto itr = stats_map->find(key);
241     if (itr != stats_map->end()) {
242       itr->second.first += gradients(i);
243       itr->second.second += hessians(i);
244     } else {
245       (*stats_map)[key] = {gradients(i), hessians(i)};
246     }
247   }
248 }
249 
AddToScalarAccumulator(StatsAccumulatorScalarResource * accumulator_resource,OpKernelContext * context)250 void AddToScalarAccumulator(
251     StatsAccumulatorScalarResource* accumulator_resource,
252     OpKernelContext* context) {
253   const Tensor* partition_ids_t;
254   OP_REQUIRES_OK(context, context->input("partition_ids", &partition_ids_t));
255   const Tensor* feature_ids_t;
256   OP_REQUIRES_OK(context, context->input("feature_ids", &feature_ids_t));
257   const Tensor* gradients_t;
258   OP_REQUIRES_OK(context, context->input("gradients", &gradients_t));
259   const Tensor* hessians_t;
260   OP_REQUIRES_OK(context, context->input("hessians", &hessians_t));
261   AddToScalarAccumulator(accumulator_resource, *partition_ids_t, *feature_ids_t,
262                          *gradients_t, *hessians_t);
263 }
264 
AddToTensorAccumulator(StatsAccumulatorTensorResource * accumulator_resource,const Tensor & partition_ids_t,const Tensor & feature_ids_t,const Tensor & gradients_t,const Tensor & hessians_t,OpKernelContext * context)265 void AddToTensorAccumulator(
266     StatsAccumulatorTensorResource* accumulator_resource,
267     const Tensor& partition_ids_t, const Tensor& feature_ids_t,
268     const Tensor& gradients_t, const Tensor& hessians_t,
269     OpKernelContext* context) {
270   accumulator_resource->set_num_updates(accumulator_resource->num_updates() +
271                                         1);
272 
273   const TensorShape& partition_ids_shape = partition_ids_t.shape();
274   const auto& partition_ids = partition_ids_t.vec<int32>();
275   const auto& feature_ids_and_dimensions = feature_ids_t.matrix<int64>();
276   TensorShape gradients_shape = gradients_t.shape();
277   const auto& gradients = gradients_t.flat_outer_dims<float>();
278   TensorShape hessians_shape = hessians_t.shape();
279   const auto& hessians = hessians_t.flat_outer_dims<float>();
280 
281   gradients_shape.RemoveDim(0);
282   hessians_shape.RemoveDim(0);
283 
284   // TODO(soroush): Move gradient and hessian shape check to ShapeFn.
285   OP_REQUIRES(
286       context, gradients_shape == accumulator_resource->gradient_shape(),
287       errors::InvalidArgument(strings::StrCat(
288           "Gradients dimensions must match: ", gradients_shape.DebugString(),
289           ", ", accumulator_resource->gradient_shape().DebugString())));
290 
291   OP_REQUIRES(
292       context, hessians_shape == accumulator_resource->hessian_shape(),
293       errors::InvalidArgument(strings::StrCat(
294           "Hessian dimensions must match: ", hessians_shape.DebugString(), ", ",
295           accumulator_resource->hessian_shape().DebugString())));
296 
297   int64 num_updates = partition_ids_shape.dim_size(0);
298   auto stats_map = accumulator_resource->mutable_values();
299   for (int64 i = 0; i < num_updates; ++i) {
300     const auto key =
301         PartitionKey(partition_ids(i), feature_ids_and_dimensions(i, 0),
302                      feature_ids_and_dimensions(i, 1));
303     auto itr = stats_map->find(key);
304     if (itr == stats_map->end()) {
305       std::vector<float> new_gradients(gradients_shape.num_elements());
306       for (int j = 0; j < gradients_shape.num_elements(); ++j) {
307         new_gradients[j] = gradients(i, j);
308       }
309       std::vector<float> new_hessians(hessians_shape.num_elements());
310       for (int j = 0; j < hessians_shape.num_elements(); ++j) {
311         new_hessians[j] = hessians(i, j);
312       }
313       (*stats_map)[key] = {new_gradients, new_hessians};
314     } else {
315       auto& stored_gradients = itr->second.first;
316       for (int j = 0; j < gradients_shape.num_elements(); ++j) {
317         stored_gradients[j] += gradients(i, j);
318       }
319       auto& stored_hessians = itr->second.second;
320       for (int j = 0; j < hessians_shape.num_elements(); ++j) {
321         stored_hessians[j] += hessians(i, j);
322       }
323     }
324   }
325 }
326 
AddToTensorAccumulator(StatsAccumulatorTensorResource * accumulator_resource,OpKernelContext * context)327 void AddToTensorAccumulator(
328     StatsAccumulatorTensorResource* accumulator_resource,
329     OpKernelContext* context) {
330   const Tensor* partition_ids_t;
331   OP_REQUIRES_OK(context, context->input("partition_ids", &partition_ids_t));
332   const Tensor* feature_ids_t;
333   OP_REQUIRES_OK(context, context->input("feature_ids", &feature_ids_t));
334   const Tensor* gradients_t;
335   OP_REQUIRES_OK(context, context->input("gradients", &gradients_t));
336   const Tensor* hessians_t;
337   OP_REQUIRES_OK(context, context->input("hessians", &hessians_t));
338   AddToTensorAccumulator(accumulator_resource, *partition_ids_t, *feature_ids_t,
339                          *gradients_t, *hessians_t, context);
340 }
341 
342 }  // namespace
343 
344 REGISTER_RESOURCE_HANDLE_KERNEL(StatsAccumulatorScalarResource);
345 REGISTER_RESOURCE_HANDLE_KERNEL(StatsAccumulatorTensorResource);
346 
347 REGISTER_KERNEL_BUILDER(
348     Name("StatsAccumulatorScalarIsInitialized").Device(DEVICE_CPU),
349     IsResourceInitialized<StatsAccumulatorScalarResource>);
350 
351 REGISTER_KERNEL_BUILDER(
352     Name("StatsAccumulatorTensorIsInitialized").Device(DEVICE_CPU),
353     IsResourceInitialized<StatsAccumulatorTensorResource>);
354 
355 class CreateStatsAccumulatorScalarOp : public OpKernel {
356  public:
CreateStatsAccumulatorScalarOp(OpKernelConstruction * context)357   explicit CreateStatsAccumulatorScalarOp(OpKernelConstruction* context)
358       : OpKernel(context) {}
359 
Compute(OpKernelContext * context)360   void Compute(OpKernelContext* context) override {
361     const Tensor* stamp_token_t;
362     OP_REQUIRES_OK(context, context->input("stamp_token", &stamp_token_t));
363 
364     TensorShape gradient_shape = TensorShape({});
365     TensorShape hessian_shape = TensorShape({});
366 
367     auto* result =
368         new StatsAccumulatorScalarResource(gradient_shape, hessian_shape);
369     result->set_stamp(stamp_token_t->scalar<int64>()());
370     // Only create one, if one does not exist already. Report status for all
371     // other exceptions. If one already exists, it unrefs the new one.
372     auto status = CreateResource(context, HandleFromInput(context, 0), result);
373     if (!status.ok() && status.code() != tensorflow::error::ALREADY_EXISTS) {
374       OP_REQUIRES(context, false, status);
375     }
376   }
377 };
378 
379 REGISTER_KERNEL_BUILDER(Name("CreateStatsAccumulatorScalar").Device(DEVICE_CPU),
380                         CreateStatsAccumulatorScalarOp);
381 
382 class CreateStatsAccumulatorTensorOp : public OpKernel {
383  public:
CreateStatsAccumulatorTensorOp(OpKernelConstruction * context)384   explicit CreateStatsAccumulatorTensorOp(OpKernelConstruction* context)
385       : OpKernel(context) {}
386 
Compute(OpKernelContext * context)387   void Compute(OpKernelContext* context) override {
388     const Tensor* stamp_token_t;
389     OP_REQUIRES_OK(context, context->input("stamp_token", &stamp_token_t));
390 
391     const Tensor* gradient_shape_t;
392     OP_REQUIRES_OK(
393         context, context->input("per_slot_gradient_shape", &gradient_shape_t));
394 
395     const Tensor* hessian_shape_t;
396     OP_REQUIRES_OK(context,
397                    context->input("per_slot_hessian_shape", &hessian_shape_t));
398     TensorShape gradient_shape = TensorShape(gradient_shape_t->vec<int64>());
399     TensorShape hessian_shape = TensorShape(hessian_shape_t->vec<int64>());
400     auto* result =
401         new StatsAccumulatorTensorResource(gradient_shape, hessian_shape);
402     result->set_stamp(stamp_token_t->scalar<int64>()());
403 
404     // Only create one, if one does not exist already. Report status for all
405     // other exceptions. If one already exists, it unrefs the new one.
406     auto status = CreateResource(context, HandleFromInput(context, 0), result);
407     if (!status.ok() && status.code() != tensorflow::error::ALREADY_EXISTS) {
408       OP_REQUIRES(context, false, status);
409     }
410   }
411 };
412 
413 REGISTER_KERNEL_BUILDER(Name("CreateStatsAccumulatorTensor").Device(DEVICE_CPU),
414                         CreateStatsAccumulatorTensorOp);
415 
416 class StatsAccumulatorScalarAddOp : public OpKernel {
417  public:
StatsAccumulatorScalarAddOp(OpKernelConstruction * context)418   explicit StatsAccumulatorScalarAddOp(OpKernelConstruction* context)
419       : OpKernel(context) {}
420 
Compute(OpKernelContext * context)421   void Compute(OpKernelContext* context) override {
422     OpInputList resource_handle_list;
423     OP_REQUIRES_OK(context, context->input_list("stats_accumulator_handles",
424                                                 &resource_handle_list));
425     OpInputList partition_ids_list;
426     OP_REQUIRES_OK(context,
427                    context->input_list("partition_ids", &partition_ids_list));
428 
429     OpInputList feature_ids_list;
430     OP_REQUIRES_OK(context,
431                    context->input_list("feature_ids", &feature_ids_list));
432     OpInputList gradients_list;
433     OP_REQUIRES_OK(context, context->input_list("gradients", &gradients_list));
434     OpInputList hessians_list;
435     OP_REQUIRES_OK(context, context->input_list("hessians", &hessians_list));
436 
437     const Tensor* stamp_token_t;
438     OP_REQUIRES_OK(context, context->input(kStampTokenName, &stamp_token_t));
439     int64 stamp_token = stamp_token_t->scalar<int64>()();
440 
441     thread::ThreadPool* const worker_threads =
442         context->device()->tensorflow_cpu_worker_threads()->workers;
443     boosted_trees::utils::ParallelFor(
444         resource_handle_list.size(), worker_threads->NumThreads(),
445         worker_threads,
446         [&context, &resource_handle_list, &partition_ids_list,
447          &feature_ids_list, &gradients_list, &hessians_list,
448          stamp_token](int64 start, int64 end) {
449           for (int resource_handle_idx = start; resource_handle_idx < end;
450                ++resource_handle_idx) {
451             const ResourceHandle& handle =
452                 resource_handle_list[resource_handle_idx]
453                     .flat<ResourceHandle>()(0);
454 
455             StatsAccumulatorScalarResource* accumulator_resource;
456             OP_REQUIRES_OK(context, LookupResource(context, handle,
457                                                    &accumulator_resource));
458             mutex_lock l(*accumulator_resource->mutex());
459             core::ScopedUnref unref_me(accumulator_resource);
460 
461             // If the stamp is invalid we drop the update.
462             if (!accumulator_resource->is_stamp_valid(stamp_token)) {
463               VLOG(1) << "Invalid stamp token in StatsAccumulatorScalarAddOp. "
464                       << "Passed stamp token: " << stamp_token << " "
465                       << "Current token: " << accumulator_resource->stamp();
466               return;
467             }
468             AddToScalarAccumulator(accumulator_resource,
469                                    partition_ids_list[resource_handle_idx],
470                                    feature_ids_list[resource_handle_idx],
471                                    gradients_list[resource_handle_idx],
472                                    hessians_list[resource_handle_idx]);
473           }
474         });
475   }
476 };
477 
478 REGISTER_KERNEL_BUILDER(Name("StatsAccumulatorScalarAdd").Device(DEVICE_CPU),
479                         StatsAccumulatorScalarAddOp);
480 
481 class StatsAccumulatorTensorAddOp : public OpKernel {
482  public:
StatsAccumulatorTensorAddOp(OpKernelConstruction * context)483   explicit StatsAccumulatorTensorAddOp(OpKernelConstruction* context)
484       : OpKernel(context) {}
485 
Compute(OpKernelContext * context)486   void Compute(OpKernelContext* context) override {
487     OpInputList resource_handle_list;
488     OP_REQUIRES_OK(context, context->input_list("stats_accumulator_handles",
489                                                 &resource_handle_list));
490     OpInputList partition_ids_list;
491     OP_REQUIRES_OK(context,
492                    context->input_list("partition_ids", &partition_ids_list));
493 
494     OpInputList feature_ids_list;
495     OP_REQUIRES_OK(context,
496                    context->input_list("feature_ids", &feature_ids_list));
497     OpInputList gradients_list;
498     OP_REQUIRES_OK(context, context->input_list("gradients", &gradients_list));
499     OpInputList hessians_list;
500     OP_REQUIRES_OK(context, context->input_list("hessians", &hessians_list));
501 
502     const Tensor* stamp_token_t;
503     OP_REQUIRES_OK(context, context->input(kStampTokenName, &stamp_token_t));
504     int64 stamp_token = stamp_token_t->scalar<int64>()();
505 
506     thread::ThreadPool* const worker_threads =
507         context->device()->tensorflow_cpu_worker_threads()->workers;
508     boosted_trees::utils::ParallelFor(
509         resource_handle_list.size(), worker_threads->NumThreads(),
510         worker_threads,
511         [&context, &resource_handle_list, &partition_ids_list,
512          &feature_ids_list, &gradients_list, &hessians_list,
513          stamp_token](int64 start, int64 end) {
514           for (int resource_handle_idx = start; resource_handle_idx < end;
515                ++resource_handle_idx) {
516             const ResourceHandle& handle =
517                 resource_handle_list[resource_handle_idx]
518                     .flat<ResourceHandle>()(0);
519 
520             StatsAccumulatorTensorResource* accumulator_resource;
521             OP_REQUIRES_OK(context, LookupResource(context, handle,
522                                                    &accumulator_resource));
523             mutex_lock l(*accumulator_resource->mutex());
524             core::ScopedUnref unref_me(accumulator_resource);
525 
526             // If the stamp is invalid we drop the update.
527             if (!accumulator_resource->is_stamp_valid(stamp_token)) {
528               VLOG(1) << "Invalid stamp token in StatsAccumulatorScalarAddOp. "
529                       << "Passed stamp token: " << stamp_token << " "
530                       << "Current token: " << accumulator_resource->stamp();
531               return;
532             }
533             AddToTensorAccumulator(accumulator_resource,
534                                    partition_ids_list[resource_handle_idx],
535                                    feature_ids_list[resource_handle_idx],
536                                    gradients_list[resource_handle_idx],
537                                    hessians_list[resource_handle_idx], context);
538           }
539         });
540   }
541 };
542 
543 REGISTER_KERNEL_BUILDER(Name("StatsAccumulatorTensorAdd").Device(DEVICE_CPU),
544                         StatsAccumulatorTensorAddOp);
545 
546 class StatsAccumulatorScalarFlushOp : public OpKernel {
547  public:
StatsAccumulatorScalarFlushOp(OpKernelConstruction * context)548   explicit StatsAccumulatorScalarFlushOp(OpKernelConstruction* context)
549       : OpKernel(context) {}
550 
Compute(OpKernelContext * context)551   void Compute(OpKernelContext* context) override {
552     StatsAccumulatorScalarResource* accumulator_resource;
553     OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
554                                            &accumulator_resource));
555     mutex_lock l(*accumulator_resource->mutex());
556     core::ScopedUnref unref_me(accumulator_resource);
557 
558     const Tensor* stamp_token_t;
559     OP_REQUIRES_OK(context, context->input(kStampTokenName, &stamp_token_t));
560     int64 stamp_token = stamp_token_t->scalar<int64>()();
561 
562     // If the stamp is invalid we restart the PS. It shouldn't happen since
563     // only Chief should call this function and chief is guaranteed to be in
564     // a consistent state.
565     CHECK(accumulator_resource->is_stamp_valid(stamp_token));
566 
567     const Tensor* next_stamp_token_t;
568     OP_REQUIRES_OK(context,
569                    context->input(kNextStampTokenName, &next_stamp_token_t));
570     int64 next_stamp_token = next_stamp_token_t->scalar<int64>()();
571     CHECK(stamp_token != next_stamp_token);
572 
573     SerializeScalarAccumulatorToOutput(*accumulator_resource, context);
574     Tensor* num_updates_t = nullptr;
575     OP_REQUIRES_OK(context,
576                    context->allocate_output("num_updates", TensorShape({}),
577                                             &num_updates_t));
578     num_updates_t->scalar<int64>()() = accumulator_resource->num_updates();
579 
580     accumulator_resource->Clear();
581     accumulator_resource->set_stamp(next_stamp_token);
582   }
583 };
584 
585 REGISTER_KERNEL_BUILDER(Name("StatsAccumulatorScalarFlush").Device(DEVICE_CPU),
586                         StatsAccumulatorScalarFlushOp);
587 
588 class StatsAccumulatorTensorFlushOp : public OpKernel {
589  public:
StatsAccumulatorTensorFlushOp(OpKernelConstruction * context)590   explicit StatsAccumulatorTensorFlushOp(OpKernelConstruction* context)
591       : OpKernel(context) {}
592 
Compute(OpKernelContext * context)593   void Compute(OpKernelContext* context) override {
594     StatsAccumulatorTensorResource* accumulator_resource;
595     OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
596                                            &accumulator_resource));
597     mutex_lock l(*accumulator_resource->mutex());
598     core::ScopedUnref unref_me(accumulator_resource);
599 
600     const Tensor* stamp_token_t;
601     OP_REQUIRES_OK(context, context->input(kStampTokenName, &stamp_token_t));
602     int64 stamp_token = stamp_token_t->scalar<int64>()();
603 
604     const Tensor* next_stamp_token_t;
605     OP_REQUIRES_OK(context,
606                    context->input(kNextStampTokenName, &next_stamp_token_t));
607     int64 next_stamp_token = next_stamp_token_t->scalar<int64>()();
608 
609     // If the stamp is invalid we restart the PS. It shouldn't happen since
610     // only Chief should call this function and chief is guaranteed to be in
611     // a consistent state.
612     CHECK(accumulator_resource->is_stamp_valid(stamp_token));
613     CHECK(stamp_token != next_stamp_token);
614     SerializeTensorAccumulatorToOutput(*accumulator_resource, context);
615     Tensor* num_updates_t = nullptr;
616     OP_REQUIRES_OK(context,
617                    context->allocate_output("num_updates", TensorShape({}),
618                                             &num_updates_t));
619     num_updates_t->scalar<int64>()() = accumulator_resource->num_updates();
620     accumulator_resource->Clear();
621     accumulator_resource->set_stamp(next_stamp_token);
622   }
623 };
624 
625 REGISTER_KERNEL_BUILDER(Name("StatsAccumulatorTensorFlush").Device(DEVICE_CPU),
626                         StatsAccumulatorTensorFlushOp);
627 
628 class StatsAccumulatorScalarDeserializeOp : public OpKernel {
629  public:
StatsAccumulatorScalarDeserializeOp(OpKernelConstruction * context)630   explicit StatsAccumulatorScalarDeserializeOp(OpKernelConstruction* context)
631       : OpKernel(context) {}
632 
Compute(OpKernelContext * context)633   void Compute(OpKernelContext* context) override {
634     StatsAccumulatorScalarResource* accumulator_resource;
635     OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
636                                            &accumulator_resource));
637     mutex_lock l(*accumulator_resource->mutex());
638     core::ScopedUnref unref_me(accumulator_resource);
639 
640     // Check the stamp token.
641     const Tensor* stamp_token_t;
642     OP_REQUIRES_OK(context, context->input(kStampTokenName, &stamp_token_t));
643     int64 stamp_token = stamp_token_t->scalar<int64>()();
644     accumulator_resource->Clear();
645     accumulator_resource->set_stamp(stamp_token);
646     AddToScalarAccumulator(accumulator_resource, context);
647     const Tensor* num_updates_t;
648     OP_REQUIRES_OK(context, context->input("num_updates", &num_updates_t));
649     accumulator_resource->set_num_updates(num_updates_t->scalar<int64>()());
650   }
651 };
652 
653 REGISTER_KERNEL_BUILDER(
654     Name("StatsAccumulatorScalarDeserialize").Device(DEVICE_CPU),
655     StatsAccumulatorScalarDeserializeOp);
656 
657 class StatsAccumulatorTensorDeserializeOp : public OpKernel {
658  public:
StatsAccumulatorTensorDeserializeOp(OpKernelConstruction * context)659   explicit StatsAccumulatorTensorDeserializeOp(OpKernelConstruction* context)
660       : OpKernel(context) {}
661 
Compute(OpKernelContext * context)662   void Compute(OpKernelContext* context) override {
663     StatsAccumulatorTensorResource* accumulator_resource;
664     OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
665                                            &accumulator_resource));
666     mutex_lock l(*accumulator_resource->mutex());
667     core::ScopedUnref unref_me(accumulator_resource);
668 
669     // Check the stamp token.
670     const Tensor* stamp_token_t;
671     OP_REQUIRES_OK(context, context->input(kStampTokenName, &stamp_token_t));
672     int64 stamp_token = stamp_token_t->scalar<int64>()();
673     accumulator_resource->Clear();
674     accumulator_resource->set_stamp(stamp_token);
675     AddToTensorAccumulator(accumulator_resource, context);
676     const Tensor* num_updates_t;
677     OP_REQUIRES_OK(context, context->input("num_updates", &num_updates_t));
678     accumulator_resource->set_num_updates(num_updates_t->scalar<int64>()());
679   }
680 };
681 
682 REGISTER_KERNEL_BUILDER(
683     Name("StatsAccumulatorTensorDeserialize").Device(DEVICE_CPU),
684     StatsAccumulatorTensorDeserializeOp);
685 
686 class StatsAccumulatorScalarSerializeOp : public OpKernel {
687  public:
StatsAccumulatorScalarSerializeOp(OpKernelConstruction * context)688   explicit StatsAccumulatorScalarSerializeOp(OpKernelConstruction* context)
689       : OpKernel(context) {}
690 
Compute(OpKernelContext * context)691   void Compute(OpKernelContext* context) override {
692     StatsAccumulatorScalarResource* accumulator_resource;
693     OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
694                                            &accumulator_resource));
695     mutex_lock l(*accumulator_resource->mutex());
696     core::ScopedUnref unref_me(accumulator_resource);
697     SerializeScalarAccumulatorToOutput(*accumulator_resource, context);
698     Tensor* stamp_token_t = nullptr;
699     OP_REQUIRES_OK(context,
700                    context->allocate_output("stamp_token", TensorShape({}),
701                                             &stamp_token_t));
702     stamp_token_t->scalar<int64>()() = accumulator_resource->stamp();
703 
704     Tensor* num_updates_t = nullptr;
705     OP_REQUIRES_OK(context,
706                    context->allocate_output("num_updates", TensorShape({}),
707                                             &num_updates_t));
708     num_updates_t->scalar<int64>()() = accumulator_resource->num_updates();
709   }
710 };
711 
712 REGISTER_KERNEL_BUILDER(
713     Name("StatsAccumulatorScalarSerialize").Device(DEVICE_CPU),
714     StatsAccumulatorScalarSerializeOp);
715 
716 class StatsAccumulatorTensorSerializeOp : public OpKernel {
717  public:
StatsAccumulatorTensorSerializeOp(OpKernelConstruction * context)718   explicit StatsAccumulatorTensorSerializeOp(OpKernelConstruction* context)
719       : OpKernel(context) {}
720 
Compute(OpKernelContext * context)721   void Compute(OpKernelContext* context) override {
722     StatsAccumulatorTensorResource* accumulator_resource;
723     OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
724                                            &accumulator_resource));
725     mutex_lock l(*accumulator_resource->mutex());
726     core::ScopedUnref unref_me(accumulator_resource);
727     SerializeTensorAccumulatorToOutput(*accumulator_resource, context);
728     Tensor* stamp_token_t = nullptr;
729     OP_REQUIRES_OK(context,
730                    context->allocate_output("stamp_token", TensorShape({}),
731                                             &stamp_token_t));
732     stamp_token_t->scalar<int64>()() = accumulator_resource->stamp();
733 
734     Tensor* num_updates_t = nullptr;
735     OP_REQUIRES_OK(context,
736                    context->allocate_output("num_updates", TensorShape({}),
737                                             &num_updates_t));
738     num_updates_t->scalar<int64>()() = accumulator_resource->num_updates();
739   }
740 };
741 
742 REGISTER_KERNEL_BUILDER(
743     Name("StatsAccumulatorTensorSerialize").Device(DEVICE_CPU),
744     StatsAccumulatorTensorSerializeOp);
745 
746 class StatsAccumulatorScalarMakeSummaryOp : public OpKernel {
747  public:
StatsAccumulatorScalarMakeSummaryOp(OpKernelConstruction * context)748   explicit StatsAccumulatorScalarMakeSummaryOp(OpKernelConstruction* context)
749       : OpKernel(context) {}
750 
Compute(OpKernelContext * context)751   void Compute(OpKernelContext* context) override {
752     TensorShape gradient_shape = TensorShape({});
753     TensorShape hessian_shape = TensorShape({});
754     StatsAccumulatorScalarResource* accumulator_resource =
755         new StatsAccumulatorScalarResource(gradient_shape, hessian_shape);
756     core::ScopedUnref unref_me(accumulator_resource);
757     // Check the stamp token.
758     AddToScalarAccumulator(accumulator_resource, context);
759     SerializeScalarAccumulatorToOutput(*accumulator_resource, context);
760   }
761 };
762 
763 REGISTER_KERNEL_BUILDER(
764     Name("StatsAccumulatorScalarMakeSummary").Device(DEVICE_CPU),
765     StatsAccumulatorScalarMakeSummaryOp);
766 
767 class StatsAccumulatorTensorMakeSummaryOp : public OpKernel {
768  public:
StatsAccumulatorTensorMakeSummaryOp(OpKernelConstruction * context)769   explicit StatsAccumulatorTensorMakeSummaryOp(OpKernelConstruction* context)
770       : OpKernel(context) {}
771 
Compute(OpKernelContext * context)772   void Compute(OpKernelContext* context) override {
773     const Tensor* gradients_t;
774     OP_REQUIRES_OK(context, context->input("gradients", &gradients_t));
775     TensorShape gradients_shape = gradients_t->shape();
776     gradients_shape.RemoveDim(0);
777 
778     const Tensor* hessians_t;
779     OP_REQUIRES_OK(context, context->input("hessians", &hessians_t));
780     TensorShape hessians_shape = hessians_t->shape();
781     hessians_shape.RemoveDim(0);
782 
783     StatsAccumulatorTensorResource* accumulator_resource =
784         new StatsAccumulatorTensorResource(gradients_shape, hessians_shape);
785     core::ScopedUnref unref_me(accumulator_resource);
786     // Check the stamp token.
787     AddToTensorAccumulator(accumulator_resource, context);
788     SerializeTensorAccumulatorToOutput(*accumulator_resource, context);
789   }
790 };
791 
792 REGISTER_KERNEL_BUILDER(
793     Name("StatsAccumulatorTensorMakeSummary").Device(DEVICE_CPU),
794     StatsAccumulatorTensorMakeSummaryOp);
795 
796 }  // namespace boosted_trees
797 }  // namespace tensorflow
798