1 /*
2 *
3 * Copyright 2015-2016 gRPC authors.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *
17 */
18
19 #include <grpc/support/port_platform.h>
20
21 #include "src/core/lib/security/credentials/composite/composite_credentials.h"
22
23 #include <string.h>
24
25 #include "src/core/lib/iomgr/polling_entity.h"
26 #include "src/core/lib/surface/api_trace.h"
27
28 #include <grpc/support/alloc.h>
29 #include <grpc/support/log.h>
30 #include <grpc/support/string_util.h>
31
32 /* -- Composite call credentials. -- */
33
34 typedef struct {
35 grpc_composite_call_credentials* composite_creds;
36 size_t creds_index;
37 grpc_polling_entity* pollent;
38 grpc_auth_metadata_context auth_md_context;
39 grpc_credentials_mdelem_array* md_array;
40 grpc_closure* on_request_metadata;
41 grpc_closure internal_on_request_metadata;
42 } grpc_composite_call_credentials_metadata_context;
43
composite_call_destruct(grpc_call_credentials * creds)44 static void composite_call_destruct(grpc_call_credentials* creds) {
45 grpc_composite_call_credentials* c =
46 reinterpret_cast<grpc_composite_call_credentials*>(creds);
47 for (size_t i = 0; i < c->inner.num_creds; i++) {
48 grpc_call_credentials_unref(c->inner.creds_array[i]);
49 }
50 gpr_free(c->inner.creds_array);
51 }
52
composite_call_metadata_cb(void * arg,grpc_error * error)53 static void composite_call_metadata_cb(void* arg, grpc_error* error) {
54 grpc_composite_call_credentials_metadata_context* ctx =
55 static_cast<grpc_composite_call_credentials_metadata_context*>(arg);
56 if (error == GRPC_ERROR_NONE) {
57 /* See if we need to get some more metadata. */
58 if (ctx->creds_index < ctx->composite_creds->inner.num_creds) {
59 grpc_call_credentials* inner_creds =
60 ctx->composite_creds->inner.creds_array[ctx->creds_index++];
61 if (grpc_call_credentials_get_request_metadata(
62 inner_creds, ctx->pollent, ctx->auth_md_context, ctx->md_array,
63 &ctx->internal_on_request_metadata, &error)) {
64 // Synchronous response, so call ourselves recursively.
65 composite_call_metadata_cb(arg, error);
66 GRPC_ERROR_UNREF(error);
67 }
68 return;
69 }
70 // We're done!
71 }
72 GRPC_CLOSURE_SCHED(ctx->on_request_metadata, GRPC_ERROR_REF(error));
73 gpr_free(ctx);
74 }
75
composite_call_get_request_metadata(grpc_call_credentials * creds,grpc_polling_entity * pollent,grpc_auth_metadata_context auth_md_context,grpc_credentials_mdelem_array * md_array,grpc_closure * on_request_metadata,grpc_error ** error)76 static bool composite_call_get_request_metadata(
77 grpc_call_credentials* creds, grpc_polling_entity* pollent,
78 grpc_auth_metadata_context auth_md_context,
79 grpc_credentials_mdelem_array* md_array, grpc_closure* on_request_metadata,
80 grpc_error** error) {
81 grpc_composite_call_credentials* c =
82 reinterpret_cast<grpc_composite_call_credentials*>(creds);
83 grpc_composite_call_credentials_metadata_context* ctx;
84 ctx = static_cast<grpc_composite_call_credentials_metadata_context*>(
85 gpr_zalloc(sizeof(grpc_composite_call_credentials_metadata_context)));
86 ctx->composite_creds = c;
87 ctx->pollent = pollent;
88 ctx->auth_md_context = auth_md_context;
89 ctx->md_array = md_array;
90 ctx->on_request_metadata = on_request_metadata;
91 GRPC_CLOSURE_INIT(&ctx->internal_on_request_metadata,
92 composite_call_metadata_cb, ctx, grpc_schedule_on_exec_ctx);
93 bool synchronous = true;
94 while (ctx->creds_index < ctx->composite_creds->inner.num_creds) {
95 grpc_call_credentials* inner_creds =
96 ctx->composite_creds->inner.creds_array[ctx->creds_index++];
97 if (grpc_call_credentials_get_request_metadata(
98 inner_creds, ctx->pollent, ctx->auth_md_context, ctx->md_array,
99 &ctx->internal_on_request_metadata, error)) {
100 if (*error != GRPC_ERROR_NONE) break;
101 } else {
102 synchronous = false; // Async return.
103 break;
104 }
105 }
106 if (synchronous) gpr_free(ctx);
107 return synchronous;
108 }
109
composite_call_cancel_get_request_metadata(grpc_call_credentials * creds,grpc_credentials_mdelem_array * md_array,grpc_error * error)110 static void composite_call_cancel_get_request_metadata(
111 grpc_call_credentials* creds, grpc_credentials_mdelem_array* md_array,
112 grpc_error* error) {
113 grpc_composite_call_credentials* c =
114 reinterpret_cast<grpc_composite_call_credentials*>(creds);
115 for (size_t i = 0; i < c->inner.num_creds; ++i) {
116 grpc_call_credentials_cancel_get_request_metadata(
117 c->inner.creds_array[i], md_array, GRPC_ERROR_REF(error));
118 }
119 GRPC_ERROR_UNREF(error);
120 }
121
122 static grpc_call_credentials_vtable composite_call_credentials_vtable = {
123 composite_call_destruct, composite_call_get_request_metadata,
124 composite_call_cancel_get_request_metadata};
125
get_creds_array(grpc_call_credentials ** creds_addr)126 static grpc_call_credentials_array get_creds_array(
127 grpc_call_credentials** creds_addr) {
128 grpc_call_credentials_array result;
129 grpc_call_credentials* creds = *creds_addr;
130 result.creds_array = creds_addr;
131 result.num_creds = 1;
132 if (strcmp(creds->type, GRPC_CALL_CREDENTIALS_TYPE_COMPOSITE) == 0) {
133 result = *grpc_composite_call_credentials_get_credentials(creds);
134 }
135 return result;
136 }
137
grpc_composite_call_credentials_create(grpc_call_credentials * creds1,grpc_call_credentials * creds2,void * reserved)138 grpc_call_credentials* grpc_composite_call_credentials_create(
139 grpc_call_credentials* creds1, grpc_call_credentials* creds2,
140 void* reserved) {
141 size_t i;
142 size_t creds_array_byte_size;
143 grpc_call_credentials_array creds1_array;
144 grpc_call_credentials_array creds2_array;
145 grpc_composite_call_credentials* c;
146 GRPC_API_TRACE(
147 "grpc_composite_call_credentials_create(creds1=%p, creds2=%p, "
148 "reserved=%p)",
149 3, (creds1, creds2, reserved));
150 GPR_ASSERT(reserved == nullptr);
151 GPR_ASSERT(creds1 != nullptr);
152 GPR_ASSERT(creds2 != nullptr);
153 c = static_cast<grpc_composite_call_credentials*>(
154 gpr_zalloc(sizeof(grpc_composite_call_credentials)));
155 c->base.type = GRPC_CALL_CREDENTIALS_TYPE_COMPOSITE;
156 c->base.vtable = &composite_call_credentials_vtable;
157 gpr_ref_init(&c->base.refcount, 1);
158 creds1_array = get_creds_array(&creds1);
159 creds2_array = get_creds_array(&creds2);
160 c->inner.num_creds = creds1_array.num_creds + creds2_array.num_creds;
161 creds_array_byte_size = c->inner.num_creds * sizeof(grpc_call_credentials*);
162 c->inner.creds_array =
163 static_cast<grpc_call_credentials**>(gpr_zalloc(creds_array_byte_size));
164 for (i = 0; i < creds1_array.num_creds; i++) {
165 grpc_call_credentials* cur_creds = creds1_array.creds_array[i];
166 c->inner.creds_array[i] = grpc_call_credentials_ref(cur_creds);
167 }
168 for (i = 0; i < creds2_array.num_creds; i++) {
169 grpc_call_credentials* cur_creds = creds2_array.creds_array[i];
170 c->inner.creds_array[i + creds1_array.num_creds] =
171 grpc_call_credentials_ref(cur_creds);
172 }
173 return &c->base;
174 }
175
176 const grpc_call_credentials_array*
grpc_composite_call_credentials_get_credentials(grpc_call_credentials * creds)177 grpc_composite_call_credentials_get_credentials(grpc_call_credentials* creds) {
178 const grpc_composite_call_credentials* c =
179 reinterpret_cast<const grpc_composite_call_credentials*>(creds);
180 GPR_ASSERT(strcmp(creds->type, GRPC_CALL_CREDENTIALS_TYPE_COMPOSITE) == 0);
181 return &c->inner;
182 }
183
grpc_credentials_contains_type(grpc_call_credentials * creds,const char * type,grpc_call_credentials ** composite_creds)184 grpc_call_credentials* grpc_credentials_contains_type(
185 grpc_call_credentials* creds, const char* type,
186 grpc_call_credentials** composite_creds) {
187 size_t i;
188 if (strcmp(creds->type, type) == 0) {
189 if (composite_creds != nullptr) *composite_creds = nullptr;
190 return creds;
191 } else if (strcmp(creds->type, GRPC_CALL_CREDENTIALS_TYPE_COMPOSITE) == 0) {
192 const grpc_call_credentials_array* inner_creds_array =
193 grpc_composite_call_credentials_get_credentials(creds);
194 for (i = 0; i < inner_creds_array->num_creds; i++) {
195 if (strcmp(type, inner_creds_array->creds_array[i]->type) == 0) {
196 if (composite_creds != nullptr) *composite_creds = creds;
197 return inner_creds_array->creds_array[i];
198 }
199 }
200 }
201 return nullptr;
202 }
203
204 /* -- Composite channel credentials. -- */
205
composite_channel_destruct(grpc_channel_credentials * creds)206 static void composite_channel_destruct(grpc_channel_credentials* creds) {
207 grpc_composite_channel_credentials* c =
208 reinterpret_cast<grpc_composite_channel_credentials*>(creds);
209 grpc_channel_credentials_unref(c->inner_creds);
210 grpc_call_credentials_unref(c->call_creds);
211 }
212
composite_channel_create_security_connector(grpc_channel_credentials * creds,grpc_call_credentials * call_creds,const char * target,const grpc_channel_args * args,grpc_channel_security_connector ** sc,grpc_channel_args ** new_args)213 static grpc_security_status composite_channel_create_security_connector(
214 grpc_channel_credentials* creds, grpc_call_credentials* call_creds,
215 const char* target, const grpc_channel_args* args,
216 grpc_channel_security_connector** sc, grpc_channel_args** new_args) {
217 grpc_composite_channel_credentials* c =
218 reinterpret_cast<grpc_composite_channel_credentials*>(creds);
219 grpc_security_status status = GRPC_SECURITY_ERROR;
220
221 GPR_ASSERT(c->inner_creds != nullptr && c->call_creds != nullptr &&
222 c->inner_creds->vtable != nullptr &&
223 c->inner_creds->vtable->create_security_connector != nullptr);
224 /* If we are passed a call_creds, create a call composite to pass it
225 downstream. */
226 if (call_creds != nullptr) {
227 grpc_call_credentials* composite_call_creds =
228 grpc_composite_call_credentials_create(c->call_creds, call_creds,
229 nullptr);
230 status = c->inner_creds->vtable->create_security_connector(
231 c->inner_creds, composite_call_creds, target, args, sc, new_args);
232 grpc_call_credentials_unref(composite_call_creds);
233 } else {
234 status = c->inner_creds->vtable->create_security_connector(
235 c->inner_creds, c->call_creds, target, args, sc, new_args);
236 }
237 return status;
238 }
239
240 static grpc_channel_credentials*
composite_channel_duplicate_without_call_credentials(grpc_channel_credentials * creds)241 composite_channel_duplicate_without_call_credentials(
242 grpc_channel_credentials* creds) {
243 grpc_composite_channel_credentials* c =
244 reinterpret_cast<grpc_composite_channel_credentials*>(creds);
245 return grpc_channel_credentials_ref(c->inner_creds);
246 }
247
248 static grpc_channel_credentials_vtable composite_channel_credentials_vtable = {
249 composite_channel_destruct, composite_channel_create_security_connector,
250 composite_channel_duplicate_without_call_credentials};
251
grpc_composite_channel_credentials_create(grpc_channel_credentials * channel_creds,grpc_call_credentials * call_creds,void * reserved)252 grpc_channel_credentials* grpc_composite_channel_credentials_create(
253 grpc_channel_credentials* channel_creds, grpc_call_credentials* call_creds,
254 void* reserved) {
255 grpc_composite_channel_credentials* c =
256 static_cast<grpc_composite_channel_credentials*>(gpr_zalloc(sizeof(*c)));
257 GPR_ASSERT(channel_creds != nullptr && call_creds != nullptr &&
258 reserved == nullptr);
259 GRPC_API_TRACE(
260 "grpc_composite_channel_credentials_create(channel_creds=%p, "
261 "call_creds=%p, reserved=%p)",
262 3, (channel_creds, call_creds, reserved));
263 c->base.type = channel_creds->type;
264 c->base.vtable = &composite_channel_credentials_vtable;
265 gpr_ref_init(&c->base.refcount, 1);
266 c->inner_creds = grpc_channel_credentials_ref(channel_creds);
267 c->call_creds = grpc_call_credentials_ref(call_creds);
268 return &c->base;
269 }
270