1 // Copyright 2022 The Pigweed Authors 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not 4 // use this file except in compliance with the License. You may obtain a copy of 5 // the License at 6 // 7 // https://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 // License for the specific language governing permissions and limitations under 13 // the License. 14 15 package dev.pigweed.pw_transfer; 16 17 import com.google.common.util.concurrent.ListenableFuture; 18 import dev.pigweed.pw_log.Logger; 19 import dev.pigweed.pw_rpc.Call; 20 import dev.pigweed.pw_rpc.ChannelOutputException; 21 import dev.pigweed.pw_rpc.MethodClient; 22 import dev.pigweed.pw_rpc.Status; 23 import dev.pigweed.pw_rpc.StreamObserver; 24 import java.time.Instant; 25 import java.time.temporal.ChronoUnit; 26 import java.time.temporal.TemporalUnit; 27 import java.util.Comparator; 28 import java.util.HashMap; 29 import java.util.Map; 30 import java.util.Optional; 31 import java.util.concurrent.BlockingQueue; 32 import java.util.concurrent.LinkedBlockingQueue; 33 import java.util.concurrent.Semaphore; 34 import java.util.concurrent.TimeUnit; 35 import java.util.function.BooleanSupplier; 36 import java.util.function.Consumer; 37 import javax.annotation.Nullable; 38 39 /** Manages the active transfers and dispatches events to them. */ 40 class TransferEventHandler { 41 private static final Logger logger = Logger.forClass(TransferEventHandler.class); 42 43 // Instant and BlockingQueue use different time unit types. 44 private static final TemporalUnit TIME_UNIT = ChronoUnit.MICROS; 45 private static final TimeUnit POLL_TIME_UNIT = TimeUnit.MICROSECONDS; 46 47 private final MethodClient readMethod; 48 private final MethodClient writeMethod; 49 50 private final BlockingQueue<Event> events = new LinkedBlockingQueue<>(); 51 52 // Map session ID to transfer. 53 private final Map<Integer, Transfer<?>> sessionIdToTransfer = new HashMap<>(); 54 // Legacy transfers only use the resource ID. The client assigns an arbitrary session ID that 55 // legacy servers ignore. The client then maps from the legacy ID to its local session ID. 56 private final Map<Integer, Integer> legacyIdToSessionId = new HashMap<>(); 57 58 @Nullable private Call.ClientStreaming<Chunk> readStream = null; 59 @Nullable private Call.ClientStreaming<Chunk> writeStream = null; 60 private boolean processEvents = true; 61 62 private int nextSessionId = 1; 63 TransferEventHandler(MethodClient readMethod, MethodClient writeMethod)64 TransferEventHandler(MethodClient readMethod, MethodClient writeMethod) { 65 this.readMethod = readMethod; 66 this.writeMethod = writeMethod; 67 } 68 startWriteTransferAsClient(int resourceId, ProtocolVersion desiredProtocolVersion, TransferTimeoutSettings settings, byte[] data, Consumer<TransferProgress> progressCallback, BooleanSupplier shouldAbortCallback, int initialOffset)69 ListenableFuture<Void> startWriteTransferAsClient(int resourceId, 70 ProtocolVersion desiredProtocolVersion, 71 TransferTimeoutSettings settings, 72 byte[] data, 73 Consumer<TransferProgress> progressCallback, 74 BooleanSupplier shouldAbortCallback, 75 int initialOffset) { 76 WriteTransfer transfer = new WriteTransfer( 77 resourceId, assignSessionId(), desiredProtocolVersion, new TransferInterface() { 78 @Override 79 Call.ClientStreaming<Chunk> getStream() throws ChannelOutputException { 80 if (writeStream == null) { 81 writeStream = writeMethod.invokeBidirectionalStreaming(new ChunkHandler() { 82 @Override 83 void resetStream() { 84 writeStream = null; 85 } 86 }); 87 // TODO(b/355249134): Revert this sleep workaround when transfer 88 // sets on_next_ sooner 89 try { 90 Thread.sleep(10); 91 } catch (InterruptedException e) { 92 logger.atWarning().log("Interrupted while waiting for write stream"); 93 } 94 } 95 return writeStream; 96 } 97 }, settings, data, progressCallback, shouldAbortCallback, initialOffset); 98 startTransferAsClient(transfer); 99 return transfer; 100 } 101 startReadTransferAsClient(int resourceId, ProtocolVersion desiredProtocolVersion, TransferTimeoutSettings settings, TransferParameters parameters, Consumer<TransferProgress> progressCallback, BooleanSupplier shouldAbortCallback, int initialOffset)102 ListenableFuture<byte[]> startReadTransferAsClient(int resourceId, 103 ProtocolVersion desiredProtocolVersion, 104 TransferTimeoutSettings settings, 105 TransferParameters parameters, 106 Consumer<TransferProgress> progressCallback, 107 BooleanSupplier shouldAbortCallback, 108 int initialOffset) { 109 ReadTransfer transfer = new ReadTransfer( 110 resourceId, assignSessionId(), desiredProtocolVersion, new TransferInterface() { 111 @Override 112 Call.ClientStreaming<Chunk> getStream() throws ChannelOutputException { 113 if (readStream == null) { 114 readStream = readMethod.invokeBidirectionalStreaming(new ChunkHandler() { 115 @Override 116 void resetStream() { 117 readStream = null; 118 } 119 }); 120 // TODO(b/355249134): Revert this sleep workaround when transfer 121 // sets on_next_ sooner 122 try { 123 Thread.sleep(10); 124 } catch (InterruptedException e) { 125 logger.atWarning().log("Interrupted while waiting for read stream"); 126 } 127 } 128 return readStream; 129 } 130 }, settings, parameters, progressCallback, shouldAbortCallback, initialOffset); 131 startTransferAsClient(transfer); 132 return transfer; 133 } 134 startTransferAsClient(Transfer<?> transfer)135 private void startTransferAsClient(Transfer<?> transfer) { 136 enqueueEvent(() -> { 137 if (sessionIdToTransfer.containsKey(transfer.getSessionId())) { 138 throw new AssertionError("Duplicate session ID " + transfer.getSessionId()); 139 } 140 141 if (transfer.getDesiredProtocolVersion() == ProtocolVersion.LEGACY 142 && transfer.getOffset() != 0) { 143 throw new AssertionError("Cannot start non-zero offset transfer with legacy version"); 144 } 145 146 // The v2 protocol supports multiple transfers for a single resource. For simplicity while 147 // supporting both protocols, only support a single transfer per resource. 148 if (legacyIdToSessionId.containsKey(transfer.getResourceId())) { 149 transfer.terminate( 150 new TransferError("A transfer for resource ID " + transfer.getResourceId() 151 + " is already in progress! Only one read/write transfer per resource is " 152 + "supported at a time", 153 Status.ALREADY_EXISTS)); 154 return; 155 } 156 sessionIdToTransfer.put(transfer.getSessionId(), transfer); 157 legacyIdToSessionId.put(transfer.getResourceId(), transfer.getSessionId()); 158 transfer.start(); 159 }); 160 } 161 162 /** Handles events until stop() is called. */ run()163 void run() { 164 while (processEvents) { 165 handleNextEvent(); 166 handleTimeouts(); 167 } 168 } 169 170 /** 171 * Test version of run() that processes all enqueued events before checking for timeouts. 172 * 173 * Tests that need to time out should process all enqueued events first to prevent flaky failures. 174 * If handling one of several queued packets takes longer than the timeout (which must be short 175 * for a unit test), then the test may fail spuriously. 176 * 177 * This run function is not used outside of tests because processing all incoming packets before 178 * checking for timeouts could delay the transfer client's outgoing write packets if there are 179 * lots of inbound packets. This could delay transfers and cause unnecessary timeouts. 180 */ runForTestsThatMustTimeOut()181 void runForTestsThatMustTimeOut() { 182 while (processEvents) { 183 while (!events.isEmpty()) { 184 handleNextEvent(); 185 } 186 handleTimeouts(); 187 } 188 } 189 190 /** Stops the transfer event handler from processing events. */ stop()191 void stop() { 192 enqueueEvent(() -> { 193 logger.atFine().log("Terminating TransferEventHandler"); 194 sessionIdToTransfer.values().forEach(Transfer::handleTermination); 195 processEvents = false; 196 }); 197 } 198 199 /** Blocks until all events currently in the queue are processed; for test use only. */ waitUntilEventsAreProcessedForTest()200 void waitUntilEventsAreProcessedForTest() { 201 Semaphore semaphore = new Semaphore(0); 202 enqueueEvent(semaphore::release); 203 try { 204 semaphore.acquire(); 205 } catch (InterruptedException e) { 206 throw new AssertionError("Unexpectedly interrupted", e); 207 } 208 } 209 210 /** Generates the session ID to use for the next transfer. */ assignSessionId()211 private int assignSessionId() { 212 return nextSessionId++; 213 } 214 215 /** Returns the session ID that will be used for the next transfer. */ getNextSessionIdForTest()216 final int getNextSessionIdForTest() { 217 return nextSessionId; 218 } 219 enqueueEvent(Event event)220 private void enqueueEvent(Event event) { 221 while (true) { 222 try { 223 events.put(event); 224 return; 225 } catch (InterruptedException e) { 226 // Ignore and keep trying. 227 } 228 } 229 } 230 handleNextEvent()231 private void handleNextEvent() { 232 final long sleepFor = TIME_UNIT.between(Instant.now(), getNextTimeout()); 233 try { 234 Event event = events.poll(sleepFor, POLL_TIME_UNIT); 235 if (event != null) { 236 event.handle(); 237 } 238 } catch (InterruptedException e) { 239 // If interrupted, continue around the loop. 240 } 241 } 242 handleTimeouts()243 private void handleTimeouts() { 244 // Copy to array since transfers may remove themselves from sessionIdToTransfer while iterating. 245 for (Transfer<?> transfer : sessionIdToTransfer.values().toArray(Transfer<?>[] ::new)) { 246 transfer.handleTimeoutIfDeadlineExceeded(); 247 } 248 } 249 getNextTimeout()250 private Instant getNextTimeout() { 251 Optional<Transfer<?>> transfer = 252 sessionIdToTransfer.values().stream().min(Comparator.comparing(Transfer::getDeadline)); 253 return transfer.isPresent() ? transfer.get().getDeadline() : Transfer.NO_TIMEOUT; 254 } 255 256 /** This interface gives a Transfer access to the TransferEventHandler. */ 257 abstract class TransferInterface { TransferInterface()258 private TransferInterface() {} 259 260 /** 261 * Sends the provided transfer chunk. 262 * 263 * Must be called on the transfer thread. 264 */ sendChunk(Chunk chunk)265 void sendChunk(Chunk chunk) throws TransferError { 266 try { 267 getStream().write(chunk); 268 } catch (ChannelOutputException e) { 269 throw new TransferError("Failed to send chunk for write transfer", e); 270 } 271 } 272 273 /** 274 * Removes this transfer from the list of active transfers. 275 * 276 * Must be called on the transfer thread. 277 */ 278 // TODO(frolv): Investigate why this is occurring -- there shouldn't be any 279 // futures here. 280 @SuppressWarnings("FutureReturnValueIgnored") unregisterTransfer(Transfer<?> transfer)281 void unregisterTransfer(Transfer<?> transfer) { 282 sessionIdToTransfer.remove(transfer.getSessionId()); 283 legacyIdToSessionId.remove(transfer.getResourceId()); 284 } 285 286 /** 287 * Initiates the cancellation process for the provided transfer. 288 * 289 * May be called from any thread. 290 */ cancelTransfer(Transfer<?> transfer)291 void cancelTransfer(Transfer<?> transfer) { 292 enqueueEvent(transfer::handleCancellation); 293 } 294 295 /** Gets either the read or write stream. */ getStream()296 abstract Call.ClientStreaming<Chunk> getStream() throws ChannelOutputException; 297 } 298 299 /** Handles responses on the pw_transfer RPCs. */ 300 private abstract class ChunkHandler implements StreamObserver<Chunk> { 301 @Override onNext(Chunk chunkProto)302 public final void onNext(Chunk chunkProto) { 303 VersionedChunk chunk = VersionedChunk.fromMessage(chunkProto, legacyIdToSessionId); 304 305 enqueueEvent(() -> { 306 Transfer<?> transfer; 307 if (chunk.sessionId() == VersionedChunk.UNKNOWN_SESSION_ID 308 || (transfer = sessionIdToTransfer.get(chunk.sessionId())) == null) { 309 logger.atInfo().log("Ignoring unrecognized transfer chunk: %s", chunk); 310 return; 311 } 312 313 logger.atFinest().log("%s received chunk: %s", transfer, chunk); 314 transfer.handleChunk(chunk); 315 }); 316 } 317 318 @Override onCompleted(Status status)319 public final void onCompleted(Status status) { 320 onError(Status.INTERNAL); // This RPC should never complete: treat as an internal error. 321 } 322 323 @Override onError(Status status)324 public final void onError(Status status) { 325 enqueueEvent(() -> { 326 resetStream(); 327 328 TransferError error = new TransferError( 329 "Transfer stream RPC closed unexpectedly with status " + status, Status.INTERNAL); 330 331 // The transfers remove themselves from the Map during cleanup; iterate over a copied list. 332 for (Transfer<?> transfer : sessionIdToTransfer.values().toArray(Transfer<?>[] ::new)) { 333 transfer.terminate(error); 334 } 335 }); 336 } 337 resetStream()338 abstract void resetStream(); 339 } 340 341 // Represents an event that occurs during a transfer 342 private interface Event { handle()343 void handle(); 344 } 345 } 346