1 // Protocol Buffers - Google's data interchange format
2 // Copyright 2008 Google Inc. All rights reserved.
3 //
4 // Use of this source code is governed by a BSD-style
5 // license that can be found in the LICENSE file or at
6 // https://developers.google.com/open-source/licenses/bsd
7
8 // Author: ksroka@google.com (Krzysztof Sroka)
9
10 #include "google/protobuf/util/field_comparator.h"
11
12 #include <algorithm>
13 #include <cfloat>
14 #include <cmath>
15 #include <limits>
16 #include <string>
17
18 #include "absl/log/absl_check.h"
19 #include "absl/log/absl_log.h"
20 #include "google/protobuf/descriptor.h"
21 #include "google/protobuf/message.h"
22 #include "google/protobuf/util/message_differencer.h"
23
24 // Must be included last.
25 #include "google/protobuf/port_def.inc"
26
27 namespace google {
28 namespace protobuf {
29 namespace util {
30 namespace {
31 template <typename T>
32 struct Epsilon {};
33 template <>
34 struct Epsilon<float> {
35 constexpr static auto value = 32 * FLT_EPSILON;
36 };
37 template <>
38 struct Epsilon<double> {
39 constexpr static auto value = 32 * DBL_EPSILON;
40 };
41
42 template <typename T>
WithinFractionOrMargin(const T x,const T y,const T fraction,const T margin)43 bool WithinFractionOrMargin(const T x, const T y, const T fraction,
44 const T margin) {
45 ABSL_DCHECK(fraction >= T(0) && fraction < T(1) && margin >= T(0));
46
47 if (!std::isfinite(x) || !std::isfinite(y)) {
48 return false;
49 }
50 const T relative_margin = fraction * std::max(std::fabs(x), std::fabs(y));
51 return std::fabs(x - y) <= std::max(margin, relative_margin);
52 }
53
54 } // namespace
55
FieldComparator()56 FieldComparator::FieldComparator() {}
~FieldComparator()57 FieldComparator::~FieldComparator() {}
58
SimpleFieldComparator()59 SimpleFieldComparator::SimpleFieldComparator()
60 : float_comparison_(EXACT),
61 treat_nan_as_equal_(false),
62 has_default_tolerance_(false) {}
63
~SimpleFieldComparator()64 SimpleFieldComparator::~SimpleFieldComparator() {}
65
SimpleCompare(const Message & message_1,const Message & message_2,const FieldDescriptor * field,int index_1,int index_2,const util::FieldContext *)66 FieldComparator::ComparisonResult SimpleFieldComparator::SimpleCompare(
67 const Message& message_1, const Message& message_2,
68 const FieldDescriptor* field, int index_1, int index_2,
69 const util::FieldContext* /*field_context*/) {
70 const Reflection* reflection_1 = message_1.GetReflection();
71 const Reflection* reflection_2 = message_2.GetReflection();
72
73 switch (field->cpp_type()) {
74 #define COMPARE_FIELD(METHOD) \
75 if (field->is_repeated()) { \
76 return ResultFromBoolean(Compare##METHOD( \
77 *field, reflection_1->GetRepeated##METHOD(message_1, field, index_1), \
78 reflection_2->GetRepeated##METHOD(message_2, field, index_2))); \
79 } else { \
80 return ResultFromBoolean( \
81 Compare##METHOD(*field, reflection_1->Get##METHOD(message_1, field), \
82 reflection_2->Get##METHOD(message_2, field))); \
83 } \
84 break; // Make sure no fall-through is introduced.
85
86 case FieldDescriptor::CPPTYPE_BOOL:
87 COMPARE_FIELD(Bool);
88 case FieldDescriptor::CPPTYPE_DOUBLE:
89 COMPARE_FIELD(Double);
90 case FieldDescriptor::CPPTYPE_ENUM:
91 COMPARE_FIELD(Enum);
92 case FieldDescriptor::CPPTYPE_FLOAT:
93 COMPARE_FIELD(Float);
94 case FieldDescriptor::CPPTYPE_INT32:
95 COMPARE_FIELD(Int32);
96 case FieldDescriptor::CPPTYPE_INT64:
97 COMPARE_FIELD(Int64);
98 case FieldDescriptor::CPPTYPE_STRING:
99 if (field->is_repeated()) {
100 // Allocate scratch strings to store the result if a conversion is
101 // needed.
102 std::string scratch1;
103 std::string scratch2;
104 return ResultFromBoolean(
105 CompareString(*field,
106 reflection_1->GetRepeatedStringReference(
107 message_1, field, index_1, &scratch1),
108 reflection_2->GetRepeatedStringReference(
109 message_2, field, index_2, &scratch2)));
110 } else {
111 // Allocate scratch strings to store the result if a conversion is
112 // needed.
113 std::string scratch1;
114 std::string scratch2;
115 return ResultFromBoolean(CompareString(
116 *field,
117 reflection_1->GetStringReference(message_1, field, &scratch1),
118 reflection_2->GetStringReference(message_2, field, &scratch2)));
119 }
120 break;
121 case FieldDescriptor::CPPTYPE_UINT32:
122 COMPARE_FIELD(UInt32);
123 case FieldDescriptor::CPPTYPE_UINT64:
124 COMPARE_FIELD(UInt64);
125
126 #undef COMPARE_FIELD
127
128 case FieldDescriptor::CPPTYPE_MESSAGE:
129 return RECURSE;
130
131 default:
132 ABSL_LOG(FATAL) << "No comparison code for field " << field->full_name()
133 << " of CppType = " << field->cpp_type();
134 return DIFFERENT;
135 }
136 }
137
CompareWithDifferencer(MessageDifferencer * differencer,const Message & message1,const Message & message2,const util::FieldContext * field_context)138 bool SimpleFieldComparator::CompareWithDifferencer(
139 MessageDifferencer* differencer, const Message& message1,
140 const Message& message2, const util::FieldContext* field_context) {
141 const Descriptor* descriptor1 = message1.GetDescriptor();
142 const Descriptor* descriptor2 = message2.GetDescriptor();
143 if (descriptor1 != descriptor2) {
144 ABSL_DLOG(FATAL) << "Comparison between two messages with different "
145 << "descriptors. " << descriptor1->full_name() << " vs "
146 << descriptor2->full_name();
147 return false;
148 }
149 return differencer->Compare(message1, message2, false,
150 field_context->parent_fields());
151 }
152
SetDefaultFractionAndMargin(double fraction,double margin)153 void SimpleFieldComparator::SetDefaultFractionAndMargin(double fraction,
154 double margin) {
155 default_tolerance_ = Tolerance(fraction, margin);
156 has_default_tolerance_ = true;
157 }
158
SetFractionAndMargin(const FieldDescriptor * field,double fraction,double margin)159 void SimpleFieldComparator::SetFractionAndMargin(const FieldDescriptor* field,
160 double fraction,
161 double margin) {
162 ABSL_CHECK(FieldDescriptor::CPPTYPE_FLOAT == field->cpp_type() ||
163 FieldDescriptor::CPPTYPE_DOUBLE == field->cpp_type())
164 << "Field has to be float or double type. Field name is: "
165 << field->full_name();
166 map_tolerance_[field] = Tolerance(fraction, margin);
167 }
168
CompareDouble(const FieldDescriptor & field,double value_1,double value_2)169 bool SimpleFieldComparator::CompareDouble(const FieldDescriptor& field,
170 double value_1, double value_2) {
171 return CompareDoubleOrFloat(field, value_1, value_2);
172 }
173
CompareEnum(const FieldDescriptor &,const EnumValueDescriptor * value_1,const EnumValueDescriptor * value_2)174 bool SimpleFieldComparator::CompareEnum(const FieldDescriptor& /*field*/,
175 const EnumValueDescriptor* value_1,
176 const EnumValueDescriptor* value_2) {
177 return value_1->number() == value_2->number();
178 }
179
CompareFloat(const FieldDescriptor & field,float value_1,float value_2)180 bool SimpleFieldComparator::CompareFloat(const FieldDescriptor& field,
181 float value_1, float value_2) {
182 return CompareDoubleOrFloat(field, value_1, value_2);
183 }
184
185 template <typename T>
CompareDoubleOrFloat(const FieldDescriptor & field,T value_1,T value_2)186 bool SimpleFieldComparator::CompareDoubleOrFloat(const FieldDescriptor& field,
187 T value_1, T value_2) {
188 if (value_1 == value_2) {
189 // Covers +inf and -inf (which are not within margin or fraction of
190 // themselves), and is a shortcut for finite values.
191 return true;
192 } else if (float_comparison_ == EXACT) {
193 if (treat_nan_as_equal_ && std::isnan(value_1) && std::isnan(value_2)) {
194 return true;
195 }
196 return false;
197 } else {
198 if (treat_nan_as_equal_ && std::isnan(value_1) && std::isnan(value_2)) {
199 return true;
200 }
201 // float_comparison_ == APPROXIMATE covers two use cases.
202 Tolerance* tolerance = nullptr;
203 if (has_default_tolerance_) tolerance = &default_tolerance_;
204
205 auto it = map_tolerance_.find(&field);
206 if (it != map_tolerance_.end()) {
207 tolerance = &it->second;
208 }
209
210 if (tolerance != nullptr) {
211 // Use user-provided fraction and margin. Since they are stored as
212 // doubles, we explicitly cast them to types of values provided. This
213 // is very likely to fail if provided values are not numeric.
214 return WithinFractionOrMargin(value_1, value_2,
215 static_cast<T>(tolerance->fraction),
216 static_cast<T>(tolerance->margin));
217 } else {
218 if (std::fabs(value_1) <= Epsilon<T>::value &&
219 std::fabs(value_2) <= Epsilon<T>::value) {
220 return true;
221 }
222 return WithinFractionOrMargin(value_1, value_2, Epsilon<T>::value,
223 Epsilon<T>::value);
224 }
225 }
226 }
227
ResultFromBoolean(bool boolean_result) const228 FieldComparator::ComparisonResult SimpleFieldComparator::ResultFromBoolean(
229 bool boolean_result) const {
230 return boolean_result ? FieldComparator::SAME : FieldComparator::DIFFERENT;
231 }
232
233 } // namespace util
234 } // namespace protobuf
235 } // namespace google
236
237 #include "google/protobuf/port_undef.inc"
238