1 /* 2 * Copyright 2021 The gRPC Authors 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package io.grpc.xds; 18 19 import static com.google.common.base.Preconditions.checkNotNull; 20 import static java.util.concurrent.TimeUnit.NANOSECONDS; 21 22 import com.google.common.annotations.VisibleForTesting; 23 import com.google.common.base.Supplier; 24 import com.google.common.base.Suppliers; 25 import com.google.common.util.concurrent.MoreExecutors; 26 import com.google.protobuf.Any; 27 import com.google.protobuf.InvalidProtocolBufferException; 28 import com.google.protobuf.Message; 29 import com.google.protobuf.util.Durations; 30 import io.envoyproxy.envoy.extensions.filters.http.fault.v3.HTTPFault; 31 import io.envoyproxy.envoy.type.v3.FractionalPercent; 32 import io.grpc.CallOptions; 33 import io.grpc.Channel; 34 import io.grpc.ClientCall; 35 import io.grpc.ClientInterceptor; 36 import io.grpc.Context; 37 import io.grpc.Deadline; 38 import io.grpc.ForwardingClientCall; 39 import io.grpc.ForwardingClientCallListener.SimpleForwardingClientCallListener; 40 import io.grpc.LoadBalancer.PickSubchannelArgs; 41 import io.grpc.Metadata; 42 import io.grpc.MethodDescriptor; 43 import io.grpc.Status; 44 import io.grpc.Status.Code; 45 import io.grpc.internal.DelayedClientCall; 46 import io.grpc.internal.GrpcUtil; 47 import io.grpc.xds.FaultConfig.FaultAbort; 48 import io.grpc.xds.FaultConfig.FaultDelay; 49 import io.grpc.xds.Filter.ClientInterceptorBuilder; 50 import io.grpc.xds.ThreadSafeRandom.ThreadSafeRandomImpl; 51 import java.util.Locale; 52 import java.util.concurrent.Executor; 53 import java.util.concurrent.ScheduledExecutorService; 54 import java.util.concurrent.ScheduledFuture; 55 import java.util.concurrent.TimeUnit; 56 import java.util.concurrent.atomic.AtomicLong; 57 import javax.annotation.Nullable; 58 59 /** HttpFault filter implementation. */ 60 final class FaultFilter implements Filter, ClientInterceptorBuilder { 61 62 static final FaultFilter INSTANCE = 63 new FaultFilter(ThreadSafeRandomImpl.instance, new AtomicLong()); 64 @VisibleForTesting 65 static final Metadata.Key<String> HEADER_DELAY_KEY = 66 Metadata.Key.of("x-envoy-fault-delay-request", Metadata.ASCII_STRING_MARSHALLER); 67 @VisibleForTesting 68 static final Metadata.Key<String> HEADER_DELAY_PERCENTAGE_KEY = 69 Metadata.Key.of("x-envoy-fault-delay-request-percentage", Metadata.ASCII_STRING_MARSHALLER); 70 @VisibleForTesting 71 static final Metadata.Key<String> HEADER_ABORT_HTTP_STATUS_KEY = 72 Metadata.Key.of("x-envoy-fault-abort-request", Metadata.ASCII_STRING_MARSHALLER); 73 @VisibleForTesting 74 static final Metadata.Key<String> HEADER_ABORT_GRPC_STATUS_KEY = 75 Metadata.Key.of("x-envoy-fault-abort-grpc-request", Metadata.ASCII_STRING_MARSHALLER); 76 @VisibleForTesting 77 static final Metadata.Key<String> HEADER_ABORT_PERCENTAGE_KEY = 78 Metadata.Key.of("x-envoy-fault-abort-request-percentage", Metadata.ASCII_STRING_MARSHALLER); 79 static final String TYPE_URL = 80 "type.googleapis.com/envoy.extensions.filters.http.fault.v3.HTTPFault"; 81 82 private final ThreadSafeRandom random; 83 private final AtomicLong activeFaultCounter; 84 85 @VisibleForTesting FaultFilter(ThreadSafeRandom random, AtomicLong activeFaultCounter)86 FaultFilter(ThreadSafeRandom random, AtomicLong activeFaultCounter) { 87 this.random = random; 88 this.activeFaultCounter = activeFaultCounter; 89 } 90 91 @Override typeUrls()92 public String[] typeUrls() { 93 return new String[] { TYPE_URL }; 94 } 95 96 @Override parseFilterConfig(Message rawProtoMessage)97 public ConfigOrError<FaultConfig> parseFilterConfig(Message rawProtoMessage) { 98 HTTPFault httpFaultProto; 99 if (!(rawProtoMessage instanceof Any)) { 100 return ConfigOrError.fromError("Invalid config type: " + rawProtoMessage.getClass()); 101 } 102 Any anyMessage = (Any) rawProtoMessage; 103 try { 104 httpFaultProto = anyMessage.unpack(HTTPFault.class); 105 } catch (InvalidProtocolBufferException e) { 106 return ConfigOrError.fromError("Invalid proto: " + e); 107 } 108 return parseHttpFault(httpFaultProto); 109 } 110 parseHttpFault(HTTPFault httpFault)111 private static ConfigOrError<FaultConfig> parseHttpFault(HTTPFault httpFault) { 112 FaultDelay faultDelay = null; 113 FaultAbort faultAbort = null; 114 if (httpFault.hasDelay()) { 115 faultDelay = parseFaultDelay(httpFault.getDelay()); 116 } 117 if (httpFault.hasAbort()) { 118 ConfigOrError<FaultAbort> faultAbortOrError = parseFaultAbort(httpFault.getAbort()); 119 if (faultAbortOrError.errorDetail != null) { 120 return ConfigOrError.fromError( 121 "HttpFault contains invalid FaultAbort: " + faultAbortOrError.errorDetail); 122 } 123 faultAbort = faultAbortOrError.config; 124 } 125 Integer maxActiveFaults = null; 126 if (httpFault.hasMaxActiveFaults()) { 127 maxActiveFaults = httpFault.getMaxActiveFaults().getValue(); 128 if (maxActiveFaults < 0) { 129 maxActiveFaults = Integer.MAX_VALUE; 130 } 131 } 132 return ConfigOrError.fromConfig(FaultConfig.create(faultDelay, faultAbort, maxActiveFaults)); 133 } 134 parseFaultDelay( io.envoyproxy.envoy.extensions.filters.common.fault.v3.FaultDelay faultDelay)135 private static FaultDelay parseFaultDelay( 136 io.envoyproxy.envoy.extensions.filters.common.fault.v3.FaultDelay faultDelay) { 137 FaultConfig.FractionalPercent percent = parsePercent(faultDelay.getPercentage()); 138 if (faultDelay.hasHeaderDelay()) { 139 return FaultDelay.forHeader(percent); 140 } 141 return FaultDelay.forFixedDelay(Durations.toNanos(faultDelay.getFixedDelay()), percent); 142 } 143 144 @VisibleForTesting parseFaultAbort( io.envoyproxy.envoy.extensions.filters.http.fault.v3.FaultAbort faultAbort)145 static ConfigOrError<FaultAbort> parseFaultAbort( 146 io.envoyproxy.envoy.extensions.filters.http.fault.v3.FaultAbort faultAbort) { 147 FaultConfig.FractionalPercent percent = parsePercent(faultAbort.getPercentage()); 148 switch (faultAbort.getErrorTypeCase()) { 149 case HEADER_ABORT: 150 return ConfigOrError.fromConfig(FaultAbort.forHeader(percent)); 151 case HTTP_STATUS: 152 return ConfigOrError.fromConfig(FaultAbort.forStatus( 153 GrpcUtil.httpStatusToGrpcStatus(faultAbort.getHttpStatus()), percent)); 154 case GRPC_STATUS: 155 return ConfigOrError.fromConfig(FaultAbort.forStatus( 156 Status.fromCodeValue(faultAbort.getGrpcStatus()), percent)); 157 case ERRORTYPE_NOT_SET: 158 default: 159 return ConfigOrError.fromError( 160 "Unknown error type case: " + faultAbort.getErrorTypeCase()); 161 } 162 } 163 parsePercent(FractionalPercent proto)164 private static FaultConfig.FractionalPercent parsePercent(FractionalPercent proto) { 165 switch (proto.getDenominator()) { 166 case HUNDRED: 167 return FaultConfig.FractionalPercent.perHundred(proto.getNumerator()); 168 case TEN_THOUSAND: 169 return FaultConfig.FractionalPercent.perTenThousand(proto.getNumerator()); 170 case MILLION: 171 return FaultConfig.FractionalPercent.perMillion(proto.getNumerator()); 172 case UNRECOGNIZED: 173 default: 174 throw new IllegalArgumentException("Unknown denominator type: " + proto.getDenominator()); 175 } 176 } 177 178 @Override parseFilterConfigOverride(Message rawProtoMessage)179 public ConfigOrError<FaultConfig> parseFilterConfigOverride(Message rawProtoMessage) { 180 return parseFilterConfig(rawProtoMessage); 181 } 182 183 @Nullable 184 @Override buildClientInterceptor( FilterConfig config, @Nullable FilterConfig overrideConfig, PickSubchannelArgs args, final ScheduledExecutorService scheduler)185 public ClientInterceptor buildClientInterceptor( 186 FilterConfig config, @Nullable FilterConfig overrideConfig, PickSubchannelArgs args, 187 final ScheduledExecutorService scheduler) { 188 checkNotNull(config, "config"); 189 if (overrideConfig != null) { 190 config = overrideConfig; 191 } 192 FaultConfig faultConfig = (FaultConfig) config; 193 Long delayNanos = null; 194 Status abortStatus = null; 195 if (faultConfig.maxActiveFaults() == null 196 || activeFaultCounter.get() < faultConfig.maxActiveFaults()) { 197 Metadata headers = args.getHeaders(); 198 if (faultConfig.faultDelay() != null) { 199 delayNanos = determineFaultDelayNanos(faultConfig.faultDelay(), headers); 200 } 201 if (faultConfig.faultAbort() != null) { 202 abortStatus = determineFaultAbortStatus(faultConfig.faultAbort(), headers); 203 } 204 } 205 if (delayNanos == null && abortStatus == null) { 206 return null; 207 } 208 final Long finalDelayNanos = delayNanos; 209 final Status finalAbortStatus = getAbortStatusWithDescription(abortStatus); 210 211 final class FaultInjectionInterceptor implements ClientInterceptor { 212 @Override 213 public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall( 214 final MethodDescriptor<ReqT, RespT> method, final CallOptions callOptions, 215 final Channel next) { 216 Executor callExecutor = callOptions.getExecutor(); 217 if (callExecutor == null) { // This should never happen in practice because 218 // ManagedChannelImpl.ConfigSelectingClientCall always provides CallOptions with 219 // a callExecutor. 220 // TODO(https://github.com/grpc/grpc-java/issues/7868) 221 callExecutor = MoreExecutors.directExecutor(); 222 } 223 if (finalDelayNanos != null) { 224 Supplier<? extends ClientCall<ReqT, RespT>> callSupplier; 225 if (finalAbortStatus != null) { 226 callSupplier = Suppliers.ofInstance( 227 new FailingClientCall<ReqT, RespT>(finalAbortStatus, callExecutor)); 228 } else { 229 callSupplier = new Supplier<ClientCall<ReqT, RespT>>() { 230 @Override 231 public ClientCall<ReqT, RespT> get() { 232 return next.newCall(method, callOptions); 233 } 234 }; 235 } 236 final DelayInjectedCall<ReqT, RespT> delayInjectedCall = new DelayInjectedCall<>( 237 finalDelayNanos, callExecutor, scheduler, callOptions.getDeadline(), callSupplier); 238 239 final class DeadlineInsightForwardingCall extends ForwardingClientCall<ReqT, RespT> { 240 @Override 241 protected ClientCall<ReqT, RespT> delegate() { 242 return delayInjectedCall; 243 } 244 245 @Override 246 public void start(Listener<RespT> listener, Metadata headers) { 247 Listener<RespT> finalListener = 248 new SimpleForwardingClientCallListener<RespT>(listener) { 249 @Override 250 public void onClose(Status status, Metadata trailers) { 251 if (status.getCode().equals(Code.DEADLINE_EXCEEDED)) { 252 // TODO(zdapeng:) check effective deadline locally, and 253 // do the following only if the local deadline is exceeded. 254 // (If the server sends DEADLINE_EXCEEDED for its own deadline, then the 255 // injected delay does not contribute to the error, because the request is 256 // only sent out after the delay. There could be a race between local and 257 // remote, but it is rather rare.) 258 String description = String.format( 259 Locale.US, 260 "Deadline exceeded after up to %d ns of fault-injected delay", 261 finalDelayNanos); 262 if (status.getDescription() != null) { 263 description = description + ": " + status.getDescription(); 264 } 265 status = Status.DEADLINE_EXCEEDED 266 .withDescription(description).withCause(status.getCause()); 267 // Replace trailers to prevent mixing sources of status and trailers. 268 trailers = new Metadata(); 269 } 270 delegate().onClose(status, trailers); 271 } 272 }; 273 delegate().start(finalListener, headers); 274 } 275 } 276 277 return new DeadlineInsightForwardingCall(); 278 } else { 279 return new FailingClientCall<>(finalAbortStatus, callExecutor); 280 } 281 } 282 } 283 284 return new FaultInjectionInterceptor(); 285 } 286 getAbortStatusWithDescription(Status abortStatus)287 private static Status getAbortStatusWithDescription(Status abortStatus) { 288 Status finalAbortStatus = null; 289 if (abortStatus != null) { 290 String abortDesc = "RPC terminated due to fault injection"; 291 if (abortStatus.getDescription() != null) { 292 abortDesc = abortDesc + ": " + abortStatus.getDescription(); 293 } 294 finalAbortStatus = abortStatus.withDescription(abortDesc); 295 } 296 return finalAbortStatus; 297 } 298 299 @Nullable determineFaultDelayNanos(FaultDelay faultDelay, Metadata headers)300 private Long determineFaultDelayNanos(FaultDelay faultDelay, Metadata headers) { 301 Long delayNanos; 302 FaultConfig.FractionalPercent fractionalPercent = faultDelay.percent(); 303 if (faultDelay.headerDelay()) { 304 try { 305 int delayMillis = Integer.parseInt(headers.get(HEADER_DELAY_KEY)); 306 delayNanos = TimeUnit.MILLISECONDS.toNanos(delayMillis); 307 String delayPercentageStr = headers.get(HEADER_DELAY_PERCENTAGE_KEY); 308 if (delayPercentageStr != null) { 309 int delayPercentage = Integer.parseInt(delayPercentageStr); 310 if (delayPercentage >= 0 && delayPercentage < fractionalPercent.numerator()) { 311 fractionalPercent = FaultConfig.FractionalPercent.create( 312 delayPercentage, fractionalPercent.denominatorType()); 313 } 314 } 315 } catch (NumberFormatException e) { 316 return null; // treated as header_delay not applicable 317 } 318 } else { 319 delayNanos = faultDelay.delayNanos(); 320 } 321 if (random.nextInt(1_000_000) >= getRatePerMillion(fractionalPercent)) { 322 return null; 323 } 324 return delayNanos; 325 } 326 327 @Nullable determineFaultAbortStatus(FaultAbort faultAbort, Metadata headers)328 private Status determineFaultAbortStatus(FaultAbort faultAbort, Metadata headers) { 329 Status abortStatus = null; 330 FaultConfig.FractionalPercent fractionalPercent = faultAbort.percent(); 331 if (faultAbort.headerAbort()) { 332 try { 333 String grpcCodeStr = headers.get(HEADER_ABORT_GRPC_STATUS_KEY); 334 if (grpcCodeStr != null) { 335 int grpcCode = Integer.parseInt(grpcCodeStr); 336 abortStatus = Status.fromCodeValue(grpcCode); 337 } 338 String httpCodeStr = headers.get(HEADER_ABORT_HTTP_STATUS_KEY); 339 if (httpCodeStr != null) { 340 int httpCode = Integer.parseInt(httpCodeStr); 341 abortStatus = GrpcUtil.httpStatusToGrpcStatus(httpCode); 342 } 343 String abortPercentageStr = headers.get(HEADER_ABORT_PERCENTAGE_KEY); 344 if (abortPercentageStr != null) { 345 int abortPercentage = 346 Integer.parseInt(headers.get(HEADER_ABORT_PERCENTAGE_KEY)); 347 if (abortPercentage >= 0 && abortPercentage < fractionalPercent.numerator()) { 348 fractionalPercent = FaultConfig.FractionalPercent.create( 349 abortPercentage, fractionalPercent.denominatorType()); 350 } 351 } 352 } catch (NumberFormatException e) { 353 return null; // treated as header_abort not applicable 354 } 355 } else { 356 abortStatus = faultAbort.status(); 357 } 358 if (random.nextInt(1_000_000) >= getRatePerMillion(fractionalPercent)) { 359 return null; 360 } 361 return abortStatus; 362 } 363 getRatePerMillion(FaultConfig.FractionalPercent percent)364 private static int getRatePerMillion(FaultConfig.FractionalPercent percent) { 365 int numerator = percent.numerator(); 366 FaultConfig.FractionalPercent.DenominatorType type = percent.denominatorType(); 367 switch (type) { 368 case TEN_THOUSAND: 369 numerator *= 100; 370 break; 371 case HUNDRED: 372 numerator *= 10_000; 373 break; 374 case MILLION: 375 default: 376 break; 377 } 378 if (numerator > 1_000_000 || numerator < 0) { 379 numerator = 1_000_000; 380 } 381 return numerator; 382 } 383 384 /** A {@link DelayedClientCall} with a fixed delay. */ 385 private final class DelayInjectedCall<ReqT, RespT> extends DelayedClientCall<ReqT, RespT> { 386 final Object lock = new Object(); 387 ScheduledFuture<?> delayTask; 388 boolean cancelled; 389 DelayInjectedCall( long delayNanos, Executor callExecutor, ScheduledExecutorService scheduler, @Nullable Deadline deadline, final Supplier<? extends ClientCall<ReqT, RespT>> callSupplier)390 DelayInjectedCall( 391 long delayNanos, Executor callExecutor, ScheduledExecutorService scheduler, 392 @Nullable Deadline deadline, 393 final Supplier<? extends ClientCall<ReqT, RespT>> callSupplier) { 394 super(callExecutor, scheduler, deadline); 395 activeFaultCounter.incrementAndGet(); 396 ScheduledFuture<?> task = scheduler.schedule( 397 new Runnable() { 398 @Override 399 public void run() { 400 synchronized (lock) { 401 if (!cancelled) { 402 activeFaultCounter.decrementAndGet(); 403 } 404 } 405 Runnable toRun = setCall(callSupplier.get()); 406 if (toRun != null) { 407 toRun.run(); 408 } 409 } 410 }, 411 delayNanos, 412 NANOSECONDS); 413 synchronized (lock) { 414 if (!cancelled) { 415 delayTask = task; 416 return; 417 } 418 } 419 task.cancel(false); 420 } 421 422 @Override callCancelled()423 protected void callCancelled() { 424 ScheduledFuture<?> savedDelayTask; 425 synchronized (lock) { 426 cancelled = true; 427 activeFaultCounter.decrementAndGet(); 428 savedDelayTask = delayTask; 429 } 430 if (savedDelayTask != null) { 431 savedDelayTask.cancel(false); 432 } 433 } 434 } 435 436 /** An implementation of {@link ClientCall} that fails when started. */ 437 private final class FailingClientCall<ReqT, RespT> extends ClientCall<ReqT, RespT> { 438 final Status error; 439 final Executor callExecutor; 440 final Context context; 441 FailingClientCall(Status error, Executor callExecutor)442 FailingClientCall(Status error, Executor callExecutor) { 443 this.error = error; 444 this.callExecutor = callExecutor; 445 this.context = Context.current(); 446 } 447 448 @Override start(final ClientCall.Listener<RespT> listener, Metadata headers)449 public void start(final ClientCall.Listener<RespT> listener, Metadata headers) { 450 activeFaultCounter.incrementAndGet(); 451 callExecutor.execute( 452 new Runnable() { 453 @Override 454 public void run() { 455 Context previous = context.attach(); 456 try { 457 listener.onClose(error, new Metadata()); 458 activeFaultCounter.decrementAndGet(); 459 } finally { 460 context.detach(previous); 461 } 462 } 463 }); 464 } 465 466 @Override request(int numMessages)467 public void request(int numMessages) {} 468 469 @Override cancel(String message, Throwable cause)470 public void cancel(String message, Throwable cause) {} 471 472 @Override halfClose()473 public void halfClose() {} 474 475 @Override sendMessage(ReqT message)476 public void sendMessage(ReqT message) {} 477 } 478 } 479