• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2022 gRPC authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "src/core/ext/filters/http/message_compress/compression_filter.h"
16 
17 #include <grpc/compression.h>
18 #include <grpc/grpc.h>
19 #include <grpc/impl/channel_arg_names.h>
20 #include <grpc/impl/compression_types.h>
21 #include <grpc/support/port_platform.h>
22 #include <inttypes.h>
23 
24 #include <functional>
25 #include <memory>
26 #include <utility>
27 
28 #include "absl/log/check.h"
29 #include "absl/status/status.h"
30 #include "absl/strings/str_cat.h"
31 #include "absl/strings/str_format.h"
32 #include "absl/types/optional.h"
33 #include "src/core/ext/filters/message_size/message_size_filter.h"
34 #include "src/core/lib/channel/channel_args.h"
35 #include "src/core/lib/channel/channel_stack.h"
36 #include "src/core/lib/channel/promise_based_filter.h"
37 #include "src/core/lib/compression/compression_internal.h"
38 #include "src/core/lib/compression/message_compress.h"
39 #include "src/core/lib/debug/trace.h"
40 #include "src/core/lib/promise/activity.h"
41 #include "src/core/lib/promise/context.h"
42 #include "src/core/lib/promise/latch.h"
43 #include "src/core/lib/promise/pipe.h"
44 #include "src/core/lib/promise/prioritized_race.h"
45 #include "src/core/lib/resource_quota/arena.h"
46 #include "src/core/lib/slice/slice_buffer.h"
47 #include "src/core/lib/surface/call.h"
48 #include "src/core/lib/transport/metadata_batch.h"
49 #include "src/core/lib/transport/transport.h"
50 #include "src/core/telemetry/call_tracer.h"
51 #include "src/core/util/latent_see.h"
52 
53 namespace grpc_core {
54 
55 const NoInterceptor ServerCompressionFilter::Call::OnClientToServerHalfClose;
56 const NoInterceptor ServerCompressionFilter::Call::OnServerTrailingMetadata;
57 const NoInterceptor ServerCompressionFilter::Call::OnFinalize;
58 const NoInterceptor ClientCompressionFilter::Call::OnClientToServerHalfClose;
59 const NoInterceptor ClientCompressionFilter::Call::OnServerTrailingMetadata;
60 const NoInterceptor ClientCompressionFilter::Call::OnFinalize;
61 
62 const grpc_channel_filter ClientCompressionFilter::kFilter =
63     MakePromiseBasedFilter<ClientCompressionFilter, FilterEndpoint::kClient,
64                            kFilterExaminesServerInitialMetadata |
65                                kFilterExaminesInboundMessages |
66                                kFilterExaminesOutboundMessages>();
67 const grpc_channel_filter ServerCompressionFilter::kFilter =
68     MakePromiseBasedFilter<ServerCompressionFilter, FilterEndpoint::kServer,
69                            kFilterExaminesServerInitialMetadata |
70                                kFilterExaminesInboundMessages |
71                                kFilterExaminesOutboundMessages>();
72 
73 absl::StatusOr<std::unique_ptr<ClientCompressionFilter>>
Create(const ChannelArgs & args,ChannelFilter::Args)74 ClientCompressionFilter::Create(const ChannelArgs& args, ChannelFilter::Args) {
75   return std::make_unique<ClientCompressionFilter>(args);
76 }
77 
78 absl::StatusOr<std::unique_ptr<ServerCompressionFilter>>
Create(const ChannelArgs & args,ChannelFilter::Args)79 ServerCompressionFilter::Create(const ChannelArgs& args, ChannelFilter::Args) {
80   return std::make_unique<ServerCompressionFilter>(args);
81 }
82 
ChannelCompression(const ChannelArgs & args)83 ChannelCompression::ChannelCompression(const ChannelArgs& args)
84     : max_recv_size_(GetMaxRecvSizeFromChannelArgs(args)),
85       message_size_service_config_parser_index_(
86           MessageSizeParser::ParserIndex()),
87       default_compression_algorithm_(
88           DefaultCompressionAlgorithmFromChannelArgs(args).value_or(
89               GRPC_COMPRESS_NONE)),
90       enabled_compression_algorithms_(
91           CompressionAlgorithmSet::FromChannelArgs(args)),
92       enable_compression_(
93           args.GetBool(GRPC_ARG_ENABLE_PER_MESSAGE_COMPRESSION).value_or(true)),
94       enable_decompression_(
95           args.GetBool(GRPC_ARG_ENABLE_PER_MESSAGE_DECOMPRESSION)
96               .value_or(true)) {
97   // Make sure the default is enabled.
98   if (!enabled_compression_algorithms_.IsSet(default_compression_algorithm_)) {
99     const char* name;
100     if (!grpc_compression_algorithm_name(default_compression_algorithm_,
101                                          &name)) {
102       name = "<unknown>";
103     }
104     LOG(ERROR) << "default compression algorithm " << name
105                << " not enabled: switching to none";
106     default_compression_algorithm_ = GRPC_COMPRESS_NONE;
107   }
108 }
109 
CompressMessage(MessageHandle message,grpc_compression_algorithm algorithm) const110 MessageHandle ChannelCompression::CompressMessage(
111     MessageHandle message, grpc_compression_algorithm algorithm) const {
112   GRPC_TRACE_LOG(compression, INFO)
113       << "CompressMessage: len=" << message->payload()->Length()
114       << " alg=" << algorithm << " flags=" << message->flags();
115   auto* call_tracer = MaybeGetContext<CallTracerInterface>();
116   if (call_tracer != nullptr) {
117     call_tracer->RecordSendMessage(*message->payload());
118   }
119   // Check if we're allowed to compress this message
120   // (apps might want to disable compression for certain messages to avoid
121   // crime/beast like vulns).
122   uint32_t& flags = message->mutable_flags();
123   if (algorithm == GRPC_COMPRESS_NONE || !enable_compression_ ||
124       (flags & (GRPC_WRITE_NO_COMPRESS | GRPC_WRITE_INTERNAL_COMPRESS))) {
125     return message;
126   }
127   // Try to compress the payload.
128   SliceBuffer tmp;
129   SliceBuffer* payload = message->payload();
130   bool did_compress = grpc_msg_compress(algorithm, payload->c_slice_buffer(),
131                                         tmp.c_slice_buffer());
132   // If we achieved compression send it as compressed, otherwise send it as (to
133   // avoid spending cycles on the receiver decompressing).
134   if (did_compress) {
135     if (GRPC_TRACE_FLAG_ENABLED(compression)) {
136       const char* algo_name;
137       const size_t before_size = payload->Length();
138       const size_t after_size = tmp.Length();
139       const float savings_ratio = 1.0f - (static_cast<float>(after_size) /
140                                           static_cast<float>(before_size));
141       CHECK(grpc_compression_algorithm_name(algorithm, &algo_name));
142       LOG(INFO) << absl::StrFormat(
143           "Compressed[%s] %" PRIuPTR " bytes vs. %" PRIuPTR
144           " bytes (%.2f%% savings)",
145           algo_name, before_size, after_size, 100 * savings_ratio);
146     }
147     tmp.Swap(payload);
148     flags |= GRPC_WRITE_INTERNAL_COMPRESS;
149     if (call_tracer != nullptr) {
150       call_tracer->RecordSendCompressedMessage(*message->payload());
151     }
152   } else {
153     if (GRPC_TRACE_FLAG_ENABLED(compression)) {
154       const char* algo_name;
155       CHECK(grpc_compression_algorithm_name(algorithm, &algo_name));
156       LOG(INFO) << "Algorithm '" << algo_name
157                 << "' enabled but decided not to compress. Input size: "
158                 << payload->Length();
159     }
160   }
161   return message;
162 }
163 
DecompressMessage(bool is_client,MessageHandle message,DecompressArgs args) const164 absl::StatusOr<MessageHandle> ChannelCompression::DecompressMessage(
165     bool is_client, MessageHandle message, DecompressArgs args) const {
166   GRPC_TRACE_LOG(compression, INFO)
167       << "DecompressMessage: len=" << message->payload()->Length()
168       << " max=" << args.max_recv_message_length.value_or(-1)
169       << " alg=" << args.algorithm;
170   auto* call_tracer = MaybeGetContext<CallTracerInterface>();
171   if (call_tracer != nullptr) {
172     call_tracer->RecordReceivedMessage(*message->payload());
173   }
174   // Check max message length.
175   if (args.max_recv_message_length.has_value() &&
176       message->payload()->Length() >
177           static_cast<size_t>(*args.max_recv_message_length)) {
178     return absl::ResourceExhaustedError(absl::StrFormat(
179         "%s: Received message larger than max (%u vs. %d)",
180         is_client ? "CLIENT" : "SERVER", message->payload()->Length(),
181         *args.max_recv_message_length));
182   }
183   // Check if decompression is enabled (if not, we can just pass the message
184   // up).
185   if (!enable_decompression_ ||
186       (message->flags() & GRPC_WRITE_INTERNAL_COMPRESS) == 0) {
187     return std::move(message);
188   }
189   // Try to decompress the payload.
190   SliceBuffer decompressed_slices;
191   if (grpc_msg_decompress(args.algorithm, message->payload()->c_slice_buffer(),
192                           decompressed_slices.c_slice_buffer()) == 0) {
193     return absl::InternalError(
194         absl::StrCat("Unexpected error decompressing data for algorithm ",
195                      CompressionAlgorithmAsString(args.algorithm)));
196   }
197   // Swap the decompressed slices into the message.
198   message->payload()->Swap(&decompressed_slices);
199   message->mutable_flags() &= ~GRPC_WRITE_INTERNAL_COMPRESS;
200   message->mutable_flags() |= GRPC_WRITE_INTERNAL_TEST_ONLY_WAS_COMPRESSED;
201   if (call_tracer != nullptr) {
202     call_tracer->RecordReceivedDecompressedMessage(*message->payload());
203   }
204   return std::move(message);
205 }
206 
HandleOutgoingMetadata(grpc_metadata_batch & outgoing_metadata)207 grpc_compression_algorithm ChannelCompression::HandleOutgoingMetadata(
208     grpc_metadata_batch& outgoing_metadata) {
209   const auto algorithm = outgoing_metadata.Take(GrpcInternalEncodingRequest())
210                              .value_or(default_compression_algorithm());
211   // Convey supported compression algorithms.
212   outgoing_metadata.Set(GrpcAcceptEncodingMetadata(),
213                         enabled_compression_algorithms());
214   if (algorithm != GRPC_COMPRESS_NONE) {
215     outgoing_metadata.Set(GrpcEncodingMetadata(), algorithm);
216   }
217   return algorithm;
218 }
219 
HandleIncomingMetadata(const grpc_metadata_batch & incoming_metadata)220 ChannelCompression::DecompressArgs ChannelCompression::HandleIncomingMetadata(
221     const grpc_metadata_batch& incoming_metadata) {
222   // Configure max receive size.
223   auto max_recv_message_length = max_recv_size_;
224   const MessageSizeParsedConfig* limits =
225       MessageSizeParsedConfig::GetFromCallContext(
226           GetContext<Arena>(), message_size_service_config_parser_index_);
227   if (limits != nullptr && limits->max_recv_size().has_value() &&
228       (!max_recv_message_length.has_value() ||
229        *limits->max_recv_size() < *max_recv_message_length)) {
230     max_recv_message_length = limits->max_recv_size();
231   }
232   return DecompressArgs{incoming_metadata.get(GrpcEncodingMetadata())
233                             .value_or(GRPC_COMPRESS_NONE),
234                         max_recv_message_length};
235 }
236 
OnClientInitialMetadata(ClientMetadata & md,ClientCompressionFilter * filter)237 void ClientCompressionFilter::Call::OnClientInitialMetadata(
238     ClientMetadata& md, ClientCompressionFilter* filter) {
239   GRPC_LATENT_SEE_INNER_SCOPE(
240       "ClientCompressionFilter::Call::OnClientInitialMetadata");
241   compression_algorithm_ =
242       filter->compression_engine_.HandleOutgoingMetadata(md);
243 }
244 
OnClientToServerMessage(MessageHandle message,ClientCompressionFilter * filter)245 MessageHandle ClientCompressionFilter::Call::OnClientToServerMessage(
246     MessageHandle message, ClientCompressionFilter* filter) {
247   GRPC_LATENT_SEE_INNER_SCOPE(
248       "ClientCompressionFilter::Call::OnClientToServerMessage");
249   return filter->compression_engine_.CompressMessage(std::move(message),
250                                                      compression_algorithm_);
251 }
252 
OnServerInitialMetadata(ServerMetadata & md,ClientCompressionFilter * filter)253 void ClientCompressionFilter::Call::OnServerInitialMetadata(
254     ServerMetadata& md, ClientCompressionFilter* filter) {
255   GRPC_LATENT_SEE_INNER_SCOPE(
256       "ClientCompressionFilter::Call::OnServerInitialMetadata");
257   decompress_args_ = filter->compression_engine_.HandleIncomingMetadata(md);
258 }
259 
260 absl::StatusOr<MessageHandle>
OnServerToClientMessage(MessageHandle message,ClientCompressionFilter * filter)261 ClientCompressionFilter::Call::OnServerToClientMessage(
262     MessageHandle message, ClientCompressionFilter* filter) {
263   GRPC_LATENT_SEE_INNER_SCOPE(
264       "ClientCompressionFilter::Call::OnServerToClientMessage");
265   return filter->compression_engine_.DecompressMessage(
266       /*is_client=*/true, std::move(message), decompress_args_);
267 }
268 
OnClientInitialMetadata(ClientMetadata & md,ServerCompressionFilter * filter)269 void ServerCompressionFilter::Call::OnClientInitialMetadata(
270     ClientMetadata& md, ServerCompressionFilter* filter) {
271   GRPC_LATENT_SEE_INNER_SCOPE(
272       "ServerCompressionFilter::Call::OnClientInitialMetadata");
273   decompress_args_ = filter->compression_engine_.HandleIncomingMetadata(md);
274 }
275 
276 absl::StatusOr<MessageHandle>
OnClientToServerMessage(MessageHandle message,ServerCompressionFilter * filter)277 ServerCompressionFilter::Call::OnClientToServerMessage(
278     MessageHandle message, ServerCompressionFilter* filter) {
279   GRPC_LATENT_SEE_INNER_SCOPE(
280       "ServerCompressionFilter::Call::OnClientToServerMessage");
281   return filter->compression_engine_.DecompressMessage(
282       /*is_client=*/false, std::move(message), decompress_args_);
283 }
284 
OnServerInitialMetadata(ServerMetadata & md,ServerCompressionFilter * filter)285 void ServerCompressionFilter::Call::OnServerInitialMetadata(
286     ServerMetadata& md, ServerCompressionFilter* filter) {
287   GRPC_LATENT_SEE_INNER_SCOPE(
288       "ServerCompressionFilter::Call::OnServerInitialMetadata");
289   compression_algorithm_ =
290       filter->compression_engine_.HandleOutgoingMetadata(md);
291 }
292 
OnServerToClientMessage(MessageHandle message,ServerCompressionFilter * filter)293 MessageHandle ServerCompressionFilter::Call::OnServerToClientMessage(
294     MessageHandle message, ServerCompressionFilter* filter) {
295   GRPC_LATENT_SEE_INNER_SCOPE(
296       "ServerCompressionFilter::Call::OnServerToClientMessage");
297   return filter->compression_engine_.CompressMessage(std::move(message),
298                                                      compression_algorithm_);
299 }
300 
301 }  // namespace grpc_core
302