1 /* 2 * Copyright 2018 The gRPC Authors 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package io.grpc.alts.internal; 18 19 import static com.google.common.base.Preconditions.checkState; 20 21 import io.netty.channel.Channel; 22 import io.netty.channel.ChannelPromise; 23 import io.netty.channel.DefaultChannelPromise; 24 import io.netty.util.concurrent.EventExecutor; 25 import java.util.ArrayList; 26 import java.util.List; 27 28 /** 29 * Promise used when flushing the {@code pendingUnprotectedWrites} queue. It manages the many-to 30 * many relationship between pending unprotected messages and the individual writes. Each protected 31 * frame will be written using the same instance of this promise and it will accumulate the results. 32 * Once all frames have been successfully written (or any failed), all of the promises for the 33 * pending unprotected writes are notified. 34 * 35 * <p>NOTE: this code is based on code in Netty's {@code Http2CodecUtil}. 36 */ 37 final class ProtectedPromise extends DefaultChannelPromise { 38 private final List<ChannelPromise> unprotectedPromises; 39 private int expectedCount; 40 private int successfulCount; 41 private int failureCount; 42 private boolean doneAllocating; 43 ProtectedPromise(Channel channel, EventExecutor executor, int numUnprotectedPromises)44 ProtectedPromise(Channel channel, EventExecutor executor, int numUnprotectedPromises) { 45 super(channel, executor); 46 unprotectedPromises = new ArrayList<>(numUnprotectedPromises); 47 } 48 49 /** 50 * Adds a promise for a pending unprotected write. This will be notified after all of the writes 51 * complete. 52 */ addUnprotectedPromise(ChannelPromise promise)53 void addUnprotectedPromise(ChannelPromise promise) { 54 unprotectedPromises.add(promise); 55 } 56 57 /** 58 * Allocate a new promise for the write of a protected frame. This will be used to aggregate the 59 * overall success of the unprotected promises. 60 * 61 * @return {@code this} promise. 62 */ newPromise()63 ChannelPromise newPromise() { 64 checkState(!doneAllocating, "Done allocating. No more promises can be allocated."); 65 expectedCount++; 66 return this; 67 } 68 69 /** 70 * Signify that no more {@link #newPromise()} allocations will be made. The aggregation can not be 71 * successful until this method is called. 72 * 73 * @return {@code this} promise. 74 */ doneAllocatingPromises()75 ChannelPromise doneAllocatingPromises() { 76 if (!doneAllocating) { 77 doneAllocating = true; 78 if (successfulCount == expectedCount) { 79 trySuccessInternal(null); 80 return super.setSuccess(null); 81 } 82 } 83 return this; 84 } 85 86 @Override tryFailure(Throwable cause)87 public boolean tryFailure(Throwable cause) { 88 if (awaitingPromises()) { 89 ++failureCount; 90 if (failureCount == 1) { 91 tryFailureInternal(cause); 92 return super.tryFailure(cause); 93 } 94 // TODO: We break the interface a bit here. 95 // Multiple failure events can be processed without issue because this is an aggregation. 96 return true; 97 } 98 return false; 99 } 100 101 /** 102 * Fail this object if it has not already been failed. 103 * 104 * <p>This method will NOT throw an {@link IllegalStateException} if called multiple times because 105 * that may be expected. 106 */ 107 @Override setFailure(Throwable cause)108 public ChannelPromise setFailure(Throwable cause) { 109 tryFailure(cause); 110 return this; 111 } 112 awaitingPromises()113 private boolean awaitingPromises() { 114 return successfulCount + failureCount < expectedCount; 115 } 116 117 @Override setSuccess(Void result)118 public ChannelPromise setSuccess(Void result) { 119 trySuccess(result); 120 return this; 121 } 122 123 @Override trySuccess(Void result)124 public boolean trySuccess(Void result) { 125 if (awaitingPromises()) { 126 ++successfulCount; 127 if (successfulCount == expectedCount && doneAllocating) { 128 trySuccessInternal(result); 129 return super.trySuccess(result); 130 } 131 // TODO: We break the interface a bit here. 132 // Multiple success events can be processed without issue because this is an aggregation. 133 return true; 134 } 135 return false; 136 } 137 trySuccessInternal(Void result)138 private void trySuccessInternal(Void result) { 139 for (int i = 0; i < unprotectedPromises.size(); ++i) { 140 unprotectedPromises.get(i).trySuccess(result); 141 } 142 } 143 tryFailureInternal(Throwable cause)144 private void tryFailureInternal(Throwable cause) { 145 for (int i = 0; i < unprotectedPromises.size(); ++i) { 146 unprotectedPromises.get(i).tryFailure(cause); 147 } 148 } 149 } 150