• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/sdca_ops.cc.
17 
18 #define EIGEN_USE_THREADS
19 
20 #include <stdint.h>
21 #include <atomic>
22 #include <limits>
23 #include <memory>
24 #include <new>
25 #include <string>
26 #include <vector>
27 
28 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
29 #include "tensorflow/core/framework/device_base.h"
30 #include "tensorflow/core/framework/kernel_def_builder.h"
31 #include "tensorflow/core/framework/op.h"
32 #include "tensorflow/core/framework/op_def_builder.h"
33 #include "tensorflow/core/framework/op_kernel.h"
34 #include "tensorflow/core/framework/tensor.h"
35 #include "tensorflow/core/framework/tensor_shape.h"
36 #include "tensorflow/core/framework/tensor_types.h"
37 #include "tensorflow/core/framework/types.h"
38 #include "tensorflow/core/kernels/hinge-loss.h"
39 #include "tensorflow/core/kernels/logistic-loss.h"
40 #include "tensorflow/core/kernels/loss.h"
41 #include "tensorflow/core/kernels/poisson-loss.h"
42 #include "tensorflow/core/kernels/sdca_internal.h"
43 #include "tensorflow/core/kernels/smooth-hinge-loss.h"
44 #include "tensorflow/core/kernels/squared-loss.h"
45 #include "tensorflow/core/lib/core/coding.h"
46 #include "tensorflow/core/lib/core/errors.h"
47 #include "tensorflow/core/lib/core/status.h"
48 #include "tensorflow/core/lib/core/stringpiece.h"
49 #include "tensorflow/core/lib/gtl/inlined_vector.h"
50 #include "tensorflow/core/lib/strings/stringprintf.h"
51 #include "tensorflow/core/platform/fingerprint.h"
52 #include "tensorflow/core/platform/macros.h"
53 #include "tensorflow/core/platform/mutex.h"
54 #include "tensorflow/core/platform/types.h"
55 #include "tensorflow/core/util/work_sharder.h"
56 
57 namespace tensorflow {
58 
59 namespace {
60 
61 using sdca::Example;
62 using sdca::Examples;
63 using sdca::ExampleStatistics;
64 using sdca::ModelWeights;
65 using sdca::Regularizations;
66 
67 struct ComputeOptions {
ComputeOptionstensorflow::__anonc1cd44790111::ComputeOptions68   explicit ComputeOptions(OpKernelConstruction* const context) {
69     string loss_type;
70     OP_REQUIRES_OK(context, context->GetAttr("loss_type", &loss_type));
71     if (loss_type == "logistic_loss") {
72       loss_updater.reset(new LogisticLossUpdater);
73     } else if (loss_type == "squared_loss") {
74       loss_updater.reset(new SquaredLossUpdater);
75     } else if (loss_type == "hinge_loss") {
76       loss_updater.reset(new HingeLossUpdater);
77     } else if (loss_type == "smooth_hinge_loss") {
78       loss_updater.reset(new SmoothHingeLossUpdater);
79     } else if (loss_type == "poisson_loss") {
80       loss_updater.reset(new PoissonLossUpdater);
81     } else {
82       OP_REQUIRES(
83           context, false,
84           errors::InvalidArgument("Unsupported loss type: ", loss_type));
85     }
86     auto s = context->GetAttr("adaptative", &adaptive);
87     if (!s.ok()) {
88       s = context->GetAttr("adaptive", &adaptive);
89     }
90     OP_REQUIRES_OK(context, s);
91     OP_REQUIRES_OK(
92         context, context->GetAttr("num_sparse_features", &num_sparse_features));
93     OP_REQUIRES_OK(context, context->GetAttr("num_sparse_features_with_values",
94                                              &num_sparse_features_with_values));
95     OP_REQUIRES_OK(context,
96                    context->GetAttr("num_dense_features", &num_dense_features));
97     OP_REQUIRES(
98         context, num_sparse_features + num_dense_features > 0,
99         errors::InvalidArgument("Requires at least one feature to train."));
100 
101     OP_REQUIRES(context,
102                 static_cast<int64>(num_sparse_features) +
103                         static_cast<int64>(num_dense_features) <=
104                     std::numeric_limits<int>::max(),
105                 errors::InvalidArgument(
106                     strings::Printf("Too many feature groups: %lld > %d",
107                                     static_cast<int64>(num_sparse_features) +
108                                         static_cast<int64>(num_dense_features),
109                                     std::numeric_limits<int>::max())));
110     OP_REQUIRES_OK(
111         context, context->GetAttr("num_loss_partitions", &num_loss_partitions));
112     OP_REQUIRES_OK(context, context->GetAttr("num_inner_iterations",
113                                              &num_inner_iterations));
114     OP_REQUIRES_OK(context, regularizations.Initialize(context));
115   }
116 
117   std::unique_ptr<DualLossUpdater> loss_updater;
118   int num_sparse_features = 0;
119   int num_sparse_features_with_values = 0;
120   int num_dense_features = 0;
121   int num_inner_iterations = 0;
122   int num_loss_partitions = 0;
123   bool adaptive = true;
124   Regularizations regularizations;
125 };
126 
127 // TODO(shengx): The helper classes/methods are changed to support multiclass
128 // SDCA, which lead to changes within this function. Need to revisit the
129 // convergence once the multiclass SDCA is in.
DoCompute(const ComputeOptions & options,OpKernelContext * const context)130 void DoCompute(const ComputeOptions& options, OpKernelContext* const context) {
131   ModelWeights model_weights;
132   OP_REQUIRES_OK(context, model_weights.Initialize(context));
133 
134   Examples examples;
135   OP_REQUIRES_OK(
136       context,
137       examples.Initialize(context, model_weights, options.num_sparse_features,
138                           options.num_sparse_features_with_values,
139                           options.num_dense_features));
140 
141   const Tensor* example_state_data_t;
142   OP_REQUIRES_OK(context,
143                  context->input("example_state_data", &example_state_data_t));
144   TensorShape expected_example_state_shape({examples.num_examples(), 4});
145   OP_REQUIRES(context,
146               example_state_data_t->shape() == expected_example_state_shape,
147               errors::InvalidArgument(
148                   "Expected shape ", expected_example_state_shape.DebugString(),
149                   " for example_state_data, got ",
150                   example_state_data_t->shape().DebugString()));
151 
152   Tensor mutable_example_state_data_t(*example_state_data_t);
153   auto example_state_data = mutable_example_state_data_t.matrix<float>();
154   OP_REQUIRES_OK(context, context->set_output("out_example_state_data",
155                                               mutable_example_state_data_t));
156 
157   if (options.adaptive) {
158     OP_REQUIRES_OK(context,
159                    examples.SampleAdaptiveProbabilities(
160                        options.num_loss_partitions, options.regularizations,
161                        model_weights, example_state_data, options.loss_updater,
162                        /*num_weight_vectors =*/1));
163   } else {
164     examples.RandomShuffle();
165   }
166   struct {
167     mutex mu;
168     Status value GUARDED_BY(mu);
169   } train_step_status;
170   std::atomic<std::int64_t> atomic_index(-1);
171   auto train_step = [&](const int64 begin, const int64 end) {
172     // The static_cast here is safe since begin and end can be at most
173     // num_examples which is an int.
174     for (int id = static_cast<int>(begin); id < end; ++id) {
175       const int64 example_index = examples.sampled_index(++atomic_index);
176       const Example& example = examples.example(example_index);
177       const float dual = example_state_data(example_index, 0);
178       const float example_weight = example.example_weight();
179       float example_label = example.example_label();
180       const Status conversion_status =
181           options.loss_updater->ConvertLabel(&example_label);
182       if (!conversion_status.ok()) {
183         mutex_lock l(train_step_status.mu);
184         train_step_status.value = conversion_status;
185         // Return from this worker thread - the calling thread is
186         // responsible for checking context status and returning on error.
187         return;
188       }
189 
190       // Compute wx, example norm weighted by regularization, dual loss,
191       // primal loss.
192       // For binary SDCA, num_weight_vectors should be one.
193       const ExampleStatistics example_statistics =
194           example.ComputeWxAndWeightedExampleNorm(
195               options.num_loss_partitions, model_weights,
196               options.regularizations, 1 /* num_weight_vectors */);
197 
198       const double new_dual = options.loss_updater->ComputeUpdatedDual(
199           options.num_loss_partitions, example_label, example_weight, dual,
200           example_statistics.wx[0], example_statistics.normalized_squared_norm);
201 
202       // Compute new weights.
203       const double normalized_bounded_dual_delta =
204           (new_dual - dual) * example_weight /
205           options.regularizations.symmetric_l2();
206       model_weights.UpdateDeltaWeights(
207           context->eigen_cpu_device(), example,
208           std::vector<double>{normalized_bounded_dual_delta});
209 
210       // Update example data.
211       example_state_data(example_index, 0) = new_dual;
212       example_state_data(example_index, 1) =
213           options.loss_updater->ComputePrimalLoss(
214               example_statistics.prev_wx[0], example_label, example_weight);
215       example_state_data(example_index, 2) =
216           options.loss_updater->ComputeDualLoss(dual, example_label,
217                                                 example_weight);
218       example_state_data(example_index, 3) = example_weight;
219     }
220   };
221   // TODO(sibyl-Aix6ihai): Tune this properly based on sparsity of the data,
222   // number of cpus, and cost per example.
223   const int64 kCostPerUnit = examples.num_features();
224   const DeviceBase::CpuWorkerThreads& worker_threads =
225       *context->device()->tensorflow_cpu_worker_threads();
226 
227   Shard(worker_threads.num_threads, worker_threads.workers,
228         examples.num_examples(), kCostPerUnit, train_step);
229   mutex_lock l(train_step_status.mu);
230   OP_REQUIRES_OK(context, train_step_status.value);
231 }
232 
233 }  // namespace
234 
235 class SdcaOptimizer : public OpKernel {
236  public:
SdcaOptimizer(OpKernelConstruction * const context)237   explicit SdcaOptimizer(OpKernelConstruction* const context)
238       : OpKernel(context), options_(context) {}
239 
Compute(OpKernelContext * context)240   void Compute(OpKernelContext* context) override {
241     DoCompute(options_, context);
242   }
243 
244  private:
245   // TODO(sibyl-Aix6ihai): We could use the type-constraint on loss_type, and
246   // template the entire class to avoid the virtual table lookup penalty in
247   // the inner loop.
248   ComputeOptions options_;
249 };
250 REGISTER_KERNEL_BUILDER(Name("SdcaOptimizer").Device(DEVICE_CPU),
251                         SdcaOptimizer);
252 REGISTER_KERNEL_BUILDER(Name("SdcaOptimizerV2").Device(DEVICE_CPU),
253                         SdcaOptimizer);
254 
255 class SdcaShrinkL1 : public OpKernel {
256  public:
SdcaShrinkL1(OpKernelConstruction * const context)257   explicit SdcaShrinkL1(OpKernelConstruction* const context)
258       : OpKernel(context) {
259     OP_REQUIRES_OK(context, regularizations_.Initialize(context));
260   }
261 
Compute(OpKernelContext * context)262   void Compute(OpKernelContext* context) override {
263     OpMutableInputList weights_inputs;
264     OP_REQUIRES_OK(context,
265                    context->mutable_input_list("weights", &weights_inputs));
266 
267     auto do_work = [&](const int64 begin, const int64 end) {
268       for (int i = begin; i < end; ++i) {
269         auto prox_w = weights_inputs.at(i, /*lock_held=*/true).flat<float>();
270         prox_w.device(context->eigen_cpu_device()) =
271             regularizations_.EigenShrinkVector(prox_w);
272       }
273     };
274 
275     if (weights_inputs.size() > 0) {
276       int64 num_weights = 0;
277       for (int i = 0; i < weights_inputs.size(); ++i) {
278         num_weights += weights_inputs.at(i, /*lock_held=*/true).NumElements();
279       }
280       // TODO(sibyl-Aix6ihai): Tune this value.
281       const int64 kCostPerUnit = (num_weights * 50) / weights_inputs.size();
282       const DeviceBase::CpuWorkerThreads& worker_threads =
283           *context->device()->tensorflow_cpu_worker_threads();
284       Shard(worker_threads.num_threads, worker_threads.workers,
285             weights_inputs.size(), kCostPerUnit, do_work);
286     }
287   }
288 
289  private:
290   Regularizations regularizations_;
291 };
292 REGISTER_KERNEL_BUILDER(Name("SdcaShrinkL1").Device(DEVICE_CPU), SdcaShrinkL1);
293 
294 // Computes platform independent, compact and unique (with very high
295 // probability) representation of an example id. It shouldn't be put in
296 // persistent storage, as its implementation may change in the future.
297 //
298 // The current probability of at least one collision for 1B example_ids is
299 // approximately 10^-21 (ie 2^60 / 2^129).
300 class SdcaFprint : public OpKernel {
301  public:
SdcaFprint(OpKernelConstruction * const context)302   explicit SdcaFprint(OpKernelConstruction* const context)
303       : OpKernel(context) {}
304 
Compute(OpKernelContext * context)305   void Compute(OpKernelContext* context) override {
306     const Tensor& input = context->input(0);
307     OP_REQUIRES(context, TensorShapeUtils::IsVector(input.shape()),
308                 errors::InvalidArgument("Input must be a vector, got shape ",
309                                         input.shape().DebugString()));
310     Tensor* out;
311     const int64 num_elements = input.NumElements();
312     OP_REQUIRES_OK(context, context->allocate_output(
313                                 0, TensorShape({num_elements, 2}), &out));
314 
315     const auto in_values = input.flat<string>();
316     auto out_values = out->matrix<int64>();
317 
318     for (int64 i = 0; i < num_elements; ++i) {
319       const Fprint128 fprint = Fingerprint128(in_values(i));
320       // Never return 0 or 1 as the first value of the hash to allow these to
321       // safely be used as sentinel values (e.g. dense hash table empty key).
322       out_values(i, 0) = TF_PREDICT_TRUE(fprint.low64 >= 2)
323                              ? fprint.low64
324                              : fprint.low64 + ~static_cast<uint64>(1);
325       out_values(i, 1) = fprint.high64;
326     }
327   }
328 };
329 REGISTER_KERNEL_BUILDER(Name("SdcaFprint").Device(DEVICE_CPU), SdcaFprint);
330 
331 }  // namespace tensorflow
332