1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // See docs in ../ops/sdca_ops.cc.
17
18 #define EIGEN_USE_THREADS
19
20 #include <stdint.h>
21 #include <atomic>
22 #include <limits>
23 #include <memory>
24 #include <new>
25 #include <string>
26 #include <vector>
27
28 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
29 #include "tensorflow/core/framework/device_base.h"
30 #include "tensorflow/core/framework/kernel_def_builder.h"
31 #include "tensorflow/core/framework/op.h"
32 #include "tensorflow/core/framework/op_def_builder.h"
33 #include "tensorflow/core/framework/op_kernel.h"
34 #include "tensorflow/core/framework/tensor.h"
35 #include "tensorflow/core/framework/tensor_shape.h"
36 #include "tensorflow/core/framework/tensor_types.h"
37 #include "tensorflow/core/framework/types.h"
38 #include "tensorflow/core/kernels/hinge-loss.h"
39 #include "tensorflow/core/kernels/logistic-loss.h"
40 #include "tensorflow/core/kernels/loss.h"
41 #include "tensorflow/core/kernels/poisson-loss.h"
42 #include "tensorflow/core/kernels/sdca_internal.h"
43 #include "tensorflow/core/kernels/smooth-hinge-loss.h"
44 #include "tensorflow/core/kernels/squared-loss.h"
45 #include "tensorflow/core/lib/core/coding.h"
46 #include "tensorflow/core/lib/core/errors.h"
47 #include "tensorflow/core/lib/core/status.h"
48 #include "tensorflow/core/lib/core/stringpiece.h"
49 #include "tensorflow/core/lib/gtl/inlined_vector.h"
50 #include "tensorflow/core/lib/strings/stringprintf.h"
51 #include "tensorflow/core/platform/fingerprint.h"
52 #include "tensorflow/core/platform/macros.h"
53 #include "tensorflow/core/platform/mutex.h"
54 #include "tensorflow/core/platform/types.h"
55 #include "tensorflow/core/util/work_sharder.h"
56
57 namespace tensorflow {
58
59 namespace {
60
61 using sdca::Example;
62 using sdca::Examples;
63 using sdca::ExampleStatistics;
64 using sdca::ModelWeights;
65 using sdca::Regularizations;
66
67 struct ComputeOptions {
ComputeOptionstensorflow::__anonc1cd44790111::ComputeOptions68 explicit ComputeOptions(OpKernelConstruction* const context) {
69 string loss_type;
70 OP_REQUIRES_OK(context, context->GetAttr("loss_type", &loss_type));
71 if (loss_type == "logistic_loss") {
72 loss_updater.reset(new LogisticLossUpdater);
73 } else if (loss_type == "squared_loss") {
74 loss_updater.reset(new SquaredLossUpdater);
75 } else if (loss_type == "hinge_loss") {
76 loss_updater.reset(new HingeLossUpdater);
77 } else if (loss_type == "smooth_hinge_loss") {
78 loss_updater.reset(new SmoothHingeLossUpdater);
79 } else if (loss_type == "poisson_loss") {
80 loss_updater.reset(new PoissonLossUpdater);
81 } else {
82 OP_REQUIRES(
83 context, false,
84 errors::InvalidArgument("Unsupported loss type: ", loss_type));
85 }
86 auto s = context->GetAttr("adaptative", &adaptive);
87 if (!s.ok()) {
88 s = context->GetAttr("adaptive", &adaptive);
89 }
90 OP_REQUIRES_OK(context, s);
91 OP_REQUIRES_OK(
92 context, context->GetAttr("num_sparse_features", &num_sparse_features));
93 OP_REQUIRES_OK(context, context->GetAttr("num_sparse_features_with_values",
94 &num_sparse_features_with_values));
95 OP_REQUIRES_OK(context,
96 context->GetAttr("num_dense_features", &num_dense_features));
97 OP_REQUIRES(
98 context, num_sparse_features + num_dense_features > 0,
99 errors::InvalidArgument("Requires at least one feature to train."));
100
101 OP_REQUIRES(context,
102 static_cast<int64>(num_sparse_features) +
103 static_cast<int64>(num_dense_features) <=
104 std::numeric_limits<int>::max(),
105 errors::InvalidArgument(
106 strings::Printf("Too many feature groups: %lld > %d",
107 static_cast<int64>(num_sparse_features) +
108 static_cast<int64>(num_dense_features),
109 std::numeric_limits<int>::max())));
110 OP_REQUIRES_OK(
111 context, context->GetAttr("num_loss_partitions", &num_loss_partitions));
112 OP_REQUIRES_OK(context, context->GetAttr("num_inner_iterations",
113 &num_inner_iterations));
114 OP_REQUIRES_OK(context, regularizations.Initialize(context));
115 }
116
117 std::unique_ptr<DualLossUpdater> loss_updater;
118 int num_sparse_features = 0;
119 int num_sparse_features_with_values = 0;
120 int num_dense_features = 0;
121 int num_inner_iterations = 0;
122 int num_loss_partitions = 0;
123 bool adaptive = true;
124 Regularizations regularizations;
125 };
126
127 // TODO(shengx): The helper classes/methods are changed to support multiclass
128 // SDCA, which lead to changes within this function. Need to revisit the
129 // convergence once the multiclass SDCA is in.
DoCompute(const ComputeOptions & options,OpKernelContext * const context)130 void DoCompute(const ComputeOptions& options, OpKernelContext* const context) {
131 ModelWeights model_weights;
132 OP_REQUIRES_OK(context, model_weights.Initialize(context));
133
134 Examples examples;
135 OP_REQUIRES_OK(
136 context,
137 examples.Initialize(context, model_weights, options.num_sparse_features,
138 options.num_sparse_features_with_values,
139 options.num_dense_features));
140
141 const Tensor* example_state_data_t;
142 OP_REQUIRES_OK(context,
143 context->input("example_state_data", &example_state_data_t));
144 TensorShape expected_example_state_shape({examples.num_examples(), 4});
145 OP_REQUIRES(context,
146 example_state_data_t->shape() == expected_example_state_shape,
147 errors::InvalidArgument(
148 "Expected shape ", expected_example_state_shape.DebugString(),
149 " for example_state_data, got ",
150 example_state_data_t->shape().DebugString()));
151
152 Tensor mutable_example_state_data_t(*example_state_data_t);
153 auto example_state_data = mutable_example_state_data_t.matrix<float>();
154 OP_REQUIRES_OK(context, context->set_output("out_example_state_data",
155 mutable_example_state_data_t));
156
157 if (options.adaptive) {
158 OP_REQUIRES_OK(context,
159 examples.SampleAdaptiveProbabilities(
160 options.num_loss_partitions, options.regularizations,
161 model_weights, example_state_data, options.loss_updater,
162 /*num_weight_vectors =*/1));
163 } else {
164 examples.RandomShuffle();
165 }
166 struct {
167 mutex mu;
168 Status value GUARDED_BY(mu);
169 } train_step_status;
170 std::atomic<std::int64_t> atomic_index(-1);
171 auto train_step = [&](const int64 begin, const int64 end) {
172 // The static_cast here is safe since begin and end can be at most
173 // num_examples which is an int.
174 for (int id = static_cast<int>(begin); id < end; ++id) {
175 const int64 example_index = examples.sampled_index(++atomic_index);
176 const Example& example = examples.example(example_index);
177 const float dual = example_state_data(example_index, 0);
178 const float example_weight = example.example_weight();
179 float example_label = example.example_label();
180 const Status conversion_status =
181 options.loss_updater->ConvertLabel(&example_label);
182 if (!conversion_status.ok()) {
183 mutex_lock l(train_step_status.mu);
184 train_step_status.value = conversion_status;
185 // Return from this worker thread - the calling thread is
186 // responsible for checking context status and returning on error.
187 return;
188 }
189
190 // Compute wx, example norm weighted by regularization, dual loss,
191 // primal loss.
192 // For binary SDCA, num_weight_vectors should be one.
193 const ExampleStatistics example_statistics =
194 example.ComputeWxAndWeightedExampleNorm(
195 options.num_loss_partitions, model_weights,
196 options.regularizations, 1 /* num_weight_vectors */);
197
198 const double new_dual = options.loss_updater->ComputeUpdatedDual(
199 options.num_loss_partitions, example_label, example_weight, dual,
200 example_statistics.wx[0], example_statistics.normalized_squared_norm);
201
202 // Compute new weights.
203 const double normalized_bounded_dual_delta =
204 (new_dual - dual) * example_weight /
205 options.regularizations.symmetric_l2();
206 model_weights.UpdateDeltaWeights(
207 context->eigen_cpu_device(), example,
208 std::vector<double>{normalized_bounded_dual_delta});
209
210 // Update example data.
211 example_state_data(example_index, 0) = new_dual;
212 example_state_data(example_index, 1) =
213 options.loss_updater->ComputePrimalLoss(
214 example_statistics.prev_wx[0], example_label, example_weight);
215 example_state_data(example_index, 2) =
216 options.loss_updater->ComputeDualLoss(dual, example_label,
217 example_weight);
218 example_state_data(example_index, 3) = example_weight;
219 }
220 };
221 // TODO(sibyl-Aix6ihai): Tune this properly based on sparsity of the data,
222 // number of cpus, and cost per example.
223 const int64 kCostPerUnit = examples.num_features();
224 const DeviceBase::CpuWorkerThreads& worker_threads =
225 *context->device()->tensorflow_cpu_worker_threads();
226
227 Shard(worker_threads.num_threads, worker_threads.workers,
228 examples.num_examples(), kCostPerUnit, train_step);
229 mutex_lock l(train_step_status.mu);
230 OP_REQUIRES_OK(context, train_step_status.value);
231 }
232
233 } // namespace
234
235 class SdcaOptimizer : public OpKernel {
236 public:
SdcaOptimizer(OpKernelConstruction * const context)237 explicit SdcaOptimizer(OpKernelConstruction* const context)
238 : OpKernel(context), options_(context) {}
239
Compute(OpKernelContext * context)240 void Compute(OpKernelContext* context) override {
241 DoCompute(options_, context);
242 }
243
244 private:
245 // TODO(sibyl-Aix6ihai): We could use the type-constraint on loss_type, and
246 // template the entire class to avoid the virtual table lookup penalty in
247 // the inner loop.
248 ComputeOptions options_;
249 };
250 REGISTER_KERNEL_BUILDER(Name("SdcaOptimizer").Device(DEVICE_CPU),
251 SdcaOptimizer);
252 REGISTER_KERNEL_BUILDER(Name("SdcaOptimizerV2").Device(DEVICE_CPU),
253 SdcaOptimizer);
254
255 class SdcaShrinkL1 : public OpKernel {
256 public:
SdcaShrinkL1(OpKernelConstruction * const context)257 explicit SdcaShrinkL1(OpKernelConstruction* const context)
258 : OpKernel(context) {
259 OP_REQUIRES_OK(context, regularizations_.Initialize(context));
260 }
261
Compute(OpKernelContext * context)262 void Compute(OpKernelContext* context) override {
263 OpMutableInputList weights_inputs;
264 OP_REQUIRES_OK(context,
265 context->mutable_input_list("weights", &weights_inputs));
266
267 auto do_work = [&](const int64 begin, const int64 end) {
268 for (int i = begin; i < end; ++i) {
269 auto prox_w = weights_inputs.at(i, /*lock_held=*/true).flat<float>();
270 prox_w.device(context->eigen_cpu_device()) =
271 regularizations_.EigenShrinkVector(prox_w);
272 }
273 };
274
275 if (weights_inputs.size() > 0) {
276 int64 num_weights = 0;
277 for (int i = 0; i < weights_inputs.size(); ++i) {
278 num_weights += weights_inputs.at(i, /*lock_held=*/true).NumElements();
279 }
280 // TODO(sibyl-Aix6ihai): Tune this value.
281 const int64 kCostPerUnit = (num_weights * 50) / weights_inputs.size();
282 const DeviceBase::CpuWorkerThreads& worker_threads =
283 *context->device()->tensorflow_cpu_worker_threads();
284 Shard(worker_threads.num_threads, worker_threads.workers,
285 weights_inputs.size(), kCostPerUnit, do_work);
286 }
287 }
288
289 private:
290 Regularizations regularizations_;
291 };
292 REGISTER_KERNEL_BUILDER(Name("SdcaShrinkL1").Device(DEVICE_CPU), SdcaShrinkL1);
293
294 // Computes platform independent, compact and unique (with very high
295 // probability) representation of an example id. It shouldn't be put in
296 // persistent storage, as its implementation may change in the future.
297 //
298 // The current probability of at least one collision for 1B example_ids is
299 // approximately 10^-21 (ie 2^60 / 2^129).
300 class SdcaFprint : public OpKernel {
301 public:
SdcaFprint(OpKernelConstruction * const context)302 explicit SdcaFprint(OpKernelConstruction* const context)
303 : OpKernel(context) {}
304
Compute(OpKernelContext * context)305 void Compute(OpKernelContext* context) override {
306 const Tensor& input = context->input(0);
307 OP_REQUIRES(context, TensorShapeUtils::IsVector(input.shape()),
308 errors::InvalidArgument("Input must be a vector, got shape ",
309 input.shape().DebugString()));
310 Tensor* out;
311 const int64 num_elements = input.NumElements();
312 OP_REQUIRES_OK(context, context->allocate_output(
313 0, TensorShape({num_elements, 2}), &out));
314
315 const auto in_values = input.flat<string>();
316 auto out_values = out->matrix<int64>();
317
318 for (int64 i = 0; i < num_elements; ++i) {
319 const Fprint128 fprint = Fingerprint128(in_values(i));
320 // Never return 0 or 1 as the first value of the hash to allow these to
321 // safely be used as sentinel values (e.g. dense hash table empty key).
322 out_values(i, 0) = TF_PREDICT_TRUE(fprint.low64 >= 2)
323 ? fprint.low64
324 : fprint.low64 + ~static_cast<uint64>(1);
325 out_values(i, 1) = fprint.high64;
326 }
327 }
328 };
329 REGISTER_KERNEL_BUILDER(Name("SdcaFprint").Device(DEVICE_CPU), SdcaFprint);
330
331 } // namespace tensorflow
332