1 //
2 //
3 // Copyright 2015 gRPC authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 // http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 //
17 //
18
19 #include <grpc/credentials.h>
20 #include <grpc/grpc.h>
21 #include <grpc/grpc_security.h>
22 #include <grpc/status.h>
23 #include <grpc/support/alloc.h>
24 #include <grpc/support/port_platform.h>
25
26 #include <algorithm>
27 #include <atomic>
28 #include <cstddef>
29 #include <functional>
30 #include <memory>
31 #include <utility>
32
33 #include "absl/log/check.h"
34 #include "absl/log/log.h"
35 #include "absl/status/status.h"
36 #include "absl/status/statusor.h"
37 #include "src/core/lib/channel/channel_args.h"
38 #include "src/core/lib/channel/channel_fwd.h"
39 #include "src/core/lib/channel/channel_stack.h"
40 #include "src/core/lib/channel/promise_based_filter.h"
41 #include "src/core/lib/debug/trace.h"
42 #include "src/core/lib/iomgr/error.h"
43 #include "src/core/lib/iomgr/exec_ctx.h"
44 #include "src/core/lib/promise/activity.h"
45 #include "src/core/lib/promise/arena_promise.h"
46 #include "src/core/lib/promise/context.h"
47 #include "src/core/lib/promise/poll.h"
48 #include "src/core/lib/promise/try_seq.h"
49 #include "src/core/lib/resource_quota/arena.h"
50 #include "src/core/lib/security/context/security_context.h"
51 #include "src/core/lib/security/credentials/credentials.h"
52 #include "src/core/lib/security/transport/auth_filters.h" // IWYU pragma: keep
53 #include "src/core/lib/slice/slice.h"
54 #include "src/core/lib/slice/slice_internal.h"
55 #include "src/core/lib/transport/metadata_batch.h"
56 #include "src/core/lib/transport/transport.h"
57 #include "src/core/util/debug_location.h"
58 #include "src/core/util/ref_counted_ptr.h"
59 #include "src/core/util/status_helper.h"
60
61 namespace grpc_core {
62
63 const grpc_channel_filter ServerAuthFilter::kFilter =
64 MakePromiseBasedFilter<ServerAuthFilter, FilterEndpoint::kServer>();
65
66 const NoInterceptor ServerAuthFilter::Call::OnClientToServerMessage;
67 const NoInterceptor ServerAuthFilter::Call::OnClientToServerHalfClose;
68 const NoInterceptor ServerAuthFilter::Call::OnServerToClientMessage;
69 const NoInterceptor ServerAuthFilter::Call::OnServerInitialMetadata;
70 const NoInterceptor ServerAuthFilter::Call::OnServerTrailingMetadata;
71 const NoInterceptor ServerAuthFilter::Call::OnFinalize;
72
73 namespace {
74
75 class ArrayEncoder {
76 public:
ArrayEncoder(grpc_metadata_array * result)77 explicit ArrayEncoder(grpc_metadata_array* result) : result_(result) {}
78
Encode(const Slice & key,const Slice & value)79 void Encode(const Slice& key, const Slice& value) {
80 Append(key.Ref(), value.Ref());
81 }
82
83 template <typename Which>
Encode(Which,const typename Which::ValueType & value)84 void Encode(Which, const typename Which::ValueType& value) {
85 Append(Slice(StaticSlice::FromStaticString(Which::key())),
86 Slice(Which::Encode(value)));
87 }
88
Encode(HttpMethodMetadata,const typename HttpMethodMetadata::ValueType &)89 void Encode(HttpMethodMetadata,
90 const typename HttpMethodMetadata::ValueType&) {}
91
92 private:
Append(Slice key,Slice value)93 void Append(Slice key, Slice value) {
94 if (result_->count == result_->capacity) {
95 result_->capacity =
96 std::max(result_->capacity + 8, result_->capacity * 2);
97 result_->metadata = static_cast<grpc_metadata*>(gpr_realloc(
98 result_->metadata, result_->capacity * sizeof(grpc_metadata)));
99 }
100 auto* usr_md = &result_->metadata[result_->count++];
101 usr_md->key = key.TakeCSlice();
102 usr_md->value = value.TakeCSlice();
103 }
104
105 grpc_metadata_array* result_;
106 };
107
108 // TODO(ctiller): seek out all users of this functionality and change API so
109 // that this unilateral format conversion IS NOT REQUIRED.
MetadataBatchToMetadataArray(const grpc_metadata_batch * batch)110 grpc_metadata_array MetadataBatchToMetadataArray(
111 const grpc_metadata_batch* batch) {
112 grpc_metadata_array result;
113 grpc_metadata_array_init(&result);
114 ArrayEncoder encoder(&result);
115 batch->Encode(&encoder);
116 return result;
117 }
118
119 } // namespace
120
121 struct ServerAuthFilter::RunApplicationCode::State {
Stategrpc_core::ServerAuthFilter::RunApplicationCode::State122 explicit State(ClientMetadata& client_metadata)
123 : client_metadata(&client_metadata) {}
124 Waker waker{GetContext<Activity>()->MakeOwningWaker()};
125 absl::StatusOr<ClientMetadata*> client_metadata;
126 grpc_metadata_array md = MetadataBatchToMetadataArray(*client_metadata);
127 std::atomic<bool> done{false};
128 };
129
RunApplicationCode(ServerAuthFilter * filter,ClientMetadata & metadata)130 ServerAuthFilter::RunApplicationCode::RunApplicationCode(
131 ServerAuthFilter* filter, ClientMetadata& metadata)
132 : state_(GetContext<Arena>()->ManagedNew<State>(metadata)) {
133 GRPC_TRACE_LOG(call, ERROR)
134 << GetContext<Activity>()->DebugTag()
135 << "[server-auth]: Delegate to application: filter=" << filter
136 << " this=" << this << " auth_ctx=" << filter->auth_context_.get();
137 filter->server_credentials_->auth_metadata_processor().process(
138 filter->server_credentials_->auth_metadata_processor().state,
139 filter->auth_context_.get(), state_->md.metadata, state_->md.count,
140 OnMdProcessingDone, state_);
141 }
142
operator ()()143 Poll<absl::Status> ServerAuthFilter::RunApplicationCode::operator()() {
144 if (state_->done.load(std::memory_order_acquire)) {
145 return Poll<absl::Status>(std::move(state_->client_metadata).status());
146 }
147 return Pending{};
148 }
149
OnMdProcessingDone(void * user_data,const grpc_metadata * consumed_md,size_t num_consumed_md,const grpc_metadata * response_md,size_t num_response_md,grpc_status_code status,const char * error_details)150 void ServerAuthFilter::RunApplicationCode::OnMdProcessingDone(
151 void* user_data, const grpc_metadata* consumed_md, size_t num_consumed_md,
152 const grpc_metadata* response_md, size_t num_response_md,
153 grpc_status_code status, const char* error_details) {
154 ApplicationCallbackExecCtx callback_exec_ctx;
155 ExecCtx exec_ctx;
156
157 auto* state = static_cast<State*>(user_data);
158
159 // TODO(ZhenLian): Implement support for response_md.
160 if (response_md != nullptr && num_response_md > 0) {
161 LOG(ERROR) << "response_md in auth metadata processing not supported for "
162 "now. Ignoring...";
163 }
164
165 if (status == GRPC_STATUS_OK) {
166 ClientMetadata& md = **state->client_metadata;
167 for (size_t i = 0; i < num_consumed_md; i++) {
168 md.Remove(StringViewFromSlice(consumed_md[i].key));
169 }
170 } else {
171 if (error_details == nullptr) {
172 error_details = "Authentication metadata processing failed.";
173 }
174 state->client_metadata = grpc_error_set_int(
175 absl::Status(static_cast<absl::StatusCode>(status), error_details),
176 StatusIntProperty::kRpcStatus, status);
177 }
178
179 // Clean up.
180 for (size_t i = 0; i < state->md.count; i++) {
181 CSliceUnref(state->md.metadata[i].key);
182 CSliceUnref(state->md.metadata[i].value);
183 }
184 grpc_metadata_array_destroy(&state->md);
185
186 auto waker = std::move(state->waker);
187 state->done.store(true, std::memory_order_release);
188 waker.Wakeup();
189 }
190
Call(ServerAuthFilter * filter)191 ServerAuthFilter::Call::Call(ServerAuthFilter* filter) {
192 // Create server security context. Set its auth context from channel
193 // data and save it in the call context.
194 grpc_server_security_context* server_ctx =
195 grpc_server_security_context_create(GetContext<Arena>());
196 server_ctx->auth_context =
197 filter->auth_context_->Ref(DEBUG_LOCATION, "server_auth_filter");
198 SetContext<SecurityContext>(server_ctx);
199 }
200
ServerAuthFilter(RefCountedPtr<grpc_server_credentials> server_credentials,RefCountedPtr<grpc_auth_context> auth_context)201 ServerAuthFilter::ServerAuthFilter(
202 RefCountedPtr<grpc_server_credentials> server_credentials,
203 RefCountedPtr<grpc_auth_context> auth_context)
204 : server_credentials_(server_credentials), auth_context_(auth_context) {}
205
Create(const ChannelArgs & args,ChannelFilter::Args)206 absl::StatusOr<std::unique_ptr<ServerAuthFilter>> ServerAuthFilter::Create(
207 const ChannelArgs& args, ChannelFilter::Args) {
208 auto auth_context = args.GetObjectRef<grpc_auth_context>();
209 CHECK(auth_context != nullptr);
210 auto creds = args.GetObjectRef<grpc_server_credentials>();
211 return std::make_unique<ServerAuthFilter>(std::move(creds),
212 std::move(auth_context));
213 }
214
215 } // namespace grpc_core
216