1 // Copyright 2022 The Pigweed Authors 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not 4 // use this file except in compliance with the License. You may obtain a copy of 5 // the License at 6 // 7 // https://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 // License for the specific language governing permissions and limitations under 13 // the License. 14 15 package dev.pigweed.pw_transfer; 16 17 import com.google.common.util.concurrent.ListenableFuture; 18 import dev.pigweed.pw_log.Logger; 19 import dev.pigweed.pw_rpc.Call; 20 import dev.pigweed.pw_rpc.ChannelOutputException; 21 import dev.pigweed.pw_rpc.MethodClient; 22 import dev.pigweed.pw_rpc.Status; 23 import dev.pigweed.pw_rpc.StreamObserver; 24 import java.time.Instant; 25 import java.time.temporal.ChronoUnit; 26 import java.time.temporal.TemporalUnit; 27 import java.util.Comparator; 28 import java.util.HashMap; 29 import java.util.Map; 30 import java.util.Optional; 31 import java.util.concurrent.BlockingQueue; 32 import java.util.concurrent.LinkedBlockingQueue; 33 import java.util.concurrent.Semaphore; 34 import java.util.concurrent.TimeUnit; 35 import java.util.function.BooleanSupplier; 36 import java.util.function.Consumer; 37 import javax.annotation.Nullable; 38 39 /** Manages the active transfers and dispatches events to them. */ 40 class TransferEventHandler { 41 private static final Logger logger = Logger.forClass(TransferEventHandler.class); 42 43 // Instant and BlockingQueue use different time unit types. 44 private static final TemporalUnit TIME_UNIT = ChronoUnit.MICROS; 45 private static final TimeUnit POLL_TIME_UNIT = TimeUnit.MICROSECONDS; 46 47 private final MethodClient readMethod; 48 private final MethodClient writeMethod; 49 50 private final BlockingQueue<Event> events = new LinkedBlockingQueue<>(); 51 52 // Map session ID to transfer. 53 private final Map<Integer, Transfer<?>> sessionIdToTransfer = new HashMap<>(); 54 // Legacy transfers only use the resource ID. The client assigns an arbitrary session ID that 55 // legacy servers ignore. The client then maps from the legacy ID to its local session ID. 56 private final Map<Integer, Integer> legacyIdToSessionId = new HashMap<>(); 57 58 @Nullable private Call.ClientStreaming<Chunk> readStream = null; 59 @Nullable private Call.ClientStreaming<Chunk> writeStream = null; 60 private boolean processEvents = true; 61 62 private int nextSessionId = 1; 63 TransferEventHandler(MethodClient readMethod, MethodClient writeMethod)64 TransferEventHandler(MethodClient readMethod, MethodClient writeMethod) { 65 this.readMethod = readMethod; 66 this.writeMethod = writeMethod; 67 } 68 startWriteTransferAsClient(int resourceId, ProtocolVersion desiredProtocolVersion, TransferTimeoutSettings settings, byte[] data, Consumer<TransferProgress> progressCallback, BooleanSupplier shouldAbortCallback, int initialOffset)69 ListenableFuture<Void> startWriteTransferAsClient(int resourceId, 70 ProtocolVersion desiredProtocolVersion, 71 TransferTimeoutSettings settings, 72 byte[] data, 73 Consumer<TransferProgress> progressCallback, 74 BooleanSupplier shouldAbortCallback, 75 int initialOffset) { 76 WriteTransfer transfer = new WriteTransfer( 77 resourceId, assignSessionId(), desiredProtocolVersion, new TransferInterface() { 78 @Override 79 Call.ClientStreaming<Chunk> getStream() throws ChannelOutputException { 80 if (writeStream == null) { 81 writeStream = writeMethod.invokeBidirectionalStreaming(new ChunkHandler() { 82 @Override 83 void resetStream() { 84 writeStream = null; 85 } 86 }); 87 } 88 return writeStream; 89 } 90 }, settings, data, progressCallback, shouldAbortCallback, initialOffset); 91 startTransferAsClient(transfer); 92 return transfer; 93 } 94 startReadTransferAsClient(int resourceId, ProtocolVersion desiredProtocolVersion, TransferTimeoutSettings settings, TransferParameters parameters, Consumer<TransferProgress> progressCallback, BooleanSupplier shouldAbortCallback, int initialOffset)95 ListenableFuture<byte[]> startReadTransferAsClient(int resourceId, 96 ProtocolVersion desiredProtocolVersion, 97 TransferTimeoutSettings settings, 98 TransferParameters parameters, 99 Consumer<TransferProgress> progressCallback, 100 BooleanSupplier shouldAbortCallback, 101 int initialOffset) { 102 ReadTransfer transfer = new ReadTransfer( 103 resourceId, assignSessionId(), desiredProtocolVersion, new TransferInterface() { 104 @Override 105 Call.ClientStreaming<Chunk> getStream() throws ChannelOutputException { 106 if (readStream == null) { 107 readStream = readMethod.invokeBidirectionalStreaming(new ChunkHandler() { 108 @Override 109 void resetStream() { 110 readStream = null; 111 } 112 }); 113 } 114 return readStream; 115 } 116 }, settings, parameters, progressCallback, shouldAbortCallback, initialOffset); 117 startTransferAsClient(transfer); 118 return transfer; 119 } 120 startTransferAsClient(Transfer<?> transfer)121 private void startTransferAsClient(Transfer<?> transfer) { 122 enqueueEvent(() -> { 123 if (sessionIdToTransfer.containsKey(transfer.getSessionId())) { 124 throw new AssertionError("Duplicate session ID " + transfer.getSessionId()); 125 } 126 127 if (transfer.getDesiredProtocolVersion() == ProtocolVersion.LEGACY 128 && transfer.getOffset() != 0) { 129 throw new AssertionError("Cannot start non-zero offset transfer with legacy version"); 130 } 131 132 // The v2 protocol supports multiple transfers for a single resource. For simplicity while 133 // supporting both protocols, only support a single transfer per resource. 134 if (legacyIdToSessionId.containsKey(transfer.getResourceId())) { 135 transfer.terminate( 136 new TransferError("A transfer for resource ID " + transfer.getResourceId() 137 + " is already in progress! Only one read/write transfer per resource is " 138 + "supported at a time", 139 Status.ALREADY_EXISTS)); 140 return; 141 } 142 sessionIdToTransfer.put(transfer.getSessionId(), transfer); 143 legacyIdToSessionId.put(transfer.getResourceId(), transfer.getSessionId()); 144 transfer.start(); 145 }); 146 } 147 148 /** Handles events until stop() is called. */ run()149 void run() { 150 while (processEvents) { 151 handleNextEvent(); 152 handleTimeouts(); 153 } 154 } 155 156 /** 157 * Test version of run() that processes all enqueued events before checking for timeouts. 158 * 159 * Tests that need to time out should process all enqueued events first to prevent flaky failures. 160 * If handling one of several queued packets takes longer than the timeout (which must be short 161 * for a unit test), then the test may fail spuriously. 162 * 163 * This run function is not used outside of tests because processing all incoming packets before 164 * checking for timeouts could delay the transfer client's outgoing write packets if there are 165 * lots of inbound packets. This could delay transfers and cause unnecessary timeouts. 166 */ runForTestsThatMustTimeOut()167 void runForTestsThatMustTimeOut() { 168 while (processEvents) { 169 while (!events.isEmpty()) { 170 handleNextEvent(); 171 } 172 handleTimeouts(); 173 } 174 } 175 176 /** Stops the transfer event handler from processing events. */ stop()177 void stop() { 178 enqueueEvent(() -> { 179 logger.atFine().log("Terminating TransferEventHandler"); 180 sessionIdToTransfer.values().forEach(Transfer::handleTermination); 181 processEvents = false; 182 }); 183 } 184 185 /** Blocks until all events currently in the queue are processed; for test use only. */ waitUntilEventsAreProcessedForTest()186 void waitUntilEventsAreProcessedForTest() { 187 Semaphore semaphore = new Semaphore(0); 188 enqueueEvent(semaphore::release); 189 try { 190 semaphore.acquire(); 191 } catch (InterruptedException e) { 192 throw new AssertionError("Unexpectedly interrupted", e); 193 } 194 } 195 196 /** Generates the session ID to use for the next transfer. */ assignSessionId()197 private int assignSessionId() { 198 return nextSessionId++; 199 } 200 201 /** Returns the session ID that will be used for the next transfer. */ getNextSessionIdForTest()202 final int getNextSessionIdForTest() { 203 return nextSessionId; 204 } 205 enqueueEvent(Event event)206 private void enqueueEvent(Event event) { 207 while (true) { 208 try { 209 events.put(event); 210 return; 211 } catch (InterruptedException e) { 212 // Ignore and keep trying. 213 } 214 } 215 } 216 handleNextEvent()217 private void handleNextEvent() { 218 final long sleepFor = TIME_UNIT.between(Instant.now(), getNextTimeout()); 219 try { 220 Event event = events.poll(sleepFor, POLL_TIME_UNIT); 221 if (event != null) { 222 event.handle(); 223 } 224 } catch (InterruptedException e) { 225 // If interrupted, continue around the loop. 226 } 227 } 228 handleTimeouts()229 private void handleTimeouts() { 230 // Copy to array since transfers may remove themselves from sessionIdToTransfer while iterating. 231 for (Transfer<?> transfer : sessionIdToTransfer.values().toArray(Transfer<?>[] ::new)) { 232 transfer.handleTimeoutIfDeadlineExceeded(); 233 } 234 } 235 getNextTimeout()236 private Instant getNextTimeout() { 237 Optional<Transfer<?>> transfer = 238 sessionIdToTransfer.values().stream().min(Comparator.comparing(Transfer::getDeadline)); 239 return transfer.isPresent() ? transfer.get().getDeadline() : Transfer.NO_TIMEOUT; 240 } 241 242 /** This interface gives a Transfer access to the TransferEventHandler. */ 243 abstract class TransferInterface { TransferInterface()244 private TransferInterface() {} 245 246 /** 247 * Sends the provided transfer chunk. 248 * 249 * Must be called on the transfer thread. 250 */ sendChunk(Chunk chunk)251 void sendChunk(Chunk chunk) throws TransferError { 252 try { 253 getStream().write(chunk); 254 } catch (ChannelOutputException e) { 255 throw new TransferError("Failed to send chunk for write transfer", e); 256 } 257 } 258 259 /** 260 * Removes this transfer from the list of active transfers. 261 * 262 * Must be called on the transfer thread. 263 */ 264 // TODO(frolv): Investigate why this is occurring -- there shouldn't be any 265 // futures here. 266 @SuppressWarnings("FutureReturnValueIgnored") unregisterTransfer(Transfer<?> transfer)267 void unregisterTransfer(Transfer<?> transfer) { 268 sessionIdToTransfer.remove(transfer.getSessionId()); 269 legacyIdToSessionId.remove(transfer.getResourceId()); 270 } 271 272 /** 273 * Initiates the cancellation process for the provided transfer. 274 * 275 * May be called from any thread. 276 */ cancelTransfer(Transfer<?> transfer)277 void cancelTransfer(Transfer<?> transfer) { 278 enqueueEvent(transfer::handleCancellation); 279 } 280 281 /** Gets either the read or write stream. */ getStream()282 abstract Call.ClientStreaming<Chunk> getStream() throws ChannelOutputException; 283 } 284 285 /** Handles responses on the pw_transfer RPCs. */ 286 private abstract class ChunkHandler implements StreamObserver<Chunk> { 287 @Override onNext(Chunk chunkProto)288 public final void onNext(Chunk chunkProto) { 289 VersionedChunk chunk = VersionedChunk.fromMessage(chunkProto, legacyIdToSessionId); 290 291 enqueueEvent(() -> { 292 Transfer<?> transfer; 293 if (chunk.sessionId() == VersionedChunk.UNKNOWN_SESSION_ID 294 || (transfer = sessionIdToTransfer.get(chunk.sessionId())) == null) { 295 logger.atInfo().log("Ignoring unrecognized transfer chunk: %s", chunk); 296 return; 297 } 298 299 logger.atFinest().log("%s received chunk: %s", transfer, chunk); 300 transfer.handleChunk(chunk); 301 }); 302 } 303 304 @Override onCompleted(Status status)305 public final void onCompleted(Status status) { 306 onError(Status.INTERNAL); // This RPC should never complete: treat as an internal error. 307 } 308 309 @Override onError(Status status)310 public final void onError(Status status) { 311 enqueueEvent(() -> { 312 resetStream(); 313 314 TransferError error = new TransferError( 315 "Transfer stream RPC closed unexpectedly with status " + status, Status.INTERNAL); 316 317 // The transfers remove themselves from the Map during cleanup; iterate over a copied list. 318 for (Transfer<?> transfer : sessionIdToTransfer.values().toArray(Transfer<?>[] ::new)) { 319 transfer.terminate(error); 320 } 321 }); 322 } 323 resetStream()324 abstract void resetStream(); 325 } 326 327 // Represents an event that occurs during a transfer 328 private interface Event { handle()329 void handle(); 330 } 331 } 332