1 // Protocol Buffers - Google's data interchange format
2 // Copyright 2008 Google Inc. All rights reserved.
3 // https://developers.google.com/protocol-buffers/
4 //
5 // Redistribution and use in source and binary forms, with or without
6 // modification, are permitted provided that the following conditions are
7 // met:
8 //
9 // * Redistributions of source code must retain the above copyright
10 // notice, this list of conditions and the following disclaimer.
11 // * Redistributions in binary form must reproduce the above
12 // copyright notice, this list of conditions and the following disclaimer
13 // in the documentation and/or other materials provided with the
14 // distribution.
15 // * Neither the name of Google Inc. nor the names of its
16 // contributors may be used to endorse or promote products derived from
17 // this software without specific prior written permission.
18 //
19 // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20 // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21 // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22 // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
23 // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
24 // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
25 // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
26 // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
27 // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
28 // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
31 #include <google/protobuf/compiler/cpp/cpp_map_field.h>
32 #include <google/protobuf/compiler/cpp/cpp_helpers.h>
33 #include <google/protobuf/io/printer.h>
34 #include <google/protobuf/wire_format.h>
35 #include <google/protobuf/stubs/strutil.h>
36
37
38 namespace google {
39 namespace protobuf {
40 namespace compiler {
41 namespace cpp {
42
IsProto3Field(const FieldDescriptor * field_descriptor)43 bool IsProto3Field(const FieldDescriptor* field_descriptor) {
44 const FileDescriptor* file_descriptor = field_descriptor->file();
45 return file_descriptor->syntax() == FileDescriptor::SYNTAX_PROTO3;
46 }
47
SetMessageVariables(const FieldDescriptor * descriptor,std::map<std::string,std::string> * variables,const Options & options)48 void SetMessageVariables(const FieldDescriptor* descriptor,
49 std::map<std::string, std::string>* variables,
50 const Options& options) {
51 SetCommonFieldVariables(descriptor, variables, options);
52 (*variables)["type"] = ClassName(descriptor->message_type(), false);
53 (*variables)["stream_writer"] =
54 (*variables)["declared_type"] +
55 (HasFastArraySerialization(descriptor->message_type()->file(), options)
56 ? "MaybeToArray"
57 : "");
58 (*variables)["full_name"] = descriptor->full_name();
59
60 const FieldDescriptor* key =
61 descriptor->message_type()->FindFieldByName("key");
62 const FieldDescriptor* val =
63 descriptor->message_type()->FindFieldByName("value");
64 (*variables)["key_cpp"] = PrimitiveTypeName(options, key->cpp_type());
65 switch (val->cpp_type()) {
66 case FieldDescriptor::CPPTYPE_MESSAGE:
67 (*variables)["val_cpp"] = FieldMessageTypeName(val, options);
68 break;
69 case FieldDescriptor::CPPTYPE_ENUM:
70 (*variables)["val_cpp"] = ClassName(val->enum_type(), true);
71 break;
72 default:
73 (*variables)["val_cpp"] = PrimitiveTypeName(options, val->cpp_type());
74 }
75 (*variables)["key_wire_type"] =
76 "TYPE_" + ToUpper(DeclaredTypeMethodName(key->type()));
77 (*variables)["val_wire_type"] =
78 "TYPE_" + ToUpper(DeclaredTypeMethodName(val->type()));
79 (*variables)["map_classname"] = ClassName(descriptor->message_type(), false);
80 (*variables)["number"] = StrCat(descriptor->number());
81 (*variables)["tag"] = StrCat(internal::WireFormat::MakeTag(descriptor));
82
83 if (HasDescriptorMethods(descriptor->file(), options)) {
84 (*variables)["lite"] = "";
85 } else {
86 (*variables)["lite"] = "Lite";
87 }
88
89 if (!IsProto3Field(descriptor) && val->type() == FieldDescriptor::TYPE_ENUM) {
90 const EnumValueDescriptor* default_value = val->default_value_enum();
91 (*variables)["default_enum_value"] = Int32ToString(default_value->number());
92 } else {
93 (*variables)["default_enum_value"] = "0";
94 }
95 }
96
MapFieldGenerator(const FieldDescriptor * descriptor,const Options & options)97 MapFieldGenerator::MapFieldGenerator(const FieldDescriptor* descriptor,
98 const Options& options)
99 : FieldGenerator(descriptor, options) {
100 SetMessageVariables(descriptor, &variables_, options);
101 }
102
~MapFieldGenerator()103 MapFieldGenerator::~MapFieldGenerator() {}
104
GeneratePrivateMembers(io::Printer * printer) const105 void MapFieldGenerator::GeneratePrivateMembers(io::Printer* printer) const {
106 Formatter format(printer, variables_);
107 format(
108 "::$proto_ns$::internal::MapField$lite$<\n"
109 " $map_classname$,\n"
110 " $key_cpp$, $val_cpp$,\n"
111 " ::$proto_ns$::internal::WireFormatLite::$key_wire_type$,\n"
112 " ::$proto_ns$::internal::WireFormatLite::$val_wire_type$,\n"
113 " $default_enum_value$ > $name$_;\n");
114 }
115
GenerateAccessorDeclarations(io::Printer * printer) const116 void MapFieldGenerator::GenerateAccessorDeclarations(
117 io::Printer* printer) const {
118 Formatter format(printer, variables_);
119 format(
120 "$deprecated_attr$const ::$proto_ns$::Map< $key_cpp$, $val_cpp$ >&\n"
121 " ${1$$name$$}$() const;\n"
122 "$deprecated_attr$::$proto_ns$::Map< $key_cpp$, $val_cpp$ >*\n"
123 " ${1$mutable_$name$$}$();\n",
124 descriptor_);
125 }
126
GenerateInlineAccessorDefinitions(io::Printer * printer) const127 void MapFieldGenerator::GenerateInlineAccessorDefinitions(
128 io::Printer* printer) const {
129 Formatter format(printer, variables_);
130 format(
131 "inline const ::$proto_ns$::Map< $key_cpp$, $val_cpp$ >&\n"
132 "$classname$::$name$() const {\n"
133 " // @@protoc_insertion_point(field_map:$full_name$)\n"
134 " return $name$_.GetMap();\n"
135 "}\n"
136 "inline ::$proto_ns$::Map< $key_cpp$, $val_cpp$ >*\n"
137 "$classname$::mutable_$name$() {\n"
138 " // @@protoc_insertion_point(field_mutable_map:$full_name$)\n"
139 " return $name$_.MutableMap();\n"
140 "}\n");
141 }
142
GenerateClearingCode(io::Printer * printer) const143 void MapFieldGenerator::GenerateClearingCode(io::Printer* printer) const {
144 Formatter format(printer, variables_);
145 format("$name$_.Clear();\n");
146 }
147
GenerateMergingCode(io::Printer * printer) const148 void MapFieldGenerator::GenerateMergingCode(io::Printer* printer) const {
149 Formatter format(printer, variables_);
150 format("$name$_.MergeFrom(from.$name$_);\n");
151 }
152
GenerateSwappingCode(io::Printer * printer) const153 void MapFieldGenerator::GenerateSwappingCode(io::Printer* printer) const {
154 Formatter format(printer, variables_);
155 format("$name$_.Swap(&other->$name$_);\n");
156 }
157
GenerateCopyConstructorCode(io::Printer * printer) const158 void MapFieldGenerator::GenerateCopyConstructorCode(
159 io::Printer* printer) const {
160 GenerateConstructorCode(printer);
161 GenerateMergingCode(printer);
162 }
163
GenerateMergeFromCodedStream(io::Printer * printer) const164 void MapFieldGenerator::GenerateMergeFromCodedStream(
165 io::Printer* printer) const {
166 Formatter format(printer, variables_);
167 const FieldDescriptor* key_field =
168 descriptor_->message_type()->FindFieldByName("key");
169 const FieldDescriptor* value_field =
170 descriptor_->message_type()->FindFieldByName("value");
171 std::string key;
172 std::string value;
173 format(
174 "$map_classname$::Parser< ::$proto_ns$::internal::MapField$lite$<\n"
175 " $map_classname$,\n"
176 " $key_cpp$, $val_cpp$,\n"
177 " ::$proto_ns$::internal::WireFormatLite::$key_wire_type$,\n"
178 " ::$proto_ns$::internal::WireFormatLite::$val_wire_type$,\n"
179 " $default_enum_value$ >,\n"
180 " ::$proto_ns$::Map< $key_cpp$, $val_cpp$ > >"
181 " parser(&$name$_);\n");
182 if (IsProto3Field(descriptor_) ||
183 value_field->type() != FieldDescriptor::TYPE_ENUM) {
184 format(
185 "DO_(::$proto_ns$::internal::WireFormatLite::ReadMessageNoVirtual(\n"
186 " input, &parser));\n");
187 key = "parser.key()";
188 value = "parser.value()";
189 } else {
190 key = "entry->key()";
191 value = "entry->value()";
192 format("auto entry = parser.NewEntry();\n");
193 format(
194 "std::string data;\n"
195 "DO_(::$proto_ns$::internal::WireFormatLite::ReadString(input, "
196 "&data));\n"
197 "DO_(entry->ParseFromString(data));\n"
198 "if ($val_cpp$_IsValid(*entry->mutable_value())) {\n"
199 " (*mutable_$name$())[entry->key()] =\n"
200 " static_cast< $val_cpp$ >(*entry->mutable_value());\n"
201 "} else {\n");
202 if (HasDescriptorMethods(descriptor_->file(), options_)) {
203 format(
204 " mutable_unknown_fields()"
205 "->AddLengthDelimited($number$, data);\n");
206 } else {
207 format(
208 " unknown_fields_stream.WriteVarint32($tag$u);\n"
209 " unknown_fields_stream.WriteVarint32(\n"
210 " static_cast< ::google::protobuf::uint32>(data.size()));\n"
211 " unknown_fields_stream.WriteString(data);\n");
212 }
213 format("}\n");
214 }
215
216 if (key_field->type() == FieldDescriptor::TYPE_STRING) {
217 GenerateUtf8CheckCodeForString(
218 key_field, options_, true,
219 StrCat(key, ".data(), static_cast<int>(", key, ".length()),\n")
220 .data(),
221 format);
222 }
223 if (value_field->type() == FieldDescriptor::TYPE_STRING) {
224 GenerateUtf8CheckCodeForString(
225 value_field, options_, true,
226 StrCat(value, ".data(), static_cast<int>(", value,
227 ".length()),\n")
228 .data(),
229 format);
230 }
231 }
232
GenerateSerializationLoop(const Formatter & format,bool string_key,bool string_value,bool to_array,bool is_deterministic)233 static void GenerateSerializationLoop(const Formatter& format, bool string_key,
234 bool string_value, bool to_array,
235 bool is_deterministic) {
236 std::string ptr;
237 if (is_deterministic) {
238 format("for (size_type i = 0; i < n; i++) {\n");
239 ptr = string_key ? "items[static_cast<ptrdiff_t>(i)]"
240 : "items[static_cast<ptrdiff_t>(i)].second";
241 } else {
242 format(
243 "for (::$proto_ns$::Map< $key_cpp$, $val_cpp$ >::const_iterator\n"
244 " it = this->$name$().begin();\n"
245 " it != this->$name$().end(); ++it) {\n");
246 ptr = "it";
247 }
248 format.Indent();
249
250 if (to_array) {
251 format(
252 "target = $map_classname$::Funcs::SerializeToArray($number$, "
253 "$1$->first, $1$->second, target);\n",
254 ptr);
255 } else {
256 format(
257 "$map_classname$::Funcs::SerializeToCodedStream($number$, "
258 "$1$->first, $1$->second, output);\n",
259 ptr);
260 }
261
262 if (string_key || string_value) {
263 // ptr is either an actual pointer or an iterator, either way we can
264 // create a pointer by taking the address after de-referencing it.
265 format("Utf8Check::Check(&(*$1$));\n", ptr);
266 }
267
268 format.Outdent();
269 format("}\n");
270 }
271
GenerateSerializeWithCachedSizes(io::Printer * printer) const272 void MapFieldGenerator::GenerateSerializeWithCachedSizes(
273 io::Printer* printer) const {
274 GenerateSerializeWithCachedSizes(printer, false);
275 }
276
GenerateSerializeWithCachedSizesToArray(io::Printer * printer) const277 void MapFieldGenerator::GenerateSerializeWithCachedSizesToArray(
278 io::Printer* printer) const {
279 GenerateSerializeWithCachedSizes(printer, true);
280 }
281
GenerateSerializeWithCachedSizes(io::Printer * printer,bool to_array) const282 void MapFieldGenerator::GenerateSerializeWithCachedSizes(io::Printer* printer,
283 bool to_array) const {
284 Formatter format(printer, variables_);
285 format("if (!this->$name$().empty()) {\n");
286 format.Indent();
287 const FieldDescriptor* key_field =
288 descriptor_->message_type()->FindFieldByName("key");
289 const FieldDescriptor* value_field =
290 descriptor_->message_type()->FindFieldByName("value");
291 const bool string_key = key_field->type() == FieldDescriptor::TYPE_STRING;
292 const bool string_value = value_field->type() == FieldDescriptor::TYPE_STRING;
293
294 format(
295 "typedef ::$proto_ns$::Map< $key_cpp$, $val_cpp$ >::const_pointer\n"
296 " ConstPtr;\n");
297 if (string_key) {
298 format(
299 "typedef ConstPtr SortItem;\n"
300 "typedef ::$proto_ns$::internal::"
301 "CompareByDerefFirst<SortItem> Less;\n");
302 } else {
303 format(
304 "typedef ::$proto_ns$::internal::SortItem< $key_cpp$, ConstPtr > "
305 "SortItem;\n"
306 "typedef ::$proto_ns$::internal::CompareByFirstField<SortItem> "
307 "Less;\n");
308 }
309 bool utf8_check = string_key || string_value;
310 if (utf8_check) {
311 format(
312 "struct Utf8Check {\n"
313 " static void Check(ConstPtr p) {\n");
314 format.Indent();
315 format.Indent();
316 if (string_key) {
317 GenerateUtf8CheckCodeForString(
318 key_field, options_, false,
319 "p->first.data(), static_cast<int>(p->first.length()),\n", format);
320 }
321 if (string_value) {
322 GenerateUtf8CheckCodeForString(
323 value_field, options_, false,
324 "p->second.data(), static_cast<int>(p->second.length()),\n", format);
325 }
326 format.Outdent();
327 format.Outdent();
328 format(
329 " }\n"
330 "};\n");
331 }
332
333 format(
334 "\n"
335 "if ($1$ &&\n"
336 " this->$name$().size() > 1) {\n"
337 " ::std::unique_ptr<SortItem[]> items(\n"
338 " new SortItem[this->$name$().size()]);\n"
339 " typedef ::$proto_ns$::Map< $key_cpp$, $val_cpp$ >::size_type "
340 "size_type;\n"
341 " size_type n = 0;\n"
342 " for (::$proto_ns$::Map< $key_cpp$, $val_cpp$ >::const_iterator\n"
343 " it = this->$name$().begin();\n"
344 " it != this->$name$().end(); ++it, ++n) {\n"
345 " items[static_cast<ptrdiff_t>(n)] = SortItem(&*it);\n"
346 " }\n"
347 " ::std::sort(&items[0], &items[static_cast<ptrdiff_t>(n)], Less());\n",
348 to_array ? "false" : "output->IsSerializationDeterministic()");
349 format.Indent();
350 GenerateSerializationLoop(format, string_key, string_value, to_array, true);
351 format.Outdent();
352 format("} else {\n");
353 format.Indent();
354 GenerateSerializationLoop(format, string_key, string_value, to_array, false);
355 format.Outdent();
356 format("}\n");
357 format.Outdent();
358 format("}\n");
359 }
360
GenerateByteSize(io::Printer * printer) const361 void MapFieldGenerator::GenerateByteSize(io::Printer* printer) const {
362 Formatter format(printer, variables_);
363 format(
364 "total_size += $tag_size$ *\n"
365 " ::$proto_ns$::internal::FromIntSize(this->$name$_size());\n"
366 "for (::$proto_ns$::Map< $key_cpp$, $val_cpp$ >::const_iterator\n"
367 " it = this->$name$().begin();\n"
368 " it != this->$name$().end(); ++it) {\n"
369 " total_size += $map_classname$::Funcs::ByteSizeLong(it->first, "
370 "it->second);\n"
371 "}\n");
372 }
373
374 } // namespace cpp
375 } // namespace compiler
376 } // namespace protobuf
377 } // namespace google
378