• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Protocol Buffers - Google's data interchange format
2 // Copyright 2008 Google Inc.  All rights reserved.
3 // https://developers.google.com/protocol-buffers/
4 //
5 // Redistribution and use in source and binary forms, with or without
6 // modification, are permitted provided that the following conditions are
7 // met:
8 //
9 //     * Redistributions of source code must retain the above copyright
10 // notice, this list of conditions and the following disclaimer.
11 //     * Redistributions in binary form must reproduce the above
12 // copyright notice, this list of conditions and the following disclaimer
13 // in the documentation and/or other materials provided with the
14 // distribution.
15 //     * Neither the name of Google Inc. nor the names of its
16 // contributors may be used to endorse or promote products derived from
17 // this software without specific prior written permission.
18 //
19 // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20 // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21 // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22 // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
23 // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
24 // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
25 // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
26 // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
27 // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
28 // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 
31 #include <google/protobuf/compiler/cpp/cpp_map_field.h>
32 #include <google/protobuf/compiler/cpp/cpp_helpers.h>
33 #include <google/protobuf/io/printer.h>
34 #include <google/protobuf/wire_format.h>
35 #include <google/protobuf/stubs/strutil.h>
36 
37 
38 namespace google {
39 namespace protobuf {
40 namespace compiler {
41 namespace cpp {
42 
IsProto3Field(const FieldDescriptor * field_descriptor)43 bool IsProto3Field(const FieldDescriptor* field_descriptor) {
44   const FileDescriptor* file_descriptor = field_descriptor->file();
45   return file_descriptor->syntax() == FileDescriptor::SYNTAX_PROTO3;
46 }
47 
SetMessageVariables(const FieldDescriptor * descriptor,std::map<std::string,std::string> * variables,const Options & options)48 void SetMessageVariables(const FieldDescriptor* descriptor,
49                          std::map<std::string, std::string>* variables,
50                          const Options& options) {
51   SetCommonFieldVariables(descriptor, variables, options);
52   (*variables)["type"] = ClassName(descriptor->message_type(), false);
53   (*variables)["stream_writer"] =
54       (*variables)["declared_type"] +
55       (HasFastArraySerialization(descriptor->message_type()->file(), options)
56            ? "MaybeToArray"
57            : "");
58   (*variables)["full_name"] = descriptor->full_name();
59 
60   const FieldDescriptor* key =
61       descriptor->message_type()->FindFieldByName("key");
62   const FieldDescriptor* val =
63       descriptor->message_type()->FindFieldByName("value");
64   (*variables)["key_cpp"] = PrimitiveTypeName(options, key->cpp_type());
65   switch (val->cpp_type()) {
66     case FieldDescriptor::CPPTYPE_MESSAGE:
67       (*variables)["val_cpp"] = FieldMessageTypeName(val, options);
68       break;
69     case FieldDescriptor::CPPTYPE_ENUM:
70       (*variables)["val_cpp"] = ClassName(val->enum_type(), true);
71       break;
72     default:
73       (*variables)["val_cpp"] = PrimitiveTypeName(options, val->cpp_type());
74   }
75   (*variables)["key_wire_type"] =
76       "TYPE_" + ToUpper(DeclaredTypeMethodName(key->type()));
77   (*variables)["val_wire_type"] =
78       "TYPE_" + ToUpper(DeclaredTypeMethodName(val->type()));
79   (*variables)["map_classname"] = ClassName(descriptor->message_type(), false);
80   (*variables)["number"] = StrCat(descriptor->number());
81   (*variables)["tag"] = StrCat(internal::WireFormat::MakeTag(descriptor));
82 
83   if (HasDescriptorMethods(descriptor->file(), options)) {
84     (*variables)["lite"] = "";
85   } else {
86     (*variables)["lite"] = "Lite";
87   }
88 
89   if (!IsProto3Field(descriptor) && val->type() == FieldDescriptor::TYPE_ENUM) {
90     const EnumValueDescriptor* default_value = val->default_value_enum();
91     (*variables)["default_enum_value"] = Int32ToString(default_value->number());
92   } else {
93     (*variables)["default_enum_value"] = "0";
94   }
95 }
96 
MapFieldGenerator(const FieldDescriptor * descriptor,const Options & options)97 MapFieldGenerator::MapFieldGenerator(const FieldDescriptor* descriptor,
98                                      const Options& options)
99     : FieldGenerator(descriptor, options) {
100   SetMessageVariables(descriptor, &variables_, options);
101 }
102 
~MapFieldGenerator()103 MapFieldGenerator::~MapFieldGenerator() {}
104 
GeneratePrivateMembers(io::Printer * printer) const105 void MapFieldGenerator::GeneratePrivateMembers(io::Printer* printer) const {
106   Formatter format(printer, variables_);
107   format(
108       "::$proto_ns$::internal::MapField$lite$<\n"
109       "    $map_classname$,\n"
110       "    $key_cpp$, $val_cpp$,\n"
111       "    ::$proto_ns$::internal::WireFormatLite::$key_wire_type$,\n"
112       "    ::$proto_ns$::internal::WireFormatLite::$val_wire_type$,\n"
113       "    $default_enum_value$ > $name$_;\n");
114 }
115 
GenerateAccessorDeclarations(io::Printer * printer) const116 void MapFieldGenerator::GenerateAccessorDeclarations(
117     io::Printer* printer) const {
118   Formatter format(printer, variables_);
119   format(
120       "$deprecated_attr$const ::$proto_ns$::Map< $key_cpp$, $val_cpp$ >&\n"
121       "    ${1$$name$$}$() const;\n"
122       "$deprecated_attr$::$proto_ns$::Map< $key_cpp$, $val_cpp$ >*\n"
123       "    ${1$mutable_$name$$}$();\n",
124       descriptor_);
125 }
126 
GenerateInlineAccessorDefinitions(io::Printer * printer) const127 void MapFieldGenerator::GenerateInlineAccessorDefinitions(
128     io::Printer* printer) const {
129   Formatter format(printer, variables_);
130   format(
131       "inline const ::$proto_ns$::Map< $key_cpp$, $val_cpp$ >&\n"
132       "$classname$::$name$() const {\n"
133       "  // @@protoc_insertion_point(field_map:$full_name$)\n"
134       "  return $name$_.GetMap();\n"
135       "}\n"
136       "inline ::$proto_ns$::Map< $key_cpp$, $val_cpp$ >*\n"
137       "$classname$::mutable_$name$() {\n"
138       "  // @@protoc_insertion_point(field_mutable_map:$full_name$)\n"
139       "  return $name$_.MutableMap();\n"
140       "}\n");
141 }
142 
GenerateClearingCode(io::Printer * printer) const143 void MapFieldGenerator::GenerateClearingCode(io::Printer* printer) const {
144   Formatter format(printer, variables_);
145   format("$name$_.Clear();\n");
146 }
147 
GenerateMergingCode(io::Printer * printer) const148 void MapFieldGenerator::GenerateMergingCode(io::Printer* printer) const {
149   Formatter format(printer, variables_);
150   format("$name$_.MergeFrom(from.$name$_);\n");
151 }
152 
GenerateSwappingCode(io::Printer * printer) const153 void MapFieldGenerator::GenerateSwappingCode(io::Printer* printer) const {
154   Formatter format(printer, variables_);
155   format("$name$_.Swap(&other->$name$_);\n");
156 }
157 
GenerateCopyConstructorCode(io::Printer * printer) const158 void MapFieldGenerator::GenerateCopyConstructorCode(
159     io::Printer* printer) const {
160   GenerateConstructorCode(printer);
161   GenerateMergingCode(printer);
162 }
163 
GenerateMergeFromCodedStream(io::Printer * printer) const164 void MapFieldGenerator::GenerateMergeFromCodedStream(
165     io::Printer* printer) const {
166   Formatter format(printer, variables_);
167   const FieldDescriptor* key_field =
168       descriptor_->message_type()->FindFieldByName("key");
169   const FieldDescriptor* value_field =
170       descriptor_->message_type()->FindFieldByName("value");
171   std::string key;
172   std::string value;
173   format(
174       "$map_classname$::Parser< ::$proto_ns$::internal::MapField$lite$<\n"
175       "    $map_classname$,\n"
176       "    $key_cpp$, $val_cpp$,\n"
177       "    ::$proto_ns$::internal::WireFormatLite::$key_wire_type$,\n"
178       "    ::$proto_ns$::internal::WireFormatLite::$val_wire_type$,\n"
179       "    $default_enum_value$ >,\n"
180       "  ::$proto_ns$::Map< $key_cpp$, $val_cpp$ > >"
181       " parser(&$name$_);\n");
182   if (IsProto3Field(descriptor_) ||
183       value_field->type() != FieldDescriptor::TYPE_ENUM) {
184     format(
185         "DO_(::$proto_ns$::internal::WireFormatLite::ReadMessageNoVirtual(\n"
186         "    input, &parser));\n");
187     key = "parser.key()";
188     value = "parser.value()";
189   } else {
190     key = "entry->key()";
191     value = "entry->value()";
192     format("auto entry = parser.NewEntry();\n");
193     format(
194         "std::string data;\n"
195         "DO_(::$proto_ns$::internal::WireFormatLite::ReadString(input, "
196         "&data));\n"
197         "DO_(entry->ParseFromString(data));\n"
198         "if ($val_cpp$_IsValid(*entry->mutable_value())) {\n"
199         "  (*mutable_$name$())[entry->key()] =\n"
200         "      static_cast< $val_cpp$ >(*entry->mutable_value());\n"
201         "} else {\n");
202     if (HasDescriptorMethods(descriptor_->file(), options_)) {
203       format(
204           "  mutable_unknown_fields()"
205           "->AddLengthDelimited($number$, data);\n");
206     } else {
207       format(
208           "  unknown_fields_stream.WriteVarint32($tag$u);\n"
209           "  unknown_fields_stream.WriteVarint32(\n"
210           "      static_cast< ::google::protobuf::uint32>(data.size()));\n"
211           "  unknown_fields_stream.WriteString(data);\n");
212     }
213     format("}\n");
214   }
215 
216   if (key_field->type() == FieldDescriptor::TYPE_STRING) {
217     GenerateUtf8CheckCodeForString(
218         key_field, options_, true,
219         StrCat(key, ".data(), static_cast<int>(", key, ".length()),\n")
220             .data(),
221         format);
222   }
223   if (value_field->type() == FieldDescriptor::TYPE_STRING) {
224     GenerateUtf8CheckCodeForString(
225         value_field, options_, true,
226         StrCat(value, ".data(), static_cast<int>(", value,
227                      ".length()),\n")
228             .data(),
229         format);
230   }
231 }
232 
GenerateSerializationLoop(const Formatter & format,bool string_key,bool string_value,bool to_array,bool is_deterministic)233 static void GenerateSerializationLoop(const Formatter& format, bool string_key,
234                                       bool string_value, bool to_array,
235                                       bool is_deterministic) {
236   std::string ptr;
237   if (is_deterministic) {
238     format("for (size_type i = 0; i < n; i++) {\n");
239     ptr = string_key ? "items[static_cast<ptrdiff_t>(i)]"
240                      : "items[static_cast<ptrdiff_t>(i)].second";
241   } else {
242     format(
243         "for (::$proto_ns$::Map< $key_cpp$, $val_cpp$ >::const_iterator\n"
244         "    it = this->$name$().begin();\n"
245         "    it != this->$name$().end(); ++it) {\n");
246     ptr = "it";
247   }
248   format.Indent();
249 
250   if (to_array) {
251     format(
252         "target = $map_classname$::Funcs::SerializeToArray($number$, "
253         "$1$->first, $1$->second, target);\n",
254         ptr);
255   } else {
256     format(
257         "$map_classname$::Funcs::SerializeToCodedStream($number$, "
258         "$1$->first, $1$->second, output);\n",
259         ptr);
260   }
261 
262   if (string_key || string_value) {
263     // ptr is either an actual pointer or an iterator, either way we can
264     // create a pointer by taking the address after de-referencing it.
265     format("Utf8Check::Check(&(*$1$));\n", ptr);
266   }
267 
268   format.Outdent();
269   format("}\n");
270 }
271 
GenerateSerializeWithCachedSizes(io::Printer * printer) const272 void MapFieldGenerator::GenerateSerializeWithCachedSizes(
273     io::Printer* printer) const {
274   GenerateSerializeWithCachedSizes(printer, false);
275 }
276 
GenerateSerializeWithCachedSizesToArray(io::Printer * printer) const277 void MapFieldGenerator::GenerateSerializeWithCachedSizesToArray(
278     io::Printer* printer) const {
279   GenerateSerializeWithCachedSizes(printer, true);
280 }
281 
GenerateSerializeWithCachedSizes(io::Printer * printer,bool to_array) const282 void MapFieldGenerator::GenerateSerializeWithCachedSizes(io::Printer* printer,
283                                                          bool to_array) const {
284   Formatter format(printer, variables_);
285   format("if (!this->$name$().empty()) {\n");
286   format.Indent();
287   const FieldDescriptor* key_field =
288       descriptor_->message_type()->FindFieldByName("key");
289   const FieldDescriptor* value_field =
290       descriptor_->message_type()->FindFieldByName("value");
291   const bool string_key = key_field->type() == FieldDescriptor::TYPE_STRING;
292   const bool string_value = value_field->type() == FieldDescriptor::TYPE_STRING;
293 
294   format(
295       "typedef ::$proto_ns$::Map< $key_cpp$, $val_cpp$ >::const_pointer\n"
296       "    ConstPtr;\n");
297   if (string_key) {
298     format(
299         "typedef ConstPtr SortItem;\n"
300         "typedef ::$proto_ns$::internal::"
301         "CompareByDerefFirst<SortItem> Less;\n");
302   } else {
303     format(
304         "typedef ::$proto_ns$::internal::SortItem< $key_cpp$, ConstPtr > "
305         "SortItem;\n"
306         "typedef ::$proto_ns$::internal::CompareByFirstField<SortItem> "
307         "Less;\n");
308   }
309   bool utf8_check = string_key || string_value;
310   if (utf8_check) {
311     format(
312         "struct Utf8Check {\n"
313         "  static void Check(ConstPtr p) {\n");
314     format.Indent();
315     format.Indent();
316     if (string_key) {
317       GenerateUtf8CheckCodeForString(
318           key_field, options_, false,
319           "p->first.data(), static_cast<int>(p->first.length()),\n", format);
320     }
321     if (string_value) {
322       GenerateUtf8CheckCodeForString(
323           value_field, options_, false,
324           "p->second.data(), static_cast<int>(p->second.length()),\n", format);
325     }
326     format.Outdent();
327     format.Outdent();
328     format(
329         "  }\n"
330         "};\n");
331   }
332 
333   format(
334       "\n"
335       "if ($1$ &&\n"
336       "    this->$name$().size() > 1) {\n"
337       "  ::std::unique_ptr<SortItem[]> items(\n"
338       "      new SortItem[this->$name$().size()]);\n"
339       "  typedef ::$proto_ns$::Map< $key_cpp$, $val_cpp$ >::size_type "
340       "size_type;\n"
341       "  size_type n = 0;\n"
342       "  for (::$proto_ns$::Map< $key_cpp$, $val_cpp$ >::const_iterator\n"
343       "      it = this->$name$().begin();\n"
344       "      it != this->$name$().end(); ++it, ++n) {\n"
345       "    items[static_cast<ptrdiff_t>(n)] = SortItem(&*it);\n"
346       "  }\n"
347       "  ::std::sort(&items[0], &items[static_cast<ptrdiff_t>(n)], Less());\n",
348       to_array ? "false" : "output->IsSerializationDeterministic()");
349   format.Indent();
350   GenerateSerializationLoop(format, string_key, string_value, to_array, true);
351   format.Outdent();
352   format("} else {\n");
353   format.Indent();
354   GenerateSerializationLoop(format, string_key, string_value, to_array, false);
355   format.Outdent();
356   format("}\n");
357   format.Outdent();
358   format("}\n");
359 }
360 
GenerateByteSize(io::Printer * printer) const361 void MapFieldGenerator::GenerateByteSize(io::Printer* printer) const {
362   Formatter format(printer, variables_);
363   format(
364       "total_size += $tag_size$ *\n"
365       "    ::$proto_ns$::internal::FromIntSize(this->$name$_size());\n"
366       "for (::$proto_ns$::Map< $key_cpp$, $val_cpp$ >::const_iterator\n"
367       "    it = this->$name$().begin();\n"
368       "    it != this->$name$().end(); ++it) {\n"
369       "  total_size += $map_classname$::Funcs::ByteSizeLong(it->first, "
370       "it->second);\n"
371       "}\n");
372 }
373 
374 }  // namespace cpp
375 }  // namespace compiler
376 }  // namespace protobuf
377 }  // namespace google
378