1 // Copyright 2022 gRPC authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "src/core/ext/filters/http/message_compress/compression_filter.h"
16
17 #include <grpc/compression.h>
18 #include <grpc/grpc.h>
19 #include <grpc/impl/channel_arg_names.h>
20 #include <grpc/impl/compression_types.h>
21 #include <grpc/support/port_platform.h>
22 #include <inttypes.h>
23
24 #include <functional>
25 #include <memory>
26 #include <utility>
27
28 #include "absl/log/check.h"
29 #include "absl/status/status.h"
30 #include "absl/strings/str_cat.h"
31 #include "absl/strings/str_format.h"
32 #include "absl/types/optional.h"
33 #include "src/core/ext/filters/message_size/message_size_filter.h"
34 #include "src/core/lib/channel/channel_args.h"
35 #include "src/core/lib/channel/channel_stack.h"
36 #include "src/core/lib/channel/promise_based_filter.h"
37 #include "src/core/lib/compression/compression_internal.h"
38 #include "src/core/lib/compression/message_compress.h"
39 #include "src/core/lib/debug/trace.h"
40 #include "src/core/lib/promise/activity.h"
41 #include "src/core/lib/promise/context.h"
42 #include "src/core/lib/promise/latch.h"
43 #include "src/core/lib/promise/pipe.h"
44 #include "src/core/lib/promise/prioritized_race.h"
45 #include "src/core/lib/resource_quota/arena.h"
46 #include "src/core/lib/slice/slice_buffer.h"
47 #include "src/core/lib/surface/call.h"
48 #include "src/core/lib/transport/metadata_batch.h"
49 #include "src/core/lib/transport/transport.h"
50 #include "src/core/telemetry/call_tracer.h"
51 #include "src/core/util/latent_see.h"
52
53 namespace grpc_core {
54
55 const NoInterceptor ServerCompressionFilter::Call::OnClientToServerHalfClose;
56 const NoInterceptor ServerCompressionFilter::Call::OnServerTrailingMetadata;
57 const NoInterceptor ServerCompressionFilter::Call::OnFinalize;
58 const NoInterceptor ClientCompressionFilter::Call::OnClientToServerHalfClose;
59 const NoInterceptor ClientCompressionFilter::Call::OnServerTrailingMetadata;
60 const NoInterceptor ClientCompressionFilter::Call::OnFinalize;
61
62 const grpc_channel_filter ClientCompressionFilter::kFilter =
63 MakePromiseBasedFilter<ClientCompressionFilter, FilterEndpoint::kClient,
64 kFilterExaminesServerInitialMetadata |
65 kFilterExaminesInboundMessages |
66 kFilterExaminesOutboundMessages>();
67 const grpc_channel_filter ServerCompressionFilter::kFilter =
68 MakePromiseBasedFilter<ServerCompressionFilter, FilterEndpoint::kServer,
69 kFilterExaminesServerInitialMetadata |
70 kFilterExaminesInboundMessages |
71 kFilterExaminesOutboundMessages>();
72
73 absl::StatusOr<std::unique_ptr<ClientCompressionFilter>>
Create(const ChannelArgs & args,ChannelFilter::Args)74 ClientCompressionFilter::Create(const ChannelArgs& args, ChannelFilter::Args) {
75 return std::make_unique<ClientCompressionFilter>(args);
76 }
77
78 absl::StatusOr<std::unique_ptr<ServerCompressionFilter>>
Create(const ChannelArgs & args,ChannelFilter::Args)79 ServerCompressionFilter::Create(const ChannelArgs& args, ChannelFilter::Args) {
80 return std::make_unique<ServerCompressionFilter>(args);
81 }
82
ChannelCompression(const ChannelArgs & args)83 ChannelCompression::ChannelCompression(const ChannelArgs& args)
84 : max_recv_size_(GetMaxRecvSizeFromChannelArgs(args)),
85 message_size_service_config_parser_index_(
86 MessageSizeParser::ParserIndex()),
87 default_compression_algorithm_(
88 DefaultCompressionAlgorithmFromChannelArgs(args).value_or(
89 GRPC_COMPRESS_NONE)),
90 enabled_compression_algorithms_(
91 CompressionAlgorithmSet::FromChannelArgs(args)),
92 enable_compression_(
93 args.GetBool(GRPC_ARG_ENABLE_PER_MESSAGE_COMPRESSION).value_or(true)),
94 enable_decompression_(
95 args.GetBool(GRPC_ARG_ENABLE_PER_MESSAGE_DECOMPRESSION)
96 .value_or(true)) {
97 // Make sure the default is enabled.
98 if (!enabled_compression_algorithms_.IsSet(default_compression_algorithm_)) {
99 const char* name;
100 if (!grpc_compression_algorithm_name(default_compression_algorithm_,
101 &name)) {
102 name = "<unknown>";
103 }
104 LOG(ERROR) << "default compression algorithm " << name
105 << " not enabled: switching to none";
106 default_compression_algorithm_ = GRPC_COMPRESS_NONE;
107 }
108 }
109
CompressMessage(MessageHandle message,grpc_compression_algorithm algorithm) const110 MessageHandle ChannelCompression::CompressMessage(
111 MessageHandle message, grpc_compression_algorithm algorithm) const {
112 GRPC_TRACE_LOG(compression, INFO)
113 << "CompressMessage: len=" << message->payload()->Length()
114 << " alg=" << algorithm << " flags=" << message->flags();
115 auto* call_tracer = MaybeGetContext<CallTracerInterface>();
116 if (call_tracer != nullptr) {
117 call_tracer->RecordSendMessage(*message->payload());
118 }
119 // Check if we're allowed to compress this message
120 // (apps might want to disable compression for certain messages to avoid
121 // crime/beast like vulns).
122 uint32_t& flags = message->mutable_flags();
123 if (algorithm == GRPC_COMPRESS_NONE || !enable_compression_ ||
124 (flags & (GRPC_WRITE_NO_COMPRESS | GRPC_WRITE_INTERNAL_COMPRESS))) {
125 return message;
126 }
127 // Try to compress the payload.
128 SliceBuffer tmp;
129 SliceBuffer* payload = message->payload();
130 bool did_compress = grpc_msg_compress(algorithm, payload->c_slice_buffer(),
131 tmp.c_slice_buffer());
132 // If we achieved compression send it as compressed, otherwise send it as (to
133 // avoid spending cycles on the receiver decompressing).
134 if (did_compress) {
135 if (GRPC_TRACE_FLAG_ENABLED(compression)) {
136 const char* algo_name;
137 const size_t before_size = payload->Length();
138 const size_t after_size = tmp.Length();
139 const float savings_ratio = 1.0f - (static_cast<float>(after_size) /
140 static_cast<float>(before_size));
141 CHECK(grpc_compression_algorithm_name(algorithm, &algo_name));
142 LOG(INFO) << absl::StrFormat(
143 "Compressed[%s] %" PRIuPTR " bytes vs. %" PRIuPTR
144 " bytes (%.2f%% savings)",
145 algo_name, before_size, after_size, 100 * savings_ratio);
146 }
147 tmp.Swap(payload);
148 flags |= GRPC_WRITE_INTERNAL_COMPRESS;
149 if (call_tracer != nullptr) {
150 call_tracer->RecordSendCompressedMessage(*message->payload());
151 }
152 } else {
153 if (GRPC_TRACE_FLAG_ENABLED(compression)) {
154 const char* algo_name;
155 CHECK(grpc_compression_algorithm_name(algorithm, &algo_name));
156 LOG(INFO) << "Algorithm '" << algo_name
157 << "' enabled but decided not to compress. Input size: "
158 << payload->Length();
159 }
160 }
161 return message;
162 }
163
DecompressMessage(bool is_client,MessageHandle message,DecompressArgs args) const164 absl::StatusOr<MessageHandle> ChannelCompression::DecompressMessage(
165 bool is_client, MessageHandle message, DecompressArgs args) const {
166 GRPC_TRACE_LOG(compression, INFO)
167 << "DecompressMessage: len=" << message->payload()->Length()
168 << " max=" << args.max_recv_message_length.value_or(-1)
169 << " alg=" << args.algorithm;
170 auto* call_tracer = MaybeGetContext<CallTracerInterface>();
171 if (call_tracer != nullptr) {
172 call_tracer->RecordReceivedMessage(*message->payload());
173 }
174 // Check max message length.
175 if (args.max_recv_message_length.has_value() &&
176 message->payload()->Length() >
177 static_cast<size_t>(*args.max_recv_message_length)) {
178 return absl::ResourceExhaustedError(absl::StrFormat(
179 "%s: Received message larger than max (%u vs. %d)",
180 is_client ? "CLIENT" : "SERVER", message->payload()->Length(),
181 *args.max_recv_message_length));
182 }
183 // Check if decompression is enabled (if not, we can just pass the message
184 // up).
185 if (!enable_decompression_ ||
186 (message->flags() & GRPC_WRITE_INTERNAL_COMPRESS) == 0) {
187 return std::move(message);
188 }
189 // Try to decompress the payload.
190 SliceBuffer decompressed_slices;
191 if (grpc_msg_decompress(args.algorithm, message->payload()->c_slice_buffer(),
192 decompressed_slices.c_slice_buffer()) == 0) {
193 return absl::InternalError(
194 absl::StrCat("Unexpected error decompressing data for algorithm ",
195 CompressionAlgorithmAsString(args.algorithm)));
196 }
197 // Swap the decompressed slices into the message.
198 message->payload()->Swap(&decompressed_slices);
199 message->mutable_flags() &= ~GRPC_WRITE_INTERNAL_COMPRESS;
200 message->mutable_flags() |= GRPC_WRITE_INTERNAL_TEST_ONLY_WAS_COMPRESSED;
201 if (call_tracer != nullptr) {
202 call_tracer->RecordReceivedDecompressedMessage(*message->payload());
203 }
204 return std::move(message);
205 }
206
HandleOutgoingMetadata(grpc_metadata_batch & outgoing_metadata)207 grpc_compression_algorithm ChannelCompression::HandleOutgoingMetadata(
208 grpc_metadata_batch& outgoing_metadata) {
209 const auto algorithm = outgoing_metadata.Take(GrpcInternalEncodingRequest())
210 .value_or(default_compression_algorithm());
211 // Convey supported compression algorithms.
212 outgoing_metadata.Set(GrpcAcceptEncodingMetadata(),
213 enabled_compression_algorithms());
214 if (algorithm != GRPC_COMPRESS_NONE) {
215 outgoing_metadata.Set(GrpcEncodingMetadata(), algorithm);
216 }
217 return algorithm;
218 }
219
HandleIncomingMetadata(const grpc_metadata_batch & incoming_metadata)220 ChannelCompression::DecompressArgs ChannelCompression::HandleIncomingMetadata(
221 const grpc_metadata_batch& incoming_metadata) {
222 // Configure max receive size.
223 auto max_recv_message_length = max_recv_size_;
224 const MessageSizeParsedConfig* limits =
225 MessageSizeParsedConfig::GetFromCallContext(
226 GetContext<Arena>(), message_size_service_config_parser_index_);
227 if (limits != nullptr && limits->max_recv_size().has_value() &&
228 (!max_recv_message_length.has_value() ||
229 *limits->max_recv_size() < *max_recv_message_length)) {
230 max_recv_message_length = limits->max_recv_size();
231 }
232 return DecompressArgs{incoming_metadata.get(GrpcEncodingMetadata())
233 .value_or(GRPC_COMPRESS_NONE),
234 max_recv_message_length};
235 }
236
OnClientInitialMetadata(ClientMetadata & md,ClientCompressionFilter * filter)237 void ClientCompressionFilter::Call::OnClientInitialMetadata(
238 ClientMetadata& md, ClientCompressionFilter* filter) {
239 GRPC_LATENT_SEE_INNER_SCOPE(
240 "ClientCompressionFilter::Call::OnClientInitialMetadata");
241 compression_algorithm_ =
242 filter->compression_engine_.HandleOutgoingMetadata(md);
243 }
244
OnClientToServerMessage(MessageHandle message,ClientCompressionFilter * filter)245 MessageHandle ClientCompressionFilter::Call::OnClientToServerMessage(
246 MessageHandle message, ClientCompressionFilter* filter) {
247 GRPC_LATENT_SEE_INNER_SCOPE(
248 "ClientCompressionFilter::Call::OnClientToServerMessage");
249 return filter->compression_engine_.CompressMessage(std::move(message),
250 compression_algorithm_);
251 }
252
OnServerInitialMetadata(ServerMetadata & md,ClientCompressionFilter * filter)253 void ClientCompressionFilter::Call::OnServerInitialMetadata(
254 ServerMetadata& md, ClientCompressionFilter* filter) {
255 GRPC_LATENT_SEE_INNER_SCOPE(
256 "ClientCompressionFilter::Call::OnServerInitialMetadata");
257 decompress_args_ = filter->compression_engine_.HandleIncomingMetadata(md);
258 }
259
260 absl::StatusOr<MessageHandle>
OnServerToClientMessage(MessageHandle message,ClientCompressionFilter * filter)261 ClientCompressionFilter::Call::OnServerToClientMessage(
262 MessageHandle message, ClientCompressionFilter* filter) {
263 GRPC_LATENT_SEE_INNER_SCOPE(
264 "ClientCompressionFilter::Call::OnServerToClientMessage");
265 return filter->compression_engine_.DecompressMessage(
266 /*is_client=*/true, std::move(message), decompress_args_);
267 }
268
OnClientInitialMetadata(ClientMetadata & md,ServerCompressionFilter * filter)269 void ServerCompressionFilter::Call::OnClientInitialMetadata(
270 ClientMetadata& md, ServerCompressionFilter* filter) {
271 GRPC_LATENT_SEE_INNER_SCOPE(
272 "ServerCompressionFilter::Call::OnClientInitialMetadata");
273 decompress_args_ = filter->compression_engine_.HandleIncomingMetadata(md);
274 }
275
276 absl::StatusOr<MessageHandle>
OnClientToServerMessage(MessageHandle message,ServerCompressionFilter * filter)277 ServerCompressionFilter::Call::OnClientToServerMessage(
278 MessageHandle message, ServerCompressionFilter* filter) {
279 GRPC_LATENT_SEE_INNER_SCOPE(
280 "ServerCompressionFilter::Call::OnClientToServerMessage");
281 return filter->compression_engine_.DecompressMessage(
282 /*is_client=*/false, std::move(message), decompress_args_);
283 }
284
OnServerInitialMetadata(ServerMetadata & md,ServerCompressionFilter * filter)285 void ServerCompressionFilter::Call::OnServerInitialMetadata(
286 ServerMetadata& md, ServerCompressionFilter* filter) {
287 GRPC_LATENT_SEE_INNER_SCOPE(
288 "ServerCompressionFilter::Call::OnServerInitialMetadata");
289 compression_algorithm_ =
290 filter->compression_engine_.HandleOutgoingMetadata(md);
291 }
292
OnServerToClientMessage(MessageHandle message,ServerCompressionFilter * filter)293 MessageHandle ServerCompressionFilter::Call::OnServerToClientMessage(
294 MessageHandle message, ServerCompressionFilter* filter) {
295 GRPC_LATENT_SEE_INNER_SCOPE(
296 "ServerCompressionFilter::Call::OnServerToClientMessage");
297 return filter->compression_engine_.CompressMessage(std::move(message),
298 compression_algorithm_);
299 }
300
301 } // namespace grpc_core
302