• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Protocol Buffers - Google's data interchange format
2 // Copyright 2008 Google Inc.  All rights reserved.
3 //
4 // Use of this source code is governed by a BSD-style
5 // license that can be found in the LICENSE file or at
6 // https://developers.google.com/open-source/licenses/bsd
7 
8 // Author: ksroka@google.com (Krzysztof Sroka)
9 
10 #include "google/protobuf/util/field_comparator.h"
11 
12 #include <algorithm>
13 #include <cfloat>
14 #include <cmath>
15 #include <limits>
16 #include <string>
17 
18 #include "absl/log/absl_check.h"
19 #include "absl/log/absl_log.h"
20 #include "google/protobuf/descriptor.h"
21 #include "google/protobuf/message.h"
22 #include "google/protobuf/util/message_differencer.h"
23 
24 // Must be included last.
25 #include "google/protobuf/port_def.inc"
26 
27 namespace google {
28 namespace protobuf {
29 namespace util {
30 namespace {
31 template <typename T>
32 struct Epsilon {};
33 template <>
34 struct Epsilon<float> {
35   constexpr static auto value = 32 * FLT_EPSILON;
36 };
37 template <>
38 struct Epsilon<double> {
39   constexpr static auto value = 32 * DBL_EPSILON;
40 };
41 
42 template <typename T>
WithinFractionOrMargin(const T x,const T y,const T fraction,const T margin)43 bool WithinFractionOrMargin(const T x, const T y, const T fraction,
44                             const T margin) {
45   ABSL_DCHECK(fraction >= T(0) && fraction < T(1) && margin >= T(0));
46 
47   if (!std::isfinite(x) || !std::isfinite(y)) {
48     return false;
49   }
50   const T relative_margin = fraction * std::max(std::fabs(x), std::fabs(y));
51   return std::fabs(x - y) <= std::max(margin, relative_margin);
52 }
53 
54 }  // namespace
55 
FieldComparator()56 FieldComparator::FieldComparator() {}
~FieldComparator()57 FieldComparator::~FieldComparator() {}
58 
SimpleFieldComparator()59 SimpleFieldComparator::SimpleFieldComparator()
60     : float_comparison_(EXACT),
61       treat_nan_as_equal_(false),
62       has_default_tolerance_(false) {}
63 
~SimpleFieldComparator()64 SimpleFieldComparator::~SimpleFieldComparator() {}
65 
SimpleCompare(const Message & message_1,const Message & message_2,const FieldDescriptor * field,int index_1,int index_2,const util::FieldContext *)66 FieldComparator::ComparisonResult SimpleFieldComparator::SimpleCompare(
67     const Message& message_1, const Message& message_2,
68     const FieldDescriptor* field, int index_1, int index_2,
69     const util::FieldContext* /*field_context*/) {
70   const Reflection* reflection_1 = message_1.GetReflection();
71   const Reflection* reflection_2 = message_2.GetReflection();
72 
73   switch (field->cpp_type()) {
74 #define COMPARE_FIELD(METHOD)                                                 \
75   if (field->is_repeated()) {                                                 \
76     return ResultFromBoolean(Compare##METHOD(                                 \
77         *field, reflection_1->GetRepeated##METHOD(message_1, field, index_1), \
78         reflection_2->GetRepeated##METHOD(message_2, field, index_2)));       \
79   } else {                                                                    \
80     return ResultFromBoolean(                                                 \
81         Compare##METHOD(*field, reflection_1->Get##METHOD(message_1, field),  \
82                         reflection_2->Get##METHOD(message_2, field)));        \
83   }                                                                           \
84   break;  // Make sure no fall-through is introduced.
85 
86     case FieldDescriptor::CPPTYPE_BOOL:
87       COMPARE_FIELD(Bool);
88     case FieldDescriptor::CPPTYPE_DOUBLE:
89       COMPARE_FIELD(Double);
90     case FieldDescriptor::CPPTYPE_ENUM:
91       COMPARE_FIELD(Enum);
92     case FieldDescriptor::CPPTYPE_FLOAT:
93       COMPARE_FIELD(Float);
94     case FieldDescriptor::CPPTYPE_INT32:
95       COMPARE_FIELD(Int32);
96     case FieldDescriptor::CPPTYPE_INT64:
97       COMPARE_FIELD(Int64);
98     case FieldDescriptor::CPPTYPE_STRING:
99       if (field->is_repeated()) {
100         // Allocate scratch strings to store the result if a conversion is
101         // needed.
102         std::string scratch1;
103         std::string scratch2;
104         return ResultFromBoolean(
105             CompareString(*field,
106                           reflection_1->GetRepeatedStringReference(
107                               message_1, field, index_1, &scratch1),
108                           reflection_2->GetRepeatedStringReference(
109                               message_2, field, index_2, &scratch2)));
110       } else {
111         // Allocate scratch strings to store the result if a conversion is
112         // needed.
113         std::string scratch1;
114         std::string scratch2;
115         return ResultFromBoolean(CompareString(
116             *field,
117             reflection_1->GetStringReference(message_1, field, &scratch1),
118             reflection_2->GetStringReference(message_2, field, &scratch2)));
119       }
120       break;
121     case FieldDescriptor::CPPTYPE_UINT32:
122       COMPARE_FIELD(UInt32);
123     case FieldDescriptor::CPPTYPE_UINT64:
124       COMPARE_FIELD(UInt64);
125 
126 #undef COMPARE_FIELD
127 
128     case FieldDescriptor::CPPTYPE_MESSAGE:
129       return RECURSE;
130 
131     default:
132       ABSL_LOG(FATAL) << "No comparison code for field " << field->full_name()
133                       << " of CppType = " << field->cpp_type();
134       return DIFFERENT;
135   }
136 }
137 
CompareWithDifferencer(MessageDifferencer * differencer,const Message & message1,const Message & message2,const util::FieldContext * field_context)138 bool SimpleFieldComparator::CompareWithDifferencer(
139     MessageDifferencer* differencer, const Message& message1,
140     const Message& message2, const util::FieldContext* field_context) {
141   const Descriptor* descriptor1 = message1.GetDescriptor();
142   const Descriptor* descriptor2 = message2.GetDescriptor();
143   if (descriptor1 != descriptor2) {
144     ABSL_DLOG(FATAL) << "Comparison between two messages with different "
145                      << "descriptors. " << descriptor1->full_name() << " vs "
146                      << descriptor2->full_name();
147     return false;
148   }
149   return differencer->Compare(message1, message2, false,
150                               field_context->parent_fields());
151 }
152 
SetDefaultFractionAndMargin(double fraction,double margin)153 void SimpleFieldComparator::SetDefaultFractionAndMargin(double fraction,
154                                                         double margin) {
155   default_tolerance_ = Tolerance(fraction, margin);
156   has_default_tolerance_ = true;
157 }
158 
SetFractionAndMargin(const FieldDescriptor * field,double fraction,double margin)159 void SimpleFieldComparator::SetFractionAndMargin(const FieldDescriptor* field,
160                                                  double fraction,
161                                                  double margin) {
162   ABSL_CHECK(FieldDescriptor::CPPTYPE_FLOAT == field->cpp_type() ||
163              FieldDescriptor::CPPTYPE_DOUBLE == field->cpp_type())
164       << "Field has to be float or double type. Field name is: "
165       << field->full_name();
166   map_tolerance_[field] = Tolerance(fraction, margin);
167 }
168 
CompareDouble(const FieldDescriptor & field,double value_1,double value_2)169 bool SimpleFieldComparator::CompareDouble(const FieldDescriptor& field,
170                                           double value_1, double value_2) {
171   return CompareDoubleOrFloat(field, value_1, value_2);
172 }
173 
CompareEnum(const FieldDescriptor &,const EnumValueDescriptor * value_1,const EnumValueDescriptor * value_2)174 bool SimpleFieldComparator::CompareEnum(const FieldDescriptor& /*field*/,
175                                         const EnumValueDescriptor* value_1,
176                                         const EnumValueDescriptor* value_2) {
177   return value_1->number() == value_2->number();
178 }
179 
CompareFloat(const FieldDescriptor & field,float value_1,float value_2)180 bool SimpleFieldComparator::CompareFloat(const FieldDescriptor& field,
181                                          float value_1, float value_2) {
182   return CompareDoubleOrFloat(field, value_1, value_2);
183 }
184 
185 template <typename T>
CompareDoubleOrFloat(const FieldDescriptor & field,T value_1,T value_2)186 bool SimpleFieldComparator::CompareDoubleOrFloat(const FieldDescriptor& field,
187                                                  T value_1, T value_2) {
188   if (value_1 == value_2) {
189     // Covers +inf and -inf (which are not within margin or fraction of
190     // themselves), and is a shortcut for finite values.
191     return true;
192   } else if (float_comparison_ == EXACT) {
193     if (treat_nan_as_equal_ && std::isnan(value_1) && std::isnan(value_2)) {
194       return true;
195     }
196     return false;
197   } else {
198     if (treat_nan_as_equal_ && std::isnan(value_1) && std::isnan(value_2)) {
199       return true;
200     }
201     // float_comparison_ == APPROXIMATE covers two use cases.
202     Tolerance* tolerance = nullptr;
203     if (has_default_tolerance_) tolerance = &default_tolerance_;
204 
205     auto it = map_tolerance_.find(&field);
206     if (it != map_tolerance_.end()) {
207       tolerance = &it->second;
208     }
209 
210     if (tolerance != nullptr) {
211       // Use user-provided fraction and margin. Since they are stored as
212       // doubles, we explicitly cast them to types of values provided. This
213       // is very likely to fail if provided values are not numeric.
214       return WithinFractionOrMargin(value_1, value_2,
215                                     static_cast<T>(tolerance->fraction),
216                                     static_cast<T>(tolerance->margin));
217     } else {
218       if (std::fabs(value_1) <= Epsilon<T>::value &&
219           std::fabs(value_2) <= Epsilon<T>::value) {
220         return true;
221       }
222       return WithinFractionOrMargin(value_1, value_2, Epsilon<T>::value,
223                                     Epsilon<T>::value);
224     }
225   }
226 }
227 
ResultFromBoolean(bool boolean_result) const228 FieldComparator::ComparisonResult SimpleFieldComparator::ResultFromBoolean(
229     bool boolean_result) const {
230   return boolean_result ? FieldComparator::SAME : FieldComparator::DIFFERENT;
231 }
232 
233 }  // namespace util
234 }  // namespace protobuf
235 }  // namespace google
236 
237 #include "google/protobuf/port_undef.inc"
238