• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  *
3  * Copyright 2015 gRPC authors.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *     http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  *
17  */
18 
19 #include <grpc/support/port_platform.h>
20 
21 #include "src/core/tsi/fake_transport_security.h"
22 
23 #include <stdlib.h>
24 #include <string.h>
25 
26 #include <grpc/support/alloc.h>
27 #include <grpc/support/log.h>
28 
29 #include "src/core/lib/gpr/useful.h"
30 #include "src/core/lib/slice/slice_internal.h"
31 #include "src/core/tsi/transport_security_grpc.h"
32 
33 /* --- Constants. ---*/
34 #define TSI_FAKE_FRAME_HEADER_SIZE 4
35 #define TSI_FAKE_FRAME_INITIAL_ALLOCATED_SIZE 64
36 #define TSI_FAKE_DEFAULT_FRAME_SIZE 16384
37 #define TSI_FAKE_HANDSHAKER_OUTGOING_BUFFER_INITIAL_SIZE 256
38 
39 /* --- Structure definitions. ---*/
40 
41 /* a frame is encoded like this:
42    | size |     data    |
43    where the size field value is the size of the size field plus the size of
44    the data encoded in little endian on 4 bytes.  */
45 typedef struct {
46   unsigned char* data;
47   size_t size;
48   size_t allocated_size;
49   size_t offset;
50   int needs_draining;
51 } tsi_fake_frame;
52 
53 typedef enum {
54   TSI_FAKE_CLIENT_INIT = 0,
55   TSI_FAKE_SERVER_INIT = 1,
56   TSI_FAKE_CLIENT_FINISHED = 2,
57   TSI_FAKE_SERVER_FINISHED = 3,
58   TSI_FAKE_HANDSHAKE_MESSAGE_MAX = 4
59 } tsi_fake_handshake_message;
60 
61 typedef struct {
62   tsi_handshaker base;
63   int is_client;
64   tsi_fake_handshake_message next_message_to_send;
65   int needs_incoming_message;
66   tsi_fake_frame incoming_frame;
67   tsi_fake_frame outgoing_frame;
68   unsigned char* outgoing_bytes_buffer;
69   size_t outgoing_bytes_buffer_size;
70   tsi_result result;
71 } tsi_fake_handshaker;
72 
73 typedef struct {
74   tsi_frame_protector base;
75   tsi_fake_frame protect_frame;
76   tsi_fake_frame unprotect_frame;
77   size_t max_frame_size;
78 } tsi_fake_frame_protector;
79 
80 typedef struct {
81   tsi_zero_copy_grpc_protector base;
82   grpc_slice_buffer header_sb;
83   grpc_slice_buffer protected_sb;
84   size_t max_frame_size;
85   size_t parsed_frame_size;
86 } tsi_fake_zero_copy_grpc_protector;
87 
88 /* --- Utils. ---*/
89 
90 static const char* tsi_fake_handshake_message_strings[] = {
91     "CLIENT_INIT", "SERVER_INIT", "CLIENT_FINISHED", "SERVER_FINISHED"};
92 
tsi_fake_handshake_message_to_string(int msg)93 static const char* tsi_fake_handshake_message_to_string(int msg) {
94   if (msg < 0 || msg >= TSI_FAKE_HANDSHAKE_MESSAGE_MAX) {
95     gpr_log(GPR_ERROR, "Invalid message %d", msg);
96     return "UNKNOWN";
97   }
98   return tsi_fake_handshake_message_strings[msg];
99 }
100 
tsi_fake_handshake_message_from_string(const char * msg_string,tsi_fake_handshake_message * msg)101 static tsi_result tsi_fake_handshake_message_from_string(
102     const char* msg_string, tsi_fake_handshake_message* msg) {
103   for (int i = 0; i < TSI_FAKE_HANDSHAKE_MESSAGE_MAX; i++) {
104     if (strncmp(msg_string, tsi_fake_handshake_message_strings[i],
105                 strlen(tsi_fake_handshake_message_strings[i])) == 0) {
106       *msg = static_cast<tsi_fake_handshake_message>(i);
107       return TSI_OK;
108     }
109   }
110   gpr_log(GPR_ERROR, "Invalid handshake message.");
111   return TSI_DATA_CORRUPTED;
112 }
113 
load32_little_endian(const unsigned char * buf)114 static uint32_t load32_little_endian(const unsigned char* buf) {
115   return (static_cast<uint32_t>(buf[0]) | static_cast<uint32_t>(buf[1] << 8) |
116           static_cast<uint32_t>(buf[2] << 16) |
117           static_cast<uint32_t>(buf[3] << 24));
118 }
119 
store32_little_endian(uint32_t value,unsigned char * buf)120 static void store32_little_endian(uint32_t value, unsigned char* buf) {
121   buf[3] = static_cast<unsigned char>((value >> 24) & 0xFF);
122   buf[2] = static_cast<unsigned char>((value >> 16) & 0xFF);
123   buf[1] = static_cast<unsigned char>((value >> 8) & 0xFF);
124   buf[0] = static_cast<unsigned char>((value)&0xFF);
125 }
126 
read_frame_size(const grpc_slice_buffer * sb)127 static uint32_t read_frame_size(const grpc_slice_buffer* sb) {
128   GPR_ASSERT(sb != nullptr && sb->length >= TSI_FAKE_FRAME_HEADER_SIZE);
129   uint8_t frame_size_buffer[TSI_FAKE_FRAME_HEADER_SIZE];
130   uint8_t* buf = frame_size_buffer;
131   /* Copies the first 4 bytes to a temporary buffer.  */
132   size_t remaining = TSI_FAKE_FRAME_HEADER_SIZE;
133   for (size_t i = 0; i < sb->count; i++) {
134     size_t slice_length = GRPC_SLICE_LENGTH(sb->slices[i]);
135     if (remaining <= slice_length) {
136       memcpy(buf, GRPC_SLICE_START_PTR(sb->slices[i]), remaining);
137       remaining = 0;
138       break;
139     } else {
140       memcpy(buf, GRPC_SLICE_START_PTR(sb->slices[i]), slice_length);
141       buf += slice_length;
142       remaining -= slice_length;
143     }
144   }
145   GPR_ASSERT(remaining == 0);
146   return load32_little_endian(frame_size_buffer);
147 }
148 
tsi_fake_frame_reset(tsi_fake_frame * frame,int needs_draining)149 static void tsi_fake_frame_reset(tsi_fake_frame* frame, int needs_draining) {
150   frame->offset = 0;
151   frame->needs_draining = needs_draining;
152   if (!needs_draining) frame->size = 0;
153 }
154 
155 /* Checks if the frame's allocated size is at least frame->size, and reallocs
156  * more memory if necessary. */
tsi_fake_frame_ensure_size(tsi_fake_frame * frame)157 static void tsi_fake_frame_ensure_size(tsi_fake_frame* frame) {
158   if (frame->data == nullptr) {
159     frame->allocated_size = frame->size;
160     frame->data =
161         static_cast<unsigned char*>(gpr_malloc(frame->allocated_size));
162   } else if (frame->size > frame->allocated_size) {
163     unsigned char* new_data =
164         static_cast<unsigned char*>(gpr_realloc(frame->data, frame->size));
165     frame->data = new_data;
166     frame->allocated_size = frame->size;
167   }
168 }
169 
170 /* Decodes the serialized fake frame contained in incoming_bytes, and fills
171  * frame with the contents of the decoded frame.
172  * This method should not be called if frame->needs_framing is not 0.  */
tsi_fake_frame_decode(const unsigned char * incoming_bytes,size_t * incoming_bytes_size,tsi_fake_frame * frame)173 static tsi_result tsi_fake_frame_decode(const unsigned char* incoming_bytes,
174                                         size_t* incoming_bytes_size,
175                                         tsi_fake_frame* frame) {
176   size_t available_size = *incoming_bytes_size;
177   size_t to_read_size = 0;
178   const unsigned char* bytes_cursor = incoming_bytes;
179 
180   if (frame->needs_draining) return TSI_INTERNAL_ERROR;
181   if (frame->data == nullptr) {
182     frame->allocated_size = TSI_FAKE_FRAME_INITIAL_ALLOCATED_SIZE;
183     frame->data =
184         static_cast<unsigned char*>(gpr_malloc(frame->allocated_size));
185   }
186 
187   if (frame->offset < TSI_FAKE_FRAME_HEADER_SIZE) {
188     to_read_size = TSI_FAKE_FRAME_HEADER_SIZE - frame->offset;
189     if (to_read_size > available_size) {
190       /* Just fill what we can and exit. */
191       memcpy(frame->data + frame->offset, bytes_cursor, available_size);
192       bytes_cursor += available_size;
193       frame->offset += available_size;
194       *incoming_bytes_size = static_cast<size_t>(bytes_cursor - incoming_bytes);
195       return TSI_INCOMPLETE_DATA;
196     }
197     memcpy(frame->data + frame->offset, bytes_cursor, to_read_size);
198     bytes_cursor += to_read_size;
199     frame->offset += to_read_size;
200     available_size -= to_read_size;
201     frame->size = load32_little_endian(frame->data);
202     tsi_fake_frame_ensure_size(frame);
203   }
204 
205   to_read_size = frame->size - frame->offset;
206   if (to_read_size > available_size) {
207     memcpy(frame->data + frame->offset, bytes_cursor, available_size);
208     frame->offset += available_size;
209     bytes_cursor += available_size;
210     *incoming_bytes_size = static_cast<size_t>(bytes_cursor - incoming_bytes);
211     return TSI_INCOMPLETE_DATA;
212   }
213   memcpy(frame->data + frame->offset, bytes_cursor, to_read_size);
214   bytes_cursor += to_read_size;
215   *incoming_bytes_size = static_cast<size_t>(bytes_cursor - incoming_bytes);
216   tsi_fake_frame_reset(frame, 1 /* needs_draining */);
217   return TSI_OK;
218 }
219 
220 /* Encodes a fake frame into its wire format and places the result in
221  * outgoing_bytes. outgoing_bytes_size indicates the size of the encoded frame.
222  * This method should not be called if frame->needs_framing is 0.  */
tsi_fake_frame_encode(unsigned char * outgoing_bytes,size_t * outgoing_bytes_size,tsi_fake_frame * frame)223 static tsi_result tsi_fake_frame_encode(unsigned char* outgoing_bytes,
224                                         size_t* outgoing_bytes_size,
225                                         tsi_fake_frame* frame) {
226   size_t to_write_size = frame->size - frame->offset;
227   if (!frame->needs_draining) return TSI_INTERNAL_ERROR;
228   if (*outgoing_bytes_size < to_write_size) {
229     memcpy(outgoing_bytes, frame->data + frame->offset, *outgoing_bytes_size);
230     frame->offset += *outgoing_bytes_size;
231     return TSI_INCOMPLETE_DATA;
232   }
233   memcpy(outgoing_bytes, frame->data + frame->offset, to_write_size);
234   *outgoing_bytes_size = to_write_size;
235   tsi_fake_frame_reset(frame, 0 /* needs_draining */);
236   return TSI_OK;
237 }
238 
239 /* Sets the payload of a fake frame to contain the given data blob, where
240  * data_size indicates the size of data. */
tsi_fake_frame_set_data(unsigned char * data,size_t data_size,tsi_fake_frame * frame)241 static tsi_result tsi_fake_frame_set_data(unsigned char* data, size_t data_size,
242                                           tsi_fake_frame* frame) {
243   frame->offset = 0;
244   frame->size = data_size + TSI_FAKE_FRAME_HEADER_SIZE;
245   tsi_fake_frame_ensure_size(frame);
246   store32_little_endian(static_cast<uint32_t>(frame->size), frame->data);
247   memcpy(frame->data + TSI_FAKE_FRAME_HEADER_SIZE, data, data_size);
248   tsi_fake_frame_reset(frame, 1 /* needs draining */);
249   return TSI_OK;
250 }
251 
252 /* Destroys the contents of a fake frame. */
tsi_fake_frame_destruct(tsi_fake_frame * frame)253 static void tsi_fake_frame_destruct(tsi_fake_frame* frame) {
254   if (frame->data != nullptr) gpr_free(frame->data);
255 }
256 
257 /* --- tsi_frame_protector methods implementation. ---*/
258 
fake_protector_protect(tsi_frame_protector * self,const unsigned char * unprotected_bytes,size_t * unprotected_bytes_size,unsigned char * protected_output_frames,size_t * protected_output_frames_size)259 static tsi_result fake_protector_protect(tsi_frame_protector* self,
260                                          const unsigned char* unprotected_bytes,
261                                          size_t* unprotected_bytes_size,
262                                          unsigned char* protected_output_frames,
263                                          size_t* protected_output_frames_size) {
264   tsi_result result = TSI_OK;
265   tsi_fake_frame_protector* impl =
266       reinterpret_cast<tsi_fake_frame_protector*>(self);
267   unsigned char frame_header[TSI_FAKE_FRAME_HEADER_SIZE];
268   tsi_fake_frame* frame = &impl->protect_frame;
269   size_t saved_output_size = *protected_output_frames_size;
270   size_t drained_size = 0;
271   size_t* num_bytes_written = protected_output_frames_size;
272   *num_bytes_written = 0;
273 
274   /* Try to drain first. */
275   if (frame->needs_draining) {
276     drained_size = saved_output_size - *num_bytes_written;
277     result =
278         tsi_fake_frame_encode(protected_output_frames, &drained_size, frame);
279     *num_bytes_written += drained_size;
280     protected_output_frames += drained_size;
281     if (result != TSI_OK) {
282       if (result == TSI_INCOMPLETE_DATA) {
283         *unprotected_bytes_size = 0;
284         result = TSI_OK;
285       }
286       return result;
287     }
288   }
289 
290   /* Now process the unprotected_bytes. */
291   if (frame->needs_draining) return TSI_INTERNAL_ERROR;
292   if (frame->size == 0) {
293     /* New frame, create a header. */
294     size_t written_in_frame_size = 0;
295     store32_little_endian(static_cast<uint32_t>(impl->max_frame_size),
296                           frame_header);
297     written_in_frame_size = TSI_FAKE_FRAME_HEADER_SIZE;
298     result = tsi_fake_frame_decode(frame_header, &written_in_frame_size, frame);
299     if (result != TSI_INCOMPLETE_DATA) {
300       gpr_log(GPR_ERROR, "tsi_fake_frame_decode returned %s",
301               tsi_result_to_string(result));
302       return result;
303     }
304   }
305   result =
306       tsi_fake_frame_decode(unprotected_bytes, unprotected_bytes_size, frame);
307   if (result != TSI_OK) {
308     if (result == TSI_INCOMPLETE_DATA) result = TSI_OK;
309     return result;
310   }
311 
312   /* Try to drain again. */
313   if (!frame->needs_draining) return TSI_INTERNAL_ERROR;
314   if (frame->offset != 0) return TSI_INTERNAL_ERROR;
315   drained_size = saved_output_size - *num_bytes_written;
316   result = tsi_fake_frame_encode(protected_output_frames, &drained_size, frame);
317   *num_bytes_written += drained_size;
318   if (result == TSI_INCOMPLETE_DATA) result = TSI_OK;
319   return result;
320 }
321 
fake_protector_protect_flush(tsi_frame_protector * self,unsigned char * protected_output_frames,size_t * protected_output_frames_size,size_t * still_pending_size)322 static tsi_result fake_protector_protect_flush(
323     tsi_frame_protector* self, unsigned char* protected_output_frames,
324     size_t* protected_output_frames_size, size_t* still_pending_size) {
325   tsi_result result = TSI_OK;
326   tsi_fake_frame_protector* impl =
327       reinterpret_cast<tsi_fake_frame_protector*>(self);
328   tsi_fake_frame* frame = &impl->protect_frame;
329   if (!frame->needs_draining) {
330     /* Create a short frame. */
331     frame->size = frame->offset;
332     frame->offset = 0;
333     frame->needs_draining = 1;
334     store32_little_endian(static_cast<uint32_t>(frame->size),
335                           frame->data); /* Overwrite header. */
336   }
337   result = tsi_fake_frame_encode(protected_output_frames,
338                                  protected_output_frames_size, frame);
339   if (result == TSI_INCOMPLETE_DATA) result = TSI_OK;
340   *still_pending_size = frame->size - frame->offset;
341   return result;
342 }
343 
fake_protector_unprotect(tsi_frame_protector * self,const unsigned char * protected_frames_bytes,size_t * protected_frames_bytes_size,unsigned char * unprotected_bytes,size_t * unprotected_bytes_size)344 static tsi_result fake_protector_unprotect(
345     tsi_frame_protector* self, const unsigned char* protected_frames_bytes,
346     size_t* protected_frames_bytes_size, unsigned char* unprotected_bytes,
347     size_t* unprotected_bytes_size) {
348   tsi_result result = TSI_OK;
349   tsi_fake_frame_protector* impl =
350       reinterpret_cast<tsi_fake_frame_protector*>(self);
351   tsi_fake_frame* frame = &impl->unprotect_frame;
352   size_t saved_output_size = *unprotected_bytes_size;
353   size_t drained_size = 0;
354   size_t* num_bytes_written = unprotected_bytes_size;
355   *num_bytes_written = 0;
356 
357   /* Try to drain first. */
358   if (frame->needs_draining) {
359     /* Go past the header if needed. */
360     if (frame->offset == 0) frame->offset = TSI_FAKE_FRAME_HEADER_SIZE;
361     drained_size = saved_output_size - *num_bytes_written;
362     result = tsi_fake_frame_encode(unprotected_bytes, &drained_size, frame);
363     unprotected_bytes += drained_size;
364     *num_bytes_written += drained_size;
365     if (result != TSI_OK) {
366       if (result == TSI_INCOMPLETE_DATA) {
367         *protected_frames_bytes_size = 0;
368         result = TSI_OK;
369       }
370       return result;
371     }
372   }
373 
374   /* Now process the protected_bytes. */
375   if (frame->needs_draining) return TSI_INTERNAL_ERROR;
376   result = tsi_fake_frame_decode(protected_frames_bytes,
377                                  protected_frames_bytes_size, frame);
378   if (result != TSI_OK) {
379     if (result == TSI_INCOMPLETE_DATA) result = TSI_OK;
380     return result;
381   }
382 
383   /* Try to drain again. */
384   if (!frame->needs_draining) return TSI_INTERNAL_ERROR;
385   if (frame->offset != 0) return TSI_INTERNAL_ERROR;
386   frame->offset = TSI_FAKE_FRAME_HEADER_SIZE; /* Go past the header. */
387   drained_size = saved_output_size - *num_bytes_written;
388   result = tsi_fake_frame_encode(unprotected_bytes, &drained_size, frame);
389   *num_bytes_written += drained_size;
390   if (result == TSI_INCOMPLETE_DATA) result = TSI_OK;
391   return result;
392 }
393 
fake_protector_destroy(tsi_frame_protector * self)394 static void fake_protector_destroy(tsi_frame_protector* self) {
395   tsi_fake_frame_protector* impl =
396       reinterpret_cast<tsi_fake_frame_protector*>(self);
397   tsi_fake_frame_destruct(&impl->protect_frame);
398   tsi_fake_frame_destruct(&impl->unprotect_frame);
399   gpr_free(self);
400 }
401 
402 static const tsi_frame_protector_vtable frame_protector_vtable = {
403     fake_protector_protect,
404     fake_protector_protect_flush,
405     fake_protector_unprotect,
406     fake_protector_destroy,
407 };
408 
409 /* --- tsi_zero_copy_grpc_protector methods implementation. ---*/
410 
fake_zero_copy_grpc_protector_protect(tsi_zero_copy_grpc_protector * self,grpc_slice_buffer * unprotected_slices,grpc_slice_buffer * protected_slices)411 static tsi_result fake_zero_copy_grpc_protector_protect(
412     tsi_zero_copy_grpc_protector* self, grpc_slice_buffer* unprotected_slices,
413     grpc_slice_buffer* protected_slices) {
414   if (self == nullptr || unprotected_slices == nullptr ||
415       protected_slices == nullptr) {
416     return TSI_INVALID_ARGUMENT;
417   }
418   tsi_fake_zero_copy_grpc_protector* impl =
419       reinterpret_cast<tsi_fake_zero_copy_grpc_protector*>(self);
420   /* Protects each frame. */
421   while (unprotected_slices->length > 0) {
422     size_t frame_length =
423         GPR_MIN(impl->max_frame_size,
424                 unprotected_slices->length + TSI_FAKE_FRAME_HEADER_SIZE);
425     grpc_slice slice = GRPC_SLICE_MALLOC(TSI_FAKE_FRAME_HEADER_SIZE);
426     store32_little_endian(static_cast<uint32_t>(frame_length),
427                           GRPC_SLICE_START_PTR(slice));
428     grpc_slice_buffer_add(protected_slices, slice);
429     size_t data_length = frame_length - TSI_FAKE_FRAME_HEADER_SIZE;
430     grpc_slice_buffer_move_first(unprotected_slices, data_length,
431                                  protected_slices);
432   }
433   return TSI_OK;
434 }
435 
fake_zero_copy_grpc_protector_unprotect(tsi_zero_copy_grpc_protector * self,grpc_slice_buffer * protected_slices,grpc_slice_buffer * unprotected_slices)436 static tsi_result fake_zero_copy_grpc_protector_unprotect(
437     tsi_zero_copy_grpc_protector* self, grpc_slice_buffer* protected_slices,
438     grpc_slice_buffer* unprotected_slices) {
439   if (self == nullptr || unprotected_slices == nullptr ||
440       protected_slices == nullptr) {
441     return TSI_INVALID_ARGUMENT;
442   }
443   tsi_fake_zero_copy_grpc_protector* impl =
444       reinterpret_cast<tsi_fake_zero_copy_grpc_protector*>(self);
445   grpc_slice_buffer_move_into(protected_slices, &impl->protected_sb);
446   /* Unprotect each frame, if we get a full frame. */
447   while (impl->protected_sb.length >= TSI_FAKE_FRAME_HEADER_SIZE) {
448     if (impl->parsed_frame_size == 0) {
449       impl->parsed_frame_size = read_frame_size(&impl->protected_sb);
450       if (impl->parsed_frame_size <= 4) {
451         gpr_log(GPR_ERROR, "Invalid frame size.");
452         return TSI_DATA_CORRUPTED;
453       }
454     }
455     /* If we do not have a full frame, return with OK status. */
456     if (impl->protected_sb.length < impl->parsed_frame_size) break;
457     /* Strips header bytes. */
458     grpc_slice_buffer_move_first(&impl->protected_sb,
459                                  TSI_FAKE_FRAME_HEADER_SIZE, &impl->header_sb);
460     /* Moves data to unprotected slices. */
461     grpc_slice_buffer_move_first(
462         &impl->protected_sb,
463         impl->parsed_frame_size - TSI_FAKE_FRAME_HEADER_SIZE,
464         unprotected_slices);
465     impl->parsed_frame_size = 0;
466     grpc_slice_buffer_reset_and_unref_internal(&impl->header_sb);
467   }
468   return TSI_OK;
469 }
470 
fake_zero_copy_grpc_protector_destroy(tsi_zero_copy_grpc_protector * self)471 static void fake_zero_copy_grpc_protector_destroy(
472     tsi_zero_copy_grpc_protector* self) {
473   if (self == nullptr) return;
474   tsi_fake_zero_copy_grpc_protector* impl =
475       reinterpret_cast<tsi_fake_zero_copy_grpc_protector*>(self);
476   grpc_slice_buffer_destroy_internal(&impl->header_sb);
477   grpc_slice_buffer_destroy_internal(&impl->protected_sb);
478   gpr_free(impl);
479 }
480 
481 static const tsi_zero_copy_grpc_protector_vtable
482     zero_copy_grpc_protector_vtable = {
483         fake_zero_copy_grpc_protector_protect,
484         fake_zero_copy_grpc_protector_unprotect,
485         fake_zero_copy_grpc_protector_destroy,
486 };
487 
488 /* --- tsi_handshaker_result methods implementation. ---*/
489 
490 typedef struct {
491   tsi_handshaker_result base;
492   unsigned char* unused_bytes;
493   size_t unused_bytes_size;
494 } fake_handshaker_result;
495 
fake_handshaker_result_extract_peer(const tsi_handshaker_result * self,tsi_peer * peer)496 static tsi_result fake_handshaker_result_extract_peer(
497     const tsi_handshaker_result* self, tsi_peer* peer) {
498   /* Construct a tsi_peer with 1 property: certificate type. */
499   tsi_result result = tsi_construct_peer(1, peer);
500   if (result != TSI_OK) return result;
501   result = tsi_construct_string_peer_property_from_cstring(
502       TSI_CERTIFICATE_TYPE_PEER_PROPERTY, TSI_FAKE_CERTIFICATE_TYPE,
503       &peer->properties[0]);
504   if (result != TSI_OK) tsi_peer_destruct(peer);
505   return result;
506 }
507 
fake_handshaker_result_create_zero_copy_grpc_protector(const tsi_handshaker_result * self,size_t * max_output_protected_frame_size,tsi_zero_copy_grpc_protector ** protector)508 static tsi_result fake_handshaker_result_create_zero_copy_grpc_protector(
509     const tsi_handshaker_result* self, size_t* max_output_protected_frame_size,
510     tsi_zero_copy_grpc_protector** protector) {
511   *protector =
512       tsi_create_fake_zero_copy_grpc_protector(max_output_protected_frame_size);
513   return TSI_OK;
514 }
515 
fake_handshaker_result_create_frame_protector(const tsi_handshaker_result * self,size_t * max_output_protected_frame_size,tsi_frame_protector ** protector)516 static tsi_result fake_handshaker_result_create_frame_protector(
517     const tsi_handshaker_result* self, size_t* max_output_protected_frame_size,
518     tsi_frame_protector** protector) {
519   *protector = tsi_create_fake_frame_protector(max_output_protected_frame_size);
520   return TSI_OK;
521 }
522 
fake_handshaker_result_get_unused_bytes(const tsi_handshaker_result * self,const unsigned char ** bytes,size_t * bytes_size)523 static tsi_result fake_handshaker_result_get_unused_bytes(
524     const tsi_handshaker_result* self, const unsigned char** bytes,
525     size_t* bytes_size) {
526   fake_handshaker_result* result = (fake_handshaker_result*)self;
527   *bytes_size = result->unused_bytes_size;
528   *bytes = result->unused_bytes;
529   return TSI_OK;
530 }
531 
fake_handshaker_result_destroy(tsi_handshaker_result * self)532 static void fake_handshaker_result_destroy(tsi_handshaker_result* self) {
533   fake_handshaker_result* result =
534       reinterpret_cast<fake_handshaker_result*>(self);
535   gpr_free(result->unused_bytes);
536   gpr_free(self);
537 }
538 
539 static const tsi_handshaker_result_vtable handshaker_result_vtable = {
540     fake_handshaker_result_extract_peer,
541     fake_handshaker_result_create_zero_copy_grpc_protector,
542     fake_handshaker_result_create_frame_protector,
543     fake_handshaker_result_get_unused_bytes,
544     fake_handshaker_result_destroy,
545 };
546 
fake_handshaker_result_create(const unsigned char * unused_bytes,size_t unused_bytes_size,tsi_handshaker_result ** handshaker_result)547 static tsi_result fake_handshaker_result_create(
548     const unsigned char* unused_bytes, size_t unused_bytes_size,
549     tsi_handshaker_result** handshaker_result) {
550   if ((unused_bytes_size > 0 && unused_bytes == nullptr) ||
551       handshaker_result == nullptr) {
552     return TSI_INVALID_ARGUMENT;
553   }
554   fake_handshaker_result* result =
555       static_cast<fake_handshaker_result*>(gpr_zalloc(sizeof(*result)));
556   result->base.vtable = &handshaker_result_vtable;
557   if (unused_bytes_size > 0) {
558     result->unused_bytes =
559         static_cast<unsigned char*>(gpr_malloc(unused_bytes_size));
560     memcpy(result->unused_bytes, unused_bytes, unused_bytes_size);
561   }
562   result->unused_bytes_size = unused_bytes_size;
563   *handshaker_result = &result->base;
564   return TSI_OK;
565 }
566 
567 /* --- tsi_handshaker methods implementation. ---*/
568 
fake_handshaker_get_bytes_to_send_to_peer(tsi_handshaker * self,unsigned char * bytes,size_t * bytes_size)569 static tsi_result fake_handshaker_get_bytes_to_send_to_peer(
570     tsi_handshaker* self, unsigned char* bytes, size_t* bytes_size) {
571   tsi_fake_handshaker* impl = reinterpret_cast<tsi_fake_handshaker*>(self);
572   tsi_result result = TSI_OK;
573   if (impl->needs_incoming_message || impl->result == TSI_OK) {
574     *bytes_size = 0;
575     return TSI_OK;
576   }
577   if (!impl->outgoing_frame.needs_draining) {
578     tsi_fake_handshake_message next_message_to_send =
579         static_cast<tsi_fake_handshake_message>(impl->next_message_to_send + 2);
580     const char* msg_string =
581         tsi_fake_handshake_message_to_string(impl->next_message_to_send);
582     result = tsi_fake_frame_set_data((unsigned char*)msg_string,
583                                      strlen(msg_string), &impl->outgoing_frame);
584     if (result != TSI_OK) return result;
585     if (next_message_to_send > TSI_FAKE_HANDSHAKE_MESSAGE_MAX) {
586       next_message_to_send = TSI_FAKE_HANDSHAKE_MESSAGE_MAX;
587     }
588     if (tsi_tracing_enabled.enabled()) {
589       gpr_log(GPR_INFO, "%s prepared %s.",
590               impl->is_client ? "Client" : "Server",
591               tsi_fake_handshake_message_to_string(impl->next_message_to_send));
592     }
593     impl->next_message_to_send = next_message_to_send;
594   }
595   result = tsi_fake_frame_encode(bytes, bytes_size, &impl->outgoing_frame);
596   if (result != TSI_OK) return result;
597   if (!impl->is_client &&
598       impl->next_message_to_send == TSI_FAKE_HANDSHAKE_MESSAGE_MAX) {
599     /* We're done. */
600     if (tsi_tracing_enabled.enabled()) {
601       gpr_log(GPR_INFO, "Server is done.");
602     }
603     impl->result = TSI_OK;
604   } else {
605     impl->needs_incoming_message = 1;
606   }
607   return TSI_OK;
608 }
609 
fake_handshaker_process_bytes_from_peer(tsi_handshaker * self,const unsigned char * bytes,size_t * bytes_size)610 static tsi_result fake_handshaker_process_bytes_from_peer(
611     tsi_handshaker* self, const unsigned char* bytes, size_t* bytes_size) {
612   tsi_result result = TSI_OK;
613   tsi_fake_handshaker* impl = reinterpret_cast<tsi_fake_handshaker*>(self);
614   tsi_fake_handshake_message expected_msg =
615       static_cast<tsi_fake_handshake_message>(impl->next_message_to_send - 1);
616   tsi_fake_handshake_message received_msg;
617 
618   if (!impl->needs_incoming_message || impl->result == TSI_OK) {
619     *bytes_size = 0;
620     return TSI_OK;
621   }
622   result = tsi_fake_frame_decode(bytes, bytes_size, &impl->incoming_frame);
623   if (result != TSI_OK) return result;
624 
625   /* We now have a complete frame. */
626   result = tsi_fake_handshake_message_from_string(
627       reinterpret_cast<const char*>(impl->incoming_frame.data) +
628           TSI_FAKE_FRAME_HEADER_SIZE,
629       &received_msg);
630   if (result != TSI_OK) {
631     impl->result = result;
632     return result;
633   }
634   if (received_msg != expected_msg) {
635     gpr_log(GPR_ERROR, "Invalid received message (%s instead of %s)",
636             tsi_fake_handshake_message_to_string(received_msg),
637             tsi_fake_handshake_message_to_string(expected_msg));
638   }
639   if (tsi_tracing_enabled.enabled()) {
640     gpr_log(GPR_INFO, "%s received %s.", impl->is_client ? "Client" : "Server",
641             tsi_fake_handshake_message_to_string(received_msg));
642   }
643   tsi_fake_frame_reset(&impl->incoming_frame, 0 /* needs_draining */);
644   impl->needs_incoming_message = 0;
645   if (impl->next_message_to_send == TSI_FAKE_HANDSHAKE_MESSAGE_MAX) {
646     /* We're done. */
647     if (tsi_tracing_enabled.enabled()) {
648       gpr_log(GPR_INFO, "%s is done.", impl->is_client ? "Client" : "Server");
649     }
650     impl->result = TSI_OK;
651   }
652   return TSI_OK;
653 }
654 
fake_handshaker_get_result(tsi_handshaker * self)655 static tsi_result fake_handshaker_get_result(tsi_handshaker* self) {
656   tsi_fake_handshaker* impl = reinterpret_cast<tsi_fake_handshaker*>(self);
657   return impl->result;
658 }
659 
fake_handshaker_destroy(tsi_handshaker * self)660 static void fake_handshaker_destroy(tsi_handshaker* self) {
661   tsi_fake_handshaker* impl = reinterpret_cast<tsi_fake_handshaker*>(self);
662   tsi_fake_frame_destruct(&impl->incoming_frame);
663   tsi_fake_frame_destruct(&impl->outgoing_frame);
664   gpr_free(impl->outgoing_bytes_buffer);
665   gpr_free(self);
666 }
667 
fake_handshaker_next(tsi_handshaker * self,const unsigned char * received_bytes,size_t received_bytes_size,const unsigned char ** bytes_to_send,size_t * bytes_to_send_size,tsi_handshaker_result ** handshaker_result,tsi_handshaker_on_next_done_cb cb,void * user_data)668 static tsi_result fake_handshaker_next(
669     tsi_handshaker* self, const unsigned char* received_bytes,
670     size_t received_bytes_size, const unsigned char** bytes_to_send,
671     size_t* bytes_to_send_size, tsi_handshaker_result** handshaker_result,
672     tsi_handshaker_on_next_done_cb cb, void* user_data) {
673   /* Sanity check the arguments. */
674   if ((received_bytes_size > 0 && received_bytes == nullptr) ||
675       bytes_to_send == nullptr || bytes_to_send_size == nullptr ||
676       handshaker_result == nullptr) {
677     return TSI_INVALID_ARGUMENT;
678   }
679   tsi_fake_handshaker* handshaker =
680       reinterpret_cast<tsi_fake_handshaker*>(self);
681   tsi_result result = TSI_OK;
682 
683   /* Decode and process a handshake frame from the peer. */
684   size_t consumed_bytes_size = received_bytes_size;
685   if (received_bytes_size > 0) {
686     result = fake_handshaker_process_bytes_from_peer(self, received_bytes,
687                                                      &consumed_bytes_size);
688     if (result != TSI_OK) return result;
689   }
690 
691   /* Create a handshake message to send to the peer and encode it as a fake
692    * frame. */
693   size_t offset = 0;
694   do {
695     size_t sent_bytes_size = handshaker->outgoing_bytes_buffer_size - offset;
696     result = fake_handshaker_get_bytes_to_send_to_peer(
697         self, handshaker->outgoing_bytes_buffer + offset, &sent_bytes_size);
698     offset += sent_bytes_size;
699     if (result == TSI_INCOMPLETE_DATA) {
700       handshaker->outgoing_bytes_buffer_size *= 2;
701       handshaker->outgoing_bytes_buffer = static_cast<unsigned char*>(
702           gpr_realloc(handshaker->outgoing_bytes_buffer,
703                       handshaker->outgoing_bytes_buffer_size));
704     }
705   } while (result == TSI_INCOMPLETE_DATA);
706   if (result != TSI_OK) return result;
707   *bytes_to_send = handshaker->outgoing_bytes_buffer;
708   *bytes_to_send_size = offset;
709 
710   /* Check if the handshake was completed. */
711   if (fake_handshaker_get_result(self) == TSI_HANDSHAKE_IN_PROGRESS) {
712     *handshaker_result = nullptr;
713   } else {
714     /* Calculate the unused bytes. */
715     const unsigned char* unused_bytes = nullptr;
716     size_t unused_bytes_size = received_bytes_size - consumed_bytes_size;
717     if (unused_bytes_size > 0) {
718       unused_bytes = received_bytes + consumed_bytes_size;
719     }
720 
721     /* Create a handshaker_result containing the unused bytes. */
722     result = fake_handshaker_result_create(unused_bytes, unused_bytes_size,
723                                            handshaker_result);
724     if (result == TSI_OK) {
725       /* Indicate that the handshake has completed and that a handshaker_result
726        * has been created. */
727       self->handshaker_result_created = true;
728     }
729   }
730   return result;
731 }
732 
733 static const tsi_handshaker_vtable handshaker_vtable = {
734     nullptr, /* get_bytes_to_send_to_peer -- deprecated */
735     nullptr, /* process_bytes_from_peer   -- deprecated */
736     nullptr, /* get_result                -- deprecated */
737     nullptr, /* extract_peer              -- deprecated */
738     nullptr, /* create_frame_protector    -- deprecated */
739     fake_handshaker_destroy,
740     fake_handshaker_next,
741     nullptr, /* shutdown */
742 };
743 
tsi_create_fake_handshaker(int is_client)744 tsi_handshaker* tsi_create_fake_handshaker(int is_client) {
745   tsi_fake_handshaker* impl =
746       static_cast<tsi_fake_handshaker*>(gpr_zalloc(sizeof(*impl)));
747   impl->base.vtable = &handshaker_vtable;
748   impl->is_client = is_client;
749   impl->result = TSI_HANDSHAKE_IN_PROGRESS;
750   impl->outgoing_bytes_buffer_size =
751       TSI_FAKE_HANDSHAKER_OUTGOING_BUFFER_INITIAL_SIZE;
752   impl->outgoing_bytes_buffer =
753       static_cast<unsigned char*>(gpr_malloc(impl->outgoing_bytes_buffer_size));
754   if (is_client) {
755     impl->needs_incoming_message = 0;
756     impl->next_message_to_send = TSI_FAKE_CLIENT_INIT;
757   } else {
758     impl->needs_incoming_message = 1;
759     impl->next_message_to_send = TSI_FAKE_SERVER_INIT;
760   }
761   return &impl->base;
762 }
763 
tsi_create_fake_frame_protector(size_t * max_protected_frame_size)764 tsi_frame_protector* tsi_create_fake_frame_protector(
765     size_t* max_protected_frame_size) {
766   tsi_fake_frame_protector* impl =
767       static_cast<tsi_fake_frame_protector*>(gpr_zalloc(sizeof(*impl)));
768   impl->max_frame_size = (max_protected_frame_size == nullptr)
769                              ? TSI_FAKE_DEFAULT_FRAME_SIZE
770                              : *max_protected_frame_size;
771   impl->base.vtable = &frame_protector_vtable;
772   return &impl->base;
773 }
774 
tsi_create_fake_zero_copy_grpc_protector(size_t * max_protected_frame_size)775 tsi_zero_copy_grpc_protector* tsi_create_fake_zero_copy_grpc_protector(
776     size_t* max_protected_frame_size) {
777   tsi_fake_zero_copy_grpc_protector* impl =
778       static_cast<tsi_fake_zero_copy_grpc_protector*>(
779           gpr_zalloc(sizeof(*impl)));
780   grpc_slice_buffer_init(&impl->header_sb);
781   grpc_slice_buffer_init(&impl->protected_sb);
782   impl->max_frame_size = (max_protected_frame_size == nullptr)
783                              ? TSI_FAKE_DEFAULT_FRAME_SIZE
784                              : *max_protected_frame_size;
785   impl->parsed_frame_size = 0;
786   impl->base.vtable = &zero_copy_grpc_protector_vtable;
787   return &impl->base;
788 }
789