• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 //
3 // Copyright 2015 gRPC authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //     http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 //
17 //
18 
19 #include <grpc/credentials.h>
20 #include <grpc/grpc.h>
21 #include <grpc/grpc_security.h>
22 #include <grpc/status.h>
23 #include <grpc/support/alloc.h>
24 #include <grpc/support/port_platform.h>
25 
26 #include <algorithm>
27 #include <atomic>
28 #include <cstddef>
29 #include <functional>
30 #include <memory>
31 #include <utility>
32 
33 #include "absl/log/check.h"
34 #include "absl/log/log.h"
35 #include "absl/status/status.h"
36 #include "absl/status/statusor.h"
37 #include "src/core/lib/channel/channel_args.h"
38 #include "src/core/lib/channel/channel_fwd.h"
39 #include "src/core/lib/channel/channel_stack.h"
40 #include "src/core/lib/channel/promise_based_filter.h"
41 #include "src/core/lib/debug/trace.h"
42 #include "src/core/lib/iomgr/error.h"
43 #include "src/core/lib/iomgr/exec_ctx.h"
44 #include "src/core/lib/promise/activity.h"
45 #include "src/core/lib/promise/arena_promise.h"
46 #include "src/core/lib/promise/context.h"
47 #include "src/core/lib/promise/poll.h"
48 #include "src/core/lib/promise/try_seq.h"
49 #include "src/core/lib/resource_quota/arena.h"
50 #include "src/core/lib/security/context/security_context.h"
51 #include "src/core/lib/security/credentials/credentials.h"
52 #include "src/core/lib/security/transport/auth_filters.h"  // IWYU pragma: keep
53 #include "src/core/lib/slice/slice.h"
54 #include "src/core/lib/slice/slice_internal.h"
55 #include "src/core/lib/transport/metadata_batch.h"
56 #include "src/core/lib/transport/transport.h"
57 #include "src/core/util/debug_location.h"
58 #include "src/core/util/ref_counted_ptr.h"
59 #include "src/core/util/status_helper.h"
60 
61 namespace grpc_core {
62 
63 const grpc_channel_filter ServerAuthFilter::kFilter =
64     MakePromiseBasedFilter<ServerAuthFilter, FilterEndpoint::kServer>();
65 
66 const NoInterceptor ServerAuthFilter::Call::OnClientToServerMessage;
67 const NoInterceptor ServerAuthFilter::Call::OnClientToServerHalfClose;
68 const NoInterceptor ServerAuthFilter::Call::OnServerToClientMessage;
69 const NoInterceptor ServerAuthFilter::Call::OnServerInitialMetadata;
70 const NoInterceptor ServerAuthFilter::Call::OnServerTrailingMetadata;
71 const NoInterceptor ServerAuthFilter::Call::OnFinalize;
72 
73 namespace {
74 
75 class ArrayEncoder {
76  public:
ArrayEncoder(grpc_metadata_array * result)77   explicit ArrayEncoder(grpc_metadata_array* result) : result_(result) {}
78 
Encode(const Slice & key,const Slice & value)79   void Encode(const Slice& key, const Slice& value) {
80     Append(key.Ref(), value.Ref());
81   }
82 
83   template <typename Which>
Encode(Which,const typename Which::ValueType & value)84   void Encode(Which, const typename Which::ValueType& value) {
85     Append(Slice(StaticSlice::FromStaticString(Which::key())),
86            Slice(Which::Encode(value)));
87   }
88 
Encode(HttpMethodMetadata,const typename HttpMethodMetadata::ValueType &)89   void Encode(HttpMethodMetadata,
90               const typename HttpMethodMetadata::ValueType&) {}
91 
92  private:
Append(Slice key,Slice value)93   void Append(Slice key, Slice value) {
94     if (result_->count == result_->capacity) {
95       result_->capacity =
96           std::max(result_->capacity + 8, result_->capacity * 2);
97       result_->metadata = static_cast<grpc_metadata*>(gpr_realloc(
98           result_->metadata, result_->capacity * sizeof(grpc_metadata)));
99     }
100     auto* usr_md = &result_->metadata[result_->count++];
101     usr_md->key = key.TakeCSlice();
102     usr_md->value = value.TakeCSlice();
103   }
104 
105   grpc_metadata_array* result_;
106 };
107 
108 // TODO(ctiller): seek out all users of this functionality and change API so
109 // that this unilateral format conversion IS NOT REQUIRED.
MetadataBatchToMetadataArray(const grpc_metadata_batch * batch)110 grpc_metadata_array MetadataBatchToMetadataArray(
111     const grpc_metadata_batch* batch) {
112   grpc_metadata_array result;
113   grpc_metadata_array_init(&result);
114   ArrayEncoder encoder(&result);
115   batch->Encode(&encoder);
116   return result;
117 }
118 
119 }  // namespace
120 
121 struct ServerAuthFilter::RunApplicationCode::State {
Stategrpc_core::ServerAuthFilter::RunApplicationCode::State122   explicit State(ClientMetadata& client_metadata)
123       : client_metadata(&client_metadata) {}
124   Waker waker{GetContext<Activity>()->MakeOwningWaker()};
125   absl::StatusOr<ClientMetadata*> client_metadata;
126   grpc_metadata_array md = MetadataBatchToMetadataArray(*client_metadata);
127   std::atomic<bool> done{false};
128 };
129 
RunApplicationCode(ServerAuthFilter * filter,ClientMetadata & metadata)130 ServerAuthFilter::RunApplicationCode::RunApplicationCode(
131     ServerAuthFilter* filter, ClientMetadata& metadata)
132     : state_(GetContext<Arena>()->ManagedNew<State>(metadata)) {
133   GRPC_TRACE_LOG(call, ERROR)
134       << GetContext<Activity>()->DebugTag()
135       << "[server-auth]: Delegate to application: filter=" << filter
136       << " this=" << this << " auth_ctx=" << filter->auth_context_.get();
137   filter->server_credentials_->auth_metadata_processor().process(
138       filter->server_credentials_->auth_metadata_processor().state,
139       filter->auth_context_.get(), state_->md.metadata, state_->md.count,
140       OnMdProcessingDone, state_);
141 }
142 
operator ()()143 Poll<absl::Status> ServerAuthFilter::RunApplicationCode::operator()() {
144   if (state_->done.load(std::memory_order_acquire)) {
145     return Poll<absl::Status>(std::move(state_->client_metadata).status());
146   }
147   return Pending{};
148 }
149 
OnMdProcessingDone(void * user_data,const grpc_metadata * consumed_md,size_t num_consumed_md,const grpc_metadata * response_md,size_t num_response_md,grpc_status_code status,const char * error_details)150 void ServerAuthFilter::RunApplicationCode::OnMdProcessingDone(
151     void* user_data, const grpc_metadata* consumed_md, size_t num_consumed_md,
152     const grpc_metadata* response_md, size_t num_response_md,
153     grpc_status_code status, const char* error_details) {
154   ApplicationCallbackExecCtx callback_exec_ctx;
155   ExecCtx exec_ctx;
156 
157   auto* state = static_cast<State*>(user_data);
158 
159   // TODO(ZhenLian): Implement support for response_md.
160   if (response_md != nullptr && num_response_md > 0) {
161     LOG(ERROR) << "response_md in auth metadata processing not supported for "
162                   "now. Ignoring...";
163   }
164 
165   if (status == GRPC_STATUS_OK) {
166     ClientMetadata& md = **state->client_metadata;
167     for (size_t i = 0; i < num_consumed_md; i++) {
168       md.Remove(StringViewFromSlice(consumed_md[i].key));
169     }
170   } else {
171     if (error_details == nullptr) {
172       error_details = "Authentication metadata processing failed.";
173     }
174     state->client_metadata = grpc_error_set_int(
175         absl::Status(static_cast<absl::StatusCode>(status), error_details),
176         StatusIntProperty::kRpcStatus, status);
177   }
178 
179   // Clean up.
180   for (size_t i = 0; i < state->md.count; i++) {
181     CSliceUnref(state->md.metadata[i].key);
182     CSliceUnref(state->md.metadata[i].value);
183   }
184   grpc_metadata_array_destroy(&state->md);
185 
186   auto waker = std::move(state->waker);
187   state->done.store(true, std::memory_order_release);
188   waker.Wakeup();
189 }
190 
Call(ServerAuthFilter * filter)191 ServerAuthFilter::Call::Call(ServerAuthFilter* filter) {
192   // Create server security context.  Set its auth context from channel
193   // data and save it in the call context.
194   grpc_server_security_context* server_ctx =
195       grpc_server_security_context_create(GetContext<Arena>());
196   server_ctx->auth_context =
197       filter->auth_context_->Ref(DEBUG_LOCATION, "server_auth_filter");
198   SetContext<SecurityContext>(server_ctx);
199 }
200 
ServerAuthFilter(RefCountedPtr<grpc_server_credentials> server_credentials,RefCountedPtr<grpc_auth_context> auth_context)201 ServerAuthFilter::ServerAuthFilter(
202     RefCountedPtr<grpc_server_credentials> server_credentials,
203     RefCountedPtr<grpc_auth_context> auth_context)
204     : server_credentials_(server_credentials), auth_context_(auth_context) {}
205 
Create(const ChannelArgs & args,ChannelFilter::Args)206 absl::StatusOr<std::unique_ptr<ServerAuthFilter>> ServerAuthFilter::Create(
207     const ChannelArgs& args, ChannelFilter::Args) {
208   auto auth_context = args.GetObjectRef<grpc_auth_context>();
209   CHECK(auth_context != nullptr);
210   auto creds = args.GetObjectRef<grpc_server_credentials>();
211   return std::make_unique<ServerAuthFilter>(std::move(creds),
212                                             std::move(auth_context));
213 }
214 
215 }  // namespace grpc_core
216