• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_SMOOTH_HINGE_LOSS_H_
17 #define TENSORFLOW_CORE_KERNELS_SMOOTH_HINGE_LOSS_H_
18 
19 #include <limits>
20 
21 #include "tensorflow/core/kernels/loss.h"
22 #include "tensorflow/core/lib/core/errors.h"
23 #include "tensorflow/core/lib/core/status.h"
24 
25 namespace tensorflow {
26 
27 class SmoothHingeLossUpdater : public DualLossUpdater {
28  public:
29   // Computes the updated dual variable (corresponding) to a single example. The
30   // updated dual value maximizes the objective function of the dual
31   // optimization problem associated with smooth hinge loss. The computations
32   // are detailed in readme.md.
ComputeUpdatedDual(const int num_partitions,const double label,const double example_weight,const double current_dual,const double wx,const double weighted_example_norm)33   double ComputeUpdatedDual(const int num_partitions, const double label,
34                             const double example_weight,
35                             const double current_dual, const double wx,
36                             const double weighted_example_norm) const final {
37     // Intuitively there are 3 cases:
38     // a. new optimal value of the dual variable falls within the admissible
39     // range [0, 1]. In this case we set new dual to this value.
40     // b. new optimal value is < 0. Then, because of convexity, the optimal
41     // valid value for new dual = 0
42     // c. new optimal value > 1.0. Then new optimal value should be set to 1.0.
43     const double candidate_optimal_dual =
44         current_dual +
45         (label - wx - gamma * current_dual) /
46             (num_partitions * example_weight * weighted_example_norm + gamma);
47     if (label * candidate_optimal_dual < 0) {
48       return 0.0;
49     }
50     if (label * candidate_optimal_dual > 1.0) {
51       return label;
52     }
53     return candidate_optimal_dual;
54   }
55 
ComputeDualLoss(const double current_dual,const double example_label,const double example_weight)56   double ComputeDualLoss(const double current_dual, const double example_label,
57                          const double example_weight) const final {
58     // For binary classification, there are 2 conjugate functions, one per
59     // label value (-1 and 1).
60     const double y_alpha = current_dual * example_label;  // y \alpha
61     if (y_alpha < 0 || y_alpha > 1.0) {
62       return std::numeric_limits<double>::max();
63     }
64     return (-y_alpha + 0.5 * gamma * current_dual * current_dual) *
65            example_weight;
66   }
67 
ComputePrimalLoss(const double wx,const double example_label,const double example_weight)68   double ComputePrimalLoss(const double wx, const double example_label,
69                            const double example_weight) const final {
70     const double y_wx = example_label * wx;
71     if (y_wx >= 1) return 0;
72     if (y_wx <= 1 - gamma) return (1 - y_wx - gamma / 2) * example_weight;
73     return (1 - y_wx) * (1 - y_wx) * example_weight * 0.5 / gamma;
74   }
75 
76   // Converts binary example labels from 0.0 or 1.0 to -1.0 or 1.0 respectively
77   // as expected by smooth hinge loss.
ConvertLabel(float * const example_label)78   Status ConvertLabel(float* const example_label) const final {
79     if (*example_label == 0.0) {
80       *example_label = -1;
81       return OkStatus();
82     }
83     if (*example_label == 1.0) {
84       return OkStatus();
85     }
86     return errors::InvalidArgument(
87         "Only labels of 0.0 or 1.0 are supported right now. "
88         "Found example with label: ",
89         *example_label);
90   }
91 
PrimalLossDerivative(const double wx,const double label,const double example_weight)92   double PrimalLossDerivative(const double wx, const double label,
93                               const double example_weight) const final {
94     if (label * wx >= 1) {
95       return 0;
96     }
97     if (label * wx <= 1 - gamma) {
98       return -label;
99     }
100     return (wx - label) / gamma;
101   }
102 
SmoothnessConstant()103   double SmoothnessConstant() const final { return gamma; }
104 
105  private:
106   // Smoothness constant of smooth hinge loss
107   // TODO(sibyl-Aix6ihai): expose this parameter
108   const double gamma = 1;
109 };
110 
111 }  // namespace tensorflow
112 
113 #endif  // TENSORFLOW_CORE_KERNELS_SMOOTH_HINGE_LOSS_H_
114 // TENSORFLOW_KERNELS_SMOOTH_HINGE_LOSS_H_
115