• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/sdca_ops.cc.
17 
18 #define EIGEN_USE_THREADS
19 
20 #include <stdint.h>
21 
22 #include <atomic>
23 #include <limits>
24 #include <memory>
25 #include <new>
26 #include <string>
27 #include <vector>
28 
29 #include "absl/strings/str_format.h"
30 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
31 #include "tensorflow/core/framework/device_base.h"
32 #include "tensorflow/core/framework/kernel_def_builder.h"
33 #include "tensorflow/core/framework/op.h"
34 #include "tensorflow/core/framework/op_def_builder.h"
35 #include "tensorflow/core/framework/op_kernel.h"
36 #include "tensorflow/core/framework/tensor.h"
37 #include "tensorflow/core/framework/tensor_shape.h"
38 #include "tensorflow/core/framework/tensor_types.h"
39 #include "tensorflow/core/framework/types.h"
40 #include "tensorflow/core/kernels/hinge-loss.h"
41 #include "tensorflow/core/kernels/logistic-loss.h"
42 #include "tensorflow/core/kernels/loss.h"
43 #include "tensorflow/core/kernels/poisson-loss.h"
44 #include "tensorflow/core/kernels/sdca_internal.h"
45 #include "tensorflow/core/kernels/smooth-hinge-loss.h"
46 #include "tensorflow/core/kernels/squared-loss.h"
47 #include "tensorflow/core/lib/core/coding.h"
48 #include "tensorflow/core/lib/core/errors.h"
49 #include "tensorflow/core/lib/core/status.h"
50 #include "tensorflow/core/lib/core/stringpiece.h"
51 #include "tensorflow/core/lib/gtl/inlined_vector.h"
52 #include "tensorflow/core/platform/fingerprint.h"
53 #include "tensorflow/core/platform/macros.h"
54 #include "tensorflow/core/platform/mutex.h"
55 #include "tensorflow/core/platform/types.h"
56 #include "tensorflow/core/util/work_sharder.h"
57 
58 namespace tensorflow {
59 
60 namespace {
61 
62 using sdca::Example;
63 using sdca::Examples;
64 using sdca::ExampleStatistics;
65 using sdca::ModelWeights;
66 using sdca::Regularizations;
67 
68 struct ComputeOptions {
ComputeOptionstensorflow::__anon9fd918570111::ComputeOptions69   explicit ComputeOptions(OpKernelConstruction* const context) {
70     string loss_type;
71     OP_REQUIRES_OK(context, context->GetAttr("loss_type", &loss_type));
72     if (loss_type == "logistic_loss") {
73       loss_updater.reset(new LogisticLossUpdater);
74     } else if (loss_type == "squared_loss") {
75       loss_updater.reset(new SquaredLossUpdater);
76     } else if (loss_type == "hinge_loss") {
77       loss_updater.reset(new HingeLossUpdater);
78     } else if (loss_type == "smooth_hinge_loss") {
79       loss_updater.reset(new SmoothHingeLossUpdater);
80     } else if (loss_type == "poisson_loss") {
81       loss_updater.reset(new PoissonLossUpdater);
82     } else {
83       OP_REQUIRES(
84           context, false,
85           errors::InvalidArgument("Unsupported loss type: ", loss_type));
86     }
87     auto s = context->GetAttr("adaptative", &adaptive);
88     if (!s.ok()) {
89       s = context->GetAttr("adaptive", &adaptive);
90     }
91     OP_REQUIRES_OK(context, s);
92     OP_REQUIRES_OK(
93         context, context->GetAttr("num_sparse_features", &num_sparse_features));
94     OP_REQUIRES_OK(context, context->GetAttr("num_sparse_features_with_values",
95                                              &num_sparse_features_with_values));
96     OP_REQUIRES_OK(context,
97                    context->GetAttr("num_dense_features", &num_dense_features));
98     OP_REQUIRES(
99         context, num_sparse_features + num_dense_features > 0,
100         errors::InvalidArgument("Requires at least one feature to train."));
101 
102     OP_REQUIRES(context,
103                 static_cast<int64>(num_sparse_features) +
104                         static_cast<int64>(num_dense_features) <=
105                     std::numeric_limits<int>::max(),
106                 errors::InvalidArgument(
107                     absl::StrFormat("Too many feature groups: %d > %d",
108                                     static_cast<int64>(num_sparse_features) +
109                                         static_cast<int64>(num_dense_features),
110                                     std::numeric_limits<int>::max())));
111     OP_REQUIRES_OK(
112         context, context->GetAttr("num_loss_partitions", &num_loss_partitions));
113     OP_REQUIRES_OK(context, context->GetAttr("num_inner_iterations",
114                                              &num_inner_iterations));
115     OP_REQUIRES_OK(context, regularizations.Initialize(context));
116   }
117 
118   std::unique_ptr<DualLossUpdater> loss_updater;
119   int num_sparse_features = 0;
120   int num_sparse_features_with_values = 0;
121   int num_dense_features = 0;
122   int num_inner_iterations = 0;
123   int num_loss_partitions = 0;
124   bool adaptive = true;
125   Regularizations regularizations;
126 };
127 
128 // TODO(shengx): The helper classes/methods are changed to support multiclass
129 // SDCA, which lead to changes within this function. Need to revisit the
130 // convergence once the multiclass SDCA is in.
DoCompute(const ComputeOptions & options,OpKernelContext * const context)131 void DoCompute(const ComputeOptions& options, OpKernelContext* const context) {
132   ModelWeights model_weights;
133   OP_REQUIRES_OK(context, model_weights.Initialize(context));
134 
135   Examples examples;
136   OP_REQUIRES_OK(
137       context,
138       examples.Initialize(context, model_weights, options.num_sparse_features,
139                           options.num_sparse_features_with_values,
140                           options.num_dense_features));
141 
142   const Tensor* example_state_data_t;
143   OP_REQUIRES_OK(context,
144                  context->input("example_state_data", &example_state_data_t));
145   TensorShape expected_example_state_shape({examples.num_examples(), 4});
146   OP_REQUIRES(context,
147               example_state_data_t->shape() == expected_example_state_shape,
148               errors::InvalidArgument(
149                   "Expected shape ", expected_example_state_shape.DebugString(),
150                   " for example_state_data, got ",
151                   example_state_data_t->shape().DebugString()));
152 
153   Tensor mutable_example_state_data_t(*example_state_data_t);
154   auto example_state_data = mutable_example_state_data_t.matrix<float>();
155   OP_REQUIRES_OK(context, context->set_output("out_example_state_data",
156                                               mutable_example_state_data_t));
157 
158   if (options.adaptive) {
159     OP_REQUIRES_OK(context,
160                    examples.SampleAdaptiveProbabilities(
161                        options.num_loss_partitions, options.regularizations,
162                        model_weights, example_state_data, options.loss_updater,
163                        /*num_weight_vectors =*/1));
164   } else {
165     examples.RandomShuffle();
166   }
167   struct {
168     mutex mu;
169     Status value TF_GUARDED_BY(mu);
170   } train_step_status;
171   std::atomic<std::int64_t> atomic_index(-1);
172   auto train_step = [&](const int64 begin, const int64 end) {
173     // The static_cast here is safe since begin and end can be at most
174     // num_examples which is an int.
175     for (int id = static_cast<int>(begin); id < end; ++id) {
176       const int64 example_index = examples.sampled_index(++atomic_index);
177       const Example& example = examples.example(example_index);
178       const float dual = example_state_data(example_index, 0);
179       const float example_weight = example.example_weight();
180       float example_label = example.example_label();
181       const Status conversion_status =
182           options.loss_updater->ConvertLabel(&example_label);
183       if (!conversion_status.ok()) {
184         mutex_lock l(train_step_status.mu);
185         train_step_status.value = conversion_status;
186         // Return from this worker thread - the calling thread is
187         // responsible for checking context status and returning on error.
188         return;
189       }
190 
191       // Compute wx, example norm weighted by regularization, dual loss,
192       // primal loss.
193       // For binary SDCA, num_weight_vectors should be one.
194       const ExampleStatistics example_statistics =
195           example.ComputeWxAndWeightedExampleNorm(
196               options.num_loss_partitions, model_weights,
197               options.regularizations, 1 /* num_weight_vectors */);
198 
199       const double new_dual = options.loss_updater->ComputeUpdatedDual(
200           options.num_loss_partitions, example_label, example_weight, dual,
201           example_statistics.wx[0], example_statistics.normalized_squared_norm);
202 
203       // Compute new weights.
204       const double normalized_bounded_dual_delta =
205           (new_dual - dual) * example_weight /
206           options.regularizations.symmetric_l2();
207       model_weights.UpdateDeltaWeights(
208           context->eigen_cpu_device(), example,
209           std::vector<double>{normalized_bounded_dual_delta});
210 
211       // Update example data.
212       example_state_data(example_index, 0) = new_dual;
213       example_state_data(example_index, 1) =
214           options.loss_updater->ComputePrimalLoss(
215               example_statistics.prev_wx[0], example_label, example_weight);
216       example_state_data(example_index, 2) =
217           options.loss_updater->ComputeDualLoss(dual, example_label,
218                                                 example_weight);
219       example_state_data(example_index, 3) = example_weight;
220     }
221   };
222   // TODO(sibyl-Aix6ihai): Tune this properly based on sparsity of the data,
223   // number of cpus, and cost per example.
224   const int64 kCostPerUnit = examples.num_features();
225   const DeviceBase::CpuWorkerThreads& worker_threads =
226       *context->device()->tensorflow_cpu_worker_threads();
227 
228   Shard(worker_threads.num_threads, worker_threads.workers,
229         examples.num_examples(), kCostPerUnit, train_step);
230   mutex_lock l(train_step_status.mu);
231   OP_REQUIRES_OK(context, train_step_status.value);
232 }
233 
234 }  // namespace
235 
236 class SdcaOptimizer : public OpKernel {
237  public:
SdcaOptimizer(OpKernelConstruction * const context)238   explicit SdcaOptimizer(OpKernelConstruction* const context)
239       : OpKernel(context), options_(context) {}
240 
Compute(OpKernelContext * context)241   void Compute(OpKernelContext* context) override {
242     DoCompute(options_, context);
243   }
244 
245  private:
246   // TODO(sibyl-Aix6ihai): We could use the type-constraint on loss_type, and
247   // template the entire class to avoid the virtual table lookup penalty in
248   // the inner loop.
249   ComputeOptions options_;
250 };
251 REGISTER_KERNEL_BUILDER(Name("SdcaOptimizer").Device(DEVICE_CPU),
252                         SdcaOptimizer);
253 REGISTER_KERNEL_BUILDER(Name("SdcaOptimizerV2").Device(DEVICE_CPU),
254                         SdcaOptimizer);
255 
256 class SdcaShrinkL1 : public OpKernel {
257  public:
SdcaShrinkL1(OpKernelConstruction * const context)258   explicit SdcaShrinkL1(OpKernelConstruction* const context)
259       : OpKernel(context) {
260     OP_REQUIRES_OK(context, regularizations_.Initialize(context));
261   }
262 
Compute(OpKernelContext * context)263   void Compute(OpKernelContext* context) override {
264     OpMutableInputList weights_inputs;
265     OP_REQUIRES_OK(context,
266                    context->mutable_input_list("weights", &weights_inputs));
267 
268     auto do_work = [&](const int64 begin, const int64 end) {
269       for (int i = begin; i < end; ++i) {
270         auto prox_w = weights_inputs.at(i, /*lock_held=*/true).flat<float>();
271         prox_w.device(context->eigen_cpu_device()) =
272             regularizations_.EigenShrinkVector(prox_w);
273       }
274     };
275 
276     if (weights_inputs.size() > 0) {
277       int64 num_weights = 0;
278       for (int i = 0; i < weights_inputs.size(); ++i) {
279         num_weights += weights_inputs.at(i, /*lock_held=*/true).NumElements();
280       }
281       // TODO(sibyl-Aix6ihai): Tune this value.
282       const int64 kCostPerUnit = (num_weights * 50) / weights_inputs.size();
283       const DeviceBase::CpuWorkerThreads& worker_threads =
284           *context->device()->tensorflow_cpu_worker_threads();
285       Shard(worker_threads.num_threads, worker_threads.workers,
286             weights_inputs.size(), kCostPerUnit, do_work);
287     }
288   }
289 
290  private:
291   Regularizations regularizations_;
292 };
293 REGISTER_KERNEL_BUILDER(Name("SdcaShrinkL1").Device(DEVICE_CPU), SdcaShrinkL1);
294 
295 // Computes platform independent, compact and unique (with very high
296 // probability) representation of an example id. It shouldn't be put in
297 // persistent storage, as its implementation may change in the future.
298 //
299 // The current probability of at least one collision for 1B example_ids is
300 // approximately 10^-21 (ie 2^60 / 2^129).
301 class SdcaFprint : public OpKernel {
302  public:
SdcaFprint(OpKernelConstruction * const context)303   explicit SdcaFprint(OpKernelConstruction* const context)
304       : OpKernel(context) {}
305 
Compute(OpKernelContext * context)306   void Compute(OpKernelContext* context) override {
307     const Tensor& input = context->input(0);
308     OP_REQUIRES(context, TensorShapeUtils::IsVector(input.shape()),
309                 errors::InvalidArgument("Input must be a vector, got shape ",
310                                         input.shape().DebugString()));
311     Tensor* out;
312     const int64 num_elements = input.NumElements();
313     OP_REQUIRES_OK(context, context->allocate_output(
314                                 0, TensorShape({num_elements, 2}), &out));
315 
316     const auto in_values = input.flat<tstring>();
317     auto out_values = out->matrix<int64>();
318 
319     for (int64 i = 0; i < num_elements; ++i) {
320       const Fprint128 fprint = Fingerprint128(in_values(i));
321       // Never return 0 or 1 as the first value of the hash to allow these to
322       // safely be used as sentinel values (e.g. dense hash table empty key).
323       out_values(i, 0) = TF_PREDICT_TRUE(fprint.low64 >= 2)
324                              ? fprint.low64
325                              : fprint.low64 + ~static_cast<uint64>(1);
326       out_values(i, 1) = fprint.high64;
327     }
328   }
329 };
330 REGISTER_KERNEL_BUILDER(Name("SdcaFprint").Device(DEVICE_CPU), SdcaFprint);
331 
332 }  // namespace tensorflow
333