1 /*
2 *
3 * Copyright 2015 gRPC authors.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *
17 */
18
19 #include <grpc/support/port_platform.h>
20
21 #include "src/core/tsi/fake_transport_security.h"
22
23 #include <stdlib.h>
24 #include <string.h>
25
26 #include <grpc/support/alloc.h>
27 #include <grpc/support/log.h>
28
29 #include "src/core/lib/gpr/useful.h"
30 #include "src/core/lib/slice/slice_internal.h"
31 #include "src/core/tsi/transport_security_grpc.h"
32
33 /* --- Constants. ---*/
34 #define TSI_FAKE_FRAME_HEADER_SIZE 4
35 #define TSI_FAKE_FRAME_INITIAL_ALLOCATED_SIZE 64
36 #define TSI_FAKE_DEFAULT_FRAME_SIZE 16384
37 #define TSI_FAKE_HANDSHAKER_OUTGOING_BUFFER_INITIAL_SIZE 256
38
39 /* --- Structure definitions. ---*/
40
41 /* a frame is encoded like this:
42 | size | data |
43 where the size field value is the size of the size field plus the size of
44 the data encoded in little endian on 4 bytes. */
45 typedef struct {
46 unsigned char* data;
47 size_t size;
48 size_t allocated_size;
49 size_t offset;
50 int needs_draining;
51 } tsi_fake_frame;
52
53 typedef enum {
54 TSI_FAKE_CLIENT_INIT = 0,
55 TSI_FAKE_SERVER_INIT = 1,
56 TSI_FAKE_CLIENT_FINISHED = 2,
57 TSI_FAKE_SERVER_FINISHED = 3,
58 TSI_FAKE_HANDSHAKE_MESSAGE_MAX = 4
59 } tsi_fake_handshake_message;
60
61 typedef struct {
62 tsi_handshaker base;
63 int is_client;
64 tsi_fake_handshake_message next_message_to_send;
65 int needs_incoming_message;
66 tsi_fake_frame incoming_frame;
67 tsi_fake_frame outgoing_frame;
68 unsigned char* outgoing_bytes_buffer;
69 size_t outgoing_bytes_buffer_size;
70 tsi_result result;
71 } tsi_fake_handshaker;
72
73 typedef struct {
74 tsi_frame_protector base;
75 tsi_fake_frame protect_frame;
76 tsi_fake_frame unprotect_frame;
77 size_t max_frame_size;
78 } tsi_fake_frame_protector;
79
80 typedef struct {
81 tsi_zero_copy_grpc_protector base;
82 grpc_slice_buffer header_sb;
83 grpc_slice_buffer protected_sb;
84 size_t max_frame_size;
85 size_t parsed_frame_size;
86 } tsi_fake_zero_copy_grpc_protector;
87
88 /* --- Utils. ---*/
89
90 static const char* tsi_fake_handshake_message_strings[] = {
91 "CLIENT_INIT", "SERVER_INIT", "CLIENT_FINISHED", "SERVER_FINISHED"};
92
tsi_fake_handshake_message_to_string(int msg)93 static const char* tsi_fake_handshake_message_to_string(int msg) {
94 if (msg < 0 || msg >= TSI_FAKE_HANDSHAKE_MESSAGE_MAX) {
95 gpr_log(GPR_ERROR, "Invalid message %d", msg);
96 return "UNKNOWN";
97 }
98 return tsi_fake_handshake_message_strings[msg];
99 }
100
tsi_fake_handshake_message_from_string(const char * msg_string,tsi_fake_handshake_message * msg)101 static tsi_result tsi_fake_handshake_message_from_string(
102 const char* msg_string, tsi_fake_handshake_message* msg) {
103 for (int i = 0; i < TSI_FAKE_HANDSHAKE_MESSAGE_MAX; i++) {
104 if (strncmp(msg_string, tsi_fake_handshake_message_strings[i],
105 strlen(tsi_fake_handshake_message_strings[i])) == 0) {
106 *msg = static_cast<tsi_fake_handshake_message>(i);
107 return TSI_OK;
108 }
109 }
110 gpr_log(GPR_ERROR, "Invalid handshake message.");
111 return TSI_DATA_CORRUPTED;
112 }
113
load32_little_endian(const unsigned char * buf)114 static uint32_t load32_little_endian(const unsigned char* buf) {
115 return (static_cast<uint32_t>(buf[0]) | static_cast<uint32_t>(buf[1] << 8) |
116 static_cast<uint32_t>(buf[2] << 16) |
117 static_cast<uint32_t>(buf[3] << 24));
118 }
119
store32_little_endian(uint32_t value,unsigned char * buf)120 static void store32_little_endian(uint32_t value, unsigned char* buf) {
121 buf[3] = static_cast<unsigned char>((value >> 24) & 0xFF);
122 buf[2] = static_cast<unsigned char>((value >> 16) & 0xFF);
123 buf[1] = static_cast<unsigned char>((value >> 8) & 0xFF);
124 buf[0] = static_cast<unsigned char>((value)&0xFF);
125 }
126
read_frame_size(const grpc_slice_buffer * sb)127 static uint32_t read_frame_size(const grpc_slice_buffer* sb) {
128 GPR_ASSERT(sb != nullptr && sb->length >= TSI_FAKE_FRAME_HEADER_SIZE);
129 uint8_t frame_size_buffer[TSI_FAKE_FRAME_HEADER_SIZE];
130 uint8_t* buf = frame_size_buffer;
131 /* Copies the first 4 bytes to a temporary buffer. */
132 size_t remaining = TSI_FAKE_FRAME_HEADER_SIZE;
133 for (size_t i = 0; i < sb->count; i++) {
134 size_t slice_length = GRPC_SLICE_LENGTH(sb->slices[i]);
135 if (remaining <= slice_length) {
136 memcpy(buf, GRPC_SLICE_START_PTR(sb->slices[i]), remaining);
137 remaining = 0;
138 break;
139 } else {
140 memcpy(buf, GRPC_SLICE_START_PTR(sb->slices[i]), slice_length);
141 buf += slice_length;
142 remaining -= slice_length;
143 }
144 }
145 GPR_ASSERT(remaining == 0);
146 return load32_little_endian(frame_size_buffer);
147 }
148
tsi_fake_frame_reset(tsi_fake_frame * frame,int needs_draining)149 static void tsi_fake_frame_reset(tsi_fake_frame* frame, int needs_draining) {
150 frame->offset = 0;
151 frame->needs_draining = needs_draining;
152 if (!needs_draining) frame->size = 0;
153 }
154
155 /* Checks if the frame's allocated size is at least frame->size, and reallocs
156 * more memory if necessary. */
tsi_fake_frame_ensure_size(tsi_fake_frame * frame)157 static void tsi_fake_frame_ensure_size(tsi_fake_frame* frame) {
158 if (frame->data == nullptr) {
159 frame->allocated_size = frame->size;
160 frame->data =
161 static_cast<unsigned char*>(gpr_malloc(frame->allocated_size));
162 } else if (frame->size > frame->allocated_size) {
163 unsigned char* new_data =
164 static_cast<unsigned char*>(gpr_realloc(frame->data, frame->size));
165 frame->data = new_data;
166 frame->allocated_size = frame->size;
167 }
168 }
169
170 /* Decodes the serialized fake frame contained in incoming_bytes, and fills
171 * frame with the contents of the decoded frame.
172 * This method should not be called if frame->needs_framing is not 0. */
tsi_fake_frame_decode(const unsigned char * incoming_bytes,size_t * incoming_bytes_size,tsi_fake_frame * frame)173 static tsi_result tsi_fake_frame_decode(const unsigned char* incoming_bytes,
174 size_t* incoming_bytes_size,
175 tsi_fake_frame* frame) {
176 size_t available_size = *incoming_bytes_size;
177 size_t to_read_size = 0;
178 const unsigned char* bytes_cursor = incoming_bytes;
179
180 if (frame->needs_draining) return TSI_INTERNAL_ERROR;
181 if (frame->data == nullptr) {
182 frame->allocated_size = TSI_FAKE_FRAME_INITIAL_ALLOCATED_SIZE;
183 frame->data =
184 static_cast<unsigned char*>(gpr_malloc(frame->allocated_size));
185 }
186
187 if (frame->offset < TSI_FAKE_FRAME_HEADER_SIZE) {
188 to_read_size = TSI_FAKE_FRAME_HEADER_SIZE - frame->offset;
189 if (to_read_size > available_size) {
190 /* Just fill what we can and exit. */
191 memcpy(frame->data + frame->offset, bytes_cursor, available_size);
192 bytes_cursor += available_size;
193 frame->offset += available_size;
194 *incoming_bytes_size = static_cast<size_t>(bytes_cursor - incoming_bytes);
195 return TSI_INCOMPLETE_DATA;
196 }
197 memcpy(frame->data + frame->offset, bytes_cursor, to_read_size);
198 bytes_cursor += to_read_size;
199 frame->offset += to_read_size;
200 available_size -= to_read_size;
201 frame->size = load32_little_endian(frame->data);
202 tsi_fake_frame_ensure_size(frame);
203 }
204
205 to_read_size = frame->size - frame->offset;
206 if (to_read_size > available_size) {
207 memcpy(frame->data + frame->offset, bytes_cursor, available_size);
208 frame->offset += available_size;
209 bytes_cursor += available_size;
210 *incoming_bytes_size = static_cast<size_t>(bytes_cursor - incoming_bytes);
211 return TSI_INCOMPLETE_DATA;
212 }
213 memcpy(frame->data + frame->offset, bytes_cursor, to_read_size);
214 bytes_cursor += to_read_size;
215 *incoming_bytes_size = static_cast<size_t>(bytes_cursor - incoming_bytes);
216 tsi_fake_frame_reset(frame, 1 /* needs_draining */);
217 return TSI_OK;
218 }
219
220 /* Encodes a fake frame into its wire format and places the result in
221 * outgoing_bytes. outgoing_bytes_size indicates the size of the encoded frame.
222 * This method should not be called if frame->needs_framing is 0. */
tsi_fake_frame_encode(unsigned char * outgoing_bytes,size_t * outgoing_bytes_size,tsi_fake_frame * frame)223 static tsi_result tsi_fake_frame_encode(unsigned char* outgoing_bytes,
224 size_t* outgoing_bytes_size,
225 tsi_fake_frame* frame) {
226 size_t to_write_size = frame->size - frame->offset;
227 if (!frame->needs_draining) return TSI_INTERNAL_ERROR;
228 if (*outgoing_bytes_size < to_write_size) {
229 memcpy(outgoing_bytes, frame->data + frame->offset, *outgoing_bytes_size);
230 frame->offset += *outgoing_bytes_size;
231 return TSI_INCOMPLETE_DATA;
232 }
233 memcpy(outgoing_bytes, frame->data + frame->offset, to_write_size);
234 *outgoing_bytes_size = to_write_size;
235 tsi_fake_frame_reset(frame, 0 /* needs_draining */);
236 return TSI_OK;
237 }
238
239 /* Sets the payload of a fake frame to contain the given data blob, where
240 * data_size indicates the size of data. */
tsi_fake_frame_set_data(unsigned char * data,size_t data_size,tsi_fake_frame * frame)241 static tsi_result tsi_fake_frame_set_data(unsigned char* data, size_t data_size,
242 tsi_fake_frame* frame) {
243 frame->offset = 0;
244 frame->size = data_size + TSI_FAKE_FRAME_HEADER_SIZE;
245 tsi_fake_frame_ensure_size(frame);
246 store32_little_endian(static_cast<uint32_t>(frame->size), frame->data);
247 memcpy(frame->data + TSI_FAKE_FRAME_HEADER_SIZE, data, data_size);
248 tsi_fake_frame_reset(frame, 1 /* needs draining */);
249 return TSI_OK;
250 }
251
252 /* Destroys the contents of a fake frame. */
tsi_fake_frame_destruct(tsi_fake_frame * frame)253 static void tsi_fake_frame_destruct(tsi_fake_frame* frame) {
254 if (frame->data != nullptr) gpr_free(frame->data);
255 }
256
257 /* --- tsi_frame_protector methods implementation. ---*/
258
fake_protector_protect(tsi_frame_protector * self,const unsigned char * unprotected_bytes,size_t * unprotected_bytes_size,unsigned char * protected_output_frames,size_t * protected_output_frames_size)259 static tsi_result fake_protector_protect(tsi_frame_protector* self,
260 const unsigned char* unprotected_bytes,
261 size_t* unprotected_bytes_size,
262 unsigned char* protected_output_frames,
263 size_t* protected_output_frames_size) {
264 tsi_result result = TSI_OK;
265 tsi_fake_frame_protector* impl =
266 reinterpret_cast<tsi_fake_frame_protector*>(self);
267 unsigned char frame_header[TSI_FAKE_FRAME_HEADER_SIZE];
268 tsi_fake_frame* frame = &impl->protect_frame;
269 size_t saved_output_size = *protected_output_frames_size;
270 size_t drained_size = 0;
271 size_t* num_bytes_written = protected_output_frames_size;
272 *num_bytes_written = 0;
273
274 /* Try to drain first. */
275 if (frame->needs_draining) {
276 drained_size = saved_output_size - *num_bytes_written;
277 result =
278 tsi_fake_frame_encode(protected_output_frames, &drained_size, frame);
279 *num_bytes_written += drained_size;
280 protected_output_frames += drained_size;
281 if (result != TSI_OK) {
282 if (result == TSI_INCOMPLETE_DATA) {
283 *unprotected_bytes_size = 0;
284 result = TSI_OK;
285 }
286 return result;
287 }
288 }
289
290 /* Now process the unprotected_bytes. */
291 if (frame->needs_draining) return TSI_INTERNAL_ERROR;
292 if (frame->size == 0) {
293 /* New frame, create a header. */
294 size_t written_in_frame_size = 0;
295 store32_little_endian(static_cast<uint32_t>(impl->max_frame_size),
296 frame_header);
297 written_in_frame_size = TSI_FAKE_FRAME_HEADER_SIZE;
298 result = tsi_fake_frame_decode(frame_header, &written_in_frame_size, frame);
299 if (result != TSI_INCOMPLETE_DATA) {
300 gpr_log(GPR_ERROR, "tsi_fake_frame_decode returned %s",
301 tsi_result_to_string(result));
302 return result;
303 }
304 }
305 result =
306 tsi_fake_frame_decode(unprotected_bytes, unprotected_bytes_size, frame);
307 if (result != TSI_OK) {
308 if (result == TSI_INCOMPLETE_DATA) result = TSI_OK;
309 return result;
310 }
311
312 /* Try to drain again. */
313 if (!frame->needs_draining) return TSI_INTERNAL_ERROR;
314 if (frame->offset != 0) return TSI_INTERNAL_ERROR;
315 drained_size = saved_output_size - *num_bytes_written;
316 result = tsi_fake_frame_encode(protected_output_frames, &drained_size, frame);
317 *num_bytes_written += drained_size;
318 if (result == TSI_INCOMPLETE_DATA) result = TSI_OK;
319 return result;
320 }
321
fake_protector_protect_flush(tsi_frame_protector * self,unsigned char * protected_output_frames,size_t * protected_output_frames_size,size_t * still_pending_size)322 static tsi_result fake_protector_protect_flush(
323 tsi_frame_protector* self, unsigned char* protected_output_frames,
324 size_t* protected_output_frames_size, size_t* still_pending_size) {
325 tsi_result result = TSI_OK;
326 tsi_fake_frame_protector* impl =
327 reinterpret_cast<tsi_fake_frame_protector*>(self);
328 tsi_fake_frame* frame = &impl->protect_frame;
329 if (!frame->needs_draining) {
330 /* Create a short frame. */
331 frame->size = frame->offset;
332 frame->offset = 0;
333 frame->needs_draining = 1;
334 store32_little_endian(static_cast<uint32_t>(frame->size),
335 frame->data); /* Overwrite header. */
336 }
337 result = tsi_fake_frame_encode(protected_output_frames,
338 protected_output_frames_size, frame);
339 if (result == TSI_INCOMPLETE_DATA) result = TSI_OK;
340 *still_pending_size = frame->size - frame->offset;
341 return result;
342 }
343
fake_protector_unprotect(tsi_frame_protector * self,const unsigned char * protected_frames_bytes,size_t * protected_frames_bytes_size,unsigned char * unprotected_bytes,size_t * unprotected_bytes_size)344 static tsi_result fake_protector_unprotect(
345 tsi_frame_protector* self, const unsigned char* protected_frames_bytes,
346 size_t* protected_frames_bytes_size, unsigned char* unprotected_bytes,
347 size_t* unprotected_bytes_size) {
348 tsi_result result = TSI_OK;
349 tsi_fake_frame_protector* impl =
350 reinterpret_cast<tsi_fake_frame_protector*>(self);
351 tsi_fake_frame* frame = &impl->unprotect_frame;
352 size_t saved_output_size = *unprotected_bytes_size;
353 size_t drained_size = 0;
354 size_t* num_bytes_written = unprotected_bytes_size;
355 *num_bytes_written = 0;
356
357 /* Try to drain first. */
358 if (frame->needs_draining) {
359 /* Go past the header if needed. */
360 if (frame->offset == 0) frame->offset = TSI_FAKE_FRAME_HEADER_SIZE;
361 drained_size = saved_output_size - *num_bytes_written;
362 result = tsi_fake_frame_encode(unprotected_bytes, &drained_size, frame);
363 unprotected_bytes += drained_size;
364 *num_bytes_written += drained_size;
365 if (result != TSI_OK) {
366 if (result == TSI_INCOMPLETE_DATA) {
367 *protected_frames_bytes_size = 0;
368 result = TSI_OK;
369 }
370 return result;
371 }
372 }
373
374 /* Now process the protected_bytes. */
375 if (frame->needs_draining) return TSI_INTERNAL_ERROR;
376 result = tsi_fake_frame_decode(protected_frames_bytes,
377 protected_frames_bytes_size, frame);
378 if (result != TSI_OK) {
379 if (result == TSI_INCOMPLETE_DATA) result = TSI_OK;
380 return result;
381 }
382
383 /* Try to drain again. */
384 if (!frame->needs_draining) return TSI_INTERNAL_ERROR;
385 if (frame->offset != 0) return TSI_INTERNAL_ERROR;
386 frame->offset = TSI_FAKE_FRAME_HEADER_SIZE; /* Go past the header. */
387 drained_size = saved_output_size - *num_bytes_written;
388 result = tsi_fake_frame_encode(unprotected_bytes, &drained_size, frame);
389 *num_bytes_written += drained_size;
390 if (result == TSI_INCOMPLETE_DATA) result = TSI_OK;
391 return result;
392 }
393
fake_protector_destroy(tsi_frame_protector * self)394 static void fake_protector_destroy(tsi_frame_protector* self) {
395 tsi_fake_frame_protector* impl =
396 reinterpret_cast<tsi_fake_frame_protector*>(self);
397 tsi_fake_frame_destruct(&impl->protect_frame);
398 tsi_fake_frame_destruct(&impl->unprotect_frame);
399 gpr_free(self);
400 }
401
402 static const tsi_frame_protector_vtable frame_protector_vtable = {
403 fake_protector_protect,
404 fake_protector_protect_flush,
405 fake_protector_unprotect,
406 fake_protector_destroy,
407 };
408
409 /* --- tsi_zero_copy_grpc_protector methods implementation. ---*/
410
fake_zero_copy_grpc_protector_protect(tsi_zero_copy_grpc_protector * self,grpc_slice_buffer * unprotected_slices,grpc_slice_buffer * protected_slices)411 static tsi_result fake_zero_copy_grpc_protector_protect(
412 tsi_zero_copy_grpc_protector* self, grpc_slice_buffer* unprotected_slices,
413 grpc_slice_buffer* protected_slices) {
414 if (self == nullptr || unprotected_slices == nullptr ||
415 protected_slices == nullptr) {
416 return TSI_INVALID_ARGUMENT;
417 }
418 tsi_fake_zero_copy_grpc_protector* impl =
419 reinterpret_cast<tsi_fake_zero_copy_grpc_protector*>(self);
420 /* Protects each frame. */
421 while (unprotected_slices->length > 0) {
422 size_t frame_length =
423 GPR_MIN(impl->max_frame_size,
424 unprotected_slices->length + TSI_FAKE_FRAME_HEADER_SIZE);
425 grpc_slice slice = GRPC_SLICE_MALLOC(TSI_FAKE_FRAME_HEADER_SIZE);
426 store32_little_endian(static_cast<uint32_t>(frame_length),
427 GRPC_SLICE_START_PTR(slice));
428 grpc_slice_buffer_add(protected_slices, slice);
429 size_t data_length = frame_length - TSI_FAKE_FRAME_HEADER_SIZE;
430 grpc_slice_buffer_move_first(unprotected_slices, data_length,
431 protected_slices);
432 }
433 return TSI_OK;
434 }
435
fake_zero_copy_grpc_protector_unprotect(tsi_zero_copy_grpc_protector * self,grpc_slice_buffer * protected_slices,grpc_slice_buffer * unprotected_slices)436 static tsi_result fake_zero_copy_grpc_protector_unprotect(
437 tsi_zero_copy_grpc_protector* self, grpc_slice_buffer* protected_slices,
438 grpc_slice_buffer* unprotected_slices) {
439 if (self == nullptr || unprotected_slices == nullptr ||
440 protected_slices == nullptr) {
441 return TSI_INVALID_ARGUMENT;
442 }
443 tsi_fake_zero_copy_grpc_protector* impl =
444 reinterpret_cast<tsi_fake_zero_copy_grpc_protector*>(self);
445 grpc_slice_buffer_move_into(protected_slices, &impl->protected_sb);
446 /* Unprotect each frame, if we get a full frame. */
447 while (impl->protected_sb.length >= TSI_FAKE_FRAME_HEADER_SIZE) {
448 if (impl->parsed_frame_size == 0) {
449 impl->parsed_frame_size = read_frame_size(&impl->protected_sb);
450 if (impl->parsed_frame_size <= 4) {
451 gpr_log(GPR_ERROR, "Invalid frame size.");
452 return TSI_DATA_CORRUPTED;
453 }
454 }
455 /* If we do not have a full frame, return with OK status. */
456 if (impl->protected_sb.length < impl->parsed_frame_size) break;
457 /* Strips header bytes. */
458 grpc_slice_buffer_move_first(&impl->protected_sb,
459 TSI_FAKE_FRAME_HEADER_SIZE, &impl->header_sb);
460 /* Moves data to unprotected slices. */
461 grpc_slice_buffer_move_first(
462 &impl->protected_sb,
463 impl->parsed_frame_size - TSI_FAKE_FRAME_HEADER_SIZE,
464 unprotected_slices);
465 impl->parsed_frame_size = 0;
466 grpc_slice_buffer_reset_and_unref_internal(&impl->header_sb);
467 }
468 return TSI_OK;
469 }
470
fake_zero_copy_grpc_protector_destroy(tsi_zero_copy_grpc_protector * self)471 static void fake_zero_copy_grpc_protector_destroy(
472 tsi_zero_copy_grpc_protector* self) {
473 if (self == nullptr) return;
474 tsi_fake_zero_copy_grpc_protector* impl =
475 reinterpret_cast<tsi_fake_zero_copy_grpc_protector*>(self);
476 grpc_slice_buffer_destroy_internal(&impl->header_sb);
477 grpc_slice_buffer_destroy_internal(&impl->protected_sb);
478 gpr_free(impl);
479 }
480
481 static const tsi_zero_copy_grpc_protector_vtable
482 zero_copy_grpc_protector_vtable = {
483 fake_zero_copy_grpc_protector_protect,
484 fake_zero_copy_grpc_protector_unprotect,
485 fake_zero_copy_grpc_protector_destroy,
486 };
487
488 /* --- tsi_handshaker_result methods implementation. ---*/
489
490 typedef struct {
491 tsi_handshaker_result base;
492 unsigned char* unused_bytes;
493 size_t unused_bytes_size;
494 } fake_handshaker_result;
495
fake_handshaker_result_extract_peer(const tsi_handshaker_result * self,tsi_peer * peer)496 static tsi_result fake_handshaker_result_extract_peer(
497 const tsi_handshaker_result* self, tsi_peer* peer) {
498 /* Construct a tsi_peer with 1 property: certificate type. */
499 tsi_result result = tsi_construct_peer(1, peer);
500 if (result != TSI_OK) return result;
501 result = tsi_construct_string_peer_property_from_cstring(
502 TSI_CERTIFICATE_TYPE_PEER_PROPERTY, TSI_FAKE_CERTIFICATE_TYPE,
503 &peer->properties[0]);
504 if (result != TSI_OK) tsi_peer_destruct(peer);
505 return result;
506 }
507
fake_handshaker_result_create_zero_copy_grpc_protector(const tsi_handshaker_result * self,size_t * max_output_protected_frame_size,tsi_zero_copy_grpc_protector ** protector)508 static tsi_result fake_handshaker_result_create_zero_copy_grpc_protector(
509 const tsi_handshaker_result* self, size_t* max_output_protected_frame_size,
510 tsi_zero_copy_grpc_protector** protector) {
511 *protector =
512 tsi_create_fake_zero_copy_grpc_protector(max_output_protected_frame_size);
513 return TSI_OK;
514 }
515
fake_handshaker_result_create_frame_protector(const tsi_handshaker_result * self,size_t * max_output_protected_frame_size,tsi_frame_protector ** protector)516 static tsi_result fake_handshaker_result_create_frame_protector(
517 const tsi_handshaker_result* self, size_t* max_output_protected_frame_size,
518 tsi_frame_protector** protector) {
519 *protector = tsi_create_fake_frame_protector(max_output_protected_frame_size);
520 return TSI_OK;
521 }
522
fake_handshaker_result_get_unused_bytes(const tsi_handshaker_result * self,const unsigned char ** bytes,size_t * bytes_size)523 static tsi_result fake_handshaker_result_get_unused_bytes(
524 const tsi_handshaker_result* self, const unsigned char** bytes,
525 size_t* bytes_size) {
526 fake_handshaker_result* result = (fake_handshaker_result*)self;
527 *bytes_size = result->unused_bytes_size;
528 *bytes = result->unused_bytes;
529 return TSI_OK;
530 }
531
fake_handshaker_result_destroy(tsi_handshaker_result * self)532 static void fake_handshaker_result_destroy(tsi_handshaker_result* self) {
533 fake_handshaker_result* result =
534 reinterpret_cast<fake_handshaker_result*>(self);
535 gpr_free(result->unused_bytes);
536 gpr_free(self);
537 }
538
539 static const tsi_handshaker_result_vtable handshaker_result_vtable = {
540 fake_handshaker_result_extract_peer,
541 fake_handshaker_result_create_zero_copy_grpc_protector,
542 fake_handshaker_result_create_frame_protector,
543 fake_handshaker_result_get_unused_bytes,
544 fake_handshaker_result_destroy,
545 };
546
fake_handshaker_result_create(const unsigned char * unused_bytes,size_t unused_bytes_size,tsi_handshaker_result ** handshaker_result)547 static tsi_result fake_handshaker_result_create(
548 const unsigned char* unused_bytes, size_t unused_bytes_size,
549 tsi_handshaker_result** handshaker_result) {
550 if ((unused_bytes_size > 0 && unused_bytes == nullptr) ||
551 handshaker_result == nullptr) {
552 return TSI_INVALID_ARGUMENT;
553 }
554 fake_handshaker_result* result =
555 static_cast<fake_handshaker_result*>(gpr_zalloc(sizeof(*result)));
556 result->base.vtable = &handshaker_result_vtable;
557 if (unused_bytes_size > 0) {
558 result->unused_bytes =
559 static_cast<unsigned char*>(gpr_malloc(unused_bytes_size));
560 memcpy(result->unused_bytes, unused_bytes, unused_bytes_size);
561 }
562 result->unused_bytes_size = unused_bytes_size;
563 *handshaker_result = &result->base;
564 return TSI_OK;
565 }
566
567 /* --- tsi_handshaker methods implementation. ---*/
568
fake_handshaker_get_bytes_to_send_to_peer(tsi_handshaker * self,unsigned char * bytes,size_t * bytes_size)569 static tsi_result fake_handshaker_get_bytes_to_send_to_peer(
570 tsi_handshaker* self, unsigned char* bytes, size_t* bytes_size) {
571 tsi_fake_handshaker* impl = reinterpret_cast<tsi_fake_handshaker*>(self);
572 tsi_result result = TSI_OK;
573 if (impl->needs_incoming_message || impl->result == TSI_OK) {
574 *bytes_size = 0;
575 return TSI_OK;
576 }
577 if (!impl->outgoing_frame.needs_draining) {
578 tsi_fake_handshake_message next_message_to_send =
579 static_cast<tsi_fake_handshake_message>(impl->next_message_to_send + 2);
580 const char* msg_string =
581 tsi_fake_handshake_message_to_string(impl->next_message_to_send);
582 result = tsi_fake_frame_set_data((unsigned char*)msg_string,
583 strlen(msg_string), &impl->outgoing_frame);
584 if (result != TSI_OK) return result;
585 if (next_message_to_send > TSI_FAKE_HANDSHAKE_MESSAGE_MAX) {
586 next_message_to_send = TSI_FAKE_HANDSHAKE_MESSAGE_MAX;
587 }
588 if (tsi_tracing_enabled.enabled()) {
589 gpr_log(GPR_INFO, "%s prepared %s.",
590 impl->is_client ? "Client" : "Server",
591 tsi_fake_handshake_message_to_string(impl->next_message_to_send));
592 }
593 impl->next_message_to_send = next_message_to_send;
594 }
595 result = tsi_fake_frame_encode(bytes, bytes_size, &impl->outgoing_frame);
596 if (result != TSI_OK) return result;
597 if (!impl->is_client &&
598 impl->next_message_to_send == TSI_FAKE_HANDSHAKE_MESSAGE_MAX) {
599 /* We're done. */
600 if (tsi_tracing_enabled.enabled()) {
601 gpr_log(GPR_INFO, "Server is done.");
602 }
603 impl->result = TSI_OK;
604 } else {
605 impl->needs_incoming_message = 1;
606 }
607 return TSI_OK;
608 }
609
fake_handshaker_process_bytes_from_peer(tsi_handshaker * self,const unsigned char * bytes,size_t * bytes_size)610 static tsi_result fake_handshaker_process_bytes_from_peer(
611 tsi_handshaker* self, const unsigned char* bytes, size_t* bytes_size) {
612 tsi_result result = TSI_OK;
613 tsi_fake_handshaker* impl = reinterpret_cast<tsi_fake_handshaker*>(self);
614 tsi_fake_handshake_message expected_msg =
615 static_cast<tsi_fake_handshake_message>(impl->next_message_to_send - 1);
616 tsi_fake_handshake_message received_msg;
617
618 if (!impl->needs_incoming_message || impl->result == TSI_OK) {
619 *bytes_size = 0;
620 return TSI_OK;
621 }
622 result = tsi_fake_frame_decode(bytes, bytes_size, &impl->incoming_frame);
623 if (result != TSI_OK) return result;
624
625 /* We now have a complete frame. */
626 result = tsi_fake_handshake_message_from_string(
627 reinterpret_cast<const char*>(impl->incoming_frame.data) +
628 TSI_FAKE_FRAME_HEADER_SIZE,
629 &received_msg);
630 if (result != TSI_OK) {
631 impl->result = result;
632 return result;
633 }
634 if (received_msg != expected_msg) {
635 gpr_log(GPR_ERROR, "Invalid received message (%s instead of %s)",
636 tsi_fake_handshake_message_to_string(received_msg),
637 tsi_fake_handshake_message_to_string(expected_msg));
638 }
639 if (tsi_tracing_enabled.enabled()) {
640 gpr_log(GPR_INFO, "%s received %s.", impl->is_client ? "Client" : "Server",
641 tsi_fake_handshake_message_to_string(received_msg));
642 }
643 tsi_fake_frame_reset(&impl->incoming_frame, 0 /* needs_draining */);
644 impl->needs_incoming_message = 0;
645 if (impl->next_message_to_send == TSI_FAKE_HANDSHAKE_MESSAGE_MAX) {
646 /* We're done. */
647 if (tsi_tracing_enabled.enabled()) {
648 gpr_log(GPR_INFO, "%s is done.", impl->is_client ? "Client" : "Server");
649 }
650 impl->result = TSI_OK;
651 }
652 return TSI_OK;
653 }
654
fake_handshaker_get_result(tsi_handshaker * self)655 static tsi_result fake_handshaker_get_result(tsi_handshaker* self) {
656 tsi_fake_handshaker* impl = reinterpret_cast<tsi_fake_handshaker*>(self);
657 return impl->result;
658 }
659
fake_handshaker_destroy(tsi_handshaker * self)660 static void fake_handshaker_destroy(tsi_handshaker* self) {
661 tsi_fake_handshaker* impl = reinterpret_cast<tsi_fake_handshaker*>(self);
662 tsi_fake_frame_destruct(&impl->incoming_frame);
663 tsi_fake_frame_destruct(&impl->outgoing_frame);
664 gpr_free(impl->outgoing_bytes_buffer);
665 gpr_free(self);
666 }
667
fake_handshaker_next(tsi_handshaker * self,const unsigned char * received_bytes,size_t received_bytes_size,const unsigned char ** bytes_to_send,size_t * bytes_to_send_size,tsi_handshaker_result ** handshaker_result,tsi_handshaker_on_next_done_cb cb,void * user_data)668 static tsi_result fake_handshaker_next(
669 tsi_handshaker* self, const unsigned char* received_bytes,
670 size_t received_bytes_size, const unsigned char** bytes_to_send,
671 size_t* bytes_to_send_size, tsi_handshaker_result** handshaker_result,
672 tsi_handshaker_on_next_done_cb cb, void* user_data) {
673 /* Sanity check the arguments. */
674 if ((received_bytes_size > 0 && received_bytes == nullptr) ||
675 bytes_to_send == nullptr || bytes_to_send_size == nullptr ||
676 handshaker_result == nullptr) {
677 return TSI_INVALID_ARGUMENT;
678 }
679 tsi_fake_handshaker* handshaker =
680 reinterpret_cast<tsi_fake_handshaker*>(self);
681 tsi_result result = TSI_OK;
682
683 /* Decode and process a handshake frame from the peer. */
684 size_t consumed_bytes_size = received_bytes_size;
685 if (received_bytes_size > 0) {
686 result = fake_handshaker_process_bytes_from_peer(self, received_bytes,
687 &consumed_bytes_size);
688 if (result != TSI_OK) return result;
689 }
690
691 /* Create a handshake message to send to the peer and encode it as a fake
692 * frame. */
693 size_t offset = 0;
694 do {
695 size_t sent_bytes_size = handshaker->outgoing_bytes_buffer_size - offset;
696 result = fake_handshaker_get_bytes_to_send_to_peer(
697 self, handshaker->outgoing_bytes_buffer + offset, &sent_bytes_size);
698 offset += sent_bytes_size;
699 if (result == TSI_INCOMPLETE_DATA) {
700 handshaker->outgoing_bytes_buffer_size *= 2;
701 handshaker->outgoing_bytes_buffer = static_cast<unsigned char*>(
702 gpr_realloc(handshaker->outgoing_bytes_buffer,
703 handshaker->outgoing_bytes_buffer_size));
704 }
705 } while (result == TSI_INCOMPLETE_DATA);
706 if (result != TSI_OK) return result;
707 *bytes_to_send = handshaker->outgoing_bytes_buffer;
708 *bytes_to_send_size = offset;
709
710 /* Check if the handshake was completed. */
711 if (fake_handshaker_get_result(self) == TSI_HANDSHAKE_IN_PROGRESS) {
712 *handshaker_result = nullptr;
713 } else {
714 /* Calculate the unused bytes. */
715 const unsigned char* unused_bytes = nullptr;
716 size_t unused_bytes_size = received_bytes_size - consumed_bytes_size;
717 if (unused_bytes_size > 0) {
718 unused_bytes = received_bytes + consumed_bytes_size;
719 }
720
721 /* Create a handshaker_result containing the unused bytes. */
722 result = fake_handshaker_result_create(unused_bytes, unused_bytes_size,
723 handshaker_result);
724 if (result == TSI_OK) {
725 /* Indicate that the handshake has completed and that a handshaker_result
726 * has been created. */
727 self->handshaker_result_created = true;
728 }
729 }
730 return result;
731 }
732
733 static const tsi_handshaker_vtable handshaker_vtable = {
734 nullptr, /* get_bytes_to_send_to_peer -- deprecated */
735 nullptr, /* process_bytes_from_peer -- deprecated */
736 nullptr, /* get_result -- deprecated */
737 nullptr, /* extract_peer -- deprecated */
738 nullptr, /* create_frame_protector -- deprecated */
739 fake_handshaker_destroy,
740 fake_handshaker_next,
741 nullptr, /* shutdown */
742 };
743
tsi_create_fake_handshaker(int is_client)744 tsi_handshaker* tsi_create_fake_handshaker(int is_client) {
745 tsi_fake_handshaker* impl =
746 static_cast<tsi_fake_handshaker*>(gpr_zalloc(sizeof(*impl)));
747 impl->base.vtable = &handshaker_vtable;
748 impl->is_client = is_client;
749 impl->result = TSI_HANDSHAKE_IN_PROGRESS;
750 impl->outgoing_bytes_buffer_size =
751 TSI_FAKE_HANDSHAKER_OUTGOING_BUFFER_INITIAL_SIZE;
752 impl->outgoing_bytes_buffer =
753 static_cast<unsigned char*>(gpr_malloc(impl->outgoing_bytes_buffer_size));
754 if (is_client) {
755 impl->needs_incoming_message = 0;
756 impl->next_message_to_send = TSI_FAKE_CLIENT_INIT;
757 } else {
758 impl->needs_incoming_message = 1;
759 impl->next_message_to_send = TSI_FAKE_SERVER_INIT;
760 }
761 return &impl->base;
762 }
763
tsi_create_fake_frame_protector(size_t * max_protected_frame_size)764 tsi_frame_protector* tsi_create_fake_frame_protector(
765 size_t* max_protected_frame_size) {
766 tsi_fake_frame_protector* impl =
767 static_cast<tsi_fake_frame_protector*>(gpr_zalloc(sizeof(*impl)));
768 impl->max_frame_size = (max_protected_frame_size == nullptr)
769 ? TSI_FAKE_DEFAULT_FRAME_SIZE
770 : *max_protected_frame_size;
771 impl->base.vtable = &frame_protector_vtable;
772 return &impl->base;
773 }
774
tsi_create_fake_zero_copy_grpc_protector(size_t * max_protected_frame_size)775 tsi_zero_copy_grpc_protector* tsi_create_fake_zero_copy_grpc_protector(
776 size_t* max_protected_frame_size) {
777 tsi_fake_zero_copy_grpc_protector* impl =
778 static_cast<tsi_fake_zero_copy_grpc_protector*>(
779 gpr_zalloc(sizeof(*impl)));
780 grpc_slice_buffer_init(&impl->header_sb);
781 grpc_slice_buffer_init(&impl->protected_sb);
782 impl->max_frame_size = (max_protected_frame_size == nullptr)
783 ? TSI_FAKE_DEFAULT_FRAME_SIZE
784 : *max_protected_frame_size;
785 impl->parsed_frame_size = 0;
786 impl->base.vtable = &zero_copy_grpc_protector_vtable;
787 return &impl->base;
788 }
789