1 //
2 //
3 // Copyright 2018 gRPC authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 // http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 //
17 //
18
19 #include "src/core/tsi/alts/zero_copy_frame_protector/alts_grpc_record_protocol_common.h"
20
21 #include <grpc/support/alloc.h>
22 #include <grpc/support/port_platform.h>
23 #include <string.h>
24
25 #include "absl/log/check.h"
26 #include "absl/log/log.h"
27 #include "src/core/lib/iomgr/exec_ctx.h"
28 #include "src/core/lib/slice/slice_internal.h"
29 #include "src/core/util/crash.h"
30 #include "src/core/util/useful.h"
31
32 const size_t kInitialIovecBufferSize = 8;
33
34 // Makes sure iovec_buf in alts_grpc_record_protocol is large enough.
ensure_iovec_buf_size(alts_grpc_record_protocol * rp,const grpc_slice_buffer * sb)35 static void ensure_iovec_buf_size(alts_grpc_record_protocol* rp,
36 const grpc_slice_buffer* sb) {
37 CHECK(rp != nullptr);
38 CHECK_NE(sb, nullptr);
39 if (sb->count <= rp->iovec_buf_length) {
40 return;
41 }
42 // At least double the iovec buffer size.
43 rp->iovec_buf_length = std::max(sb->count, 2 * rp->iovec_buf_length);
44 rp->iovec_buf = static_cast<iovec_t*>(
45 gpr_realloc(rp->iovec_buf, rp->iovec_buf_length * sizeof(iovec_t)));
46 }
47
48 // --- Implementation of methods defined in tsi_grpc_record_protocol_common.h.
49 // ---
50
alts_grpc_record_protocol_convert_slice_buffer_to_iovec(alts_grpc_record_protocol * rp,const grpc_slice_buffer * sb)51 void alts_grpc_record_protocol_convert_slice_buffer_to_iovec(
52 alts_grpc_record_protocol* rp, const grpc_slice_buffer* sb) {
53 CHECK(rp != nullptr);
54 CHECK_NE(sb, nullptr);
55 ensure_iovec_buf_size(rp, sb);
56 for (size_t i = 0; i < sb->count; i++) {
57 rp->iovec_buf[i].iov_base = GRPC_SLICE_START_PTR(sb->slices[i]);
58 rp->iovec_buf[i].iov_len = GRPC_SLICE_LENGTH(sb->slices[i]);
59 }
60 }
61
alts_grpc_record_protocol_copy_slice_buffer(const grpc_slice_buffer * src,unsigned char * dst)62 void alts_grpc_record_protocol_copy_slice_buffer(const grpc_slice_buffer* src,
63 unsigned char* dst) {
64 CHECK(src != nullptr);
65 CHECK_NE(dst, nullptr);
66 for (size_t i = 0; i < src->count; i++) {
67 size_t slice_length = GRPC_SLICE_LENGTH(src->slices[i]);
68 memcpy(dst, GRPC_SLICE_START_PTR(src->slices[i]), slice_length);
69 dst += slice_length;
70 }
71 }
72
alts_grpc_record_protocol_get_header_iovec(alts_grpc_record_protocol * rp)73 iovec_t alts_grpc_record_protocol_get_header_iovec(
74 alts_grpc_record_protocol* rp) {
75 iovec_t header_iovec = {nullptr, 0};
76 if (rp == nullptr) {
77 return header_iovec;
78 }
79 header_iovec.iov_len = rp->header_length;
80 if (rp->header_sb.count == 1) {
81 header_iovec.iov_base = GRPC_SLICE_START_PTR(rp->header_sb.slices[0]);
82 } else {
83 // Frame header is in multiple slices, copies the header bytes from slice
84 // buffer to a single flat buffer.
85 alts_grpc_record_protocol_copy_slice_buffer(&rp->header_sb, rp->header_buf);
86 header_iovec.iov_base = rp->header_buf;
87 }
88 return header_iovec;
89 }
90
alts_grpc_record_protocol_init(alts_grpc_record_protocol * rp,gsec_aead_crypter * crypter,size_t overflow_size,bool is_client,bool is_integrity_only,bool is_protect)91 tsi_result alts_grpc_record_protocol_init(alts_grpc_record_protocol* rp,
92 gsec_aead_crypter* crypter,
93 size_t overflow_size, bool is_client,
94 bool is_integrity_only,
95 bool is_protect) {
96 if (rp == nullptr || crypter == nullptr) {
97 LOG(ERROR)
98 << "Invalid nullptr arguments to alts_grpc_record_protocol init.";
99 return TSI_INVALID_ARGUMENT;
100 }
101 // Creates alts_iovec_record_protocol.
102 char* error_details = nullptr;
103 grpc_status_code status = alts_iovec_record_protocol_create(
104 crypter, overflow_size, is_client, is_integrity_only, is_protect,
105 &rp->iovec_rp, &error_details);
106 if (status != GRPC_STATUS_OK) {
107 LOG(ERROR) << "Failed to create alts_iovec_record_protocol, "
108 << error_details;
109 gpr_free(error_details);
110 return TSI_INTERNAL_ERROR;
111 }
112 // Allocates header slice buffer.
113 grpc_slice_buffer_init(&rp->header_sb);
114 // Allocates header buffer.
115 rp->header_length = alts_iovec_record_protocol_get_header_length();
116 rp->header_buf = static_cast<unsigned char*>(gpr_malloc(rp->header_length));
117 rp->tag_length = alts_iovec_record_protocol_get_tag_length(rp->iovec_rp);
118 // Allocates iovec buffer.
119 rp->iovec_buf_length = kInitialIovecBufferSize;
120 rp->iovec_buf =
121 static_cast<iovec_t*>(gpr_malloc(rp->iovec_buf_length * sizeof(iovec_t)));
122 return TSI_OK;
123 }
124
125 // --- Implementation of methods defined in tsi_grpc_record_protocol.h. ---
alts_grpc_record_protocol_protect(alts_grpc_record_protocol * self,grpc_slice_buffer * unprotected_slices,grpc_slice_buffer * protected_slices)126 tsi_result alts_grpc_record_protocol_protect(
127 alts_grpc_record_protocol* self, grpc_slice_buffer* unprotected_slices,
128 grpc_slice_buffer* protected_slices) {
129 if (self == nullptr || self->vtable == nullptr ||
130 unprotected_slices == nullptr || protected_slices == nullptr) {
131 return TSI_INVALID_ARGUMENT;
132 }
133 if (self->vtable->protect == nullptr) {
134 return TSI_UNIMPLEMENTED;
135 }
136 return self->vtable->protect(self, unprotected_slices, protected_slices);
137 }
138
alts_grpc_record_protocol_unprotect(alts_grpc_record_protocol * self,grpc_slice_buffer * protected_slices,grpc_slice_buffer * unprotected_slices)139 tsi_result alts_grpc_record_protocol_unprotect(
140 alts_grpc_record_protocol* self, grpc_slice_buffer* protected_slices,
141 grpc_slice_buffer* unprotected_slices) {
142 if (self == nullptr || self->vtable == nullptr ||
143 protected_slices == nullptr || unprotected_slices == nullptr) {
144 return TSI_INVALID_ARGUMENT;
145 }
146 if (self->vtable->unprotect == nullptr) {
147 return TSI_UNIMPLEMENTED;
148 }
149 return self->vtable->unprotect(self, protected_slices, unprotected_slices);
150 }
151
alts_grpc_record_protocol_destroy(alts_grpc_record_protocol * self)152 void alts_grpc_record_protocol_destroy(alts_grpc_record_protocol* self) {
153 if (self == nullptr) {
154 return;
155 }
156 if (self->vtable->destruct != nullptr) {
157 self->vtable->destruct(self);
158 }
159 alts_iovec_record_protocol_destroy(self->iovec_rp);
160 grpc_slice_buffer_destroy(&self->header_sb);
161 gpr_free(self->header_buf);
162 gpr_free(self->iovec_buf);
163 gpr_free(self);
164 }
165
166 // Integrity-only and privacy-integrity share the same implementation. No need
167 // to call vtable.
alts_grpc_record_protocol_max_unprotected_data_size(const alts_grpc_record_protocol * self,size_t max_protected_frame_size)168 size_t alts_grpc_record_protocol_max_unprotected_data_size(
169 const alts_grpc_record_protocol* self, size_t max_protected_frame_size) {
170 if (self == nullptr) {
171 return 0;
172 }
173 return alts_iovec_record_protocol_max_unprotected_data_size(
174 self->iovec_rp, max_protected_frame_size);
175 }
176