• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Protocol Buffers - Google's data interchange format
2 // Copyright 2023 Google LLC.  All rights reserved.
3 //
4 // Use of this source code is governed by a BSD-style
5 // license that can be found in the LICENSE file or at
6 // https://developers.google.com/open-source/licenses/bsd
7 
8 /*
9  * lupb_Message -- Message/Array/Map objects in Lua/C that wrap upb
10  */
11 
12 #include <float.h>
13 #include <math.h>
14 #include <stddef.h>
15 #include <stdlib.h>
16 #include <string.h>
17 
18 #include "lauxlib.h"
19 #include "lua/upb.h"
20 #include "upb/json/decode.h"
21 #include "upb/json/encode.h"
22 #include "upb/message/map.h"
23 #include "upb/message/message.h"
24 #include "upb/port/def.inc"
25 #include "upb/reflection/message.h"
26 #include "upb/text/encode.h"
27 
28 /*
29  * Message/Map/Array objects.  These objects form a directed graph: a message
30  * can contain submessages, arrays, and maps, which can then point to other
31  * messages.  This graph can technically be cyclic, though this is an error and
32  * a cyclic graph cannot be serialized.  So it's better to think of this as a
33  * tree of objects.
34  *
35  * The actual data exists at the upb level (upb_Message, upb_Map, upb_Array),
36  * independently of Lua.  The upb objects contain all the canonical data and
37  * edges between objects.  Lua wrapper objects expose the upb objects to Lua,
38  * but ultimately they are just wrappers.  They pass through all reads and
39  * writes to the underlying upb objects.
40  *
41  * Each upb object lives in a upb arena.  We have a Lua object to wrap the upb
42  * arena, but arenas are never exposed to the user.  The Lua arena object just
43  * serves to own the upb arena and free it at the proper time, once the Lua GC
44  * has determined that there are no more references to anything that lives in
45  * that arena.  All wrapper objects strongly reference the arena to which they
46  * belong.
47  *
48  * A global object cache stores a mapping of C pointer (upb_Message*,
49  * upb_Array*, upb_Map*) to a corresponding Lua wrapper.  These references are
50  * weak so that the wrappers can be collected if they are no longer needed.  A
51  * new wrapper object can always be recreated later.
52  *
53  *                          +-----+
54  *            lupb_Arena    |cache|-weak-+
55  *                 |  ^     +-----+      |
56  *                 |  |                  V
57  * Lua level       |  +------------lupb_Message
58  * ----------------|-----------------|------------------------------------------
59  * upb level       |                 |
60  *                 |            +----V----------------------------------+
61  *                 +->upb_Arena | upb_Message  ...(empty arena storage) |
62  *                              +---------------------------------------+
63  *
64  * If the user creates a reference between two objects that have different
65  * arenas, we need to fuse the two arenas together, so that the blocks will
66  * outlive both arenas.
67  *
68  *                 +-------------------------->(fused)<----------------+
69  *                 |                                                   |
70  *                 V                           +-----+                 V
71  *            lupb_Arena                +-weak-|cache|-weak-+     lupb_Arena
72  *                 |  ^                 |      +-----+      |        ^  |
73  *                 |  |                 V                   V        |  |
74  * Lua level       |  +------------lupb_Message        lupb_Message--+  |
75  * ----------------|-----------------|----------------------|-----------|------
76  * upb level       |                 |                      |           |
77  *                 |            +----V--------+        +----V--------+  V
78  *                 +->upb_Arena | upb_Message |        | upb_Message | upb_Arena
79  *                              +------|------+        +--^----------+
80  *                                     +------------------+
81  * Key invariants:
82  *   1. every wrapper references the arena that contains it.
83  *   2. every fused arena includes all arenas that own upb objects reachable
84  *      from that arena.  In other words, when a wrapper references an arena,
85  *      this is sufficient to ensure that any upb object reachable from that
86  *      wrapper will stay alive.
87  *
88  * Additionally, every message object contains a strong reference to the
89  * corresponding Descriptor object.  Likewise, array/map objects reference a
90  * Descriptor object if they are typed to store message values.
91  */
92 
93 #define LUPB_ARENA "lupb.arena"
94 #define LUPB_ARRAY "lupb.array"
95 #define LUPB_MAP "lupb.map"
96 #define LUPB_MSG "lupb.msg"
97 
98 #define LUPB_ARENA_INDEX 1
99 #define LUPB_MSGDEF_INDEX 2 /* For msg, and map/array that store msg */
100 
101 static void lupb_Message_Newmsgwrapper(lua_State* L, int narg,
102                                        upb_MessageValue val);
103 static upb_Message* lupb_msg_check(lua_State* L, int narg);
104 
lupb_checkfieldtype(lua_State * L,int narg)105 static upb_CType lupb_checkfieldtype(lua_State* L, int narg) {
106   uint32_t n = lupb_checkuint32(L, narg);
107   bool ok = n >= kUpb_CType_Bool && n <= kUpb_CType_Bytes;
108   luaL_argcheck(L, ok, narg, "invalid field type");
109   return n;
110 }
111 
112 char cache_key;
113 
114 /* lupb_cacheinit()
115  *
116  * Creates the global cache used by lupb_cacheget() and lupb_cacheset().
117  */
lupb_cacheinit(lua_State * L)118 static void lupb_cacheinit(lua_State* L) {
119   /* Create our object cache. */
120   lua_newtable(L);
121 
122   /* Cache metatable gives the cache weak values */
123   lua_createtable(L, 0, 1);
124   lua_pushstring(L, "v");
125   lua_setfield(L, -2, "__mode");
126   lua_setmetatable(L, -2);
127 
128   /* Set cache in the registry. */
129   lua_rawsetp(L, LUA_REGISTRYINDEX, &cache_key);
130 }
131 
132 /* lupb_cacheget()
133  *
134  * Pushes cache[key] and returns true if this key is present in the cache.
135  * Otherwise returns false and leaves nothing on the stack.
136  */
lupb_cacheget(lua_State * L,const void * key)137 static bool lupb_cacheget(lua_State* L, const void* key) {
138   if (key == NULL) {
139     lua_pushnil(L);
140     return true;
141   }
142 
143   lua_rawgetp(L, LUA_REGISTRYINDEX, &cache_key);
144   lua_rawgetp(L, -1, key);
145   if (lua_isnil(L, -1)) {
146     lua_pop(L, 2); /* Pop table, nil. */
147     return false;
148   } else {
149     lua_replace(L, -2); /* Replace cache table. */
150     return true;
151   }
152 }
153 
154 /* lupb_cacheset()
155  *
156  * Sets cache[key] = val, where "val" is the value at the top of the stack.
157  * Does not pop the value.
158  */
lupb_cacheset(lua_State * L,const void * key)159 static void lupb_cacheset(lua_State* L, const void* key) {
160   lua_rawgetp(L, LUA_REGISTRYINDEX, &cache_key);
161   lua_pushvalue(L, -2);
162   lua_rawsetp(L, -2, key);
163   lua_pop(L, 1); /* Pop table. */
164 }
165 
166 /* lupb_Arena *****************************************************************/
167 
168 /* lupb_Arena only exists to wrap a upb_Arena.  It is never exposed to users; it
169  * is an internal memory management detail.  Other wrapper objects refer to this
170  * object from their userdata to keep the arena-owned data alive.
171  */
172 
173 typedef struct {
174   upb_Arena* arena;
175 } lupb_Arena;
176 
lupb_Arena_check(lua_State * L,int narg)177 static upb_Arena* lupb_Arena_check(lua_State* L, int narg) {
178   lupb_Arena* a = luaL_checkudata(L, narg, LUPB_ARENA);
179   return a->arena;
180 }
181 
lupb_Arena_pushnew(lua_State * L)182 upb_Arena* lupb_Arena_pushnew(lua_State* L) {
183   lupb_Arena* a = lupb_newuserdata(L, sizeof(lupb_Arena), 1, LUPB_ARENA);
184   a->arena = upb_Arena_New();
185   return a->arena;
186 }
187 
188 /**
189  * lupb_Arena_Fuse()
190  *
191  * Merges |from| into |to| so that there is a single arena group that contains
192  * both, and both arenas will point at this new table. */
lupb_Arena_Fuse(lua_State * L,int to,int from)193 static void lupb_Arena_Fuse(lua_State* L, int to, int from) {
194   upb_Arena* to_arena = lupb_Arena_check(L, to);
195   upb_Arena* from_arena = lupb_Arena_check(L, from);
196   upb_Arena_Fuse(to_arena, from_arena);
197 }
198 
lupb_Arena_Fuseobjs(lua_State * L,int to,int from)199 static void lupb_Arena_Fuseobjs(lua_State* L, int to, int from) {
200   lua_getiuservalue(L, to, LUPB_ARENA_INDEX);
201   lua_getiuservalue(L, from, LUPB_ARENA_INDEX);
202   lupb_Arena_Fuse(L, lua_absindex(L, -2), lua_absindex(L, -1));
203   lua_pop(L, 2);
204 }
205 
lupb_Arena_gc(lua_State * L)206 static int lupb_Arena_gc(lua_State* L) {
207   upb_Arena* a = lupb_Arena_check(L, 1);
208   upb_Arena_Free(a);
209   return 0;
210 }
211 
212 static const struct luaL_Reg lupb_Arena_mm[] = {{"__gc", lupb_Arena_gc},
213                                                 {NULL, NULL}};
214 
215 /* lupb_Arenaget()
216  *
217  * Returns the arena from the given message, array, or map object.
218  */
lupb_Arenaget(lua_State * L,int narg)219 static upb_Arena* lupb_Arenaget(lua_State* L, int narg) {
220   upb_Arena* arena;
221   lua_getiuservalue(L, narg, LUPB_ARENA_INDEX);
222   arena = lupb_Arena_check(L, -1);
223   lua_pop(L, 1);
224   return arena;
225 }
226 
227 /* upb <-> Lua type conversion ************************************************/
228 
229 /* Whether string data should be copied into the containing arena.  We can
230  * avoid a copy if the string data is only needed temporarily (like for a map
231  * lookup).
232  */
233 typedef enum {
234   LUPB_COPY, /* Copy string data into the arena. */
235   LUPB_REF   /* Reference the Lua copy of the string data. */
236 } lupb_copy_t;
237 
238 /**
239  * lupb_tomsgval()
240  *
241  * Converts the given Lua value |narg| to a upb_MessageValue.
242  */
lupb_tomsgval(lua_State * L,upb_CType type,int narg,int container,lupb_copy_t copy)243 static upb_MessageValue lupb_tomsgval(lua_State* L, upb_CType type, int narg,
244                                       int container, lupb_copy_t copy) {
245   upb_MessageValue ret;
246   switch (type) {
247     case kUpb_CType_Int32:
248     case kUpb_CType_Enum:
249       ret.int32_val = lupb_checkint32(L, narg);
250       break;
251     case kUpb_CType_Int64:
252       ret.int64_val = lupb_checkint64(L, narg);
253       break;
254     case kUpb_CType_UInt32:
255       ret.uint32_val = lupb_checkuint32(L, narg);
256       break;
257     case kUpb_CType_UInt64:
258       ret.uint64_val = lupb_checkuint64(L, narg);
259       break;
260     case kUpb_CType_Double:
261       ret.double_val = lupb_checkdouble(L, narg);
262       break;
263     case kUpb_CType_Float:
264       ret.float_val = lupb_checkfloat(L, narg);
265       break;
266     case kUpb_CType_Bool:
267       ret.bool_val = lupb_checkbool(L, narg);
268       break;
269     case kUpb_CType_String:
270     case kUpb_CType_Bytes: {
271       size_t len;
272       const char* ptr = lupb_checkstring(L, narg, &len);
273       switch (copy) {
274         case LUPB_COPY: {
275           upb_Arena* arena = lupb_Arenaget(L, container);
276           char* data = upb_Arena_Malloc(arena, len);
277           memcpy(data, ptr, len);
278           ret.str_val = upb_StringView_FromDataAndSize(data, len);
279           break;
280         }
281         case LUPB_REF:
282           ret.str_val = upb_StringView_FromDataAndSize(ptr, len);
283           break;
284       }
285       break;
286     }
287     case kUpb_CType_Message:
288       ret.msg_val = lupb_msg_check(L, narg);
289       /* Typecheck message. */
290       lua_getiuservalue(L, container, LUPB_MSGDEF_INDEX);
291       lua_getiuservalue(L, narg, LUPB_MSGDEF_INDEX);
292       luaL_argcheck(L, lua_rawequal(L, -1, -2), narg, "message type mismatch");
293       lua_pop(L, 2);
294       break;
295   }
296   return ret;
297 }
298 
lupb_pushmsgval(lua_State * L,int container,upb_CType type,upb_MessageValue val)299 void lupb_pushmsgval(lua_State* L, int container, upb_CType type,
300                      upb_MessageValue val) {
301   switch (type) {
302     case kUpb_CType_Int32:
303     case kUpb_CType_Enum:
304       lupb_pushint32(L, val.int32_val);
305       return;
306     case kUpb_CType_Int64:
307       lupb_pushint64(L, val.int64_val);
308       return;
309     case kUpb_CType_UInt32:
310       lupb_pushuint32(L, val.uint32_val);
311       return;
312     case kUpb_CType_UInt64:
313       lupb_pushuint64(L, val.uint64_val);
314       return;
315     case kUpb_CType_Double:
316       lua_pushnumber(L, val.double_val);
317       return;
318     case kUpb_CType_Float:
319       lua_pushnumber(L, val.float_val);
320       return;
321     case kUpb_CType_Bool:
322       lua_pushboolean(L, val.bool_val);
323       return;
324     case kUpb_CType_String:
325     case kUpb_CType_Bytes:
326       lua_pushlstring(L, val.str_val.data, val.str_val.size);
327       return;
328     case kUpb_CType_Message:
329       assert(container);
330       if (!lupb_cacheget(L, val.msg_val)) {
331         lupb_Message_Newmsgwrapper(L, container, val);
332       }
333       return;
334   }
335   LUPB_UNREACHABLE();
336 }
337 
338 /* lupb_array *****************************************************************/
339 
340 typedef struct {
341   upb_Array* arr;
342   upb_CType type;
343 } lupb_array;
344 
lupb_array_check(lua_State * L,int narg)345 static lupb_array* lupb_array_check(lua_State* L, int narg) {
346   return luaL_checkudata(L, narg, LUPB_ARRAY);
347 }
348 
349 /**
350  * lupb_array_checkindex()
351  *
352  * Checks the array index at Lua stack index |narg| to verify that it is an
353  * integer between 1 and |max|, inclusively.  Also corrects it to be zero-based
354  * for C.
355  */
lupb_array_checkindex(lua_State * L,int narg,uint32_t max)356 static int lupb_array_checkindex(lua_State* L, int narg, uint32_t max) {
357   uint32_t n = lupb_checkuint32(L, narg);
358   luaL_argcheck(L, n != 0 && n <= max, narg, "invalid array index");
359   return n - 1; /* Lua uses 1-based indexing. */
360 }
361 
362 /* lupb_array Public API */
363 
364 /* lupb_Array_New():
365  *
366  * Handles:
367  *   Array(upb.TYPE_INT32)
368  *   Array(message_type)
369  */
lupb_Array_New(lua_State * L)370 static int lupb_Array_New(lua_State* L) {
371   int arg_count = lua_gettop(L);
372   lupb_array* larray;
373   upb_Arena* arena;
374 
375   if (lua_type(L, 1) == LUA_TNUMBER) {
376     upb_CType type = lupb_checkfieldtype(L, 1);
377     larray = lupb_newuserdata(L, sizeof(*larray), 1, LUPB_ARRAY);
378     larray->type = type;
379   } else {
380     lupb_MessageDef_check(L, 1);
381     larray = lupb_newuserdata(L, sizeof(*larray), 2, LUPB_ARRAY);
382     larray->type = kUpb_CType_Message;
383     lua_pushvalue(L, 1);
384     lua_setiuservalue(L, -2, LUPB_MSGDEF_INDEX);
385   }
386 
387   arena = lupb_Arena_pushnew(L);
388   lua_setiuservalue(L, -2, LUPB_ARENA_INDEX);
389 
390   larray->arr = upb_Array_New(arena, larray->type);
391   lupb_cacheset(L, larray->arr);
392 
393   if (arg_count > 1) {
394     /* Set initial fields from table. */
395     int msg = arg_count + 1;
396     lua_pushnil(L);
397     while (lua_next(L, 2) != 0) {
398       lua_pushvalue(L, -2); /* now stack is key, val, key */
399       lua_insert(L, -3);    /* now stack is key, key, val */
400       lua_settable(L, msg);
401     }
402   }
403 
404   return 1;
405 }
406 
407 /* lupb_Array_Newindex():
408  *
409  * Handles:
410  *   array[idx] = val
411  *
412  * idx can be within the array or one past the end to extend.
413  */
lupb_Array_Newindex(lua_State * L)414 static int lupb_Array_Newindex(lua_State* L) {
415   lupb_array* larray = lupb_array_check(L, 1);
416   size_t size = upb_Array_Size(larray->arr);
417   uint32_t n = lupb_array_checkindex(L, 2, size + 1);
418   upb_MessageValue msgval = lupb_tomsgval(L, larray->type, 3, 1, LUPB_COPY);
419 
420   if (n == size) {
421     upb_Array_Append(larray->arr, msgval, lupb_Arenaget(L, 1));
422   } else {
423     upb_Array_Set(larray->arr, n, msgval);
424   }
425 
426   if (larray->type == kUpb_CType_Message) {
427     lupb_Arena_Fuseobjs(L, 1, 3);
428   }
429 
430   return 0; /* 1 for chained assignments? */
431 }
432 
433 /* lupb_array_index():
434  *
435  * Handles:
436  *   array[idx] -> val
437  *
438  * idx must be within the array.
439  */
lupb_array_index(lua_State * L)440 static int lupb_array_index(lua_State* L) {
441   lupb_array* larray = lupb_array_check(L, 1);
442   size_t size = upb_Array_Size(larray->arr);
443   uint32_t n = lupb_array_checkindex(L, 2, size);
444   upb_MessageValue val = upb_Array_Get(larray->arr, n);
445 
446   lupb_pushmsgval(L, 1, larray->type, val);
447 
448   return 1;
449 }
450 
451 /* lupb_array_len():
452  *
453  * Handles:
454  *   #array -> len
455  */
lupb_array_len(lua_State * L)456 static int lupb_array_len(lua_State* L) {
457   lupb_array* larray = lupb_array_check(L, 1);
458   lua_pushnumber(L, upb_Array_Size(larray->arr));
459   return 1;
460 }
461 
462 static const struct luaL_Reg lupb_array_mm[] = {
463     {"__index", lupb_array_index},
464     {"__len", lupb_array_len},
465     {"__newindex", lupb_Array_Newindex},
466     {NULL, NULL}};
467 
468 /* lupb_map *******************************************************************/
469 
470 typedef struct {
471   upb_Map* map;
472   upb_CType key_type;
473   upb_CType value_type;
474 } lupb_map;
475 
476 #define MAP_MSGDEF_INDEX 1
477 
lupb_map_check(lua_State * L,int narg)478 static lupb_map* lupb_map_check(lua_State* L, int narg) {
479   return luaL_checkudata(L, narg, LUPB_MAP);
480 }
481 
482 /* lupb_map Public API */
483 
484 /**
485  * lupb_Map_New
486  *
487  * Handles:
488  *   new_map = upb.Map(key_type, value_type)
489  *   new_map = upb.Map(key_type, value_msgdef)
490  */
lupb_Map_New(lua_State * L)491 static int lupb_Map_New(lua_State* L) {
492   upb_Arena* arena;
493   lupb_map* lmap;
494 
495   if (lua_type(L, 2) == LUA_TNUMBER) {
496     lmap = lupb_newuserdata(L, sizeof(*lmap), 1, LUPB_MAP);
497     lmap->value_type = lupb_checkfieldtype(L, 2);
498   } else {
499     lupb_MessageDef_check(L, 2);
500     lmap = lupb_newuserdata(L, sizeof(*lmap), 2, LUPB_MAP);
501     lmap->value_type = kUpb_CType_Message;
502     lua_pushvalue(L, 2);
503     lua_setiuservalue(L, -2, MAP_MSGDEF_INDEX);
504   }
505 
506   arena = lupb_Arena_pushnew(L);
507   lua_setiuservalue(L, -2, LUPB_ARENA_INDEX);
508 
509   lmap->key_type = lupb_checkfieldtype(L, 1);
510   lmap->map = upb_Map_New(arena, lmap->key_type, lmap->value_type);
511   lupb_cacheset(L, lmap->map);
512 
513   return 1;
514 }
515 
516 /**
517  * lupb_map_index
518  *
519  * Handles:
520  *   map[key]
521  */
lupb_map_index(lua_State * L)522 static int lupb_map_index(lua_State* L) {
523   lupb_map* lmap = lupb_map_check(L, 1);
524   upb_MessageValue key = lupb_tomsgval(L, lmap->key_type, 2, 1, LUPB_REF);
525   upb_MessageValue val;
526 
527   if (upb_Map_Get(lmap->map, key, &val)) {
528     lupb_pushmsgval(L, 1, lmap->value_type, val);
529   } else {
530     lua_pushnil(L);
531   }
532 
533   return 1;
534 }
535 
536 /**
537  * lupb_map_len
538  *
539  * Handles:
540  *   map_len = #map
541  */
lupb_map_len(lua_State * L)542 static int lupb_map_len(lua_State* L) {
543   lupb_map* lmap = lupb_map_check(L, 1);
544   lua_pushnumber(L, upb_Map_Size(lmap->map));
545   return 1;
546 }
547 
548 /**
549  * lupb_Map_Newindex
550  *
551  * Handles:
552  *   map[key] = val
553  *   map[key] = nil  # to remove from map
554  */
lupb_Map_Newindex(lua_State * L)555 static int lupb_Map_Newindex(lua_State* L) {
556   lupb_map* lmap = lupb_map_check(L, 1);
557   upb_Map* map = lmap->map;
558   upb_MessageValue key = lupb_tomsgval(L, lmap->key_type, 2, 1, LUPB_REF);
559 
560   if (lua_isnil(L, 3)) {
561     upb_Map_Delete(map, key, NULL);
562   } else {
563     upb_MessageValue val = lupb_tomsgval(L, lmap->value_type, 3, 1, LUPB_COPY);
564     upb_Map_Set(map, key, val, lupb_Arenaget(L, 1));
565     if (lmap->value_type == kUpb_CType_Message) {
566       lupb_Arena_Fuseobjs(L, 1, 3);
567     }
568   }
569 
570   return 0;
571 }
572 
lupb_MapIterator_Next(lua_State * L)573 static int lupb_MapIterator_Next(lua_State* L) {
574   int map = lua_upvalueindex(2);
575   size_t* iter = lua_touserdata(L, lua_upvalueindex(1));
576   lupb_map* lmap = lupb_map_check(L, map);
577 
578   upb_MessageValue key, val;
579   if (upb_Map_Next(lmap->map, &key, &val, iter)) {
580     lupb_pushmsgval(L, map, lmap->key_type, key);
581     lupb_pushmsgval(L, map, lmap->value_type, val);
582     return 2;
583   } else {
584     return 0;
585   }
586 }
587 
588 /**
589  * lupb_map_pairs()
590  *
591  * Handles:
592  *   pairs(map)
593  */
lupb_map_pairs(lua_State * L)594 static int lupb_map_pairs(lua_State* L) {
595   size_t* iter = lua_newuserdata(L, sizeof(*iter));
596   lupb_map_check(L, 1);
597 
598   *iter = kUpb_Map_Begin;
599   lua_pushvalue(L, 1);
600 
601   /* Upvalues are [iter, lupb_map]. */
602   lua_pushcclosure(L, &lupb_MapIterator_Next, 2);
603 
604   return 1;
605 }
606 
607 /* upb_mapiter ]]] */
608 
609 static const struct luaL_Reg lupb_map_mm[] = {{"__index", lupb_map_index},
610                                               {"__len", lupb_map_len},
611                                               {"__newindex", lupb_Map_Newindex},
612                                               {"__pairs", lupb_map_pairs},
613                                               {NULL, NULL}};
614 
615 /* lupb_Message
616  * *******************************************************************/
617 
618 typedef struct {
619   upb_Message* msg;
620 } lupb_Message;
621 
622 /* lupb_Message helpers */
623 
lupb_msg_check(lua_State * L,int narg)624 static upb_Message* lupb_msg_check(lua_State* L, int narg) {
625   lupb_Message* msg = luaL_checkudata(L, narg, LUPB_MSG);
626   return msg->msg;
627 }
628 
lupb_Message_Getmsgdef(lua_State * L,int msg)629 static const upb_MessageDef* lupb_Message_Getmsgdef(lua_State* L, int msg) {
630   lua_getiuservalue(L, msg, LUPB_MSGDEF_INDEX);
631   const upb_MessageDef* m = lupb_MessageDef_check(L, -1);
632   lua_pop(L, 1);
633   return m;
634 }
635 
lupb_msg_tofield(lua_State * L,int msg,int field)636 static const upb_FieldDef* lupb_msg_tofield(lua_State* L, int msg, int field) {
637   size_t len;
638   const char* fieldname = luaL_checklstring(L, field, &len);
639   const upb_MessageDef* m = lupb_Message_Getmsgdef(L, msg);
640   return upb_MessageDef_FindFieldByNameWithSize(m, fieldname, len);
641 }
642 
lupb_msg_checkfield(lua_State * L,int msg,int field)643 static const upb_FieldDef* lupb_msg_checkfield(lua_State* L, int msg,
644                                                int field) {
645   const upb_FieldDef* f = lupb_msg_tofield(L, msg, field);
646   if (f == NULL) {
647     luaL_error(L, "no such field '%s'", lua_tostring(L, field));
648   }
649   return f;
650 }
651 
lupb_msg_pushnew(lua_State * L,int narg)652 upb_Message* lupb_msg_pushnew(lua_State* L, int narg) {
653   const upb_MessageDef* m = lupb_MessageDef_check(L, narg);
654   lupb_Message* lmsg = lupb_newuserdata(L, sizeof(lupb_Message), 2, LUPB_MSG);
655   upb_Arena* arena = lupb_Arena_pushnew(L);
656 
657   lua_setiuservalue(L, -2, LUPB_ARENA_INDEX);
658   lua_pushvalue(L, 1);
659   lua_setiuservalue(L, -2, LUPB_MSGDEF_INDEX);
660 
661   lmsg->msg = upb_Message_New(upb_MessageDef_MiniTable(m), arena);
662   lupb_cacheset(L, lmsg->msg);
663   return lmsg->msg;
664 }
665 
666 /**
667  * lupb_Message_Newmsgwrapper()
668  *
669  * Creates a new wrapper for a message, copying the arena and msgdef references
670  * from |narg| (which should be an array or map).
671  */
lupb_Message_Newmsgwrapper(lua_State * L,int narg,upb_MessageValue val)672 static void lupb_Message_Newmsgwrapper(lua_State* L, int narg,
673                                        upb_MessageValue val) {
674   lupb_Message* lmsg = lupb_newuserdata(L, sizeof(*lmsg), 2, LUPB_MSG);
675   lmsg->msg = (upb_Message*)val.msg_val; /* XXX: cast isn't great. */
676   lupb_cacheset(L, lmsg->msg);
677 
678   /* Copy both arena and msgdef into the wrapper. */
679   lua_getiuservalue(L, narg, LUPB_ARENA_INDEX);
680   lua_setiuservalue(L, -2, LUPB_ARENA_INDEX);
681   lua_getiuservalue(L, narg, LUPB_MSGDEF_INDEX);
682   lua_setiuservalue(L, -2, LUPB_MSGDEF_INDEX);
683 }
684 
685 /**
686  * lupb_Message_Newud()
687  *
688  * Creates the Lua userdata for a new wrapper object, adding a reference to
689  * the msgdef if necessary.
690  */
lupb_Message_Newud(lua_State * L,int narg,size_t size,const char * type,const upb_FieldDef * f)691 static void* lupb_Message_Newud(lua_State* L, int narg, size_t size,
692                                 const char* type, const upb_FieldDef* f) {
693   if (upb_FieldDef_CType(f) == kUpb_CType_Message) {
694     /* Wrapper needs a reference to the msgdef. */
695     void* ud = lupb_newuserdata(L, size, 2, type);
696     lua_getiuservalue(L, narg, LUPB_MSGDEF_INDEX);
697     lupb_MessageDef_pushsubmsgdef(L, f);
698     lua_setiuservalue(L, -2, LUPB_MSGDEF_INDEX);
699     return ud;
700   } else {
701     return lupb_newuserdata(L, size, 1, type);
702   }
703 }
704 
705 /**
706  * lupb_Message_Newwrapper()
707  *
708  * Creates a new Lua wrapper object to wrap the given array, map, or message.
709  */
lupb_Message_Newwrapper(lua_State * L,int narg,const upb_FieldDef * f,upb_MutableMessageValue val)710 static void lupb_Message_Newwrapper(lua_State* L, int narg,
711                                     const upb_FieldDef* f,
712                                     upb_MutableMessageValue val) {
713   if (upb_FieldDef_IsMap(f)) {
714     const upb_MessageDef* entry = upb_FieldDef_MessageSubDef(f);
715     const upb_FieldDef* key_f =
716         upb_MessageDef_FindFieldByNumber(entry, kUpb_MapEntry_KeyFieldNumber);
717     const upb_FieldDef* val_f =
718         upb_MessageDef_FindFieldByNumber(entry, kUpb_MapEntry_ValueFieldNumber);
719     lupb_map* lmap =
720         lupb_Message_Newud(L, narg, sizeof(*lmap), LUPB_MAP, val_f);
721     lmap->key_type = upb_FieldDef_CType(key_f);
722     lmap->value_type = upb_FieldDef_CType(val_f);
723     lmap->map = val.map;
724   } else if (upb_FieldDef_IsRepeated(f)) {
725     lupb_array* larr =
726         lupb_Message_Newud(L, narg, sizeof(*larr), LUPB_ARRAY, f);
727     larr->type = upb_FieldDef_CType(f);
728     larr->arr = val.array;
729   } else {
730     lupb_Message* lmsg =
731         lupb_Message_Newud(L, narg, sizeof(*lmsg), LUPB_MSG, f);
732     lmsg->msg = val.msg;
733   }
734 
735   /* Copy arena ref to new wrapper.  This may be a different arena than the
736    * underlying data was originally constructed from, but if so both arenas
737    * must be in the same group. */
738   lua_getiuservalue(L, narg, LUPB_ARENA_INDEX);
739   lua_setiuservalue(L, -2, LUPB_ARENA_INDEX);
740 
741   lupb_cacheset(L, val.msg);
742 }
743 
744 /**
745  * lupb_msg_typechecksubmsg()
746  *
747  * Typechecks the given array, map, or msg against this upb_FieldDef.
748  */
lupb_msg_typechecksubmsg(lua_State * L,int narg,int msgarg,const upb_FieldDef * f)749 static void lupb_msg_typechecksubmsg(lua_State* L, int narg, int msgarg,
750                                      const upb_FieldDef* f) {
751   /* Typecheck this map's msgdef against this message field. */
752   lua_getiuservalue(L, narg, LUPB_MSGDEF_INDEX);
753   lua_getiuservalue(L, msgarg, LUPB_MSGDEF_INDEX);
754   lupb_MessageDef_pushsubmsgdef(L, f);
755   luaL_argcheck(L, lua_rawequal(L, -1, -2), narg, "message type mismatch");
756   lua_pop(L, 2);
757 }
758 
759 /* lupb_Message Public API */
760 
761 /**
762  * lupb_MessageDef_call
763  *
764  * Handles:
765  *   new_msg = MessageClass()
766  *   new_msg = MessageClass{foo = "bar", baz = 3, quux = {foo = 3}}
767  */
lupb_MessageDef_call(lua_State * L)768 int lupb_MessageDef_call(lua_State* L) {
769   int arg_count = lua_gettop(L);
770   lupb_msg_pushnew(L, 1);
771 
772   if (arg_count > 1) {
773     /* Set initial fields from table. */
774     int msg = arg_count + 1;
775     lua_pushnil(L);
776     while (lua_next(L, 2) != 0) {
777       lua_pushvalue(L, -2); /* now stack is key, val, key */
778       lua_insert(L, -3);    /* now stack is key, key, val */
779       lua_settable(L, msg);
780     }
781   }
782 
783   return 1;
784 }
785 
786 /**
787  * lupb_msg_index
788  *
789  * Handles:
790  *   msg.foo
791  *   msg["foo"]
792  *   msg[field_descriptor]  # (for extensions) (TODO)
793  */
lupb_msg_index(lua_State * L)794 static int lupb_msg_index(lua_State* L) {
795   upb_Message* msg = lupb_msg_check(L, 1);
796   const upb_FieldDef* f = lupb_msg_checkfield(L, 1, 2);
797 
798   if (upb_FieldDef_IsRepeated(f) || upb_FieldDef_IsSubMessage(f)) {
799     /* Wrapped type; get or create wrapper. */
800     upb_Arena* arena = upb_FieldDef_IsRepeated(f) ? lupb_Arenaget(L, 1) : NULL;
801     upb_MutableMessageValue val = upb_Message_Mutable(msg, f, arena);
802     if (!lupb_cacheget(L, val.msg)) {
803       lupb_Message_Newwrapper(L, 1, f, val);
804     }
805   } else {
806     /* Value type, just push value and return .*/
807     upb_MessageValue val = upb_Message_GetFieldByDef(msg, f);
808     lupb_pushmsgval(L, 0, upb_FieldDef_CType(f), val);
809   }
810 
811   return 1;
812 }
813 
814 /**
815  * lupb_Message_Newindex()
816  *
817  * Handles:
818  *   msg.foo = bar
819  *   msg["foo"] = bar
820  *   msg[field_descriptor] = bar  # (for extensions) (TODO)
821  */
lupb_Message_Newindex(lua_State * L)822 static int lupb_Message_Newindex(lua_State* L) {
823   upb_Message* msg = lupb_msg_check(L, 1);
824   const upb_FieldDef* f = lupb_msg_checkfield(L, 1, 2);
825   upb_MessageValue msgval;
826   bool merge_arenas = true;
827 
828   if (upb_FieldDef_IsMap(f)) {
829     lupb_map* lmap = lupb_map_check(L, 3);
830     const upb_MessageDef* entry = upb_FieldDef_MessageSubDef(f);
831     const upb_FieldDef* key_f =
832         upb_MessageDef_FindFieldByNumber(entry, kUpb_MapEntry_KeyFieldNumber);
833     const upb_FieldDef* val_f =
834         upb_MessageDef_FindFieldByNumber(entry, kUpb_MapEntry_ValueFieldNumber);
835     upb_CType key_type = upb_FieldDef_CType(key_f);
836     upb_CType value_type = upb_FieldDef_CType(val_f);
837     luaL_argcheck(L, lmap->key_type == key_type, 3, "key type mismatch");
838     luaL_argcheck(L, lmap->value_type == value_type, 3, "value type mismatch");
839     if (value_type == kUpb_CType_Message) {
840       lupb_msg_typechecksubmsg(L, 3, 1, val_f);
841     }
842     msgval.map_val = lmap->map;
843   } else if (upb_FieldDef_IsRepeated(f)) {
844     lupb_array* larr = lupb_array_check(L, 3);
845     upb_CType type = upb_FieldDef_CType(f);
846     luaL_argcheck(L, larr->type == type, 3, "array type mismatch");
847     if (type == kUpb_CType_Message) {
848       lupb_msg_typechecksubmsg(L, 3, 1, f);
849     }
850     msgval.array_val = larr->arr;
851   } else if (upb_FieldDef_IsSubMessage(f)) {
852     upb_Message* msg = lupb_msg_check(L, 3);
853     lupb_msg_typechecksubmsg(L, 3, 1, f);
854     msgval.msg_val = msg;
855   } else {
856     msgval = lupb_tomsgval(L, upb_FieldDef_CType(f), 3, 1, LUPB_COPY);
857     merge_arenas = false;
858   }
859 
860   if (merge_arenas) {
861     lupb_Arena_Fuseobjs(L, 1, 3);
862   }
863 
864   upb_Message_SetFieldByDef(msg, f, msgval, lupb_Arenaget(L, 1));
865 
866   /* Return the new value for chained assignments. */
867   lua_pushvalue(L, 3);
868   return 1;
869 }
870 
871 /**
872  * lupb_msg_tostring()
873  *
874  * Handles:
875  *   tostring(msg)
876  *   print(msg)
877  *   etc.
878  */
lupb_msg_tostring(lua_State * L)879 static int lupb_msg_tostring(lua_State* L) {
880   upb_Message* msg = lupb_msg_check(L, 1);
881   const upb_MessageDef* m;
882   char buf[1024];
883   size_t size;
884 
885   lua_getiuservalue(L, 1, LUPB_MSGDEF_INDEX);
886   m = lupb_MessageDef_check(L, -1);
887 
888   size = upb_TextEncode(msg, m, NULL, 0, buf, sizeof(buf));
889 
890   if (size < sizeof(buf)) {
891     lua_pushlstring(L, buf, size);
892   } else {
893     char* ptr = malloc(size + 1);
894     upb_TextEncode(msg, m, NULL, 0, ptr, size + 1);
895     lua_pushlstring(L, ptr, size);
896     free(ptr);
897   }
898 
899   return 1;
900 }
901 
902 static const struct luaL_Reg lupb_msg_mm[] = {
903     {"__index", lupb_msg_index},
904     {"__newindex", lupb_Message_Newindex},
905     {"__tostring", lupb_msg_tostring},
906     {NULL, NULL}};
907 
908 /* lupb_Message toplevel
909  * **********************************************************/
910 
lupb_getoptions(lua_State * L,int narg)911 static int lupb_getoptions(lua_State* L, int narg) {
912   int options = 0;
913   if (lua_gettop(L) >= narg) {
914     size_t len = lua_rawlen(L, narg);
915     for (size_t i = 1; i <= len; i++) {
916       lua_rawgeti(L, narg, i);
917       options |= lupb_checkuint32(L, -1);
918       lua_pop(L, 1);
919     }
920   }
921   return options;
922 }
923 
924 /**
925  * lupb_decode()
926  *
927  * Handles:
928  *   msg = upb.decode(MessageClass, bin_string)
929  */
lupb_decode(lua_State * L)930 static int lupb_decode(lua_State* L) {
931   size_t len;
932   const upb_MessageDef* m = lupb_MessageDef_check(L, 1);
933   const char* pb = lua_tolstring(L, 2, &len);
934   const upb_MiniTable* layout = upb_MessageDef_MiniTable(m);
935   upb_Message* msg = lupb_msg_pushnew(L, 1);
936   upb_Arena* arena = lupb_Arenaget(L, -1);
937   char* buf;
938 
939   /* Copy input data to arena, message will reference it. */
940   buf = upb_Arena_Malloc(arena, len);
941   memcpy(buf, pb, len);
942 
943   upb_DecodeStatus status = upb_Decode(buf, len, msg, layout, NULL,
944                                        kUpb_DecodeOption_AliasString, arena);
945 
946   if (status != kUpb_DecodeStatus_Ok) {
947     lua_pushstring(L, "Error decoding protobuf.");
948     return lua_error(L);
949   }
950 
951   return 1;
952 }
953 
954 /**
955  * lupb_Encode()
956  *
957  * Handles:
958  *   bin_string = upb.encode(msg)
959  */
lupb_Encode(lua_State * L)960 static int lupb_Encode(lua_State* L) {
961   const upb_Message* msg = lupb_msg_check(L, 1);
962   const upb_MessageDef* m = lupb_Message_Getmsgdef(L, 1);
963   const upb_MiniTable* layout = upb_MessageDef_MiniTable(m);
964   int options = lupb_getoptions(L, 2);
965   upb_Arena* arena = lupb_Arena_pushnew(L);
966   char* buf;
967   size_t size;
968   upb_EncodeStatus status =
969       upb_Encode(msg, (const void*)layout, options, arena, &buf, &size);
970   if (status != kUpb_EncodeStatus_Ok) {
971     lua_pushstring(L, "Error encoding protobuf.");
972     return lua_error(L);
973   }
974 
975   lua_pushlstring(L, buf, size);
976 
977   return 1;
978 }
979 
980 /**
981  * lupb_jsondecode()
982  *
983  * Handles:
984  *   text_string = upb.json_decode(MessageClass, json_str,
985  * {upb.JSONDEC_IGNOREUNKNOWN})
986  */
lupb_jsondecode(lua_State * L)987 static int lupb_jsondecode(lua_State* L) {
988   size_t len;
989   const upb_MessageDef* m = lupb_MessageDef_check(L, 1);
990   const char* json = lua_tolstring(L, 2, &len);
991   int options = lupb_getoptions(L, 3);
992   upb_Message* msg;
993   upb_Arena* arena;
994   upb_Status status;
995 
996   msg = lupb_msg_pushnew(L, 1);
997   arena = lupb_Arenaget(L, -1);
998   upb_Status_Clear(&status);
999   upb_JsonDecode(json, len, msg, m, NULL, options, arena, &status);
1000   lupb_checkstatus(L, &status);
1001 
1002   return 1;
1003 }
1004 
1005 /**
1006  * lupb_jsonencode()
1007  *
1008  * Handles:
1009  *   text_string = upb.json_encode(msg, {upb.JSONENC_EMITDEFAULTS})
1010  */
lupb_jsonencode(lua_State * L)1011 static int lupb_jsonencode(lua_State* L) {
1012   upb_Message* msg = lupb_msg_check(L, 1);
1013   const upb_MessageDef* m = lupb_Message_Getmsgdef(L, 1);
1014   int options = lupb_getoptions(L, 2);
1015   char buf[1024];
1016   size_t size;
1017   upb_Status status;
1018 
1019   upb_Status_Clear(&status);
1020   size = upb_JsonEncode(msg, m, NULL, options, buf, sizeof(buf), &status);
1021   lupb_checkstatus(L, &status);
1022 
1023   if (size < sizeof(buf)) {
1024     lua_pushlstring(L, buf, size);
1025   } else {
1026     char* ptr = malloc(size + 1);
1027     upb_JsonEncode(msg, m, NULL, options, ptr, size + 1, &status);
1028     lupb_checkstatus(L, &status);
1029     lua_pushlstring(L, ptr, size);
1030     free(ptr);
1031   }
1032 
1033   return 1;
1034 }
1035 
1036 /**
1037  * lupb_textencode()
1038  *
1039  * Handles:
1040  *   text_string = upb.text_encode(msg, {upb.TXTENC_SINGLELINE})
1041  */
lupb_textencode(lua_State * L)1042 static int lupb_textencode(lua_State* L) {
1043   upb_Message* msg = lupb_msg_check(L, 1);
1044   const upb_MessageDef* m = lupb_Message_Getmsgdef(L, 1);
1045   int options = lupb_getoptions(L, 2);
1046   char buf[1024];
1047   size_t size;
1048 
1049   size = upb_TextEncode(msg, m, NULL, options, buf, sizeof(buf));
1050 
1051   if (size < sizeof(buf)) {
1052     lua_pushlstring(L, buf, size);
1053   } else {
1054     char* ptr = malloc(size + 1);
1055     upb_TextEncode(msg, m, NULL, options, ptr, size + 1);
1056     lua_pushlstring(L, ptr, size);
1057     free(ptr);
1058   }
1059 
1060   return 1;
1061 }
1062 
lupb_setfieldi(lua_State * L,const char * field,int i)1063 static void lupb_setfieldi(lua_State* L, const char* field, int i) {
1064   lua_pushinteger(L, i);
1065   lua_setfield(L, -2, field);
1066 }
1067 
1068 static const struct luaL_Reg lupb_msg_toplevel_m[] = {
1069     {"Array", lupb_Array_New},        {"Map", lupb_Map_New},
1070     {"decode", lupb_decode},          {"encode", lupb_Encode},
1071     {"json_decode", lupb_jsondecode}, {"json_encode", lupb_jsonencode},
1072     {"text_encode", lupb_textencode}, {NULL, NULL}};
1073 
lupb_msg_registertypes(lua_State * L)1074 void lupb_msg_registertypes(lua_State* L) {
1075   lupb_setfuncs(L, lupb_msg_toplevel_m);
1076 
1077   lupb_register_type(L, LUPB_ARENA, NULL, lupb_Arena_mm);
1078   lupb_register_type(L, LUPB_ARRAY, NULL, lupb_array_mm);
1079   lupb_register_type(L, LUPB_MAP, NULL, lupb_map_mm);
1080   lupb_register_type(L, LUPB_MSG, NULL, lupb_msg_mm);
1081 
1082   lupb_setfieldi(L, "TXTENC_SINGLELINE", UPB_TXTENC_SINGLELINE);
1083   lupb_setfieldi(L, "TXTENC_SKIPUNKNOWN", UPB_TXTENC_SKIPUNKNOWN);
1084   lupb_setfieldi(L, "TXTENC_NOSORT", UPB_TXTENC_NOSORT);
1085 
1086   lupb_setfieldi(L, "ENCODE_DETERMINISTIC", kUpb_EncodeOption_Deterministic);
1087   lupb_setfieldi(L, "ENCODE_SKIPUNKNOWN", kUpb_EncodeOption_SkipUnknown);
1088 
1089   lupb_setfieldi(L, "JSONENC_EMITDEFAULTS", upb_JsonEncode_EmitDefaults);
1090   lupb_setfieldi(L, "JSONENC_PROTONAMES", upb_JsonEncode_UseProtoNames);
1091 
1092   lupb_setfieldi(L, "JSONDEC_IGNOREUNKNOWN", upb_JsonDecode_IgnoreUnknown);
1093 
1094   lupb_cacheinit(L);
1095 }
1096