1 //
2 //
3 // Copyright 2015 gRPC authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 // http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 //
17 //
18
19 #include "src/core/handshaker/security/secure_endpoint.h"
20
21 #include <fcntl.h>
22 #include <grpc/grpc.h>
23 #include <grpc/support/alloc.h>
24 #include <gtest/gtest.h>
25 #include <sys/types.h>
26
27 #include "absl/log/log.h"
28 #include "src/core/lib/iomgr/endpoint_pair.h"
29 #include "src/core/lib/iomgr/iomgr.h"
30 #include "src/core/lib/slice/slice_internal.h"
31 #include "src/core/tsi/fake_transport_security.h"
32 #include "src/core/util/crash.h"
33 #include "src/core/util/useful.h"
34 #include "test/core/iomgr/endpoint_tests.h"
35 #include "test/core/test_util/test_config.h"
36
37 static gpr_mu* g_mu;
38 static grpc_pollset* g_pollset;
39
40 #define TSI_FAKE_FRAME_HEADER_SIZE 4
41
42 typedef struct intercept_endpoint {
43 grpc_endpoint base;
44 grpc_endpoint* wrapped_ep;
45 grpc_slice_buffer staging_buffer;
46 } intercept_endpoint;
47
me_read(grpc_endpoint * ep,grpc_slice_buffer * slices,grpc_closure * cb,bool urgent,int min_progress_size)48 static void me_read(grpc_endpoint* ep, grpc_slice_buffer* slices,
49 grpc_closure* cb, bool urgent, int min_progress_size) {
50 intercept_endpoint* m = reinterpret_cast<intercept_endpoint*>(ep);
51 grpc_endpoint_read(m->wrapped_ep, slices, cb, urgent, min_progress_size);
52 }
53
me_write(grpc_endpoint * ep,grpc_slice_buffer * slices,grpc_closure * cb,void * arg,int max_frame_size)54 static void me_write(grpc_endpoint* ep, grpc_slice_buffer* slices,
55 grpc_closure* cb, void* arg, int max_frame_size) {
56 intercept_endpoint* m = reinterpret_cast<intercept_endpoint*>(ep);
57 int remaining = slices->length;
58 while (remaining > 0) {
59 // Estimate the frame size of the next frame.
60 int next_frame_size =
61 tsi_fake_zero_copy_grpc_protector_next_frame_size(slices);
62 ASSERT_GT(next_frame_size, TSI_FAKE_FRAME_HEADER_SIZE);
63 // Ensure the protected data size does not exceed the max_frame_size.
64 ASSERT_LE(next_frame_size - TSI_FAKE_FRAME_HEADER_SIZE, max_frame_size);
65 // Move this frame into a staging buffer and repeat.
66 grpc_slice_buffer_move_first(slices, next_frame_size, &m->staging_buffer);
67 remaining -= next_frame_size;
68 }
69 grpc_slice_buffer_swap(&m->staging_buffer, slices);
70 grpc_endpoint_write(m->wrapped_ep, slices, cb, arg, max_frame_size);
71 }
72
me_add_to_pollset(grpc_endpoint *,grpc_pollset *)73 static void me_add_to_pollset(grpc_endpoint* /*ep*/,
74 grpc_pollset* /*pollset*/) {}
75
me_add_to_pollset_set(grpc_endpoint *,grpc_pollset_set *)76 static void me_add_to_pollset_set(grpc_endpoint* /*ep*/,
77 grpc_pollset_set* /*pollset*/) {}
78
me_delete_from_pollset_set(grpc_endpoint *,grpc_pollset_set *)79 static void me_delete_from_pollset_set(grpc_endpoint* /*ep*/,
80 grpc_pollset_set* /*pollset*/) {}
81
me_destroy(grpc_endpoint * ep)82 static void me_destroy(grpc_endpoint* ep) {
83 intercept_endpoint* m = reinterpret_cast<intercept_endpoint*>(ep);
84 grpc_endpoint_destroy(m->wrapped_ep);
85 grpc_slice_buffer_destroy(&m->staging_buffer);
86 gpr_free(m);
87 }
88
me_get_peer(grpc_endpoint *)89 static absl::string_view me_get_peer(grpc_endpoint* /*ep*/) {
90 return "fake:intercept-endpoint";
91 }
92
me_get_local_address(grpc_endpoint *)93 static absl::string_view me_get_local_address(grpc_endpoint* /*ep*/) {
94 return "fake:intercept-endpoint";
95 }
96
me_get_fd(grpc_endpoint *)97 static int me_get_fd(grpc_endpoint* /*ep*/) { return -1; }
98
me_can_track_err(grpc_endpoint *)99 static bool me_can_track_err(grpc_endpoint* /*ep*/) { return false; }
100
101 static const grpc_endpoint_vtable vtable = {me_read,
102 me_write,
103 me_add_to_pollset,
104 me_add_to_pollset_set,
105 me_delete_from_pollset_set,
106 me_destroy,
107 me_get_peer,
108 me_get_local_address,
109 me_get_fd,
110 me_can_track_err};
111
wrap_with_intercept_endpoint(grpc_endpoint * wrapped_ep)112 grpc_endpoint* wrap_with_intercept_endpoint(grpc_endpoint* wrapped_ep) {
113 intercept_endpoint* m =
114 static_cast<intercept_endpoint*>(gpr_malloc(sizeof(*m)));
115 m->base.vtable = &vtable;
116 m->wrapped_ep = wrapped_ep;
117 grpc_slice_buffer_init(&m->staging_buffer);
118 return &m->base;
119 }
120
secure_endpoint_create_fixture_tcp_socketpair(size_t slice_size,grpc_slice * leftover_slices,size_t leftover_nslices,bool use_zero_copy_protector)121 static grpc_endpoint_test_fixture secure_endpoint_create_fixture_tcp_socketpair(
122 size_t slice_size, grpc_slice* leftover_slices, size_t leftover_nslices,
123 bool use_zero_copy_protector) {
124 grpc_core::ExecCtx exec_ctx;
125 tsi_frame_protector* fake_read_protector =
126 tsi_create_fake_frame_protector(nullptr);
127 tsi_frame_protector* fake_write_protector =
128 tsi_create_fake_frame_protector(nullptr);
129 tsi_zero_copy_grpc_protector* fake_read_zero_copy_protector =
130 use_zero_copy_protector
131 ? tsi_create_fake_zero_copy_grpc_protector(nullptr)
132 : nullptr;
133 tsi_zero_copy_grpc_protector* fake_write_zero_copy_protector =
134 use_zero_copy_protector
135 ? tsi_create_fake_zero_copy_grpc_protector(nullptr)
136 : nullptr;
137 grpc_endpoint_test_fixture f;
138 grpc_endpoint_pair tcp;
139
140 grpc_arg a[2];
141 a[0].key = const_cast<char*>(GRPC_ARG_TCP_READ_CHUNK_SIZE);
142 a[0].type = GRPC_ARG_INTEGER;
143 a[0].value.integer = static_cast<int>(slice_size);
144 a[1].key = const_cast<char*>(GRPC_ARG_RESOURCE_QUOTA);
145 a[1].type = GRPC_ARG_POINTER;
146 a[1].value.pointer.p = grpc_resource_quota_create("test");
147 a[1].value.pointer.vtable = grpc_resource_quota_arg_vtable();
148 grpc_channel_args args = {GPR_ARRAY_SIZE(a), a};
149 tcp = grpc_iomgr_create_endpoint_pair("fixture", &args);
150 grpc_endpoint_add_to_pollset(tcp.client, g_pollset);
151 grpc_endpoint_add_to_pollset(tcp.server, g_pollset);
152
153 // TODO(vigneshbabu): Extend the intercept endpoint logic to cover non-zero
154 // copy based frame protectors as well.
155 if (use_zero_copy_protector && leftover_nslices == 0) {
156 tcp.client = wrap_with_intercept_endpoint(tcp.client);
157 tcp.server = wrap_with_intercept_endpoint(tcp.server);
158 }
159
160 if (leftover_nslices == 0) {
161 f.client_ep = grpc_secure_endpoint_create(
162 fake_read_protector, fake_read_zero_copy_protector,
163 grpc_core::OrphanablePtr<grpc_endpoint>(tcp.client),
164 nullptr, &args, 0)
165 .release();
166 } else {
167 unsigned i;
168 tsi_result result;
169 size_t still_pending_size;
170 size_t total_buffer_size = 8192;
171 size_t buffer_size = total_buffer_size;
172 uint8_t* encrypted_buffer = static_cast<uint8_t*>(gpr_malloc(buffer_size));
173 uint8_t* cur = encrypted_buffer;
174 grpc_slice encrypted_leftover;
175 for (i = 0; i < leftover_nslices; i++) {
176 grpc_slice plain = leftover_slices[i];
177 uint8_t* message_bytes = GRPC_SLICE_START_PTR(plain);
178 size_t message_size = GRPC_SLICE_LENGTH(plain);
179 while (message_size > 0) {
180 size_t protected_buffer_size_to_send = buffer_size;
181 size_t processed_message_size = message_size;
182 result = tsi_frame_protector_protect(
183 fake_write_protector, message_bytes, &processed_message_size, cur,
184 &protected_buffer_size_to_send);
185 EXPECT_EQ(result, TSI_OK);
186 message_bytes += processed_message_size;
187 message_size -= processed_message_size;
188 cur += protected_buffer_size_to_send;
189 EXPECT_GE(buffer_size, protected_buffer_size_to_send);
190 buffer_size -= protected_buffer_size_to_send;
191 }
192 grpc_slice_unref(plain);
193 }
194 do {
195 size_t protected_buffer_size_to_send = buffer_size;
196 result = tsi_frame_protector_protect_flush(fake_write_protector, cur,
197 &protected_buffer_size_to_send,
198 &still_pending_size);
199 EXPECT_EQ(result, TSI_OK);
200 cur += protected_buffer_size_to_send;
201 EXPECT_GE(buffer_size, protected_buffer_size_to_send);
202 buffer_size -= protected_buffer_size_to_send;
203 } while (still_pending_size > 0);
204 encrypted_leftover = grpc_slice_from_copied_buffer(
205 reinterpret_cast<const char*>(encrypted_buffer),
206 total_buffer_size - buffer_size);
207 f.client_ep = grpc_secure_endpoint_create(
208 fake_read_protector, fake_read_zero_copy_protector,
209 grpc_core::OrphanablePtr<grpc_endpoint>(tcp.client),
210 &encrypted_leftover, &args, 1)
211 .release();
212 grpc_slice_unref(encrypted_leftover);
213 gpr_free(encrypted_buffer);
214 }
215
216 f.server_ep = grpc_secure_endpoint_create(
217 fake_write_protector, fake_write_zero_copy_protector,
218 grpc_core::OrphanablePtr<grpc_endpoint>(tcp.server),
219 nullptr, &args, 0)
220 .release();
221 grpc_resource_quota_unref(
222 static_cast<grpc_resource_quota*>(a[1].value.pointer.p));
223 return f;
224 }
225
226 static grpc_endpoint_test_fixture
secure_endpoint_create_fixture_tcp_socketpair_noleftover(size_t slice_size)227 secure_endpoint_create_fixture_tcp_socketpair_noleftover(size_t slice_size) {
228 return secure_endpoint_create_fixture_tcp_socketpair(slice_size, nullptr, 0,
229 false);
230 }
231
232 static grpc_endpoint_test_fixture
secure_endpoint_create_fixture_tcp_socketpair_noleftover_zero_copy(size_t slice_size)233 secure_endpoint_create_fixture_tcp_socketpair_noleftover_zero_copy(
234 size_t slice_size) {
235 return secure_endpoint_create_fixture_tcp_socketpair(slice_size, nullptr, 0,
236 true);
237 }
238
239 static grpc_endpoint_test_fixture
secure_endpoint_create_fixture_tcp_socketpair_leftover(size_t slice_size)240 secure_endpoint_create_fixture_tcp_socketpair_leftover(size_t slice_size) {
241 grpc_slice s =
242 grpc_slice_from_copied_string("hello world 12345678900987654321");
243 return secure_endpoint_create_fixture_tcp_socketpair(slice_size, &s, 1,
244 false);
245 }
246
247 static grpc_endpoint_test_fixture
secure_endpoint_create_fixture_tcp_socketpair_leftover_zero_copy(size_t slice_size)248 secure_endpoint_create_fixture_tcp_socketpair_leftover_zero_copy(
249 size_t slice_size) {
250 grpc_slice s =
251 grpc_slice_from_copied_string("hello world 12345678900987654321");
252 return secure_endpoint_create_fixture_tcp_socketpair(slice_size, &s, 1, true);
253 }
254
clean_up(void)255 static void clean_up(void) {}
256
257 static grpc_endpoint_test_config configs[] = {
258 {"secure_ep/tcp_socketpair",
259 secure_endpoint_create_fixture_tcp_socketpair_noleftover, clean_up},
260 {"secure_ep/tcp_socketpair_zero_copy",
261 secure_endpoint_create_fixture_tcp_socketpair_noleftover_zero_copy,
262 clean_up},
263 {"secure_ep/tcp_socketpair_leftover",
264 secure_endpoint_create_fixture_tcp_socketpair_leftover, clean_up},
265 {"secure_ep/tcp_socketpair_leftover_zero_copy",
266 secure_endpoint_create_fixture_tcp_socketpair_leftover_zero_copy,
267 clean_up},
268 };
269
inc_call_ctr(void * arg,grpc_error_handle)270 static void inc_call_ctr(void* arg, grpc_error_handle /*error*/) {
271 ++*static_cast<int*>(arg);
272 }
273
test_leftover(grpc_endpoint_test_config config,size_t slice_size)274 static void test_leftover(grpc_endpoint_test_config config, size_t slice_size) {
275 grpc_endpoint_test_fixture f = config.create_fixture(slice_size);
276 grpc_slice_buffer incoming;
277 grpc_slice s =
278 grpc_slice_from_copied_string("hello world 12345678900987654321");
279 grpc_core::ExecCtx exec_ctx;
280 int n = 0;
281 grpc_closure done_closure;
282 LOG(INFO) << "Start test left over";
283
284 grpc_slice_buffer_init(&incoming);
285 GRPC_CLOSURE_INIT(&done_closure, inc_call_ctr, &n, grpc_schedule_on_exec_ctx);
286 grpc_endpoint_read(f.client_ep, &incoming, &done_closure, /*urgent=*/false,
287 /*min_progress_size=*/1);
288
289 grpc_core::ExecCtx::Get()->Flush();
290 ASSERT_EQ(n, 1);
291 ASSERT_EQ(incoming.count, 1);
292 ASSERT_TRUE(grpc_slice_eq(s, incoming.slices[0]));
293
294 grpc_endpoint_destroy(f.client_ep);
295 grpc_endpoint_destroy(f.server_ep);
296
297 grpc_slice_unref(s);
298 grpc_slice_buffer_destroy(&incoming);
299
300 clean_up();
301 }
302
destroy_pollset(void * p,grpc_error_handle)303 static void destroy_pollset(void* p, grpc_error_handle /*error*/) {
304 grpc_pollset_destroy(static_cast<grpc_pollset*>(p));
305 }
306
TEST(SecureEndpointTest,MainTest)307 TEST(SecureEndpointTest, MainTest) {
308 grpc_closure destroyed;
309 grpc_init();
310
311 {
312 grpc_core::ExecCtx exec_ctx;
313 g_pollset = static_cast<grpc_pollset*>(gpr_zalloc(grpc_pollset_size()));
314 grpc_pollset_init(g_pollset, &g_mu);
315 grpc_endpoint_tests(configs[0], g_pollset, g_mu);
316 grpc_endpoint_tests(configs[1], g_pollset, g_mu);
317 test_leftover(configs[2], 1);
318 test_leftover(configs[3], 1);
319 GRPC_CLOSURE_INIT(&destroyed, destroy_pollset, g_pollset,
320 grpc_schedule_on_exec_ctx);
321 grpc_pollset_shutdown(g_pollset, &destroyed);
322 }
323
324 grpc_shutdown();
325
326 gpr_free(g_pollset);
327 }
328
main(int argc,char ** argv)329 int main(int argc, char** argv) {
330 grpc::testing::TestEnvironment env(&argc, argv);
331 ::testing::InitGoogleTest(&argc, argv);
332 return RUN_ALL_TESTS();
333 }
334