• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2021 The gRPC Authors
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package io.grpc.xds;
18 
19 import static com.google.common.base.Preconditions.checkNotNull;
20 import static java.util.concurrent.TimeUnit.NANOSECONDS;
21 
22 import com.google.common.annotations.VisibleForTesting;
23 import com.google.common.base.Supplier;
24 import com.google.common.base.Suppliers;
25 import com.google.common.util.concurrent.MoreExecutors;
26 import com.google.protobuf.Any;
27 import com.google.protobuf.InvalidProtocolBufferException;
28 import com.google.protobuf.Message;
29 import com.google.protobuf.util.Durations;
30 import io.envoyproxy.envoy.extensions.filters.http.fault.v3.HTTPFault;
31 import io.envoyproxy.envoy.type.v3.FractionalPercent;
32 import io.grpc.CallOptions;
33 import io.grpc.Channel;
34 import io.grpc.ClientCall;
35 import io.grpc.ClientInterceptor;
36 import io.grpc.Context;
37 import io.grpc.Deadline;
38 import io.grpc.ForwardingClientCall;
39 import io.grpc.ForwardingClientCallListener.SimpleForwardingClientCallListener;
40 import io.grpc.LoadBalancer.PickSubchannelArgs;
41 import io.grpc.Metadata;
42 import io.grpc.MethodDescriptor;
43 import io.grpc.Status;
44 import io.grpc.Status.Code;
45 import io.grpc.internal.DelayedClientCall;
46 import io.grpc.internal.GrpcUtil;
47 import io.grpc.xds.FaultConfig.FaultAbort;
48 import io.grpc.xds.FaultConfig.FaultDelay;
49 import io.grpc.xds.Filter.ClientInterceptorBuilder;
50 import io.grpc.xds.ThreadSafeRandom.ThreadSafeRandomImpl;
51 import java.util.Locale;
52 import java.util.concurrent.Executor;
53 import java.util.concurrent.ScheduledExecutorService;
54 import java.util.concurrent.ScheduledFuture;
55 import java.util.concurrent.TimeUnit;
56 import java.util.concurrent.atomic.AtomicLong;
57 import javax.annotation.Nullable;
58 
59 /** HttpFault filter implementation. */
60 final class FaultFilter implements Filter, ClientInterceptorBuilder {
61 
62   static final FaultFilter INSTANCE =
63       new FaultFilter(ThreadSafeRandomImpl.instance, new AtomicLong());
64   @VisibleForTesting
65   static final Metadata.Key<String> HEADER_DELAY_KEY =
66       Metadata.Key.of("x-envoy-fault-delay-request", Metadata.ASCII_STRING_MARSHALLER);
67   @VisibleForTesting
68   static final Metadata.Key<String> HEADER_DELAY_PERCENTAGE_KEY =
69       Metadata.Key.of("x-envoy-fault-delay-request-percentage", Metadata.ASCII_STRING_MARSHALLER);
70   @VisibleForTesting
71   static final Metadata.Key<String> HEADER_ABORT_HTTP_STATUS_KEY =
72       Metadata.Key.of("x-envoy-fault-abort-request", Metadata.ASCII_STRING_MARSHALLER);
73   @VisibleForTesting
74   static final Metadata.Key<String> HEADER_ABORT_GRPC_STATUS_KEY =
75       Metadata.Key.of("x-envoy-fault-abort-grpc-request", Metadata.ASCII_STRING_MARSHALLER);
76   @VisibleForTesting
77   static final Metadata.Key<String> HEADER_ABORT_PERCENTAGE_KEY =
78       Metadata.Key.of("x-envoy-fault-abort-request-percentage", Metadata.ASCII_STRING_MARSHALLER);
79   static final String TYPE_URL =
80       "type.googleapis.com/envoy.extensions.filters.http.fault.v3.HTTPFault";
81 
82   private final ThreadSafeRandom random;
83   private final AtomicLong activeFaultCounter;
84 
85   @VisibleForTesting
FaultFilter(ThreadSafeRandom random, AtomicLong activeFaultCounter)86   FaultFilter(ThreadSafeRandom random, AtomicLong activeFaultCounter) {
87     this.random = random;
88     this.activeFaultCounter = activeFaultCounter;
89   }
90 
91   @Override
typeUrls()92   public String[] typeUrls() {
93     return new String[] { TYPE_URL };
94   }
95 
96   @Override
parseFilterConfig(Message rawProtoMessage)97   public ConfigOrError<FaultConfig> parseFilterConfig(Message rawProtoMessage) {
98     HTTPFault httpFaultProto;
99     if (!(rawProtoMessage instanceof Any)) {
100       return ConfigOrError.fromError("Invalid config type: " + rawProtoMessage.getClass());
101     }
102     Any anyMessage = (Any) rawProtoMessage;
103     try {
104       httpFaultProto = anyMessage.unpack(HTTPFault.class);
105     } catch (InvalidProtocolBufferException e) {
106       return ConfigOrError.fromError("Invalid proto: " + e);
107     }
108     return parseHttpFault(httpFaultProto);
109   }
110 
parseHttpFault(HTTPFault httpFault)111   private static ConfigOrError<FaultConfig> parseHttpFault(HTTPFault httpFault) {
112     FaultDelay faultDelay = null;
113     FaultAbort faultAbort = null;
114     if (httpFault.hasDelay()) {
115       faultDelay = parseFaultDelay(httpFault.getDelay());
116     }
117     if (httpFault.hasAbort()) {
118       ConfigOrError<FaultAbort> faultAbortOrError = parseFaultAbort(httpFault.getAbort());
119       if (faultAbortOrError.errorDetail != null) {
120         return ConfigOrError.fromError(
121             "HttpFault contains invalid FaultAbort: " + faultAbortOrError.errorDetail);
122       }
123       faultAbort = faultAbortOrError.config;
124     }
125     Integer maxActiveFaults = null;
126     if (httpFault.hasMaxActiveFaults()) {
127       maxActiveFaults = httpFault.getMaxActiveFaults().getValue();
128       if (maxActiveFaults < 0) {
129         maxActiveFaults = Integer.MAX_VALUE;
130       }
131     }
132     return ConfigOrError.fromConfig(FaultConfig.create(faultDelay, faultAbort, maxActiveFaults));
133   }
134 
parseFaultDelay( io.envoyproxy.envoy.extensions.filters.common.fault.v3.FaultDelay faultDelay)135   private static FaultDelay parseFaultDelay(
136       io.envoyproxy.envoy.extensions.filters.common.fault.v3.FaultDelay faultDelay) {
137     FaultConfig.FractionalPercent percent = parsePercent(faultDelay.getPercentage());
138     if (faultDelay.hasHeaderDelay()) {
139       return FaultDelay.forHeader(percent);
140     }
141     return FaultDelay.forFixedDelay(Durations.toNanos(faultDelay.getFixedDelay()), percent);
142   }
143 
144   @VisibleForTesting
parseFaultAbort( io.envoyproxy.envoy.extensions.filters.http.fault.v3.FaultAbort faultAbort)145   static ConfigOrError<FaultAbort> parseFaultAbort(
146       io.envoyproxy.envoy.extensions.filters.http.fault.v3.FaultAbort faultAbort) {
147     FaultConfig.FractionalPercent percent = parsePercent(faultAbort.getPercentage());
148     switch (faultAbort.getErrorTypeCase()) {
149       case HEADER_ABORT:
150         return ConfigOrError.fromConfig(FaultAbort.forHeader(percent));
151       case HTTP_STATUS:
152         return ConfigOrError.fromConfig(FaultAbort.forStatus(
153             GrpcUtil.httpStatusToGrpcStatus(faultAbort.getHttpStatus()), percent));
154       case GRPC_STATUS:
155         return ConfigOrError.fromConfig(FaultAbort.forStatus(
156             Status.fromCodeValue(faultAbort.getGrpcStatus()), percent));
157       case ERRORTYPE_NOT_SET:
158       default:
159         return ConfigOrError.fromError(
160             "Unknown error type case: " + faultAbort.getErrorTypeCase());
161     }
162   }
163 
parsePercent(FractionalPercent proto)164   private static FaultConfig.FractionalPercent parsePercent(FractionalPercent proto) {
165     switch (proto.getDenominator()) {
166       case HUNDRED:
167         return FaultConfig.FractionalPercent.perHundred(proto.getNumerator());
168       case TEN_THOUSAND:
169         return FaultConfig.FractionalPercent.perTenThousand(proto.getNumerator());
170       case MILLION:
171         return FaultConfig.FractionalPercent.perMillion(proto.getNumerator());
172       case UNRECOGNIZED:
173       default:
174         throw new IllegalArgumentException("Unknown denominator type: " + proto.getDenominator());
175     }
176   }
177 
178   @Override
parseFilterConfigOverride(Message rawProtoMessage)179   public ConfigOrError<FaultConfig> parseFilterConfigOverride(Message rawProtoMessage) {
180     return parseFilterConfig(rawProtoMessage);
181   }
182 
183   @Nullable
184   @Override
buildClientInterceptor( FilterConfig config, @Nullable FilterConfig overrideConfig, PickSubchannelArgs args, final ScheduledExecutorService scheduler)185   public ClientInterceptor buildClientInterceptor(
186       FilterConfig config, @Nullable FilterConfig overrideConfig, PickSubchannelArgs args,
187       final ScheduledExecutorService scheduler) {
188     checkNotNull(config, "config");
189     if (overrideConfig != null) {
190       config = overrideConfig;
191     }
192     FaultConfig faultConfig = (FaultConfig) config;
193     Long delayNanos = null;
194     Status abortStatus = null;
195     if (faultConfig.maxActiveFaults() == null
196         || activeFaultCounter.get() < faultConfig.maxActiveFaults()) {
197       Metadata headers = args.getHeaders();
198       if (faultConfig.faultDelay() != null) {
199         delayNanos = determineFaultDelayNanos(faultConfig.faultDelay(), headers);
200       }
201       if (faultConfig.faultAbort() != null) {
202         abortStatus = determineFaultAbortStatus(faultConfig.faultAbort(), headers);
203       }
204     }
205     if (delayNanos == null && abortStatus == null) {
206       return null;
207     }
208     final Long finalDelayNanos = delayNanos;
209     final Status finalAbortStatus = getAbortStatusWithDescription(abortStatus);
210 
211     final class FaultInjectionInterceptor implements ClientInterceptor {
212       @Override
213       public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(
214           final MethodDescriptor<ReqT, RespT> method, final CallOptions callOptions,
215           final Channel next) {
216         Executor callExecutor = callOptions.getExecutor();
217         if (callExecutor == null) { // This should never happen in practice because
218           // ManagedChannelImpl.ConfigSelectingClientCall always provides CallOptions with
219           // a callExecutor.
220           // TODO(https://github.com/grpc/grpc-java/issues/7868)
221           callExecutor = MoreExecutors.directExecutor();
222         }
223         if (finalDelayNanos != null) {
224           Supplier<? extends ClientCall<ReqT, RespT>> callSupplier;
225           if (finalAbortStatus != null) {
226             callSupplier = Suppliers.ofInstance(
227                 new FailingClientCall<ReqT, RespT>(finalAbortStatus, callExecutor));
228           } else {
229             callSupplier = new Supplier<ClientCall<ReqT, RespT>>() {
230               @Override
231               public ClientCall<ReqT, RespT> get() {
232                 return next.newCall(method, callOptions);
233               }
234             };
235           }
236           final DelayInjectedCall<ReqT, RespT> delayInjectedCall = new DelayInjectedCall<>(
237               finalDelayNanos, callExecutor, scheduler, callOptions.getDeadline(), callSupplier);
238 
239           final class DeadlineInsightForwardingCall extends ForwardingClientCall<ReqT, RespT> {
240             @Override
241             protected ClientCall<ReqT, RespT> delegate() {
242               return delayInjectedCall;
243             }
244 
245             @Override
246             public void start(Listener<RespT> listener, Metadata headers) {
247               Listener<RespT> finalListener =
248                   new SimpleForwardingClientCallListener<RespT>(listener) {
249                     @Override
250                     public void onClose(Status status, Metadata trailers) {
251                       if (status.getCode().equals(Code.DEADLINE_EXCEEDED)) {
252                         // TODO(zdapeng:) check effective deadline locally, and
253                         //   do the following only if the local deadline is exceeded.
254                         //   (If the server sends DEADLINE_EXCEEDED for its own deadline, then the
255                         //   injected delay does not contribute to the error, because the request is
256                         //   only sent out after the delay. There could be a race between local and
257                         //   remote, but it is rather rare.)
258                         String description = String.format(
259                             Locale.US,
260                             "Deadline exceeded after up to %d ns of fault-injected delay",
261                             finalDelayNanos);
262                         if (status.getDescription() != null) {
263                           description = description + ": " + status.getDescription();
264                         }
265                         status = Status.DEADLINE_EXCEEDED
266                             .withDescription(description).withCause(status.getCause());
267                         // Replace trailers to prevent mixing sources of status and trailers.
268                         trailers = new Metadata();
269                       }
270                       delegate().onClose(status, trailers);
271                     }
272                   };
273               delegate().start(finalListener, headers);
274             }
275           }
276 
277           return new DeadlineInsightForwardingCall();
278         } else {
279           return new FailingClientCall<>(finalAbortStatus, callExecutor);
280         }
281       }
282     }
283 
284     return new FaultInjectionInterceptor();
285   }
286 
getAbortStatusWithDescription(Status abortStatus)287   private static Status getAbortStatusWithDescription(Status abortStatus) {
288     Status finalAbortStatus = null;
289     if (abortStatus != null) {
290       String abortDesc = "RPC terminated due to fault injection";
291       if (abortStatus.getDescription() != null) {
292         abortDesc = abortDesc + ": " + abortStatus.getDescription();
293       }
294       finalAbortStatus = abortStatus.withDescription(abortDesc);
295     }
296     return finalAbortStatus;
297   }
298 
299   @Nullable
determineFaultDelayNanos(FaultDelay faultDelay, Metadata headers)300   private Long determineFaultDelayNanos(FaultDelay faultDelay, Metadata headers) {
301     Long delayNanos;
302     FaultConfig.FractionalPercent fractionalPercent = faultDelay.percent();
303     if (faultDelay.headerDelay()) {
304       try {
305         int delayMillis = Integer.parseInt(headers.get(HEADER_DELAY_KEY));
306         delayNanos = TimeUnit.MILLISECONDS.toNanos(delayMillis);
307         String delayPercentageStr = headers.get(HEADER_DELAY_PERCENTAGE_KEY);
308         if (delayPercentageStr != null) {
309           int delayPercentage = Integer.parseInt(delayPercentageStr);
310           if (delayPercentage >= 0 && delayPercentage < fractionalPercent.numerator()) {
311             fractionalPercent = FaultConfig.FractionalPercent.create(
312                 delayPercentage, fractionalPercent.denominatorType());
313           }
314         }
315       } catch (NumberFormatException e) {
316         return null; // treated as header_delay not applicable
317       }
318     } else {
319       delayNanos = faultDelay.delayNanos();
320     }
321     if (random.nextInt(1_000_000) >= getRatePerMillion(fractionalPercent)) {
322       return null;
323     }
324     return delayNanos;
325   }
326 
327   @Nullable
determineFaultAbortStatus(FaultAbort faultAbort, Metadata headers)328   private Status determineFaultAbortStatus(FaultAbort faultAbort, Metadata headers) {
329     Status abortStatus = null;
330     FaultConfig.FractionalPercent fractionalPercent = faultAbort.percent();
331     if (faultAbort.headerAbort()) {
332       try {
333         String grpcCodeStr = headers.get(HEADER_ABORT_GRPC_STATUS_KEY);
334         if (grpcCodeStr != null) {
335           int grpcCode = Integer.parseInt(grpcCodeStr);
336           abortStatus = Status.fromCodeValue(grpcCode);
337         }
338         String httpCodeStr = headers.get(HEADER_ABORT_HTTP_STATUS_KEY);
339         if (httpCodeStr != null) {
340           int httpCode = Integer.parseInt(httpCodeStr);
341           abortStatus = GrpcUtil.httpStatusToGrpcStatus(httpCode);
342         }
343         String abortPercentageStr = headers.get(HEADER_ABORT_PERCENTAGE_KEY);
344         if (abortPercentageStr != null) {
345           int abortPercentage =
346               Integer.parseInt(headers.get(HEADER_ABORT_PERCENTAGE_KEY));
347           if (abortPercentage >= 0 && abortPercentage < fractionalPercent.numerator()) {
348             fractionalPercent = FaultConfig.FractionalPercent.create(
349                 abortPercentage, fractionalPercent.denominatorType());
350           }
351         }
352       } catch (NumberFormatException e) {
353         return null; // treated as header_abort not applicable
354       }
355     } else {
356       abortStatus = faultAbort.status();
357     }
358     if (random.nextInt(1_000_000) >= getRatePerMillion(fractionalPercent)) {
359       return null;
360     }
361     return abortStatus;
362   }
363 
getRatePerMillion(FaultConfig.FractionalPercent percent)364   private static int getRatePerMillion(FaultConfig.FractionalPercent percent) {
365     int numerator = percent.numerator();
366     FaultConfig.FractionalPercent.DenominatorType type = percent.denominatorType();
367     switch (type) {
368       case TEN_THOUSAND:
369         numerator *= 100;
370         break;
371       case HUNDRED:
372         numerator *= 10_000;
373         break;
374       case MILLION:
375       default:
376         break;
377     }
378     if (numerator > 1_000_000 || numerator < 0) {
379       numerator = 1_000_000;
380     }
381     return numerator;
382   }
383 
384   /** A {@link DelayedClientCall} with a fixed delay. */
385   private final class DelayInjectedCall<ReqT, RespT> extends DelayedClientCall<ReqT, RespT> {
386     final Object lock = new Object();
387     ScheduledFuture<?> delayTask;
388     boolean cancelled;
389 
DelayInjectedCall( long delayNanos, Executor callExecutor, ScheduledExecutorService scheduler, @Nullable Deadline deadline, final Supplier<? extends ClientCall<ReqT, RespT>> callSupplier)390     DelayInjectedCall(
391         long delayNanos, Executor callExecutor, ScheduledExecutorService scheduler,
392         @Nullable Deadline deadline,
393         final Supplier<? extends ClientCall<ReqT, RespT>> callSupplier) {
394       super(callExecutor, scheduler, deadline);
395       activeFaultCounter.incrementAndGet();
396       ScheduledFuture<?> task = scheduler.schedule(
397           new Runnable() {
398             @Override
399             public void run() {
400               synchronized (lock) {
401                 if (!cancelled) {
402                   activeFaultCounter.decrementAndGet();
403                 }
404               }
405               Runnable toRun = setCall(callSupplier.get());
406               if (toRun != null) {
407                 toRun.run();
408               }
409             }
410           },
411           delayNanos,
412           NANOSECONDS);
413       synchronized (lock) {
414         if (!cancelled) {
415           delayTask = task;
416           return;
417         }
418       }
419       task.cancel(false);
420     }
421 
422     @Override
callCancelled()423     protected void callCancelled() {
424       ScheduledFuture<?> savedDelayTask;
425       synchronized (lock) {
426         cancelled = true;
427         activeFaultCounter.decrementAndGet();
428         savedDelayTask = delayTask;
429       }
430       if (savedDelayTask != null) {
431         savedDelayTask.cancel(false);
432       }
433     }
434   }
435 
436   /** An implementation of {@link ClientCall} that fails when started. */
437   private final class FailingClientCall<ReqT, RespT> extends ClientCall<ReqT, RespT> {
438     final Status error;
439     final Executor callExecutor;
440     final Context context;
441 
FailingClientCall(Status error, Executor callExecutor)442     FailingClientCall(Status error, Executor callExecutor) {
443       this.error = error;
444       this.callExecutor = callExecutor;
445       this.context = Context.current();
446     }
447 
448     @Override
start(final ClientCall.Listener<RespT> listener, Metadata headers)449     public void start(final ClientCall.Listener<RespT> listener, Metadata headers) {
450       activeFaultCounter.incrementAndGet();
451       callExecutor.execute(
452           new Runnable() {
453             @Override
454             public void run() {
455               Context previous = context.attach();
456               try {
457                 listener.onClose(error, new Metadata());
458                 activeFaultCounter.decrementAndGet();
459               } finally {
460                 context.detach(previous);
461               }
462             }
463           });
464     }
465 
466     @Override
request(int numMessages)467     public void request(int numMessages) {}
468 
469     @Override
cancel(String message, Throwable cause)470     public void cancel(String message, Throwable cause) {}
471 
472     @Override
halfClose()473     public void halfClose() {}
474 
475     @Override
sendMessage(ReqT message)476     public void sendMessage(ReqT message) {}
477   }
478 }
479