• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_HINGE_LOSS_H_
17 #define TENSORFLOW_CORE_KERNELS_HINGE_LOSS_H_
18 
19 #include <algorithm>
20 #include <limits>
21 
22 #include "tensorflow/core/kernels/loss.h"
23 #include "tensorflow/core/lib/core/errors.h"
24 #include "tensorflow/core/lib/core/status.h"
25 
26 namespace tensorflow {
27 
28 class HingeLossUpdater : public DualLossUpdater {
29  public:
30   // Computes the updated dual variable (corresponding) to a single example. The
31   // updated dual value maximizes the objective function of the dual
32   // optimization problem associated with hinge loss (conditioned on keeping the
33   // rest of the dual variables intact). The method below finds an optimal delta
34   // (difference between updated and previous dual value) using the update rule
35   // within SDCA procedure (see http://arxiv.org/pdf/1209.1873v2.pdf, page 5)
36   // and the particular form of conjugate function for hinge loss.
37   //
38   // The CoCoA+ modification is detailed in readme.md.
39   //
40   // TODO(sibyl-vie3Poto): Write up a doc with concrete derivation and point to it from
41   // here.
ComputeUpdatedDual(const int num_loss_partitions,const double label,const double example_weight,const double current_dual,const double wx,const double weighted_example_norm)42   double ComputeUpdatedDual(const int num_loss_partitions, const double label,
43                             const double example_weight,
44                             const double current_dual, const double wx,
45                             const double weighted_example_norm) const final {
46     // Intutitvely there are 3 cases:
47     // a. new optimal value of the dual variable falls within the admissible
48     // range [0, 1]. In this case we set new dual to this value.
49     // b. new optimal value is < 0. Then, because of convexity, the optimal
50     // valid value for new dual = 0
51     // c. new optimal value > 1.0. Then new optimal value should be set to 1.0.
52     const double candidate_optimal_dual =
53         current_dual + (label - wx) / (num_loss_partitions * example_weight *
54                                        weighted_example_norm);
55     if (label * candidate_optimal_dual < 0) {
56       return 0.0;
57     }
58     if (label * candidate_optimal_dual > 1.0) {
59       return label;
60     }
61     return candidate_optimal_dual;
62   }
63 
64   // Conjugate of hinge loss. This is computed as:
65   // \phi*(z) = z if z \in [-1, 0] and +infinity everywhere else. See for
66   // instance http://www.eecs.berkeley.edu/~wainwrig/stat241b/lec10.pdf
67   // Here we want the weighted version of the conjugate loss. It turns out, that
68   // if w is the weight of an example, the conjugate of the weighted hinge loss
69   // is given by:
70   // \phi*(z) = z if z \in [-w, 0] and +infinity everywhere else. Here the
71   // conjugate function depends not only on the weight of the example but also
72   // on its label. In particular:
73   // \phi_y*(z) = y*z if y*z \in [-w, 0] and +infinity everywhere else where
74   // y \in {-1,1}. The following method implements \phi_y*(-\alpha/w).
ComputeDualLoss(const double current_dual,const double example_label,const double example_weight)75   double ComputeDualLoss(const double current_dual, const double example_label,
76                          const double example_weight) const final {
77     // For binary classification, there are 2 conjugate functions, one per
78     // label value (-1 and 1).
79     const double y_alpha = current_dual * example_label;  // y \alpha
80     if (y_alpha < 0 || y_alpha > 1.0) {
81       return std::numeric_limits<double>::max();
82     }
83     return -y_alpha * example_weight;
84   }
85 
86   // Hinge loss for binary classification for a single example. Hinge loss
87   // equals max(0, 1 - y * wx) (see https://en.wikipedia.org/wiki/Hinge_loss).
88   // For weighted instances loss should be multiplied by the instance weight.
ComputePrimalLoss(const double wx,const double example_label,const double example_weight)89   double ComputePrimalLoss(const double wx, const double example_label,
90                            const double example_weight) const final {
91     const double y_wx = example_label * wx;
92     return std::max(0.0, 1 - y_wx) * example_weight;
93   }
94 
PrimalLossDerivative(const double wx,const double label,const double example_weight)95   double PrimalLossDerivative(const double wx, const double label,
96                               const double example_weight) const final {
97     if (label * wx < 1) {
98       return -label * example_weight;
99     }
100     return 0;
101   }
102 
103   // The smoothness constant is 0 since the derivative of the loss is not
104   // Lipschitz
SmoothnessConstant()105   double SmoothnessConstant() const final { return 0; }
106 
107   // Converts binary example labels from 0.0 or 1.0 to -1.0 or 1.0 respectively
108   // as expected by hinge loss.
ConvertLabel(float * const example_label)109   Status ConvertLabel(float* const example_label) const final {
110     if (*example_label == 0.0) {
111       *example_label = -1;
112       return Status::OK();
113     }
114     if (*example_label == 1.0) {
115       return Status::OK();
116     }
117     return errors::InvalidArgument(
118         "Only labels of 0.0 or 1.0 are supported right now. "
119         "Found example with label: ",
120         *example_label);
121   }
122 };
123 
124 }  // namespace tensorflow
125 
126 #endif  // TENSORFLOW_CORE_KERNELS_HINGE_LOSS_H_
127