• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/costs/robust_stats.h"
17 #include <algorithm>
18 #include <cmath>
19 #include <utility>
20 
21 namespace tensorflow {
22 namespace grappler {
23 
24 // Given a sorted vector of values, calculate the median.
25 // Returns 0 for an empty vector.  Does not verify sortedness.
SortedMedian(const std::vector<double> & values)26 static double SortedMedian(const std::vector<double> &values) {
27   const int n = values.size();
28   if (n == 0) return 0.0;
29   if (n & 1) {
30     return values[n / 2];
31   } else {
32     return (values[n / 2] + values[n / 2 - 1]) / 2.0;
33   }
34 }
35 
36 // Given a vector of values (sorted or not), calculate the median.
Median(std::vector<double> && values)37 static double Median(std::vector<double> &&values) {
38   const size_t n = values.size();
39   if (n == 0) return 0;
40   const auto middle = values.begin() + (n / 2);
41   // Put the middle value in its place.
42   std::nth_element(values.begin(), middle, values.end());
43   if (n & 1) {
44     return *middle;
45   }
46   // Return the average of the two elements, the max_element lower than
47   // *middle is found between begin and middle as a post-cond of
48   // nth_element.
49   const auto lower_middle = std::max_element(values.begin(), middle);
50   // Preventing overflow. We know that '*lower_middle <= *middle'.
51   // If both are on opposite sides of zero, the sum won't overflow, otherwise
52   // the difference won't overflow.
53   if (*lower_middle <= 0 && *middle >= 0) {
54     return (*lower_middle + *middle) / 2;
55   }
56   return *lower_middle + (*middle - *lower_middle) / 2;
57 }
58 
59 // Given a set of values, calculates the scaled Median Absolute Deviation (a
60 // robust approximation to the standard deviation).  This is calculated as the
61 // median of the absolute deviations from the median, scaled by 1.4826.  Its
62 // advantage over the standard deviation is that it is not (as) affected by
63 // outlier values.  Returns a pair<median, mad>.
ScaledMedianAbsoluteDeviation(const std::vector<double> & sorted_values)64 static std::pair<double, double> ScaledMedianAbsoluteDeviation(
65     const std::vector<double> &sorted_values) {
66   double median = SortedMedian(sorted_values);
67 
68   // Next, we calculate the absolute deviations from the median,
69   // find the median of the resulting data, and scale by 1.4826.
70   std::vector<double> deviations;
71   deviations.reserve(sorted_values.size());
72   for (double d : sorted_values) {
73     deviations.push_back(std::abs(d - median));
74   }
75   double mad = Median(std::move(deviations)) * 1.4826;
76   return std::pair<double, double>(median, mad);
77 }
78 
RobustStats(const std::vector<double> & values)79 RobustStats::RobustStats(const std::vector<double> &values)
80     : RobustStats(std::vector<double>(values)) {}
81 
RobustStats(std::vector<double> && values)82 RobustStats::RobustStats(std::vector<double> &&values) {
83   std::sort(values.begin(), values.end());
84   lo_ = values[0];
85   hi_ = values.back();
86   HuberMAD(values);
87 }
88 
89 // Computes an updated mean using Huber's weighting function (values beyond
90 // the margin are weighted by margin / abs(value - mean).
UpdateHuberMean(const std::vector<double> & sorted_values,double mean,double margin)91 double UpdateHuberMean(const std::vector<double> &sorted_values, double mean,
92                        double margin) {
93   int num_within = 0;
94   double sum = 0.0;
95 
96   for (double d : sorted_values) {
97     if (d < mean - margin) {
98       sum -= margin;
99     } else if (d > mean + margin) {
100       sum += margin;
101     } else {
102       sum += d;
103       ++num_within;
104     }
105   }
106 
107   // It is possible, for a set with an interquartile distance of 0, i.e., with
108   // more than half of the values at the median, to encounter the case where
109   // the Huber mean drifts slightly off the median and there are no values
110   // within the margin.  In that case, just return the old mean, and the caller
111   // will quit.
112   if (num_within > 0) {
113     return sum / num_within;
114   } else {
115     return mean;
116   }
117 }
118 
119 // Given a list of values, this approximates the stddev using the MAD and then
120 // uses it to compute a Huber robust mean (sandwich mean).  A margin of
121 // c*stddev is defined around the current mean, and values are weighted by
122 // margin / abs(value - mean) if outside the margin, or 1 if inside.  This
123 // computes the mean iteratively, because each time it changes the margin
124 // shifts a bit.  It typically settles very quickly, but it's possible for it
125 // to be unstable.  We limit it to 10 iterations.
126 //
HuberMAD(const std::vector<double> & sorted_values)127 void RobustStats::HuberMAD(const std::vector<double> &sorted_values) {
128   const std::pair<double, double> median_mad =
129       ScaledMedianAbsoluteDeviation(sorted_values);
130   mean_ = median_mad.first;
131   stddev_ = median_mad.second;
132 
133   // c = 1.345 is the commonly used cutoff with 95% efficiency at the normal.
134   // We're using c = 1.5 to be a little more conservative, and because that's
135   // the default in S-plus.
136   // TODO(dehnert): Specialize Stats for integral types so we don't implement
137   // methods that don't make sense.
138   const double c = 1.5;
139   const double margin = c * stddev_;
140 
141   // Iterate 10 times, or until the Huber mean stabilizes.
142   // If the margin is zero, we don't want mean to drift from the median.
143   if (margin > 0.0) {
144     for (int k = 0; k < 10; ++k) {
145       double old_mean = mean_;
146       mean_ = UpdateHuberMean(sorted_values, mean_, margin);
147       if (mean_ == old_mean) break;
148     }
149   }
150 }
151 
152 }  // namespace grappler
153 }  // namespace tensorflow
154